diff --git a/core/math/geometry.cpp b/core/math/geometry.cpp index 2425249dfceb..096bee73ceba 100644 --- a/core/math/geometry.cpp +++ b/core/math/geometry.cpp @@ -31,8 +31,11 @@ #include "geometry.h" #include "core/print_string.h" + #include "thirdparty/misc/clipper.hpp" #include "thirdparty/misc/triangulator.h" +#define STB_RECT_PACK_IMPLEMENTATION +#include "thirdparty/stb_rect_pack/stb_rect_pack.h" #define SCALE_FACTOR 100000.0 // Based on CMP_EPSILON. @@ -1224,3 +1227,36 @@ Vector Geometry::compute_convex_mesh_points(const Plane *p_planes, int return points; } + +Vector Geometry::partial_pack_rects(const Vector &p_sizes, const Size2i &p_atlas_size) { + + Vector nodes; + nodes.resize(p_atlas_size.width); + zeromem(nodes.ptrw(), sizeof(stbrp_node) * nodes.size()); + + stbrp_context context; + stbrp_init_target(&context, p_atlas_size.width, p_atlas_size.height, nodes.ptrw(), p_atlas_size.width); + + Vector rects; + rects.resize(p_sizes.size()); + + for (int i = 0; i < p_sizes.size(); i++) { + rects.write[i].id = i; + rects.write[i].w = p_sizes[i].width; + rects.write[i].h = p_sizes[i].height; + rects.write[i].x = 0; + rects.write[i].y = 0; + rects.write[i].was_packed = 0; + } + + stbrp_pack_rects(&context, rects.ptrw(), rects.size()); + + Vector ret; + ret.resize(p_sizes.size()); + + for (int i = 0; i < p_sizes.size(); i++) { + ret.write[rects[i].id] = { rects[i].x, rects[i].y, static_cast(rects[i].was_packed) }; + } + + return ret; +} diff --git a/core/math/geometry.h b/core/math/geometry.h index 3ab108e8717b..23f4f4fce08e 100644 --- a/core/math/geometry.h +++ b/core/math/geometry.h @@ -502,6 +502,27 @@ class Geometry { return (cn.cross(an) > 0) == orientation; } + static Vector3 barycentric_coordinates_2d(const Vector2 &s, const Vector2 &a, const Vector2 &b, const Vector2 &c) { + // http://www.blackpawn.com/texts/pointinpoly/ + Vector2 v0 = c - a; + Vector2 v1 = b - a; + Vector2 v2 = s - a; + + // Compute dot products + double dot00 = v0.dot(v0); + double dot01 = v0.dot(v1); + double dot02 = v0.dot(v2); + double dot11 = v1.dot(v1); + double dot12 = v1.dot(v2); + + // Compute barycentric coordinates + double invDenom = 1.0f / (dot00 * dot11 - dot01 * dot01); + double b2 = (dot11 * dot02 - dot01 * dot12) * invDenom; + double b1 = (dot00 * dot12 - dot01 * dot02) * invDenom; + double b0 = 1.0f - b2 - b1; + return Vector3(b0, b1, b2); + } + static Vector2 get_closest_point_to_segment_uncapped_2d(const Vector2 &p_point, const Vector2 *p_segment) { Vector2 p = p_point - p_segment[0]; @@ -1014,6 +1035,13 @@ class Geometry { static void make_atlas(const Vector &p_rects, Vector &r_result, Size2i &r_size); + struct PackRectsResult { + int x; + int y; + bool packed; + }; + static Vector partial_pack_rects(const Vector &p_sizes, const Size2i &p_atlas_size); + static Vector compute_convex_mesh_points(const Plane *p_planes, int p_plane_count); private: diff --git a/doc/classes/BakedLightmap.xml b/doc/classes/BakedLightmap.xml index 71f22d186440..c93902da98d9 100644 --- a/doc/classes/BakedLightmap.xml +++ b/doc/classes/BakedLightmap.xml @@ -16,54 +16,71 @@ - + - Bakes the lightmaps within the currently edited scene. Returns a [enum BakeError] to signify if the bake was successful, or if unsuccessful, how the bake failed. - - - - - - - Executes a dry run bake of lightmaps within the currently edited scene. + Bakes the lightmap, scanning from the given [code]from_node[/code] root and saves the resulting [BakedLightmapData] in [code]data_save_path[/code]. If no save path is provided it will try to match the path from the current [member light_data]. - - Grid subdivision size for lightmapper calculation. The default value will work for most cases. Increase for better lighting on small details or if your scene is very large. + + When enabled, the lightmapper will merge the textures for all meshes into a single large layered texture. Not supported in GLES2. - - If a [member Mesh.lightmap_size_hint] isn't specified, the lightmap baker will dynamically set the lightmap size using this value. This value is measured in texels per world unit. The maximum lightmap texture size is 4096x4096. + + Maximum size of each lightmap layer, only used when [member atlas_generate] is enabled. - - Multiplies the light sources' intensity by this value. For instance, if the value is set to 2, lights will be twice as bright. If the value is set to 0.5, lights will be half as bright. + + Raycasting bias used during baking to avoid floating point precission issues. - - The size of the affected area. + + Number of light bounces that are taken into account during baking. - - If [code]true[/code], the lightmap can capture light values greater than [code]1.0[/code]. Turning this off will result in a smaller file size. + + Grid size used for real-time capture information on dynamic objects. - - Lightmapping mode. See [enum BakeMode]. + + When enabled, an octree containing the scene's lighting information will be computed. This octree will then be used to light dynamic objects in the scene. - - Defines how far the light will travel before it is no longer effective. The higher the number, the farther the light will travel. For instance, if the value is set to 2, the light will go twice as far. If the value is set to 0.5, the light will only go half as far. + + Bias value to reduce the amount of light proagation in the captured octree. - - Three quality modes are available. Higher quality requires more rendering time. See [enum BakeQuality]. + + Bake quality of the capture data. - - Grid size used for real-time capture information on dynamic objects. Cannot be larger than [member bake_cell_size]. + + If a baked mesh doesn't have a UV2 size hint, this value will be used to roughly compute a suitable lightmap size. + + + The environment color when [member environment_mode] is set to [constant ENVIRONMENT_MODE_CUSTOM_COLOR]. + + + The energy scaling factor when when [member environment_mode] is set to [constant ENVIRONMENT_MODE_CUSTOM_COLOR] or [constant ENVIRONMENT_MODE_CUSTOM_SKY]. + + + The [Sky] resource to use when [member environment_mode] is set o [constant ENVIRONMENT_MODE_CUSTOM_SKY]. + + + The rotation of the baked custom sky. + + + Decides which environment to use during baking. - - The location where lightmaps will be saved. + + Size of the baked lightmap. Only meshes inside this region will be included in the baked lightmap, also used as the bounds of the captured region for dynamic lighting. + + + Deprecated, in previous versions it determined the location where lightmaps were be saved. The calculated light data. + + Determines the amount of samples per texel used in indrect light baking. The amount of samples for each quality level can be configured in the project settings. + + + When enabled, a lightmap denoiser will be used to reduce the noise inherent to Monte Carlo based global illumination. + @@ -73,13 +90,10 @@ The default bake quality mode. - The highest bake quality mode. Takes longer to calculate. + A higher bake quality mode. Takes longer to calculate. - - Less precise but faster bake mode. - - - More precise bake mode but can take considerably longer to bake. + + The highest bake quality mode. Takes the longest to calculate. Baking was successful. @@ -93,8 +107,28 @@ Returns when the baker cannot save per-mesh textures to file. - + + The size of the generated lightmaps is too large. + + + Some mesh contains UV2 values outside the [code][0,1][/code] range. + + Returns if user cancels baking. + + + + No environment is used during baking. + + + The baked environment is automatically picked from the current scene. + + + A custom sky is used as environment during baking. + + + A custom solid color is used as environment during baking. + diff --git a/doc/classes/BakedLightmapData.xml b/doc/classes/BakedLightmapData.xml index bc44646018ec..72b4d68b7e05 100644 --- a/doc/classes/BakedLightmapData.xml +++ b/doc/classes/BakedLightmapData.xml @@ -12,9 +12,13 @@ - + - + + + + + @@ -32,7 +36,7 @@ - + diff --git a/doc/classes/GeometryInstance.xml b/doc/classes/GeometryInstance.xml index 4545d88121e6..0d5a85ea8a72 100644 --- a/doc/classes/GeometryInstance.xml +++ b/doc/classes/GeometryInstance.xml @@ -46,6 +46,12 @@ The extra distance added to the GeometryInstance's bounding box ([AABB]) to increase its cull box. + + When disabled, the mesh will be taken into account when computing indirect lighting, but the resulting lightmap will not be saved. Useful for emissive only materials or shadow casters. + + + Scale factor for the generated baked lightmap. Useful for adding detail to certain mesh instances. + The GeometryInstance's max LOD distance. [b]Note:[/b] This property currently has no effect. @@ -71,6 +77,20 @@ + + The generated lightmap texture will have the original size. + + + The generated lightmap texture will be twice as large, on each axis. + + + The generated lightmap texture will be 4 times as large, on each axis. + + + The generated lightmap texture will be 8 times as large, on each axis. + + + Will not cast any shadows. diff --git a/doc/classes/Mesh.xml b/doc/classes/Mesh.xml index d843b6192c62..54ccb0a82871 100644 --- a/doc/classes/Mesh.xml +++ b/doc/classes/Mesh.xml @@ -107,7 +107,7 @@ - Sets a hint to be used for lightmap resolution in [BakedLightmap]. Overrides [member BakedLightmap.bake_default_texels_per_unit]. + Sets a hint to be used for lightmap resolution in [BakedLightmap]. Overrides [member BakedLightmap.default_texels_per_unit]. diff --git a/doc/classes/ProjectSettings.xml b/doc/classes/ProjectSettings.xml index 32060afb1dee..2ef5b17d5606 100644 --- a/doc/classes/ProjectSettings.xml +++ b/doc/classes/ProjectSettings.xml @@ -1057,6 +1057,18 @@ The amount of UV contraction. This figure is divided by 1000000, and is a proportion of the total texture dimensions, where the width and height are both ranged from 0.0 to 1.0. Use the default unless correcting for a problem on particular hardware. + + Amount of light samples taken when using [constant BakedLightmap.BAKE_QUALITY_HIGH]. + + + Amount of light samples taken when using [constant BakedLightmap.BAKE_QUALITY_LOW]. + + + Amount of light samples taken when using [constant BakedLightmap.BAKE_QUALITY_MEDIUM]. + + + Amount of light samples taken when using [constant BakedLightmap.BAKE_QUALITY_ULTRA]. + Default background clear color. Overridable per [Viewport] using its [Environment]. See [member Environment.background_mode] and [member Environment.background_color] in particular. To change this default color programmatically, use [method VisualServer.set_default_clear_color]. @@ -1169,6 +1181,12 @@ Lower-end override for [member rendering/quality/intended_usage/framebuffer_allocation] on mobile devices, due to performance concerns or driver support. + + Enable usage of bicubic sampling in baked lightmaps. This results in smoother looking lighting at the expense of more bandwidth usage. On GLES2, changes to this setting will only be applied upon restarting the application. + + + Lower-end override for [member rendering/quality/lightmapping/use_bicubic_sampling] on mobile devices, in order to reduce bandwidth usage. + Size of the atlas used by reflection probes. A larger size can result in higher visual quality, while a smaller size will be faster and take up less memory. diff --git a/doc/classes/VisualServer.xml b/doc/classes/VisualServer.xml index ead815e13535..8aae7d98a4bc 100644 --- a/doc/classes/VisualServer.xml +++ b/doc/classes/VisualServer.xml @@ -2045,6 +2045,10 @@ + + + + Sets the lightmap to use with this instance. diff --git a/drivers/gles2/rasterizer_scene_gles2.cpp b/drivers/gles2/rasterizer_scene_gles2.cpp index ed936fefbc9f..ff27f113338d 100644 --- a/drivers/gles2/rasterizer_scene_gles2.cpp +++ b/drivers/gles2/rasterizer_scene_gles2.cpp @@ -2572,6 +2572,9 @@ void RasterizerSceneGLES2::_render_render_list(RenderList::Element **p_elements, if (rebind_lightmap && lightmap) { state.scene_shader.set_uniform(SceneShaderGLES2::LIGHTMAP_ENERGY, lightmap_energy); + if (storage->config.use_lightmap_filter_bicubic) { + state.scene_shader.set_uniform(SceneShaderGLES2::LIGHTMAP_TEXTURE_SIZE, Vector2(lightmap->width, lightmap->height)); + } } state.scene_shader.set_uniform(SceneShaderGLES2::WORLD_TRANSFORM, e->instance->transform); @@ -4047,6 +4050,10 @@ void RasterizerSceneGLES2::initialize() { } } + if (storage->config.use_lightmap_filter_bicubic) { + state.scene_shader.add_custom_define("#define USE_LIGHTMAP_FILTER_BICUBIC\n"); + } + shadow_filter_mode = SHADOW_FILTER_NEAREST; glFrontFace(GL_CW); diff --git a/drivers/gles2/rasterizer_storage_gles2.cpp b/drivers/gles2/rasterizer_storage_gles2.cpp index 1c94196055d8..03ff9b560102 100644 --- a/drivers/gles2/rasterizer_storage_gles2.cpp +++ b/drivers/gles2/rasterizer_storage_gles2.cpp @@ -6299,6 +6299,9 @@ void RasterizerStorageGLES2::initialize() { config.force_vertex_shading = GLOBAL_GET("rendering/quality/shading/force_vertex_shading"); config.use_fast_texture_filter = GLOBAL_GET("rendering/quality/filters/use_nearest_mipmap_filter"); + GLOBAL_DEF_RST("rendering/quality/lightmapping/use_bicubic_sampling", true); + GLOBAL_DEF_RST("rendering/quality/lightmapping/use_bicubic_sampling.mobile", false); + config.use_lightmap_filter_bicubic = GLOBAL_GET("rendering/quality/lightmapping/use_bicubic_sampling"); } void RasterizerStorageGLES2::finalize() { diff --git a/drivers/gles2/rasterizer_storage_gles2.h b/drivers/gles2/rasterizer_storage_gles2.h index 675f19c622ed..7cff4d59353c 100644 --- a/drivers/gles2/rasterizer_storage_gles2.h +++ b/drivers/gles2/rasterizer_storage_gles2.h @@ -57,6 +57,7 @@ class RasterizerStorageGLES2 : public RasterizerStorage { bool shrink_textures_x2; bool use_fast_texture_filter; bool use_skeleton_software; + bool use_lightmap_filter_bicubic; int max_vertex_texture_image_units; int max_texture_image_units; diff --git a/drivers/gles2/shader_compiler_gles2.h b/drivers/gles2/shader_compiler_gles2.h index f302001da593..c4877d23b0e0 100644 --- a/drivers/gles2/shader_compiler_gles2.h +++ b/drivers/gles2/shader_compiler_gles2.h @@ -98,4 +98,4 @@ class ShaderCompilerGLES2 { ShaderCompilerGLES2(); }; -#endif // SHADERCOMPILERGLES3_H +#endif // SHADERCOMPILERGLES2_H diff --git a/drivers/gles2/shaders/scene.glsl b/drivers/gles2/shaders/scene.glsl index b55dbec37663..9b2521c2c142 100644 --- a/drivers/gles2/shaders/scene.glsl +++ b/drivers/gles2/shaders/scene.glsl @@ -889,6 +889,69 @@ void reflection_process(samplerCube reflection_map, #ifdef USE_LIGHTMAP uniform mediump sampler2D lightmap; //texunit:-4 uniform mediump float lightmap_energy; + +#if defined(USE_LIGHTMAP_FILTER_BICUBIC) +uniform mediump vec2 lightmap_texture_size; + +// w0, w1, w2, and w3 are the four cubic B-spline basis functions +float w0(float a) { + return (1.0 / 6.0) * (a * (a * (-a + 3.0) - 3.0) + 1.0); +} + +float w1(float a) { + return (1.0 / 6.0) * (a * a * (3.0 * a - 6.0) + 4.0); +} + +float w2(float a) { + return (1.0 / 6.0) * (a * (a * (-3.0 * a + 3.0) + 3.0) + 1.0); +} + +float w3(float a) { + return (1.0 / 6.0) * (a * a * a); +} + +// g0 and g1 are the two amplitude functions +float g0(float a) { + return w0(a) + w1(a); +} + +float g1(float a) { + return w2(a) + w3(a); +} + +// h0 and h1 are the two offset functions +float h0(float a) { + return -1.0 + w1(a) / (w0(a) + w1(a)); +} + +float h1(float a) { + return 1.0 + w3(a) / (w2(a) + w3(a)); +} + +vec4 texture2D_bicubic(sampler2D tex, vec2 uv) { + vec2 texel_size = vec2(1.0) / lightmap_texture_size; + + uv = uv * lightmap_texture_size + vec2(0.5); + + vec2 iuv = floor(uv); + vec2 fuv = fract(uv); + + float g0x = g0(fuv.x); + float g1x = g1(fuv.x); + float h0x = h0(fuv.x); + float h1x = h1(fuv.x); + float h0y = h0(fuv.y); + float h1y = h1(fuv.y); + + vec2 p0 = (vec2(iuv.x + h0x, iuv.y + h0y) - vec2(0.5)) * texel_size; + vec2 p1 = (vec2(iuv.x + h1x, iuv.y + h0y) - vec2(0.5)) * texel_size; + vec2 p2 = (vec2(iuv.x + h0x, iuv.y + h1y) - vec2(0.5)) * texel_size; + vec2 p3 = (vec2(iuv.x + h1x, iuv.y + h1y) - vec2(0.5)) * texel_size; + + return (g0(fuv.y) * (g0x * texture2D(tex, p0) + g1x * texture2D(tex, p1))) + + (g1(fuv.y) * (g0x * texture2D(tex, p2) + g1x * texture2D(tex, p3))); +} +#endif //USE_LIGHTMAP_FILTER_BICUBIC #endif #ifdef USE_LIGHTMAP_CAPTURE @@ -1660,9 +1723,13 @@ FRAGMENT_SHADER_CODE } #ifdef USE_LIGHTMAP - //ambient light will come entirely from lightmap is lightmap is used +//ambient light will come entirely from lightmap is lightmap is used +#if defined(USE_LIGHTMAP_FILTER_BICUBIC) + ambient_light = texture2D_bicubic(lightmap, uv2_interp).rgb * lightmap_energy; +#else ambient_light = texture2D(lightmap, uv2_interp).rgb * lightmap_energy; #endif +#endif #ifdef USE_LIGHTMAP_CAPTURE { diff --git a/drivers/gles3/rasterizer_scene_gles3.cpp b/drivers/gles3/rasterizer_scene_gles3.cpp index 4e36ba94235b..7313b559c9af 100644 --- a/drivers/gles3/rasterizer_scene_gles3.cpp +++ b/drivers/gles3/rasterizer_scene_gles3.cpp @@ -1949,7 +1949,17 @@ void RasterizerSceneGLES3::_setup_light(RenderList::Element *e, const Transform if (lightmap && capture) { glActiveTexture(GL_TEXTURE0 + storage->config.max_texture_image_units - 9); - glBindTexture(GL_TEXTURE_2D, lightmap->tex_id); + if (e->instance->lightmap_slice == -1) { + glBindTexture(GL_TEXTURE_2D, lightmap->tex_id); + } else { + glBindTexture(GL_TEXTURE_2D_ARRAY, lightmap->tex_id); + state.scene_shader.set_uniform(SceneShaderGLES3::LIGHTMAP_LAYER, e->instance->lightmap_slice); + } + const Rect2 &uvr = e->instance->lightmap_uv_rect; + state.scene_shader.set_uniform(SceneShaderGLES3::LIGHTMAP_UV_RECT, Color(uvr.get_position().x, uvr.get_position().y, uvr.get_size().x, uvr.get_size().y)); + if (storage->config.use_lightmap_filter_bicubic) { + state.scene_shader.set_uniform(SceneShaderGLES3::LIGHTMAP_TEXTURE_SIZE, Vector2(lightmap->width, lightmap->height)); + } state.scene_shader.set_uniform(SceneShaderGLES3::LIGHTMAP_ENERGY, capture->energy); } } @@ -2080,6 +2090,7 @@ void RasterizerSceneGLES3::_render_list(RenderList::Element **p_elements, int p_ state.scene_shader.set_conditional(SceneShaderGLES3::USE_GI_PROBES, false); state.scene_shader.set_conditional(SceneShaderGLES3::USE_LIGHTMAP_CAPTURE, false); state.scene_shader.set_conditional(SceneShaderGLES3::USE_LIGHTMAP, false); + state.scene_shader.set_conditional(SceneShaderGLES3::USE_LIGHTMAP_LAYERED, false); state.scene_shader.set_conditional(SceneShaderGLES3::USE_RADIANCE_MAP, false); state.scene_shader.set_conditional(SceneShaderGLES3::USE_CONTACT_SHADOWS, false); @@ -2088,6 +2099,7 @@ void RasterizerSceneGLES3::_render_list(RenderList::Element **p_elements, int p_ state.scene_shader.set_conditional(SceneShaderGLES3::USE_GI_PROBES, e->instance->gi_probe_instances.size() > 0); state.scene_shader.set_conditional(SceneShaderGLES3::USE_LIGHTMAP, e->instance->lightmap.is_valid() && e->instance->gi_probe_instances.size() == 0); + state.scene_shader.set_conditional(SceneShaderGLES3::USE_LIGHTMAP_LAYERED, e->instance->lightmap_slice != -1); state.scene_shader.set_conditional(SceneShaderGLES3::USE_LIGHTMAP_CAPTURE, !e->instance->lightmap_capture_data.empty() && !e->instance->lightmap.is_valid() && e->instance->gi_probe_instances.size() == 0); state.scene_shader.set_conditional(SceneShaderGLES3::SHADELESS, false); @@ -2258,6 +2270,7 @@ void RasterizerSceneGLES3::_render_list(RenderList::Element **p_elements, int p_ state.scene_shader.set_conditional(SceneShaderGLES3::SHADOW_MODE_PCF_13, false); state.scene_shader.set_conditional(SceneShaderGLES3::USE_GI_PROBES, false); state.scene_shader.set_conditional(SceneShaderGLES3::USE_LIGHTMAP, false); + state.scene_shader.set_conditional(SceneShaderGLES3::USE_LIGHTMAP_LAYERED, false); state.scene_shader.set_conditional(SceneShaderGLES3::USE_LIGHTMAP_CAPTURE, false); state.scene_shader.set_conditional(SceneShaderGLES3::USE_CONTACT_SHADOWS, false); state.scene_shader.set_conditional(SceneShaderGLES3::USE_VERTEX_LIGHTING, false); @@ -2392,6 +2405,9 @@ void RasterizerSceneGLES3::_add_geometry_with_material(RasterizerStorageGLES3::G if (e->instance->lightmap.is_valid()) { e->sort_key |= SORT_KEY_LIGHTMAP_FLAG; + if (e->instance->lightmap_slice != -1) { + e->sort_key |= SORT_KEY_LIGHTMAP_LAYERED_FLAG; + } } if (!e->instance->lightmap_capture_data.empty()) { @@ -5337,6 +5353,8 @@ void RasterizerSceneGLES3::iteration() { subsurface_scatter_quality = SubSurfaceScatterQuality(int(GLOBAL_GET("rendering/quality/subsurface_scattering/quality"))); subsurface_scatter_size = GLOBAL_GET("rendering/quality/subsurface_scattering/scale"); + storage->config.use_lightmap_filter_bicubic = GLOBAL_GET("rendering/quality/lightmapping/use_bicubic_sampling"); + state.scene_shader.set_conditional(SceneShaderGLES3::USE_LIGHTMAP_FILTER_BICUBIC, storage->config.use_lightmap_filter_bicubic); state.scene_shader.set_conditional(SceneShaderGLES3::VCT_QUALITY_HIGH, GLOBAL_GET("rendering/quality/voxel_cone_tracing/high_quality")); } diff --git a/drivers/gles3/rasterizer_scene_gles3.h b/drivers/gles3/rasterizer_scene_gles3.h index 772d8df02417..3d379a3800c0 100644 --- a/drivers/gles3/rasterizer_scene_gles3.h +++ b/drivers/gles3/rasterizer_scene_gles3.h @@ -680,14 +680,15 @@ class RasterizerSceneGLES3 : public RasterizerScene { SORT_KEY_OPAQUE_DEPTH_LAYER_SHIFT = 52, SORT_KEY_OPAQUE_DEPTH_LAYER_MASK = 0xF, //64 bits unsupported in MSVC -#define SORT_KEY_UNSHADED_FLAG (uint64_t(1) << 49) -#define SORT_KEY_NO_DIRECTIONAL_FLAG (uint64_t(1) << 48) -#define SORT_KEY_LIGHTMAP_CAPTURE_FLAG (uint64_t(1) << 47) +#define SORT_KEY_UNSHADED_FLAG (uint64_t(1) << 50) +#define SORT_KEY_NO_DIRECTIONAL_FLAG (uint64_t(1) << 49) +#define SORT_KEY_LIGHTMAP_CAPTURE_FLAG (uint64_t(1) << 48) +#define SORT_KEY_LIGHTMAP_LAYERED_FLAG (uint64_t(1) << 47) #define SORT_KEY_LIGHTMAP_FLAG (uint64_t(1) << 46) #define SORT_KEY_GI_PROBES_FLAG (uint64_t(1) << 45) #define SORT_KEY_VERTEX_LIT_FLAG (uint64_t(1) << 44) SORT_KEY_SHADING_SHIFT = 44, - SORT_KEY_SHADING_MASK = 63, + SORT_KEY_SHADING_MASK = 127, //44-28 material index SORT_KEY_MATERIAL_INDEX_SHIFT = 28, //28-8 geometry index diff --git a/drivers/gles3/rasterizer_storage_gles3.cpp b/drivers/gles3/rasterizer_storage_gles3.cpp index 0a72c67497d9..6984e3f2e82c 100644 --- a/drivers/gles3/rasterizer_storage_gles3.cpp +++ b/drivers/gles3/rasterizer_storage_gles3.cpp @@ -8555,6 +8555,10 @@ void RasterizerStorageGLES3::initialize() { String renderer = (const char *)glGetString(GL_RENDERER); + GLOBAL_DEF("rendering/quality/lightmapping/use_bicubic_sampling", true); + GLOBAL_DEF("rendering/quality/lightmapping/use_bicubic_sampling.mobile", false); + config.use_lightmap_filter_bicubic = GLOBAL_GET("rendering/quality/lightmapping/use_bicubic_sampling"); + config.use_depth_prepass = bool(GLOBAL_GET("rendering/quality/depth_prepass/enable")); if (config.use_depth_prepass) { diff --git a/drivers/gles3/rasterizer_storage_gles3.h b/drivers/gles3/rasterizer_storage_gles3.h index c351f265ec7b..1101d25ee0c8 100644 --- a/drivers/gles3/rasterizer_storage_gles3.h +++ b/drivers/gles3/rasterizer_storage_gles3.h @@ -74,6 +74,7 @@ class RasterizerStorageGLES3 : public RasterizerStorage { bool shrink_textures_x2; bool use_fast_texture_filter; bool use_anisotropic_filter; + bool use_lightmap_filter_bicubic; bool s3tc_supported; bool latc_supported; diff --git a/drivers/gles3/shaders/scene.glsl b/drivers/gles3/shaders/scene.glsl index a45ac2eb8a60..78071ea51b04 100644 --- a/drivers/gles3/shaders/scene.glsl +++ b/drivers/gles3/shaders/scene.glsl @@ -109,6 +109,10 @@ layout(std140) uniform SceneData { // ubo:0 uniform highp mat4 world_transform; +#ifdef USE_LIGHTMAP +uniform highp vec4 lightmap_uv_rect; +#endif + #ifdef USE_LIGHT_DIRECTIONAL layout(std140) uniform DirectionalLightData { //ubo:3 @@ -346,7 +350,9 @@ void main() { uv_interp = uv_attrib; #endif -#if defined(ENABLE_UV2_INTERP) || defined(USE_LIGHTMAP) +#if defined(USE_LIGHTMAP) + uv2_interp = lightmap_uv_rect.zw * uv2_attrib + lightmap_uv_rect.xy; +#elif defined(ENABLE_UV2_INTERP) uv2_interp = uv2_attrib; #endif @@ -1435,8 +1441,109 @@ void reflection_process(int idx, vec3 vertex, vec3 normal, vec3 binormal, vec3 t } #ifdef USE_LIGHTMAP +#ifdef USE_LIGHTMAP_LAYERED +uniform mediump sampler2DArray lightmap; //texunit:-9 +uniform int lightmap_layer; +#else uniform mediump sampler2D lightmap; //texunit:-9 +#endif + uniform mediump float lightmap_energy; + +#ifdef USE_LIGHTMAP_FILTER_BICUBIC +uniform vec2 lightmap_texture_size; + +// w0, w1, w2, and w3 are the four cubic B-spline basis functions +float w0(float a) { + return (1.0 / 6.0) * (a * (a * (-a + 3.0) - 3.0) + 1.0); +} + +float w1(float a) { + return (1.0 / 6.0) * (a * a * (3.0 * a - 6.0) + 4.0); +} + +float w2(float a) { + return (1.0 / 6.0) * (a * (a * (-3.0 * a + 3.0) + 3.0) + 1.0); +} + +float w3(float a) { + return (1.0 / 6.0) * (a * a * a); +} + +// g0 and g1 are the two amplitude functions +float g0(float a) { + return w0(a) + w1(a); +} + +float g1(float a) { + return w2(a) + w3(a); +} + +// h0 and h1 are the two offset functions +float h0(float a) { + return -1.0 + w1(a) / (w0(a) + w1(a)); +} + +float h1(float a) { + return 1.0 + w3(a) / (w2(a) + w3(a)); +} + +vec4 texture_bicubic(sampler2D tex, vec2 uv) { + vec2 texel_size = vec2(1.0) / lightmap_texture_size; + + uv = uv * lightmap_texture_size + vec2(0.5); + + vec2 iuv = floor(uv); + vec2 fuv = fract(uv); + + float g0x = g0(fuv.x); + float g1x = g1(fuv.x); + float h0x = h0(fuv.x); + float h1x = h1(fuv.x); + float h0y = h0(fuv.y); + float h1y = h1(fuv.y); + + vec2 p0 = (vec2(iuv.x + h0x, iuv.y + h0y) - vec2(0.5)) * texel_size; + vec2 p1 = (vec2(iuv.x + h1x, iuv.y + h0y) - vec2(0.5)) * texel_size; + vec2 p2 = (vec2(iuv.x + h0x, iuv.y + h1y) - vec2(0.5)) * texel_size; + vec2 p3 = (vec2(iuv.x + h1x, iuv.y + h1y) - vec2(0.5)) * texel_size; + + return (g0(fuv.y) * (g0x * texture(tex, p0) + g1x * texture(tex, p1))) + + (g1(fuv.y) * (g0x * texture(tex, p2) + g1x * texture(tex, p3))); +} + +vec4 textureArray_bicubic(sampler2DArray tex, vec3 uv) { + vec2 texel_size = vec2(1.0) / lightmap_texture_size; + + uv.xy = uv.xy * lightmap_texture_size + vec2(0.5); + + vec2 iuv = floor(uv.xy); + vec2 fuv = fract(uv.xy); + + float g0x = g0(fuv.x); + float g1x = g1(fuv.x); + float h0x = h0(fuv.x); + float h1x = h1(fuv.x); + float h0y = h0(fuv.y); + float h1y = h1(fuv.y); + + vec2 p0 = (vec2(iuv.x + h0x, iuv.y + h0y) - vec2(0.5)) * texel_size; + vec2 p1 = (vec2(iuv.x + h1x, iuv.y + h0y) - vec2(0.5)) * texel_size; + vec2 p2 = (vec2(iuv.x + h0x, iuv.y + h1y) - vec2(0.5)) * texel_size; + vec2 p3 = (vec2(iuv.x + h1x, iuv.y + h1y) - vec2(0.5)) * texel_size; + + return (g0(fuv.y) * (g0x * texture(tex, vec3(p0, uv.z)) + g1x * texture(tex, vec3(p1, uv.z)))) + + (g1(fuv.y) * (g0x * texture(tex, vec3(p2, uv.z)) + g1x * texture(tex, vec3(p3, uv.z)))); +} + +#define LIGHTMAP_TEXTURE_SAMPLE(m_tex, m_uv) texture_bicubic(m_tex, m_uv) +#define LIGHTMAP_TEXTURE_LAYERED_SAMPLE(m_tex, m_uv) textureArray_bicubic(m_tex, m_uv) + +#else //!USE_LIGHTMAP_FILTER_BICUBIC +#define LIGHTMAP_TEXTURE_SAMPLE(m_tex, m_uv) texture(m_tex, m_uv) +#define LIGHTMAP_TEXTURE_LAYERED_SAMPLE(m_tex, m_uv) texture(m_tex, m_uv) + +#endif //USE_LIGHTMAP_FILTER_BICUBIC #endif #ifdef USE_LIGHTMAP_CAPTURE @@ -1823,7 +1930,11 @@ FRAGMENT_SHADER_CODE #endif #ifdef USE_LIGHTMAP - ambient_light = texture(lightmap, uv2).rgb * lightmap_energy; +#ifdef USE_LIGHTMAP_LAYERED + ambient_light = LIGHTMAP_TEXTURE_LAYERED_SAMPLE(lightmap, vec3(uv2, float(lightmap_layer))).rgb * lightmap_energy; +#else + ambient_light = LIGHTMAP_TEXTURE_SAMPLE(lightmap, uv2).rgb * lightmap_energy; +#endif #endif #ifdef USE_LIGHTMAP_CAPTURE diff --git a/editor/import/resource_importer_layered_texture.cpp b/editor/import/resource_importer_layered_texture.cpp index cdbc5b264c67..1b0b4028bac9 100644 --- a/editor/import/resource_importer_layered_texture.cpp +++ b/editor/import/resource_importer_layered_texture.cpp @@ -107,16 +107,17 @@ void ResourceImporterLayeredTexture::_save_tex(const Vector > &p_imag f->store_32(p_images[0]->get_height()); f->store_32(p_images.size()); //depth f->store_32(p_texture_flags); + + if ((p_compress_mode == COMPRESS_LOSSLESS) && p_images[0]->get_format() > Image::FORMAT_RGBA8) { + p_compress_mode = COMPRESS_UNCOMPRESSED; //these can't go as lossy + } + if (p_compress_mode != COMPRESS_VIDEO_RAM) { //vram needs to do a first compression to tell what the format is, for the rest its ok f->store_32(p_images[0]->get_format()); f->store_32(p_compress_mode); // 0 - lossless (PNG), 1 - vram, 2 - uncompressed } - if ((p_compress_mode == COMPRESS_LOSSLESS) && p_images[0]->get_format() > Image::FORMAT_RGBA8) { - p_compress_mode = COMPRESS_UNCOMPRESSED; //these can't go as lossy - } - for (int i = 0; i < p_images.size(); i++) { switch (p_compress_mode) { diff --git a/editor/import/resource_importer_scene.cpp b/editor/import/resource_importer_scene.cpp index d2012508a51c..589a916e9547 100644 --- a/editor/import/resource_importer_scene.cpp +++ b/editor/import/resource_importer_scene.cpp @@ -931,9 +931,6 @@ static String _make_extname(const String &p_str) { void ResourceImporterScene::_find_meshes(Node *p_node, Map, Transform> &meshes) { - List pi; - p_node->get_property_list(&pi); - MeshInstance *mi = Object::cast_to(p_node); if (mi) { @@ -941,11 +938,11 @@ void ResourceImporterScene::_find_meshes(Node *p_node, Map, Trans Ref mesh = mi->get_mesh(); if (mesh.is_valid() && !meshes.has(mesh)) { - Spatial *s = mi; + Spatial *s = Object::cast_to(mi); Transform transform; while (s) { transform = transform * s->get_transform(); - s = s->get_parent_spatial(); + s = Object::cast_to(s->get_parent()); } meshes[mesh] = transform; @@ -1427,29 +1424,117 @@ Error ResourceImporterScene::import(const String &p_source_file, const String &p Map, Transform> meshes; _find_meshes(scene, meshes); - if (light_bake_mode == 2) { + String file_id = src_path.get_file(); + String cache_file_path = base_path.plus_file(file_id + ".unwrap_cache"); + + int *cache_data = nullptr; + unsigned int cache_size = 0; + + if (FileAccess::exists(cache_file_path)) { + Error err2; + FileAccess *file = FileAccess::open(cache_file_path, FileAccess::READ, &err2); + + if (!err2) { + cache_size = file->get_len(); + cache_data = (int *)memalloc(cache_size); + file->get_buffer((unsigned char *)cache_data, cache_size); + } + + if (file) + memdelete(file); + } + + float texel_size = p_options["meshes/lightmap_texel_size"]; + texel_size = MAX(0.001, texel_size); + + Map used_meshes; + + EditorProgress progress2("gen_lightmaps", TTR("Generating Lightmaps"), meshes.size()); + int step = 0; + for (Map, Transform>::Element *E = meshes.front(); E; E = E->next()) { + + Ref mesh = E->key(); + String name = mesh->get_name(); + if (name == "") { //should not happen but.. + name = "Mesh " + itos(step); + } + + progress2.step(TTR("Generating for Mesh: ") + name + " (" + itos(step) + "/" + itos(meshes.size()) + ")", step); + + int *ret_cache_data = cache_data; + unsigned int ret_cache_size = cache_size; + bool ret_used_cache = true; // Tell the unwrapper to use the cache + Error err2 = mesh->lightmap_unwrap_cached(ret_cache_data, ret_cache_size, ret_used_cache, E->get(), texel_size); - float texel_size = p_options["meshes/lightmap_texel_size"]; - texel_size = MAX(0.001, texel_size); + if (err2 != OK) { + EditorNode::add_io_error("Mesh '" + name + "' failed lightmap generation. Please fix geometry."); + } else { + + String hash = String::md5((unsigned char *)ret_cache_data); + used_meshes.insert(hash, ret_cache_size); + + if (!ret_used_cache) { + // Cache was not used, add the generated entry to the current cache + + unsigned int new_cache_size = cache_size + ret_cache_size + (cache_size == 0 ? 4 : 0); + int *new_cache_data = (int *)memalloc(new_cache_size); + + if (cache_size == 0) { + // Cache was empty + new_cache_data[0] = 0; + cache_size = 4; + } else { + memcpy(new_cache_data, cache_data, cache_size); + memfree(cache_data); + } + + memcpy(&new_cache_data[cache_size / sizeof(int)], ret_cache_data, ret_cache_size); - EditorProgress progress2("gen_lightmaps", TTR("Generating Lightmaps"), meshes.size()); - int step = 0; - for (Map, Transform>::Element *E = meshes.front(); E; E = E->next()) { + cache_data = new_cache_data; + cache_size = new_cache_size; - Ref mesh = E->key(); - String name = mesh->get_name(); - if (name == "") { //should not happen but.. - name = "Mesh " + itos(step); + cache_data[0]++; // Increase entry count } + } + step++; + } - progress2.step(TTR("Generating for Mesh: ") + name + " (" + itos(step) + "/" + itos(meshes.size()) + ")", step); + Error err2; + FileAccess *file = FileAccess::open(cache_file_path, FileAccess::WRITE, &err2); + + if (err2) { + if (file) + memdelete(file); + } else { - Error err2 = mesh->lightmap_unwrap(E->get(), texel_size); - if (err2 != OK) { - EditorNode::add_io_error("Mesh '" + name + "' failed lightmap generation. Please fix geometry."); + // Store number of entries + file->store_32(used_meshes.size()); + + // Store cache entries + unsigned int r_idx = 1; + for (int i = 0; i < cache_data[0]; ++i) { + unsigned char *entry_start = (unsigned char *)&cache_data[r_idx]; + String entry_hash = String::md5(entry_start); + if (used_meshes.has(entry_hash)) { + unsigned int entry_size = used_meshes[entry_hash]; + file->store_buffer(entry_start, entry_size); } - step++; + + r_idx += 4; // hash + r_idx += 2; // size hint + + int vertex_count = cache_data[r_idx]; + r_idx += 1; // vertex count + r_idx += vertex_count; // vertex + r_idx += vertex_count * 2; // uvs + + int index_count = cache_data[r_idx]; + r_idx += 1; // index count + r_idx += index_count; // indices } + + file->close(); + memfree(cache_data); } } diff --git a/editor/plugins/baked_lightmap_editor_plugin.cpp b/editor/plugins/baked_lightmap_editor_plugin.cpp index 059448eb225b..8f10236e8e7f 100644 --- a/editor/plugins/baked_lightmap_editor_plugin.cpp +++ b/editor/plugins/baked_lightmap_editor_plugin.cpp @@ -30,32 +30,58 @@ #include "baked_lightmap_editor_plugin.h" -void BakedLightmapEditorPlugin::_bake() { - +void BakedLightmapEditorPlugin::_bake_select_file(const String &p_file) { if (lightmap) { BakedLightmap::BakeError err; if (get_tree()->get_edited_scene_root() && get_tree()->get_edited_scene_root() == lightmap) { - err = lightmap->bake(lightmap); + err = lightmap->bake(lightmap, p_file); } else { - err = lightmap->bake(lightmap->get_parent()); + err = lightmap->bake(lightmap->get_parent(), p_file); } + bake_func_end(); + switch (err) { - case BakedLightmap::BAKE_ERROR_NO_SAVE_PATH: - EditorNode::get_singleton()->show_warning(TTR("Can't determine a save path for lightmap images.\nSave your scene (for images to be saved in the same dir), or pick a save path from the BakedLightmap properties.")); - break; + case BakedLightmap::BAKE_ERROR_NO_SAVE_PATH: { + String scene_path = lightmap->get_filename(); + if (scene_path == String()) { + scene_path = lightmap->get_owner()->get_filename(); + } + if (scene_path == String()) { + EditorNode::get_singleton()->show_warning(TTR("Can't determine a save path for lightmap images.\nSave your scene and try again.")); + break; + } + scene_path = scene_path.get_basename() + ".lmbake"; + + file_dialog->set_current_path(scene_path); + file_dialog->popup_centered_ratio(); + + } break; case BakedLightmap::BAKE_ERROR_NO_MESHES: EditorNode::get_singleton()->show_warning(TTR("No meshes to bake. Make sure they contain an UV2 channel and that the 'Bake Light' flag is on.")); break; case BakedLightmap::BAKE_ERROR_CANT_CREATE_IMAGE: EditorNode::get_singleton()->show_warning(TTR("Failed creating lightmap images, make sure path is writable.")); break; + case BakedLightmap::BAKE_ERROR_LIGHTMAP_SIZE: + EditorNode::get_singleton()->show_warning(TTR("Failed determining lightmap size. Maximum lightmap size too small?")); + break; + case BakedLightmap::BAKE_ERROR_INVALID_MESH: + EditorNode::get_singleton()->show_warning(TTR("Some mesh is invalid. Make sure the UV2 channel values are conatined within the [0.0,1.0] square region.")); + break; + case BakedLightmap::BAKE_ERROR_NO_LIGHTMAPPER: + EditorNode::get_singleton()->show_warning(TTR("Godot editor was built without ray tracing support, lightmaps can't be baked.")); + break; default: { } } } } +void BakedLightmapEditorPlugin::_bake() { + _bake_select_file(""); +} + void BakedLightmapEditorPlugin::edit(Object *p_object) { BakedLightmap *s = Object::cast_to(p_object); @@ -81,29 +107,40 @@ void BakedLightmapEditorPlugin::make_visible(bool p_visible) { } EditorProgress *BakedLightmapEditorPlugin::tmp_progress = NULL; +EditorProgress *BakedLightmapEditorPlugin::tmp_subprogress = NULL; -void BakedLightmapEditorPlugin::bake_func_begin(int p_steps) { - - ERR_FAIL_COND(tmp_progress != NULL); - - tmp_progress = memnew(EditorProgress("bake_lightmaps", TTR("Bake Lightmaps"), p_steps, true)); +bool BakedLightmapEditorPlugin::bake_func_step(float p_progress, const String &p_description, void *, bool p_force_refresh) { + if (!tmp_progress) { + tmp_progress = memnew(EditorProgress("bake_lightmaps", TTR("Bake Lightmaps"), 1000, true)); + ERR_FAIL_COND_V(tmp_progress == nullptr, false); + } + return tmp_progress->step(p_description, p_progress * 1000, p_force_refresh); } -bool BakedLightmapEditorPlugin::bake_func_step(int p_step, const String &p_description) { - - ERR_FAIL_COND_V(tmp_progress == NULL, false); - return tmp_progress->step(p_description, p_step, false); +bool BakedLightmapEditorPlugin::bake_func_substep(float p_progress, const String &p_description, void *, bool p_force_refresh) { + if (!tmp_subprogress) { + tmp_subprogress = memnew(EditorProgress("bake_lightmaps_substep", "", 1000, true)); + ERR_FAIL_COND_V(tmp_subprogress == nullptr, false); + } + return tmp_subprogress->step(p_description, p_progress * 1000, p_force_refresh); } void BakedLightmapEditorPlugin::bake_func_end() { - ERR_FAIL_COND(tmp_progress == NULL); - memdelete(tmp_progress); - tmp_progress = NULL; + if (tmp_progress != nullptr) { + memdelete(tmp_progress); + tmp_progress = nullptr; + } + + if (tmp_subprogress != nullptr) { + memdelete(tmp_subprogress); + tmp_subprogress = nullptr; + } } void BakedLightmapEditorPlugin::_bind_methods() { ClassDB::bind_method("_bake", &BakedLightmapEditorPlugin::_bake); + ClassDB::bind_method("_bake_select_file", &BakedLightmapEditorPlugin::_bake_select_file); } BakedLightmapEditorPlugin::BakedLightmapEditorPlugin(EditorNode *p_node) { @@ -114,12 +151,19 @@ BakedLightmapEditorPlugin::BakedLightmapEditorPlugin(EditorNode *p_node) { bake->set_text(TTR("Bake Lightmaps")); bake->hide(); bake->connect("pressed", this, "_bake"); + + file_dialog = memnew(EditorFileDialog); + file_dialog->set_mode(EditorFileDialog::MODE_SAVE_FILE); + file_dialog->add_filter("*.lmbake ; LightMap Bake"); + file_dialog->set_title(TTR("Select lightmap bake file:")); + file_dialog->connect("file_selected", this, "_bake_select_file"); + bake->add_child(file_dialog); + add_control_to_container(CONTAINER_SPATIAL_EDITOR_MENU, bake); lightmap = NULL; - BakedLightmap::bake_begin_function = bake_func_begin; BakedLightmap::bake_step_function = bake_func_step; - BakedLightmap::bake_end_function = bake_func_end; + BakedLightmap::bake_substep_function = bake_func_substep; } BakedLightmapEditorPlugin::~BakedLightmapEditorPlugin() { diff --git a/editor/plugins/baked_lightmap_editor_plugin.h b/editor/plugins/baked_lightmap_editor_plugin.h index abd39b5ef446..5a4a6189b43b 100644 --- a/editor/plugins/baked_lightmap_editor_plugin.h +++ b/editor/plugins/baked_lightmap_editor_plugin.h @@ -45,11 +45,15 @@ class BakedLightmapEditorPlugin : public EditorPlugin { ToolButton *bake; EditorNode *editor; + EditorFileDialog *file_dialog; static EditorProgress *tmp_progress; - static void bake_func_begin(int p_steps); - static bool bake_func_step(int p_step, const String &p_description); + static EditorProgress *tmp_subprogress; + + static bool bake_func_step(float p_progress, const String &p_description, void *, bool p_force_refresh); + static bool bake_func_substep(float p_progress, const String &p_description, void *, bool p_force_refresh); static void bake_func_end(); + void _bake_select_file(const String &p_file); void _bake(); protected: diff --git a/editor/progress_dialog.cpp b/editor/progress_dialog.cpp index 7a09380387db..42aebcb2e04a 100644 --- a/editor/progress_dialog.cpp +++ b/editor/progress_dialog.cpp @@ -180,6 +180,7 @@ void ProgressDialog::add_task(const String &p_task, const String &p_label, int p t.progress = memnew(ProgressBar); t.progress->set_max(p_steps); t.progress->set_value(p_steps); + t.last_progress_tick = 0; vb2->add_child(t.progress); t.state = memnew(Label); t.state->set_clip_text(true); @@ -204,20 +205,20 @@ bool ProgressDialog::task_step(const String &p_task, const String &p_state, int ERR_FAIL_COND_V(!tasks.has(p_task), cancelled); + Task &t = tasks[p_task]; if (!p_force_redraw) { uint64_t tus = OS::get_singleton()->get_ticks_usec(); - if (tus - last_progress_tick < 200000) //200ms + if (tus - t.last_progress_tick < 200000) //200ms return cancelled; } - Task &t = tasks[p_task]; if (p_step < 0) t.progress->set_value(t.progress->get_value() + 1); else t.progress->set_value(p_step); t.state->set_text(p_state); - last_progress_tick = OS::get_singleton()->get_ticks_usec(); + t.last_progress_tick = OS::get_singleton()->get_ticks_usec(); if (cancel_hb->is_visible()) { OS::get_singleton()->force_process_input(); } @@ -254,7 +255,6 @@ ProgressDialog::ProgressDialog() { add_child(main); main->set_anchors_and_margins_preset(Control::PRESET_WIDE); set_exclusive(true); - last_progress_tick = 0; singleton = this; cancel_hb = memnew(HBoxContainer); main->add_child(cancel_hb); diff --git a/editor/progress_dialog.h b/editor/progress_dialog.h index bdd4c0ffd1d2..0a100a2a7f2d 100644 --- a/editor/progress_dialog.h +++ b/editor/progress_dialog.h @@ -77,13 +77,13 @@ class ProgressDialog : public Popup { VBoxContainer *vb; ProgressBar *progress; Label *state; + uint64_t last_progress_tick; }; HBoxContainer *cancel_hb; Button *cancel; Map tasks; VBoxContainer *main; - uint64_t last_progress_tick; static ProgressDialog *singleton; void _popup(); diff --git a/modules/denoise/SCsub b/modules/denoise/SCsub new file mode 100644 index 000000000000..bf3bd7d0739e --- /dev/null +++ b/modules/denoise/SCsub @@ -0,0 +1,118 @@ +#!/usr/bin/env python + +import resource_to_cpp + +Import("env") +Import("env_modules") + +env_oidn = env_modules.Clone() + +# Thirdparty source files +thirdparty_dir = "#thirdparty/oidn/" +thirdparty_sources = [ + "core/api.cpp", + "core/device.cpp", + "core/filter.cpp", + "core/network.cpp", + "core/autoencoder.cpp", + "core/transfer_function.cpp", + "weights/rtlightmap_hdr.gen.cpp", + "mkl-dnn/src/common/batch_normalization.cpp", + "mkl-dnn/src/common/concat.cpp", + "mkl-dnn/src/common/convolution.cpp", + "mkl-dnn/src/common/convolution_pd.cpp", + "mkl-dnn/src/common/deconvolution.cpp", + "mkl-dnn/src/common/eltwise.cpp", + "mkl-dnn/src/common/engine.cpp", + "mkl-dnn/src/common/inner_product.cpp", + "mkl-dnn/src/common/inner_product_pd.cpp", + "mkl-dnn/src/common/lrn.cpp", + "mkl-dnn/src/common/memory.cpp", + "mkl-dnn/src/common/memory_desc_wrapper.cpp", + "mkl-dnn/src/common/mkldnn_debug.cpp", + "mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp", + "mkl-dnn/src/common/pooling.cpp", + "mkl-dnn/src/common/primitive.cpp", + "mkl-dnn/src/common/primitive_attr.cpp", + "mkl-dnn/src/common/primitive_desc.cpp", + "mkl-dnn/src/common/primitive_exec_types.cpp", + "mkl-dnn/src/common/primitive_iterator.cpp", + "mkl-dnn/src/common/query.cpp", + "mkl-dnn/src/common/reorder.cpp", + "mkl-dnn/src/common/rnn.cpp", + "mkl-dnn/src/common/scratchpad.cpp", + "mkl-dnn/src/common/shuffle.cpp", + "mkl-dnn/src/common/softmax.cpp", + "mkl-dnn/src/common/stream.cpp", + "mkl-dnn/src/common/sum.cpp", + "mkl-dnn/src/common/utils.cpp", + "mkl-dnn/src/common/verbose.cpp", + "mkl-dnn/src/cpu/cpu_barrier.cpp", + "mkl-dnn/src/cpu/cpu_concat.cpp", + "mkl-dnn/src/cpu/cpu_engine.cpp", + "mkl-dnn/src/cpu/cpu_memory.cpp", + "mkl-dnn/src/cpu/cpu_reducer.cpp", + "mkl-dnn/src/cpu/cpu_reorder.cpp", + "mkl-dnn/src/cpu/cpu_sum.cpp", + "mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp", + "mkl-dnn/src/cpu/jit_avx2_convolution.cpp", + "mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp", + "mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp", + "mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp", + "mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp", + "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp", + "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp", + "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp", + "mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp", + "mkl-dnn/src/cpu/jit_sse42_convolution.cpp", + "mkl-dnn/src/cpu/jit_transpose_src_utils.cpp", + "mkl-dnn/src/cpu/jit_uni_eltwise.cpp", + "mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp", + "mkl-dnn/src/cpu/jit_uni_pooling.cpp", + "mkl-dnn/src/cpu/jit_uni_reorder.cpp", + "mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp", + "mkl-dnn/src/cpu/jit_utils/jit_utils.cpp", + "mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c", + "common/platform.cpp", + "common/thread.cpp", + "common/tensor.cpp", +] +thirdparty_sources = [thirdparty_dir + file for file in thirdparty_sources] + +thirdparty_include_dirs = [ + "", + "include", + "mkl-dnn/include", + "mkl-dnn/src", + "mkl-dnn/src/common", + "mkl-dnn/src/cpu/xbyak", + "mkl-dnn/src/cpu", +] +thirdparty_include_dirs = [thirdparty_dir + file for file in thirdparty_include_dirs] + + +env_oidn.Prepend(CPPPATH=thirdparty_include_dirs) +env_oidn.Append( + CPPDEFINES=[ + "MKLDNN_THR=MKLDNN_THR_SEQ", + "OIDN_STATIC_LIB", + "__STDC_CONSTANT_MACROS", + "__STDC_LIMIT_MACROS", + "DISABLE_VERBOSE", + "MKLDNN_ENABLE_CONCURRENT_EXEC", + "NDEBUG", + ] +) + +env_thirdparty = env_oidn.Clone() +env_thirdparty.disable_warnings() +env_thirdparty.add_source_files(env.modules_sources, thirdparty_sources) + +weights_in_path = thirdparty_dir + "weights/rtlightmap_hdr.tza" +weights_out_path = thirdparty_dir + "weights/rtlightmap_hdr.gen.cpp" + +env_thirdparty.Depends(weights_out_path, weights_in_path) +env_thirdparty.CommandNoCache(weights_out_path, weights_in_path, resource_to_cpp.tza_to_cpp) + +env_oidn.add_source_files(env.modules_sources, "denoise_wrapper.cpp") +env_modules.add_source_files(env.modules_sources, ["register_types.cpp", "lightmap_denoiser.cpp"]) diff --git a/modules/denoise/config.py b/modules/denoise/config.py new file mode 100644 index 000000000000..211b201fec7c --- /dev/null +++ b/modules/denoise/config.py @@ -0,0 +1,15 @@ +def can_build(env, platform): + # Thirdparty dependency OpenImage Denoise includes oneDNN library + # which only supports 64-bit architectures. + # It's also only relevant for tools build and desktop platforms, + # as doing lightmap generation and denoising on Android or HTML5 + # would be a bit far-fetched. + # Note: oneDNN doesn't support ARM64, OIDN needs updating to the latest version + supported_platform = platform in ["x11", "osx", "windows", "server"] + supported_bits = env["bits"] == "64" + supported_arch = env["arch"] != "arm64" + return env["tools"] and supported_platform and supported_bits and supported_arch + + +def configure(env): + pass diff --git a/modules/denoise/denoise_wrapper.cpp b/modules/denoise/denoise_wrapper.cpp new file mode 100644 index 000000000000..0ef33e7ba522 --- /dev/null +++ b/modules/denoise/denoise_wrapper.cpp @@ -0,0 +1,67 @@ +/*************************************************************************/ +/* denoise_wrapper.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "denoise_wrapper.h" +#include "core/os/copymem.h" +#include "core/os/memory.h" +#include "thirdparty/oidn/include/OpenImageDenoise/oidn.h" +#include + +void *oidn_denoiser_init() { + OIDNDeviceImpl *device = oidnNewDevice(OIDN_DEVICE_TYPE_CPU); + oidnCommitDevice(device); + return device; +} + +bool oidn_denoise(void *deviceptr, float *p_floats, int p_width, int p_height) { + OIDNDeviceImpl *device = (OIDNDeviceImpl *)deviceptr; + OIDNFilter filter = oidnNewFilter(device, "RTLightmap"); + void *input_buffer = memalloc(p_width * p_height * 3 * sizeof(float)); + copymem(input_buffer, p_floats, p_width * p_height * 3 * sizeof(float)); + oidnSetSharedFilterImage(filter, "color", input_buffer, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0); + oidnSetSharedFilterImage(filter, "output", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0); + oidnSetFilter1b(filter, "hdr", true); + oidnCommitFilter(filter); + oidnExecuteFilter(filter); + + const char *msg; + bool success = true; + if (oidnGetDeviceError(device, &msg) != OIDN_ERROR_NONE) { + printf("LightmapDenoiser: %s\n", msg); + success = false; + } + + oidnReleaseFilter(filter); + return success; +} + +void oidn_denoiser_finish(void *device) { + oidnReleaseDevice((OIDNDeviceImpl *)device); +} diff --git a/modules/denoise/denoise_wrapper.h b/modules/denoise/denoise_wrapper.h new file mode 100644 index 000000000000..25e342bc93f1 --- /dev/null +++ b/modules/denoise/denoise_wrapper.h @@ -0,0 +1,38 @@ +/*************************************************************************/ +/* denoise_wrapper.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef DENOISE_WRAPPER_H +#define DENOISE_WRAPPER_H + +void *oidn_denoiser_init(); +bool oidn_denoise(void *device, float *p_floats, int p_width, int p_height); +void oidn_denoiser_finish(void *device); + +#endif // DENOISE_WRAPPER_H diff --git a/modules/denoise/lightmap_denoiser.cpp b/modules/denoise/lightmap_denoiser.cpp new file mode 100644 index 000000000000..6023ebf05a2d --- /dev/null +++ b/modules/denoise/lightmap_denoiser.cpp @@ -0,0 +1,66 @@ +/*************************************************************************/ +/* lightmap_denoiser.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "lightmap_denoiser.h" + +#include "denoise_wrapper.h" + +LightmapDenoiser *LightmapDenoiserOIDN::create_oidn_denoiser() { + return memnew(LightmapDenoiserOIDN); +} + +void LightmapDenoiserOIDN::make_default_denoiser() { + create_function = create_oidn_denoiser; +} + +Ref LightmapDenoiserOIDN::denoise_image(const Ref &p_image) { + Ref img = p_image->duplicate(); + + img->convert(Image::FORMAT_RGBF); + + PoolByteArray data = img->get_data(); + { + PoolByteArray::Write w = data.write(); + if (!oidn_denoise(device, (float *)w.ptr(), img->get_width(), img->get_height())) { + return p_image; + } + } + + img->create(img->get_width(), img->get_height(), false, img->get_format(), data); + return img; +} + +LightmapDenoiserOIDN::LightmapDenoiserOIDN() { + device = oidn_denoiser_init(); +} + +LightmapDenoiserOIDN::~LightmapDenoiserOIDN() { + oidn_denoiser_finish(device); +} diff --git a/modules/denoise/lightmap_denoiser.h b/modules/denoise/lightmap_denoiser.h new file mode 100644 index 000000000000..21d5ca3f328d --- /dev/null +++ b/modules/denoise/lightmap_denoiser.h @@ -0,0 +1,56 @@ +/*************************************************************************/ +/* lightmap_denoiser.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef LIGHTMAP_DENOISER_H +#define LIGHTMAP_DENOISER_H + +#include "core/class_db.h" +#include "scene/3d/lightmapper.h" + +struct OIDNDeviceImpl; + +class LightmapDenoiserOIDN : public LightmapDenoiser { + GDCLASS(LightmapDenoiserOIDN, LightmapDenoiser); + +protected: + void *device = nullptr; + +public: + static LightmapDenoiser *create_oidn_denoiser(); + + Ref denoise_image(const Ref &p_image); + + static void make_default_denoiser(); + + LightmapDenoiserOIDN(); + ~LightmapDenoiserOIDN(); +}; + +#endif // LIGHTMAP_DENOISER_H diff --git a/modules/denoise/register_types.cpp b/modules/denoise/register_types.cpp new file mode 100644 index 000000000000..5fe835a365ec --- /dev/null +++ b/modules/denoise/register_types.cpp @@ -0,0 +1,41 @@ +/*************************************************************************/ +/* register_types.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "register_types.h" + +#include "core/engine.h" +#include "lightmap_denoiser.h" + +void register_denoise_types() { + LightmapDenoiserOIDN::make_default_denoiser(); +} + +void unregister_denoise_types() { +} diff --git a/modules/denoise/register_types.h b/modules/denoise/register_types.h new file mode 100644 index 000000000000..516a91b134a5 --- /dev/null +++ b/modules/denoise/register_types.h @@ -0,0 +1,37 @@ +/*************************************************************************/ +/* register_types.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef DENOISE_REGISTER_TYPES_H +#define DENOISE_REGISTER_TYPES_H + +void register_denoise_types(); +void unregister_denoise_types(); + +#endif // DENOISE_REGISTER_TYPES_H diff --git a/modules/denoise/resource_to_cpp.py b/modules/denoise/resource_to_cpp.py new file mode 100644 index 000000000000..6c8327735599 --- /dev/null +++ b/modules/denoise/resource_to_cpp.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python + +## ======================================================================== ## +## Copyright 2009-2019 Intel Corporation ## +## ## +## Licensed under the Apache License, Version 2.0 (the "License"); ## +## you may not use this file except in compliance with the License. ## +## You may obtain a copy of the License at ## +## ## +## http://www.apache.org/licenses/LICENSE-2.0 ## +## ## +## Unless required by applicable law or agreed to in writing, software ## +## distributed under the License is distributed on an "AS IS" BASIS, ## +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ## +## See the License for the specific language governing permissions and ## +## limitations under the License. ## +## ======================================================================== ## + +import os +from array import array + +# Generates a C++ file from the specified binary resource file +def generate(in_path, out_path): + + namespace = "oidn::weights" + scopes = namespace.split("::") + + file_name = os.path.basename(in_path) + var_name = os.path.splitext(file_name)[0] + + with open(in_path, "rb") as in_file, open(out_path, "w") as out_file: + # Header + out_file.write("// Generated from: %s\n" % file_name) + out_file.write("#include \n\n") + + # Open the namespaces + for s in scopes: + out_file.write("namespace %s {\n" % s) + if scopes: + out_file.write("\n") + + # Read the file + in_data = array("B", in_file.read()) + + # Write the size + out_file.write("//const size_t %s_size = %d;\n\n" % (var_name, len(in_data))) + + # Write the data + out_file.write("unsigned char %s[] = {" % var_name) + for i in range(len(in_data)): + c = in_data[i] + if i > 0: + out_file.write(",") + if (i + 1) % 20 == 1: + out_file.write("\n") + out_file.write("%d" % c) + out_file.write("\n};\n") + + # Close the namespaces + if scopes: + out_file.write("\n") + for scope in reversed(scopes): + out_file.write("} // namespace %s\n" % scope) + + +def tza_to_cpp(target, source, env): + for x in zip(source, target): + generate(str(x[0]), str(x[1])) diff --git a/modules/lightmapper_cpu/SCsub b/modules/lightmapper_cpu/SCsub new file mode 100644 index 000000000000..e27beafe3dfc --- /dev/null +++ b/modules/lightmapper_cpu/SCsub @@ -0,0 +1,9 @@ +#!/usr/bin/env python + +Import("env") +Import("env_modules") + +env_lightmapper_rd = env_modules.Clone() +# Godot source files +env_lightmapper_rd.Prepend(CPPPATH=["#thirdparty/embree/include"]) +env_lightmapper_rd.add_source_files(env.modules_sources, "*.cpp") diff --git a/modules/lightmapper_cpu/config.py b/modules/lightmapper_cpu/config.py new file mode 100644 index 000000000000..d01c1726dd34 --- /dev/null +++ b/modules/lightmapper_cpu/config.py @@ -0,0 +1,6 @@ +def can_build(env, platform): + return env["tools"] and env["module_raycast_enabled"] + + +def configure(env): + pass diff --git a/modules/lightmapper_cpu/lightmapper_cpu.cpp b/modules/lightmapper_cpu/lightmapper_cpu.cpp new file mode 100644 index 000000000000..3cbdc7dc9efa --- /dev/null +++ b/modules/lightmapper_cpu/lightmapper_cpu.cpp @@ -0,0 +1,1622 @@ +/*************************************************************************/ +/* lightmapper_cpu.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "lightmapper_cpu.h" + +#include "core/math/geometry.h" +#include "core/os/os.h" +#include "core/os/threaded_array_processor.h" +#include "core/project_settings.h" +#include "modules/raycast/lightmap_raycaster.h" + +Error LightmapperCPU::_layout_atlas(int p_max_size, Vector2i *r_atlas_size, int *r_atlas_slices) { + + Vector2i atlas_size; + for (unsigned int i = 0; i < mesh_instances.size(); i++) { + if (mesh_instances[i].generate_lightmap) { + Vector2i size = mesh_instances[i].size; + atlas_size.width = MAX(atlas_size.width, size.width + 2); + atlas_size.height = MAX(atlas_size.height, size.height + 2); + } + } + + int max = nearest_power_of_2_templated(atlas_size.width); + max = MAX(max, nearest_power_of_2_templated(atlas_size.height)); + + if (max > p_max_size) { + return ERR_INVALID_DATA; + } + + Vector2i best_atlas_size; + int best_atlas_slices = 0; + int best_atlas_memory = 0x7FFFFFFF; + float best_atlas_mem_utilization = 0; + Vector best_atlas_offsets; + Vector best_scaled_sizes; + + int first_try_mem_occupied = 0; + int first_try_mem_used = 0; + for (int recovery_percent = 0; recovery_percent <= 100; recovery_percent += 10) { + // These only make sense from the second round of the loop + float recovery_scale = 1; + int target_mem_occupied = 0; + if (recovery_percent != 0) { + target_mem_occupied = first_try_mem_occupied + (first_try_mem_used - first_try_mem_occupied) * recovery_percent * 0.01f; + float new_squared_recovery_scale = static_cast(target_mem_occupied) / first_try_mem_occupied; + if (new_squared_recovery_scale > 1.0f) { + recovery_scale = Math::sqrt(new_squared_recovery_scale); + } + } + + atlas_size = Vector2i(max, max); + while (atlas_size.x <= p_max_size && atlas_size.y <= p_max_size) { + + if (recovery_percent != 0) { + // Find out how much memory is not recoverable (because of lightmaps that can't grow), + // to compute a greater recovery scale for those that can. + int mem_unrecoverable = 0; + + for (unsigned int i = 0; i < mesh_instances.size(); i++) { + if (mesh_instances[i].generate_lightmap) { + Vector2i scaled_size = Vector2i( + static_cast(recovery_scale * mesh_instances[i].size.x), + static_cast(recovery_scale * mesh_instances[i].size.y)); + if (scaled_size.x + 2 > atlas_size.x || scaled_size.y + 2 > atlas_size.y) { + mem_unrecoverable += scaled_size.x * scaled_size.y - mesh_instances[i].size.x * mesh_instances[i].size.y; + } + } + } + float new_squared_recovery_scale = static_cast(target_mem_occupied - mem_unrecoverable) / (first_try_mem_occupied - mem_unrecoverable); + if (new_squared_recovery_scale > 1.0f) { + recovery_scale = Math::sqrt(new_squared_recovery_scale); + } + } + + Vector scaled_sizes; + scaled_sizes.resize(mesh_instances.size()); + { + for (unsigned int i = 0; i < mesh_instances.size(); i++) { + if (mesh_instances[i].generate_lightmap) { + if (recovery_percent == 0) { + scaled_sizes.write[i] = mesh_instances[i].size; + } else { + Vector2i scaled_size = Vector2i( + static_cast(recovery_scale * mesh_instances[i].size.x), + static_cast(recovery_scale * mesh_instances[i].size.y)); + if (scaled_size.x + 2 <= atlas_size.x && scaled_size.y + 2 <= atlas_size.y) { + scaled_sizes.write[i] = scaled_size; + } else { + scaled_sizes.write[i] = mesh_instances[i].size; + } + } + } else { + // Don't consider meshes with no generated lightmap here; will compensate later + scaled_sizes.write[i] = Vector2i(); + } + } + } + + Vector source_sizes; + source_sizes.resize(scaled_sizes.size()); + Vector source_indices; + source_indices.resize(scaled_sizes.size()); + for (int i = 0; i < source_sizes.size(); i++) { + source_sizes.write[i] = scaled_sizes[i] + Vector2i(2, 2); // Add padding between lightmaps + source_indices.write[i] = i; + } + + Vector curr_atlas_offsets; + curr_atlas_offsets.resize(source_sizes.size()); + + int slices = 0; + + while (source_sizes.size() > 0) { + + Vector offsets = Geometry::partial_pack_rects(source_sizes, atlas_size); + Vector new_indices; + Vector new_sources; + for (int i = 0; i < offsets.size(); i++) { + Geometry::PackRectsResult ofs = offsets[i]; + int sidx = source_indices[i]; + if (ofs.packed) { + curr_atlas_offsets.write[sidx] = { slices, ofs.x + 1, ofs.y + 1 }; + } else { + new_indices.push_back(sidx); + new_sources.push_back(source_sizes[i]); + } + } + + source_sizes = new_sources; + source_indices = new_indices; + slices++; + } + + int mem_used = atlas_size.x * atlas_size.y * slices; + int mem_occupied = 0; + for (int i = 0; i < curr_atlas_offsets.size(); i++) { + mem_occupied += scaled_sizes[i].x * scaled_sizes[i].y; + } + + float mem_utilization = static_cast(mem_occupied) / mem_used; + if (slices * atlas_size.y < 16384) { // Maximum Image size + if (mem_used < best_atlas_memory || (mem_used == best_atlas_memory && mem_utilization > best_atlas_mem_utilization)) { + best_atlas_size = atlas_size; + best_atlas_offsets = curr_atlas_offsets; + best_atlas_slices = slices; + best_atlas_memory = mem_used; + best_atlas_mem_utilization = mem_utilization; + best_scaled_sizes = scaled_sizes; + } + } + + if (recovery_percent == 0) { + first_try_mem_occupied = mem_occupied; + first_try_mem_used = mem_used; + } + + if (atlas_size.width == atlas_size.height) { + atlas_size.width *= 2; + } else { + atlas_size.height *= 2; + } + } + } + + if (best_atlas_size == Vector2i()) { + return ERR_INVALID_DATA; + } + + *r_atlas_size = best_atlas_size; + *r_atlas_slices = best_atlas_slices; + + for (unsigned int i = 0; i < mesh_instances.size(); i++) { + if (best_scaled_sizes[i] != Vector2i()) { + mesh_instances[i].size = best_scaled_sizes[i]; + mesh_instances[i].offset = Vector2i(best_atlas_offsets[i].x, best_atlas_offsets[i].y); + mesh_instances[i].slice = best_atlas_offsets[i].slice; + } + } + return OK; +} + +void LightmapperCPU::_thread_func_callback(void *p_thread_data) { + ThreadData *thread_data = reinterpret_cast(p_thread_data); + thread_process_array(thread_data->count, thread_data->instance, &LightmapperCPU::_thread_func_wrapper, thread_data); +} + +void LightmapperCPU::_thread_func_wrapper(uint32_t p_idx, ThreadData *p_thread_data) { + + if (thread_cancelled) { + return; + } + + (p_thread_data->instance->*p_thread_data->thread_func)(p_idx, p_thread_data->userdata); + + thread_progress++; +} + +bool LightmapperCPU::_parallel_run(int p_count, const String &p_description, BakeThreadFunc p_thread_func, void *p_userdata, BakeStepFunc p_substep_func) { + + bool cancelled = false; + if (p_substep_func) { + cancelled = p_substep_func(0.0f, vformat("%s (%d/%d)", p_description, 0, p_count), nullptr, false); + } + + thread_progress = 0; + thread_cancelled = false; + +#ifdef NO_THREAD + for (int i = 0; !cancelled && i < p_count; i++) { + (this->*p_thread_func)(i, p_userdata); + float p = float(i) / p_count; + if (p_substep_func) { + cancelled = p_substep_func(p, vformat("%s (%d/%d)", p_description, i + 1, p_count), nullptr, false); + } + } +#else + + if (p_count == 0) { + return cancelled; + } + + ThreadData td; + td.instance = this; + td.count = p_count; + td.thread_func = p_thread_func; + td.userdata = p_userdata; + Thread *runner_thread = Thread::create(_thread_func_callback, &td); + + int progress = thread_progress; + + while (!cancelled && progress < p_count) { + float p = float(progress) / p_count; + if (p_substep_func) { + cancelled = p_substep_func(p, vformat("%s (%d/%d)", p_description, progress + 1, p_count), nullptr, false); + } + progress = thread_progress; + } + thread_cancelled = cancelled; + Thread::wait_to_finish(runner_thread); +#endif + + thread_cancelled = false; + + return cancelled; +} + +void LightmapperCPU::_generate_buffer(uint32_t p_idx, void *p_unused) { + + const Size2i &size = mesh_instances[p_idx].size; + + int buffer_size = size.x * size.y; + + LocalVector &lightmap = scene_lightmaps[p_idx]; + LocalVector &lightmap_indices = scene_lightmap_indices[p_idx]; + + lightmap_indices.resize(buffer_size); + + for (unsigned int i = 0; i < lightmap_indices.size(); i++) { + lightmap_indices[i] = -1; + } + + MeshData &md = mesh_instances[p_idx].data; + + LocalVector > albedo_images; + LocalVector > emission_images; + + for (int surface_id = 0; surface_id < md.albedo.size(); surface_id++) { + albedo_images.push_back(_init_bake_texture(md.albedo[surface_id], albedo_textures, Image::FORMAT_RGBA8)); + emission_images.push_back(_init_bake_texture(md.emission[surface_id], emission_textures, Image::FORMAT_RGBH)); + } + + int surface_id = 0; + int surface_facecount = 0; + const Vector3 *points_ptr = md.points.ptr(); + const Vector3 *normals_ptr = md.normal.ptr(); + const Vector2 *uvs_ptr = md.uv.empty() ? nullptr : md.uv.ptr(); + const Vector2 *uv2s_ptr = md.uv2.ptr(); + + for (int i = 0; i < md.points.size() / 3; i++) { + + Ref albedo = albedo_images[surface_id]; + Ref emission = emission_images[surface_id]; + + albedo->lock(); + emission->lock(); + _plot_triangle(&(uv2s_ptr[i * 3]), &(points_ptr[i * 3]), &(normals_ptr[i * 3]), uvs_ptr ? &(uvs_ptr[i * 3]) : nullptr, albedo, emission, size, lightmap, lightmap_indices); + albedo->unlock(); + emission->unlock(); + + surface_facecount++; + if (surface_facecount == md.surface_facecounts[surface_id]) { + surface_id++; + surface_facecount = 0; + } + } +} + +Ref LightmapperCPU::_init_bake_texture(const MeshData::TextureDef &p_texture_def, const Map > &p_tex_cache, Image::Format p_default_format) { + Ref ret; + if (p_texture_def.tex_rid.is_valid()) { + ret = p_tex_cache[p_texture_def.tex_rid]->duplicate(); + ret->lock(); + for (int j = 0; j < ret->get_height(); j++) { + for (int i = 0; i < ret->get_width(); i++) { + ret->set_pixel(i, j, ret->get_pixel(i, j) * p_texture_def.mul + p_texture_def.add); + } + } + ret->unlock(); + } else { + ret.instance(); + ret->create(8, 8, false, p_default_format); + ret->fill(p_texture_def.add * p_texture_def.mul); + } + return ret; +} + +Color LightmapperCPU::_bilinear_sample(const Ref &p_img, const Vector2 &p_uv, bool p_clamp_x, bool p_clamp_y) { + + int width = p_img->get_width(); + int height = p_img->get_height(); + + Vector2 uv; + uv.x = p_clamp_x ? p_uv.x : Math::fposmod(p_uv.x, 1.0f); + uv.y = p_clamp_y ? p_uv.y : Math::fposmod(p_uv.y, 1.0f); + + float xf = uv.x * width; + float yf = uv.y * height; + + int xi = (int)xf; + int yi = (int)yf; + + Color texels[4]; + for (int i = 0; i < 4; i++) { + int sample_x = xi + i % 2; + int sample_y = yi + i / 2; + + sample_x = CLAMP(sample_x, 0, width - 1); + sample_y = CLAMP(sample_y, 0, height - 1); + + texels[i] = p_img->get_pixel(sample_x, sample_y); + } + + float tx = xf - xi; + float ty = yf - yi; + + Color c = Color(0, 0, 0, 0); + for (int i = 0; i < 4; i++) { + c[i] = Math::lerp(Math::lerp(texels[0][i], texels[1][i], tx), Math::lerp(texels[2][i], texels[3][i], tx), ty); + } + return c; +} + +Vector3 LightmapperCPU::_fix_sample_position(const Vector3 &p_position, const Vector3 &p_texel_center, const Vector3 &p_normal, const Vector3 &p_tangent, const Vector3 &p_bitangent, const Vector2 &p_texel_size) { + + Basis tangent_basis(p_tangent, p_bitangent, p_normal); + tangent_basis.orthonormalize(); + Vector2 half_size = p_texel_size / 2.0f; + Vector3 corrected = p_position; + + for (int i = -1; i <= 1; i += 1) { + for (int j = -1; j <= 1; j += 1) { + if (i == 0 && j == 0) continue; + Vector3 offset = Vector3(half_size.x * i, half_size.y * j, 0.0); + Vector3 rotated_offset = tangent_basis.xform_inv(offset); + Vector3 target = p_texel_center + rotated_offset; + Vector3 ray_vector = target - corrected; + + Vector3 ray_back_offset = -ray_vector.normalized() * parameters.bias; + Vector3 ray_origin = corrected + ray_back_offset; + ray_vector = target - ray_origin; + float ray_length = ray_vector.length(); + LightmapRaycaster::Ray ray(ray_origin + p_normal * parameters.bias, ray_vector.normalized(), 0.0f, ray_length + parameters.bias); + + bool hit = raycaster->intersect(ray); + if (hit) { + ray.normal.normalize(); + if (ray.normal.dot(ray_vector.normalized()) > 0.0f) { + corrected = ray_origin + ray.dir * ray.tfar + ray.normal * (parameters.bias * 2.0f); + } + } + } + } + + return corrected; +} + +void LightmapperCPU::_plot_triangle(const Vector2 *p_vertices, const Vector3 *p_positions, const Vector3 *p_normals, const Vector2 *p_uvs, const Ref &p_albedo, const Ref &p_emission, Vector2i p_size, LocalVector &r_lightmap, LocalVector &r_lightmap_indices) { + Vector2 pv0 = p_vertices[0]; + Vector2 pv1 = p_vertices[1]; + Vector2 pv2 = p_vertices[2]; + + Vector2 v0 = pv0 * p_size; + Vector2 v1 = pv1 * p_size; + Vector2 v2 = pv2 * p_size; + + Vector3 p0 = p_positions[0]; + Vector3 p1 = p_positions[1]; + Vector3 p2 = p_positions[2]; + + Vector3 n0 = p_normals[0]; + Vector3 n1 = p_normals[1]; + Vector3 n2 = p_normals[2]; + + Vector2 uv0 = p_uvs == nullptr ? Vector2(0.5f, 0.5f) : p_uvs[0]; + Vector2 uv1 = p_uvs == nullptr ? Vector2(0.5f, 0.5f) : p_uvs[1]; + Vector2 uv2 = p_uvs == nullptr ? Vector2(0.5f, 0.5f) : p_uvs[2]; + +#define edgeFunction(a, b, c) ((c)[0] - (a)[0]) * ((b)[1] - (a)[1]) - ((c)[1] - (a)[1]) * ((b)[0] - (a)[0]) + + if (edgeFunction(v0, v1, v2) < 0.0) { + SWAP(pv1, pv2); + SWAP(v1, v2); + SWAP(p1, p2); + SWAP(n1, n2); + SWAP(uv1, uv2); + } + + Vector3 edge1 = p1 - p0; + Vector3 edge2 = p2 - p0; + + Vector2 uv_edge1 = pv1 - pv0; + Vector2 uv_edge2 = pv2 - pv0; + + float r = 1.0f / (uv_edge1.x * uv_edge2.y - uv_edge1.y * uv_edge2.x); + + Vector3 tangent = (edge1 * uv_edge2.y - edge2 * uv_edge1.y) * r; + Vector3 bitangent = (edge2 * uv_edge1.x - edge1 * uv_edge2.x) * r; + + tangent.normalize(); + bitangent.normalize(); + + // Compute triangle bounding box + Vector2 bbox_min = Vector2(MIN(v0.x, MIN(v1.x, v2.x)), MIN(v0.y, MIN(v1.y, v2.y))); + Vector2 bbox_max = Vector2(MAX(v0.x, MAX(v1.x, v2.x)), MAX(v0.y, MAX(v1.y, v2.y))); + + bbox_min = bbox_min.floor(); + bbox_max = bbox_max.ceil(); + + uint32_t min_x = MAX(bbox_min.x - 2, 0); + uint32_t min_y = MAX(bbox_min.y - 2, 0); + uint32_t max_x = MIN(bbox_max.x, p_size.x - 1); + uint32_t max_y = MIN(bbox_max.y, p_size.y - 1); + + Vector2 texel_size; + Vector2 centroid = (v0 + v1 + v2) / 3.0f; + Vector3 centroid_pos = (p0 + p1 + p2) / 3.0f; + for (int i = 0; i < 2; i++) { + Vector2 p = centroid; + p[i] += 1; + Vector3 bary = Geometry::barycentric_coordinates_2d(p, v0, v1, v2); + Vector3 pos = p0 * bary[0] + p1 * bary[1] + p2 * bary[2]; + texel_size[i] = centroid_pos.distance_to(pos); + } + + Vector pixel_polygon; + pixel_polygon.resize(4); + static const Vector2 corners[4] = { Vector2(0, 0), Vector2(0, 1), Vector2(1, 1), Vector2(1, 0) }; + + Vector triangle_polygon; + triangle_polygon.push_back(v0); + triangle_polygon.push_back(v1); + triangle_polygon.push_back(v2); + + for (uint32_t j = min_y; j <= max_y; ++j) { + for (uint32_t i = min_x; i <= max_x; i++) { + + int ofs = j * p_size.x + i; + int texel_idx = r_lightmap_indices[ofs]; + + if (texel_idx >= 0 && r_lightmap[texel_idx].area_coverage >= 0.5f) { + continue; + } + + Vector3 barycentric_coords; + float area_coverage = 0.0f; + bool intersected = false; + + for (int k = 0; k < 4; k++) { + pixel_polygon.write[k] = Vector2(i, j) + corners[k]; + } + + const float max_dist = 0.05; + bool v0eqv1 = v0.distance_squared_to(v1) < max_dist; + bool v1eqv2 = v1.distance_squared_to(v2) < max_dist; + bool v2eqv0 = v2.distance_squared_to(v0) < max_dist; + if (v0eqv1 && v1eqv2 && v2eqv0) { + intersected = true; + barycentric_coords = Vector3(1, 0, 0); + } else if (v0eqv1 || v1eqv2 || v2eqv0) { + + Vector segment; + segment.resize(2); + if (v0eqv1) { + segment.write[0] = v0; + segment.write[1] = v2; + } else if (v1eqv2) { + segment.write[0] = v1; + segment.write[1] = v0; + } else { + segment.write[0] = v0; + segment.write[1] = v1; + } + + Vector > intersected_segments = Geometry::intersect_polyline_with_polygon_2d(segment, pixel_polygon); + ERR_FAIL_COND_MSG(intersected_segments.size() > 1, "[Lightmapper] Itersecting a segment and a convex polygon should give at most one segment."); + if (!intersected_segments.empty()) { + const Vector &intersected_segment = intersected_segments[0]; + ERR_FAIL_COND_MSG(intersected_segment.size() != 2, "[Lightmapper] Itersecting a segment and a convex polygon should give at most one segment."); + Vector2 sample_pos = (intersected_segment[0] + intersected_segment[1]) / 2.0f; + + float u = (segment[0].distance_to(sample_pos)) / (segment[0].distance_to(segment[1])); + float v = (1.0f - u) / 2.0f; + intersected = true; + if (v0eqv1) { + barycentric_coords = Vector3(v, v, u); + } else if (v1eqv2) { + barycentric_coords = Vector3(u, v, v); + } else { + barycentric_coords = Vector3(v, u, v); + } + } + + } else if (edgeFunction(v0, v1, v2) < 0.005) { + Vector2 direction = v0 - v1; + Vector2 perpendicular = Vector2(direction.y, -direction.x); + + Vector line; + int middle_vertex; + + if (SGN(edgeFunction(v0, v0 + perpendicular, v1)) != SGN(edgeFunction(v0, v0 + perpendicular, v2))) { + line.push_back(v1); + line.push_back(v2); + middle_vertex = 0; + } else if (SGN(edgeFunction(v1, v1 + perpendicular, v0)) != SGN(edgeFunction(v1, v1 + perpendicular, v2))) { + line.push_back(v0); + line.push_back(v2); + middle_vertex = 1; + } else { + line.push_back(v0); + line.push_back(v1); + middle_vertex = 2; + } + + Vector > intersected_lines = Geometry::intersect_polyline_with_polygon_2d(line, pixel_polygon); + + ERR_FAIL_COND_MSG(intersected_lines.size() > 1, "[Lightmapper] Itersecting a line and a convex polygon should give at most one line."); + + if (!intersected_lines.empty()) { + intersected = true; + const Vector &intersected_line = intersected_lines[0]; + Vector2 sample_pos = (intersected_line[0] + intersected_line[1]) / 2.0f; + + float line_length = line[0].distance_to(line[1]); + float norm = line[0].distance_to(sample_pos) / line_length; + + if (middle_vertex == 0) { + barycentric_coords = Vector3(0.0f, 1.0f - norm, norm); + } else if (middle_vertex == 1) { + barycentric_coords = Vector3(1.0f - norm, 0.0f, norm); + } else { + barycentric_coords = Vector3(1.0f - norm, norm, 0.0f); + } + } + } else { + + Vector > intersected_polygons = Geometry::intersect_polygons_2d(pixel_polygon, triangle_polygon); + + ERR_FAIL_COND_MSG(intersected_polygons.size() > 1, "[Lightmapper] Itersecting two convex polygons should give at most one polygon."); + + if (!intersected_polygons.empty()) { + const Vector &intersected_polygon = intersected_polygons[0]; + + // do centroid sampling + Vector2 sample_pos = intersected_polygon[0]; + Vector2 area_center = Vector2(i, j) + Vector2(0.5f, 0.5f); + float intersected_area = (intersected_polygon[0] - area_center).cross(intersected_polygon[intersected_polygon.size() - 1] - area_center); + for (int k = 1; k < intersected_polygon.size(); k++) { + sample_pos += intersected_polygon[k]; + intersected_area += (intersected_polygon[k] - area_center).cross(intersected_polygon[k - 1] - area_center); + } + + if (intersected_area != 0.0f) { + sample_pos /= intersected_polygon.size(); + barycentric_coords = Geometry::barycentric_coordinates_2d(sample_pos, v0, v1, v2); + intersected = true; + area_coverage = ABS(intersected_area) / 2.0f; + } + } + + if (!intersected) { + for (int k = 0; k < 4; ++k) { + for (int l = 0; l < 3; ++l) { + Vector2 intersection_point; + if (Geometry::segment_intersects_segment_2d(pixel_polygon[k], pixel_polygon[(k + 1) % 4], triangle_polygon[l], triangle_polygon[(l + 1) % 3], &intersection_point)) { + intersected = true; + barycentric_coords = Geometry::barycentric_coordinates_2d(intersection_point, v0, v1, v2); + break; + } + } + if (intersected) { + break; + } + } + } + } + + if (texel_idx >= 0 && area_coverage < r_lightmap[texel_idx].area_coverage) { + continue; // A previous triangle gives better pixel coverage + } + + Vector2 pixel = Vector2(i, j); + if (!intersected && v0.floor() == pixel) { + intersected = true; + barycentric_coords = Vector3(1, 0, 0); + } + + if (!intersected && v1.floor() == pixel) { + intersected = true; + barycentric_coords = Vector3(0, 1, 0); + } + + if (!intersected && v2.floor() == pixel) { + intersected = true; + barycentric_coords = Vector3(0, 0, 1); + } + + if (!intersected) { + continue; + } + + if (Math::is_nan(barycentric_coords.x) || Math::is_nan(barycentric_coords.y) || Math::is_nan(barycentric_coords.z)) { + continue; + } + + if (Math::is_inf(barycentric_coords.x) || Math::is_inf(barycentric_coords.y) || Math::is_inf(barycentric_coords.z)) { + continue; + } + + r_lightmap_indices[ofs] = r_lightmap.size(); + + Vector3 pos = p0 * barycentric_coords[0] + p1 * barycentric_coords[1] + p2 * barycentric_coords[2]; + Vector3 normal = n0 * barycentric_coords[0] + n1 * barycentric_coords[1] + n2 * barycentric_coords[2]; + + Vector2 uv = uv0 * barycentric_coords[0] + uv1 * barycentric_coords[1] + uv2 * barycentric_coords[2]; + Color c = _bilinear_sample(p_albedo, uv); + Color e = _bilinear_sample(p_emission, uv); + + Vector2 texel_center = Vector2(i, j) + Vector2(0.5f, 0.5f); + Vector3 texel_center_bary = Geometry::barycentric_coordinates_2d(texel_center, v0, v1, v2); + + if (!Math::is_nan(texel_center_bary.x) && !Math::is_nan(texel_center_bary.y) && !Math::is_nan(texel_center_bary.z) && !Math::is_inf(texel_center_bary.x) && !Math::is_inf(texel_center_bary.y) && !Math::is_inf(texel_center_bary.z)) { + Vector3 texel_center_pos = p0 * texel_center_bary[0] + p1 * texel_center_bary[1] + p2 * texel_center_bary[2]; + pos = _fix_sample_position(pos, texel_center_pos, normal, tangent, bitangent, texel_size); + } + + LightmapTexel texel; + texel.normal = normal.normalized(); + texel.pos = pos; + texel.albedo = Vector3(c.r, c.g, c.b); + texel.alpha = c.a; + texel.emission = Vector3(e.r, e.g, e.b); + texel.area_coverage = area_coverage; + r_lightmap.push_back(texel); + } + } +} + +void LightmapperCPU::_compute_direct_light(uint32_t p_idx, void *r_lightmap) { + + LightmapTexel *lightmap = (LightmapTexel *)r_lightmap; + for (unsigned int i = 0; i < lights.size(); ++i) { + + const Light &light = lights[i]; + Vector3 normal = lightmap[p_idx].normal; + Vector3 position = lightmap[p_idx].pos; + Vector3 final_energy; + Color c = light.color; + Vector3 light_energy = Vector3(c.r, c.g, c.b) * light.energy; + + if (light.type == LIGHT_TYPE_OMNI) { + Vector3 light_direction = (position - light.position).normalized(); + if (normal.dot(light_direction) >= 0.0) { + continue; + } + float dist = position.distance_to(light.position); + + if (dist <= light.range) { + LightmapRaycaster::Ray ray = LightmapRaycaster::Ray(position, -light_direction, parameters.bias, dist - parameters.bias); + if (raycaster->intersect(ray)) { + continue; + } + float att = powf(1.0 - dist / light.range, light.attenuation); + final_energy = light_energy * att * MAX(0, normal.dot(-light_direction)); + } + } + + if (light.type == LIGHT_TYPE_SPOT) { + + Vector3 light_direction = (position - light.position).normalized(); + if (normal.dot(light_direction) >= 0.0) { + continue; + } + + float angle = Math::acos(light.direction.dot(light_direction)); + + if (angle > light.spot_angle) { + continue; + } + + float dist = position.distance_to(light.position); + if (dist > light.range) { + continue; + } + + LightmapRaycaster::Ray ray = LightmapRaycaster::Ray(position, -light_direction, parameters.bias, dist); + if (raycaster->intersect(ray)) { + continue; + } + + float normalized_dist = dist * (1.0f / MAX(0.001f, light.range)); + float norm_light_attenuation = Math::pow(MAX(1.0f - normalized_dist, 0.001f), light.attenuation); + + float spot_cutoff = Math::cos(light.spot_angle); + float scos = MAX(light_direction.dot(light.direction), spot_cutoff); + float spot_rim = (1.0f - scos) / (1.0f - spot_cutoff); + norm_light_attenuation *= 1.0f - pow(MAX(spot_rim, 0.001f), light.spot_attenuation); + final_energy = light_energy * norm_light_attenuation * MAX(0, normal.dot(-light_direction)); + } + + if (light.type == LIGHT_TYPE_DIRECTIONAL) { + if (normal.dot(light.direction) >= 0.0) { + continue; + } + + LightmapRaycaster::Ray ray = LightmapRaycaster::Ray(position + normal * parameters.bias, -light.direction, parameters.bias); + if (raycaster->intersect(ray)) { + continue; + } + + final_energy = light_energy * MAX(0, normal.dot(-light.direction)); + } + + lightmap[p_idx].direct_light += final_energy * light.indirect_multiplier; + if (light.bake_direct) { + lightmap[p_idx].output_light += final_energy; + } + } +} + +_ALWAYS_INLINE_ float uniform_rand() { + /* Algorithm "xor" from p. 4 of Marsaglia, "Xorshift RNGs" */ + static thread_local uint32_t state = rand(); + state ^= state << 13; + state ^= state >> 17; + state ^= state << 5; + return float(state) / UINT32_MAX; +} + +void LightmapperCPU::_compute_indirect_light(uint32_t p_idx, void *r_lightmap) { + + LightmapTexel *lightmap = (LightmapTexel *)r_lightmap; + LightmapTexel &texel = lightmap[p_idx]; + + Vector3 accum; + + const Vector3 const_forward = Vector3(0, 0, 1); + const Vector3 const_up = Vector3(0, 1, 0); + + for (int i = 0; i < parameters.samples; i++) { + + Vector3 color; + Vector3 throughput = Vector3(1.0f, 1.0f, 1.0f); + + Vector3 position = texel.pos; + Vector3 normal = texel.normal; + Vector3 direction; + + for (int depth = 0; depth < parameters.bounces; depth++) { + + Vector3 tangent = const_forward.cross(normal); + if (unlikely(tangent.length_squared() < 0.005f)) { + tangent = const_up.cross(normal); + } + tangent.normalize(); + Vector3 bitangent = tangent.cross(normal); + bitangent.normalize(); + + Basis normal_xform = Basis(tangent, bitangent, normal); + normal_xform.transpose(); + + float u1 = uniform_rand(); + float u2 = uniform_rand(); + + float radius = Math::sqrt(u1); + float theta = Math_TAU * u2; + + Vector3 axis = Vector3(radius * Math::cos(theta), radius * Math::sin(theta), Math::sqrt(MAX(0.0f, 1.0f - u1))); + + direction = normal_xform.xform(axis); + + // We can skip multiplying throughput by cos(theta) because de sampling PDF is also cos(theta) and they cancel each other + //float pdf = normal.dot(direction); + //throughput *= normal.dot(direction)/pdf; + + LightmapRaycaster::Ray ray(position, direction, parameters.bias); + bool hit = raycaster->intersect(ray); + + if (!hit) { + if (parameters.environment_panorama.is_valid()) { + direction = parameters.environment_transform.xform_inv(direction); + Vector2 st = Vector2(Math::atan2(direction.z, direction.x), Math::acos(direction.y)); + + if (Math::is_nan(st.y)) { + st.y = direction.y > 0.0 ? 0.0 : Math_PI; + } + + st.x += Math_PI; + st /= Vector2(Math_TAU, Math_PI); + st.x = Math::fmod(st.x + 0.75, 1.0); + Color c = _bilinear_sample(parameters.environment_panorama, st, false, true); + color += throughput * Vector3(c.r, c.g, c.b) * c.a; + } + break; + } + + unsigned int hit_mesh_id = ray.geomID; + const Vector2i &size = mesh_instances[hit_mesh_id].size; + + int x = ray.u * size.x; + int y = ray.v * size.y; + + const int idx = scene_lightmap_indices[hit_mesh_id][y * size.x + x]; + + if (idx < 0) { + break; + } + + const LightmapTexel &sample = scene_lightmaps[hit_mesh_id][idx]; + + if (sample.normal.dot(ray.dir) > 0.0 && !no_shadow_meshes.has(hit_mesh_id)) { + // We hit a back-face + break; + } + + color += throughput * sample.emission; + throughput *= sample.albedo; + color += throughput * sample.direct_light; + + // Russian Roulette + // https://computergraphics.stackexchange.com/questions/2316/is-russian-roulette-really-the-answer + const float p = throughput[throughput.max_axis()]; + if (uniform_rand() > p) { + break; + } + throughput *= 1.0f / p; + + position = sample.pos; + normal = sample.normal; + } + accum += color; + } + + texel.output_light += accum / parameters.samples; +} + +void LightmapperCPU::_post_process(uint32_t p_idx, void *r_output) { + + const MeshInstance &mesh = mesh_instances[p_idx]; + + if (!mesh.generate_lightmap) { + return; + } + + LocalVector &indices = scene_lightmap_indices[p_idx]; + LocalVector &lightmap = scene_lightmaps[p_idx]; + Vector3 *output = ((LocalVector *)r_output)[p_idx].ptr(); + Vector2i size = mesh.size; + + // Blit texels to buffer + const int margin = 4; + for (int i = 0; i < size.y; i++) { + for (int j = 0; j < size.x; j++) { + int idx = indices[i * size.x + j]; + if (idx >= 0) { + output[i * size.x + j] = lightmap[idx].output_light; + continue; // filled, skip + } + + int closest_idx = -1; + float closest_dist = 1e20; + + for (int y = i - margin; y <= i + margin; y++) { + for (int x = j - margin; x <= j + margin; x++) { + + if (x == j && y == i) + continue; + if (x < 0 || x >= size.x) + continue; + if (y < 0 || y >= size.y) + continue; + int cell_idx = indices[y * size.x + x]; + if (cell_idx < 0) { + continue; //also ensures that blitted stuff is not reused + } + + float dist = Vector2(i - y, j - x).length_squared(); + if (dist < closest_dist) { + closest_dist = dist; + closest_idx = cell_idx; + } + } + } + + if (closest_idx != -1) { + output[i * size.x + j] = lightmap[closest_idx].output_light; + } + } + } + + lightmap.clear(); + + LocalVector seams; + _compute_seams(mesh, seams); + + _fix_seams(seams, output, size); + _dilate_lightmap(output, indices, size, margin); + + if (parameters.use_denoiser) { + Ref denoiser = LightmapDenoiser::create(); + + if (denoiser.is_valid()) { + int data_size = size.x * size.y * sizeof(Vector3); + Ref current_image; + current_image.instance(); + { + PoolByteArray data; + data.resize(data_size); + PoolByteArray::Write w = data.write(); + copymem(w.ptr(), output, data_size); + current_image->create(size.x, size.y, false, Image::FORMAT_RGBF, data); + } + + Ref denoised_image = denoiser->denoise_image(current_image); + + PoolByteArray denoised_data = denoised_image->get_data(); + denoised_image.unref(); + PoolByteArray::Read r = denoised_data.read(); + copymem(output, r.ptr(), data_size); + } + } + + _dilate_lightmap(output, indices, size, margin); + _fix_seams(seams, output, size); + _dilate_lightmap(output, indices, size, margin); + + indices.clear(); +} + +void LightmapperCPU::_compute_seams(const MeshInstance &p_mesh, LocalVector &r_seams) { + float max_uv_distance = 1.0f / MAX(p_mesh.size.x, p_mesh.size.y); + max_uv_distance *= max_uv_distance; // We use distance_to_squared(), so wee need to square the max distance as well + float max_pos_distance = 0.0005f; + float max_normal_distance = 0.05f; + + const Vector &points = p_mesh.data.points; + const Vector &uv2s = p_mesh.data.uv2; + const Vector &normals = p_mesh.data.normal; + + LocalVector edges; + edges.resize(points.size()); // One edge per vertex + + for (int i = 0; i < points.size(); i += 3) { + Vector3 triangle_vtxs[3] = { points[i + 0], points[i + 1], points[i + 2] }; + Vector2 triangle_uvs[3] = { uv2s[i + 0], uv2s[i + 1], uv2s[i + 2] }; + Vector3 triangle_normals[3] = { normals[i + 0], normals[i + 1], normals[i + 2] }; + + for (int k = 0; k < 3; k++) { + int idx[2]; + idx[0] = k; + idx[1] = (k + 1) % 3; + + if (triangle_vtxs[idx[1]] < triangle_vtxs[idx[0]]) { + SWAP(idx[0], idx[1]); + } + + SeamEdge e; + for (int l = 0; l < 2; ++l) { + e.pos[l] = triangle_vtxs[idx[l]]; + e.uv[l] = triangle_uvs[idx[l]]; + e.normal[l] = triangle_normals[idx[l]]; + } + edges[i + k] = e; + } + } + + edges.sort(); + + for (unsigned int j = 0; j < edges.size(); j++) { + const SeamEdge &edge0 = edges[j]; + for (unsigned int k = j + 1; k < edges.size() && edges[k].pos[0].x < (edge0.pos[0].x + max_pos_distance * 1.1f); k++) { + const SeamEdge &edge1 = edges[k]; + + if (edge0.uv[0].distance_squared_to(edge1.uv[0]) < max_uv_distance && edge0.uv[1].distance_squared_to(edge1.uv[1]) < max_uv_distance) { + continue; + } + + if (edge0.pos[0].distance_squared_to(edge1.pos[0]) > max_pos_distance || edge0.pos[1].distance_squared_to(edge1.pos[1]) > max_pos_distance) { + continue; + } + + if (edge0.normal[0].distance_squared_to(edge1.normal[0]) > max_normal_distance || edge0.normal[1].distance_squared_to(edge1.normal[1]) > max_normal_distance) { + continue; + } + + UVSeam s; + s.edge0[0] = edge0.uv[0]; + s.edge0[1] = edge0.uv[1]; + s.edge1[0] = edge1.uv[0]; + s.edge1[1] = edge1.uv[1]; + r_seams.push_back(s); + } + } +} + +void LightmapperCPU::_fix_seams(const LocalVector &p_seams, Vector3 *r_lightmap, Vector2i p_size) { + LocalVector extra_buffer; + extra_buffer.resize(p_size.x * p_size.y); + + copymem(extra_buffer.ptr(), r_lightmap, p_size.x * p_size.y * sizeof(Vector3)); + + Vector3 *read_ptr = extra_buffer.ptr(); + Vector3 *write_ptr = r_lightmap; + + for (int i = 0; i < 5; i++) { + for (unsigned int j = 0; j < p_seams.size(); j++) { + _fix_seam(p_seams[j].edge0[0], p_seams[j].edge0[1], p_seams[j].edge1[0], p_seams[j].edge1[1], read_ptr, write_ptr, p_size); + _fix_seam(p_seams[j].edge1[0], p_seams[j].edge1[1], p_seams[j].edge0[0], p_seams[j].edge0[1], read_ptr, write_ptr, p_size); + } + copymem(read_ptr, write_ptr, p_size.x * p_size.y * sizeof(Vector3)); + } +} + +void LightmapperCPU::_fix_seam(const Vector2 &p_pos0, const Vector2 &p_pos1, const Vector2 &p_uv0, const Vector2 &p_uv1, const Vector3 *p_read_buffer, Vector3 *r_write_buffer, const Vector2i &p_size) { + + Vector2 line[2]; + line[0] = p_pos0 * p_size; + line[1] = p_pos1 * p_size; + + const Vector2i start_pixel = line[0].floor(); + const Vector2i end_pixel = line[1].floor(); + + Vector2 seam_dir = (line[1] - line[0]).normalized(); + Vector2 t_delta = Vector2(1.0f / Math::abs(seam_dir.x), 1.0f / Math::abs(seam_dir.y)); + Vector2i step = Vector2(seam_dir.x > 0 ? 1 : (seam_dir.x < 0 ? -1 : 0), seam_dir.y > 0 ? 1 : (seam_dir.y < 0 ? -1 : 0)); + + Vector2 t_next = Vector2(Math::fmod(line[0].x, 1.0f), Math::fmod(line[0].y, 1.0f)); + + if (step.x == 1) { + t_next.x = 1.0f - t_next.x; + } + + if (step.y == 1) { + t_next.y = 1.0f - t_next.y; + } + + t_next.x /= Math::abs(seam_dir.x); + t_next.y /= Math::abs(seam_dir.y); + + if (Math::is_nan(t_next.x)) { + t_next.x = 1e20f; + } + + if (Math::is_nan(t_next.y)) { + t_next.y = 1e20f; + } + + Vector2i pixel = start_pixel; + Vector2 start_p = start_pixel; + float line_length = line[0].distance_to(line[1]); + + if (line_length == 0.0f) { + return; + } + + while (start_p.distance_to(pixel) < line_length + 1.0f) { + + Vector2 current_point = Vector2(pixel) + Vector2(0.5f, 0.5f); + current_point = Geometry::get_closest_point_to_segment_2d(current_point, line); + float t = line[0].distance_to(current_point) / line_length; + + Vector2 current_uv = p_uv0 * (1.0 - t) + p_uv1 * t; + Vector2i sampled_point = (current_uv * p_size).floor(); + + Vector3 current_color = r_write_buffer[pixel.y * p_size.x + pixel.x]; + Vector3 sampled_color = p_read_buffer[sampled_point.y * p_size.x + sampled_point.x]; + + r_write_buffer[pixel.y * p_size.x + pixel.x] = current_color * 0.6f + sampled_color * 0.4f; + + if (pixel == end_pixel) { + break; + } + + if (t_next.x < t_next.y) { + pixel.x += step.x; + t_next.x += t_delta.x; + } else { + pixel.y += step.y; + t_next.y += t_delta.y; + } + } +} + +void LightmapperCPU::_dilate_lightmap(Vector3 *r_lightmap, const LocalVector p_indices, Vector2i p_size, int margin) { + for (int i = 0; i < p_size.y; i++) { + for (int j = 0; j < p_size.x; j++) { + int idx = p_indices[i * p_size.x + j]; + if (idx >= 0) { + continue; //filled, skip + } + + Vector2i closest; + float closest_dist = 1e20; + + for (int y = i - margin; y <= i + margin; y++) { + for (int x = j - margin; x <= j + margin; x++) { + + if (x == j && y == i) + continue; + if (x < 0 || x >= p_size.x) + continue; + if (y < 0 || y >= p_size.y) + continue; + int cell_idx = p_indices[y * p_size.x + x]; + if (cell_idx < 0) { + continue; //also ensures that blitted stuff is not reused + } + + float dist = Vector2(i - y, j - x).length_squared(); + if (dist < closest_dist) { + closest_dist = dist; + closest = Vector2(x, y); + } + } + } + + if (closest_dist < 1e20) { + r_lightmap[i * p_size.x + j] = r_lightmap[closest.y * p_size.x + closest.x]; + } + } + } +} + +void LightmapperCPU::_blit_lightmap(const Vector &p_src, const Vector2i &p_size, Ref &p_dst, int p_x, int p_y, bool p_with_padding) { + int padding = p_with_padding ? 1 : 0; + ERR_FAIL_COND(p_x < padding || p_y < padding); + ERR_FAIL_COND(p_x + p_size.x > p_dst->get_width() - padding); + ERR_FAIL_COND(p_y + p_size.y > p_dst->get_height() - padding); + + p_dst->lock(); + for (int y = 0; y < p_size.y; y++) { + const Vector3 *__restrict src = p_src.ptr() + y * p_size.x; + for (int x = 0; x < p_size.x; x++) { + p_dst->set_pixel(p_x + x, p_y + y, Color(src->x, src->y, src->z)); + src++; + } + } + + if (p_with_padding) { + for (int y = -1; y < p_size.y + 1; y++) { + int yy = CLAMP(y, 0, p_size.y - 1); + int idx_left = yy * p_size.x; + int idx_right = idx_left + p_size.x - 1; + p_dst->set_pixel(p_x - 1, p_y + y, Color(p_src[idx_left].x, p_src[idx_left].y, p_src[idx_left].z)); + p_dst->set_pixel(p_x + p_size.x, p_y + y, Color(p_src[idx_right].x, p_src[idx_right].y, p_src[idx_right].z)); + } + + for (int x = -1; x < p_size.x + 1; x++) { + int xx = CLAMP(x, 0, p_size.x - 1); + int idx_top = xx; + int idx_bot = idx_top + (p_size.y - 1) * p_size.x; + p_dst->set_pixel(p_x + x, p_y - 1, Color(p_src[idx_top].x, p_src[idx_top].y, p_src[idx_top].z)); + p_dst->set_pixel(p_x + x, p_y + p_size.y, Color(p_src[idx_bot].x, p_src[idx_bot].y, p_src[idx_bot].z)); + } + } + p_dst->unlock(); +} + +LightmapperCPU::BakeError LightmapperCPU::bake(BakeQuality p_quality, bool p_use_denoiser, int p_bounces, float p_bias, bool p_generate_atlas, int p_max_texture_size, const Ref &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function, void *p_bake_userdata, BakeStepFunc p_substep_function) { + + if (p_step_function) { + bool cancelled = p_step_function(0.0, TTR("Begin Bake"), p_bake_userdata, true); + if (cancelled) { + return BAKE_ERROR_USER_ABORTED; + } + } + + raycaster = LightmapRaycaster::create(); + + ERR_FAIL_COND_V(raycaster.is_null(), BAKE_ERROR_NO_RAYCASTER); + + // Collect parameters + parameters.use_denoiser = p_use_denoiser; + parameters.bias = p_bias; + parameters.bounces = p_bounces; + parameters.environment_transform = p_environment_transform; + parameters.environment_panorama = p_environment_panorama; + + switch (p_quality) { + case BAKE_QUALITY_LOW: { + parameters.samples = GLOBAL_GET("rendering/cpu_lightmapper/quality/low_quality_ray_count"); + } break; + case BAKE_QUALITY_MEDIUM: { + parameters.samples = GLOBAL_GET("rendering/cpu_lightmapper/quality/medium_quality_ray_count"); + } break; + case BAKE_QUALITY_HIGH: { + parameters.samples = GLOBAL_GET("rendering/cpu_lightmapper/quality/high_quality_ray_count"); + } break; + case BAKE_QUALITY_ULTRA: { + parameters.samples = GLOBAL_GET("rendering/cpu_lightmapper/quality/ultra_quality_ray_count"); + } break; + } + + bake_textures.clear(); + + if (p_step_function) { + bool cancelled = p_step_function(0.1, TTR("Preparing data structures"), p_bake_userdata, true); + if (cancelled) { + return BAKE_ERROR_USER_ABORTED; + } + } + + for (unsigned int i = 0; i < mesh_instances.size(); i++) { + raycaster->add_mesh(mesh_instances[i].data.points, mesh_instances[i].data.normal, mesh_instances[i].data.uv2, i); + } + raycaster->commit(); + + scene_lightmaps.resize(mesh_instances.size()); + scene_lightmap_indices.resize(mesh_instances.size()); + + for (unsigned int i = 0; i < mesh_instances.size(); i++) { + if (!mesh_instances[i].cast_shadows) { + no_shadow_meshes.insert(i); + } + } + + raycaster->set_mesh_filter(no_shadow_meshes); + + Vector2i atlas_size = Vector2i(-1, -1); + int atlas_slices = -1; + if (p_generate_atlas) { + Error err = _layout_atlas(p_max_texture_size, &atlas_size, &atlas_slices); + if (err != OK) { + return BAKE_ERROR_LIGHTMAP_TOO_SMALL; + } + } + + if (p_step_function) { + bool cancelled = p_step_function(0.2, TTR("Generate buffers"), p_bake_userdata, true); + if (cancelled) { + return BAKE_ERROR_USER_ABORTED; + } + } + + if (_parallel_run(mesh_instances.size(), "Rasterizing meshes", &LightmapperCPU::_generate_buffer, nullptr, p_substep_function)) { + return BAKE_ERROR_USER_ABORTED; + } + + for (unsigned int i = 0; i < mesh_instances.size(); i++) { + const Size2i &size = mesh_instances[i].size; + bool has_alpha = false; + PoolVector alpha_data; + alpha_data.resize(size.x * size.y); + { + PoolVector::Write w = alpha_data.write(); + for (unsigned int j = 0; j < scene_lightmap_indices[i].size(); ++j) { + int idx = scene_lightmap_indices[i][j]; + uint8_t alpha = 0; + if (idx >= 0) { + alpha = CLAMP(scene_lightmaps[i][idx].alpha * 255, 0, 255); + if (alpha < 255) { + has_alpha = true; + } + } + w[j] = alpha; + } + } + + if (has_alpha) { + Ref alpha_texture; + alpha_texture.instance(); + alpha_texture->create(size.x, size.y, false, Image::FORMAT_L8, alpha_data); + raycaster->set_mesh_alpha_texture(alpha_texture, i); + } + } + + albedo_textures.clear(); + emission_textures.clear(); + + for (unsigned int i = 0; i < mesh_instances.size(); i++) { + if (p_step_function) { + float p = float(i) / mesh_instances.size(); + bool cancelled = p_step_function(0.2 + p * 0.2, vformat("%s (%d/%d)", TTR("Direct lighting"), i, mesh_instances.size()), p_bake_userdata, false); + if (cancelled) { + return BAKE_ERROR_USER_ABORTED; + } + } + + if (_parallel_run(scene_lightmaps[i].size(), "Computing direct light", &LightmapperCPU::_compute_direct_light, scene_lightmaps[i].ptr(), p_substep_function)) { + return BAKE_ERROR_USER_ABORTED; + } + } + raycaster->clear_mesh_filter(); + + int n_lit_meshes = 0; + for (unsigned int i = 0; i < mesh_instances.size(); i++) { + if (mesh_instances[i].generate_lightmap) { + n_lit_meshes++; + } + } + + if (parameters.environment_panorama.is_valid()) { + parameters.environment_panorama->lock(); + } + for (unsigned int i = 0; i < mesh_instances.size(); i++) { + + if (!mesh_instances[i].generate_lightmap) { + continue; + } + + if (p_step_function) { + float p = float(i) / n_lit_meshes; + bool cancelled = p_step_function(0.4 + p * 0.4, vformat("%s (%d/%d)", TTR("Indirect lighting"), i, mesh_instances.size()), p_bake_userdata, false); + if (cancelled) { + return BAKE_ERROR_USER_ABORTED; + } + } + + if (!scene_lightmaps[i].empty()) { + if (_parallel_run(scene_lightmaps[i].size(), "Computing indirect light", &LightmapperCPU::_compute_indirect_light, scene_lightmaps[i].ptr(), p_substep_function)) { + return BAKE_ERROR_USER_ABORTED; + } + } + } + if (parameters.environment_panorama.is_valid()) { + parameters.environment_panorama->unlock(); + } + + raycaster.unref(); // Not needed anymore, free some memory. + + LocalVector > lightmaps_data; + lightmaps_data.resize(mesh_instances.size()); + + for (unsigned int i = 0; i < mesh_instances.size(); i++) { + if (mesh_instances[i].generate_lightmap) { + const Vector2i size = mesh_instances[i].size; + lightmaps_data[i].resize(size.x * size.y); + } + } + + if (p_step_function) { + bool cancelled = p_step_function(0.8, TTR("Post processing"), p_bake_userdata, true); + if (cancelled) { + return BAKE_ERROR_USER_ABORTED; + } + } + + if (_parallel_run(mesh_instances.size(), "Denoise & fix seams", &LightmapperCPU::_post_process, lightmaps_data.ptr(), p_substep_function)) { + return BAKE_ERROR_USER_ABORTED; + } + + if (p_generate_atlas) { + bake_textures.resize(atlas_slices); + + for (int i = 0; i < atlas_slices; i++) { + Ref image; + image.instance(); + image->create(atlas_size.x, atlas_size.y, false, Image::FORMAT_RGBH); + bake_textures[i] = image; + } + } else { + bake_textures.resize(mesh_instances.size()); + + Set used_mesh_names; + + for (unsigned int i = 0; i < mesh_instances.size(); i++) { + if (!mesh_instances[i].generate_lightmap) { + continue; + } + + String mesh_name = mesh_instances[i].node_name; + if (mesh_name == "" || mesh_name.find(":") != -1 || mesh_name.find("/") != -1) { + mesh_name = "LightMap"; + } + + if (used_mesh_names.has(mesh_name)) { + int idx = 2; + String base = mesh_name; + while (true) { + mesh_name = base + itos(idx); + if (!used_mesh_names.has(mesh_name)) + break; + idx++; + } + } + used_mesh_names.insert(mesh_name); + + Ref image; + image.instance(); + image->create(mesh_instances[i].size.x, mesh_instances[i].size.y, false, Image::FORMAT_RGBH); + image->set_name(mesh_name); + bake_textures[i] = image; + } + } + + if (p_step_function) { + bool cancelled = p_step_function(0.9, TTR("Plotting lightmaps"), p_bake_userdata, true); + if (cancelled) { + return BAKE_ERROR_USER_ABORTED; + } + } + + { + int j = 0; + for (unsigned int i = 0; i < mesh_instances.size(); i++) { + if (!mesh_instances[i].generate_lightmap) { + continue; + } + + if (p_generate_atlas) { + _blit_lightmap(lightmaps_data[i], mesh_instances[i].size, bake_textures[mesh_instances[i].slice], mesh_instances[i].offset.x, mesh_instances[i].offset.y, true); + } else { + _blit_lightmap(lightmaps_data[i], mesh_instances[i].size, bake_textures[j], 0, 0, false); + } + j++; + } + } + + return BAKE_OK; +} + +int LightmapperCPU::get_bake_texture_count() const { + return bake_textures.size(); +} + +Ref LightmapperCPU::get_bake_texture(int p_index) const { + ERR_FAIL_INDEX_V(p_index, (int)bake_textures.size(), Ref()); + return bake_textures[p_index]; +} + +int LightmapperCPU::get_bake_mesh_count() const { + return mesh_instances.size(); +} + +Variant LightmapperCPU::get_bake_mesh_userdata(int p_index) const { + ERR_FAIL_INDEX_V(p_index, (int)mesh_instances.size(), Variant()); + return mesh_instances[p_index].data.userdata; +} + +Rect2 LightmapperCPU::get_bake_mesh_uv_scale(int p_index) const { + ERR_FAIL_COND_V(bake_textures.size() == 0, Rect2()); + Rect2 uv_ofs; + Vector2 atlas_size = Vector2(bake_textures[0]->get_width(), bake_textures[0]->get_height()); + uv_ofs.position = Vector2(mesh_instances[p_index].offset) / atlas_size; + uv_ofs.size = Vector2(mesh_instances[p_index].size) / atlas_size; + return uv_ofs; +} + +int LightmapperCPU::get_bake_mesh_texture_slice(int p_index) const { + ERR_FAIL_INDEX_V(p_index, (int)mesh_instances.size(), Variant()); + return mesh_instances[p_index].slice; +} + +void LightmapperCPU::add_albedo_texture(Ref p_texture) { + if (p_texture.is_null()) { + return; + } + + RID texture_rid = p_texture->get_rid(); + if (!texture_rid.is_valid() || albedo_textures.has(texture_rid)) { + return; + } + + Ref texture_data = p_texture->get_data(); + + if (texture_data.is_null()) { + return; + } + + if (texture_data->is_compressed()) { + texture_data->decompress(); + } + + texture_data->convert(Image::FORMAT_RGBA8); + + albedo_textures.insert(texture_rid, texture_data); +} + +void LightmapperCPU::add_emission_texture(Ref p_texture) { + if (p_texture.is_null()) { + return; + } + + RID texture_rid = p_texture->get_rid(); + if (!texture_rid.is_valid() || emission_textures.has(texture_rid)) { + return; + } + + Ref texture_data = p_texture->get_data(); + + if (texture_data.is_null()) { + return; + } + + if (texture_data->is_compressed()) { + texture_data->decompress(); + } + + texture_data->convert(Image::FORMAT_RGBH); + + emission_textures.insert(texture_rid, texture_data); +} + +void LightmapperCPU::add_mesh(const MeshData &p_mesh, Vector2i p_size) { + ERR_FAIL_COND(p_mesh.points.size() == 0); + ERR_FAIL_COND(p_mesh.points.size() != p_mesh.uv2.size()); + ERR_FAIL_COND(p_mesh.points.size() != p_mesh.normal.size()); + ERR_FAIL_COND(!p_mesh.uv.empty() && p_mesh.points.size() != p_mesh.uv.size()); + ERR_FAIL_COND(p_mesh.surface_facecounts.size() != p_mesh.albedo.size()); + ERR_FAIL_COND(p_mesh.surface_facecounts.size() != p_mesh.emission.size()); + + MeshInstance mi; + mi.data = p_mesh; + mi.size = p_size; + mi.generate_lightmap = true; + mi.cast_shadows = true; + mi.node_name = ""; + + Dictionary userdata = p_mesh.userdata; + if (userdata.has("cast_shadows")) { + mi.cast_shadows = userdata["cast_shadows"]; + } + if (userdata.has("generate_lightmap")) { + mi.generate_lightmap = userdata["generate_lightmap"]; + } + if (userdata.has("node_name")) { + mi.node_name = userdata["node_name"]; + } + + mesh_instances.push_back(mi); +} + +void LightmapperCPU::add_directional_light(bool p_bake_direct, const Vector3 &p_direction, const Color &p_color, float p_energy, float p_indirect_multiplier) { + Light l; + l.type = LIGHT_TYPE_DIRECTIONAL; + l.direction = p_direction; + l.color = p_color; + l.energy = p_energy; + l.indirect_multiplier = p_indirect_multiplier; + l.bake_direct = p_bake_direct; + lights.push_back(l); +} + +void LightmapperCPU::add_omni_light(bool p_bake_direct, const Vector3 &p_position, const Color &p_color, float p_energy, float p_indirect_multiplier, float p_range, float p_attenuation) { + Light l; + l.type = LIGHT_TYPE_OMNI; + l.position = p_position; + l.range = p_range; + l.attenuation = p_attenuation; + l.color = p_color; + l.energy = p_energy; + l.indirect_multiplier = p_indirect_multiplier; + l.bake_direct = p_bake_direct; + lights.push_back(l); +} + +void LightmapperCPU::add_spot_light(bool p_bake_direct, const Vector3 &p_position, const Vector3 p_direction, const Color &p_color, float p_energy, float p_indirect_multiplier, float p_range, float p_attenuation, float p_spot_angle, float p_spot_attenuation) { + Light l; + l.type = LIGHT_TYPE_SPOT; + l.position = p_position; + l.direction = p_direction; + l.range = p_range; + l.attenuation = p_attenuation; + l.spot_angle = Math::deg2rad(p_spot_angle); + l.spot_attenuation = p_spot_attenuation; + l.color = p_color; + l.energy = p_energy; + l.indirect_multiplier = p_indirect_multiplier; + l.bake_direct = p_bake_direct; + lights.push_back(l); +} + +LightmapperCPU::LightmapperCPU() { + thread_progress = 0; + thread_cancelled = false; +} diff --git a/modules/lightmapper_cpu/lightmapper_cpu.h b/modules/lightmapper_cpu/lightmapper_cpu.h new file mode 100644 index 000000000000..eb9a3a33ea09 --- /dev/null +++ b/modules/lightmapper_cpu/lightmapper_cpu.h @@ -0,0 +1,183 @@ +/*************************************************************************/ +/* lightmapper_cpu.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef LIGHTMAPPER_CPU_H +#define LIGHTMAPPER_CPU_H + +#include "core/local_vector.h" +#include "scene/3d/lightmapper.h" +#include "scene/resources/mesh.h" +#include "scene/resources/surface_tool.h" + +#include + +class LightmapperCPU : public Lightmapper { + GDCLASS(LightmapperCPU, Lightmapper) + + struct MeshInstance { + MeshData data; + int slice = 0; + Vector2i offset; + Vector2i size; + bool cast_shadows; + bool generate_lightmap; + String node_name; + }; + + struct Light { + Vector3 position; + uint32_t type = LIGHT_TYPE_DIRECTIONAL; + Vector3 direction; + float energy; + float indirect_multiplier; + Color color; + float range; + float attenuation; + float spot_angle; + float spot_attenuation; + bool bake_direct; + }; + + struct LightmapTexel { + Vector3 albedo; + float alpha; + Vector3 emission; + Vector3 pos; + Vector3 normal; + + Vector3 direct_light; + Vector3 output_light; + + float area_coverage; + }; + + struct BakeParams { + float bias; + int bounces; + int samples; + bool use_denoiser = true; + Ref environment_panorama; + Basis environment_transform; + }; + + struct UVSeam { + Vector2 edge0[2]; + Vector2 edge1[2]; + }; + + struct SeamEdge { + Vector3 pos[2]; + Vector3 normal[2]; + Vector2 uv[2]; + + _FORCE_INLINE_ bool operator<(const SeamEdge &p_edge) const { + return pos[0].x < p_edge.pos[0].x; + } + }; + + struct AtlasOffset { + int slice; + int x; + int y; + }; + + struct ThreadData; + + typedef void (LightmapperCPU::*BakeThreadFunc)(uint32_t, void *); + + struct ThreadData { + LightmapperCPU *instance; + uint32_t count; + BakeThreadFunc thread_func; + void *userdata; + }; + + BakeParams parameters; + + LocalVector > bake_textures; + Map > albedo_textures; + Map > emission_textures; + + LocalVector mesh_instances; + LocalVector lights; + + LocalVector > scene_lightmaps; + LocalVector > scene_lightmap_indices; + Set no_shadow_meshes; + + std::atomic thread_progress; + std::atomic thread_cancelled; + + Ref raycaster; + + Error _layout_atlas(int p_max_size, Vector2i *r_atlas_size, int *r_atlas_slices); + + static void _thread_func_callback(void *p_thread_data); + void _thread_func_wrapper(uint32_t p_idx, ThreadData *p_thread_data); + bool _parallel_run(int p_count, const String &p_description, BakeThreadFunc p_thread_func, void *p_userdata, BakeStepFunc p_substep_func = nullptr); + + void _generate_buffer(uint32_t p_idx, void *p_unused); + Ref _init_bake_texture(const MeshData::TextureDef &p_texture_def, const Map > &p_tex_cache, Image::Format p_default_format); + Color _bilinear_sample(const Ref &p_img, const Vector2 &p_uv, bool p_clamp_x = false, bool p_clamp_y = false); + Vector3 _fix_sample_position(const Vector3 &p_position, const Vector3 &p_texel_center, const Vector3 &p_normal, const Vector3 &p_tangent, const Vector3 &p_bitangent, const Vector2 &p_texel_size); + void _plot_triangle(const Vector2 *p_vertices, const Vector3 *p_positions, const Vector3 *p_normals, const Vector2 *p_uvs, const Ref &p_albedo_texture, const Ref &p_emission_texture, Vector2i p_size, LocalVector &r_texels, LocalVector &r_lightmap_indices); + + void _compute_direct_light(uint32_t p_idx, void *r_lightmap); + + void _compute_indirect_light(uint32_t p_idx, void *r_lightmap); + + void _post_process(uint32_t p_idx, void *r_output); + void _compute_seams(const MeshInstance &p_mesh, LocalVector &r_seams); + void _fix_seams(const LocalVector &p_seams, Vector3 *r_lightmap, Vector2i p_size); + void _fix_seam(const Vector2 &p_pos0, const Vector2 &p_pos1, const Vector2 &p_uv0, const Vector2 &p_uv1, const Vector3 *p_read_buffer, Vector3 *r_write_buffer, const Vector2i &p_size); + void _dilate_lightmap(Vector3 *r_lightmap, const LocalVector p_indices, Vector2i p_size, int margin); + + void _blit_lightmap(const Vector &p_src, const Vector2i &p_size, Ref &p_dst, int p_x, int p_y, bool p_with_padding); + +public: + virtual void add_albedo_texture(Ref p_texture); + virtual void add_emission_texture(Ref p_texture); + virtual void add_mesh(const MeshData &p_mesh, Vector2i p_size); + virtual void add_directional_light(bool p_bake_direct, const Vector3 &p_direction, const Color &p_color, float p_energy, float p_indirect_multiplier); + virtual void add_omni_light(bool p_bake_direct, const Vector3 &p_position, const Color &p_color, float p_energy, float p_indirect_multiplier, float p_range, float p_attenuation); + virtual void add_spot_light(bool p_bake_direct, const Vector3 &p_position, const Vector3 p_direction, const Color &p_color, float p_energy, float p_indirect_multiplier, float p_range, float p_attenuation, float p_spot_angle, float p_spot_attenuation); + virtual BakeError bake(BakeQuality p_quality, bool p_use_denoiser, int p_bounces, float p_bias, bool p_generate_atlas, int p_max_texture_size, const Ref &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function = nullptr, void *p_bake_userdata = nullptr, BakeStepFunc p_substep_function = nullptr); + + int get_bake_texture_count() const; + Ref get_bake_texture(int p_index) const; + int get_bake_mesh_count() const; + Variant get_bake_mesh_userdata(int p_index) const; + Rect2 get_bake_mesh_uv_scale(int p_index) const; + int get_bake_mesh_texture_slice(int p_index) const; + + LightmapperCPU(); +}; + +#endif // LIGHTMAPPER_H diff --git a/modules/lightmapper_cpu/register_types.cpp b/modules/lightmapper_cpu/register_types.cpp new file mode 100644 index 000000000000..f5861311b928 --- /dev/null +++ b/modules/lightmapper_cpu/register_types.cpp @@ -0,0 +1,54 @@ +/*************************************************************************/ +/* register_types.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "register_types.h" + +#include "core/project_settings.h" +#include "lightmapper_cpu.h" +#include "scene/3d/lightmapper.h" + +#ifndef _3D_DISABLED +static Lightmapper *create_lightmapper_cpu() { + return memnew(LightmapperCPU); +} +#endif + +void register_lightmapper_cpu_types() { + GLOBAL_DEF("rendering/cpu_lightmapper/quality/low_quality_ray_count", 64); + GLOBAL_DEF("rendering/cpu_lightmapper/quality/medium_quality_ray_count", 256); + GLOBAL_DEF("rendering/cpu_lightmapper/quality/high_quality_ray_count", 512); + GLOBAL_DEF("rendering/cpu_lightmapper/quality/ultra_quality_ray_count", 1024); +#ifndef _3D_DISABLED + Lightmapper::create_cpu = create_lightmapper_cpu; +#endif +} + +void unregister_lightmapper_cpu_types() { +} diff --git a/modules/lightmapper_cpu/register_types.h b/modules/lightmapper_cpu/register_types.h new file mode 100644 index 000000000000..3033e838bee2 --- /dev/null +++ b/modules/lightmapper_cpu/register_types.h @@ -0,0 +1,37 @@ +/*************************************************************************/ +/* register_types.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef LIGHTMAPPER_CPU_REGISTER_TYPES_H +#define LIGHTMAPPER_CPU_REGISTER_TYPES_H + +void register_lightmapper_cpu_types(); +void unregister_lightmapper_cpu_types(); + +#endif // LIGHTMAPPER_CPU_REGISTER_TYPES_H diff --git a/modules/raycast/SCsub b/modules/raycast/SCsub new file mode 100644 index 000000000000..84ab13358818 --- /dev/null +++ b/modules/raycast/SCsub @@ -0,0 +1,93 @@ +#!/usr/bin/env python + +Import("env") +Import("env_modules") + +embree_src = [ + "common/sys/sysinfo.cpp", + "common/sys/alloc.cpp", + "common/sys/filename.cpp", + "common/sys/library.cpp", + "common/sys/thread.cpp", + "common/sys/string.cpp", + "common/sys/regression.cpp", + "common/sys/mutex.cpp", + "common/sys/condition.cpp", + "common/sys/barrier.cpp", + "common/math/constants.cpp", + "common/simd/sse.cpp", + "common/lexers/stringstream.cpp", + "common/lexers/tokenstream.cpp", + "common/tasking/taskschedulerinternal.cpp", + "common/algorithms/parallel_for.cpp", + "common/algorithms/parallel_reduce.cpp", + "common/algorithms/parallel_prefix_sum.cpp", + "common/algorithms/parallel_for_for.cpp", + "common/algorithms/parallel_for_for_prefix_sum.cpp", + "common/algorithms/parallel_partition.cpp", + "common/algorithms/parallel_sort.cpp", + "common/algorithms/parallel_set.cpp", + "common/algorithms/parallel_map.cpp", + "common/algorithms/parallel_filter.cpp", + "kernels/common/device.cpp", + "kernels/common/stat.cpp", + "kernels/common/acceln.cpp", + "kernels/common/accelset.cpp", + "kernels/common/state.cpp", + "kernels/common/rtcore.cpp", + "kernels/common/rtcore_builder.cpp", + "kernels/common/scene.cpp", + "kernels/common/alloc.cpp", + "kernels/common/geometry.cpp", + "kernels/common/scene_triangle_mesh.cpp", + "kernels/geometry/primitive4.cpp", + "kernels/builders/primrefgen.cpp", + "kernels/bvh/bvh.cpp", + "kernels/bvh/bvh_statistics.cpp", + "kernels/bvh/bvh4_factory.cpp", + "kernels/bvh/bvh8_factory.cpp", + "kernels/bvh/bvh_collider.cpp", + "kernels/bvh/bvh_rotate.cpp", + "kernels/bvh/bvh_refit.cpp", + "kernels/bvh/bvh_builder.cpp", + "kernels/bvh/bvh_builder_morton.cpp", + "kernels/bvh/bvh_builder_sah.cpp", + "kernels/bvh/bvh_builder_sah_spatial.cpp", + "kernels/bvh/bvh_builder_sah_mb.cpp", + "kernels/bvh/bvh_builder_twolevel.cpp", + "kernels/bvh/bvh_intersector1_bvh4.cpp", +] + +embree_dir = "#thirdparty/embree/" + +env_embree = env_modules.Clone() +embree_sources = [embree_dir + file for file in embree_src] +env_embree.Prepend(CPPPATH=[embree_dir, embree_dir + "include"]) +env_embree.Append( + CPPFLAGS=[ + "-DEMBREE_TARGET_SSE2", + "-DEMBREE_LOWEST_ISA", + "-msse2", + "-DTASKING_INTERNAL", + "-DNDEBUG", + "-D__SSE2__", + "-D__SSE__", + ] +) + +if not env_embree.msvc: + env_embree.Append(CPPFLAGS=["-mxsave"]) + +if env["platform"] == "windows": + if env.msvc: + env.Append(LINKFLAGS=["psapi.lib"]) + else: + env.Append(LIBS=["psapi"]) + +env_embree.disable_warnings() +env_embree.add_source_files(env.modules_sources, embree_sources) + +env_raycast = env_modules.Clone() +env_raycast.Prepend(CPPPATH=[embree_dir, embree_dir + "include", embree_dir + "common"]) + +env_raycast.add_source_files(env.modules_sources, "*.cpp") diff --git a/modules/raycast/config.py b/modules/raycast/config.py new file mode 100644 index 000000000000..85416c16cc8d --- /dev/null +++ b/modules/raycast/config.py @@ -0,0 +1,13 @@ +def can_build(env, platform): + # Embree requires at least SSE2 to be available, so 32-bit and ARM64 builds are + # not supported. + # It's also only relevant for tools build and desktop platforms, + # as doing lightmap generation on Android or HTML5 would be a bit far-fetched. + supported_platform = platform in ["x11", "osx", "windows", "server"] + supported_bits = env["bits"] == "64" + supported_arch = env["arch"] != "arm64" + return env["tools"] and supported_platform and supported_bits and supported_arch + + +def configure(env): + pass diff --git a/modules/raycast/godot_update_embree.py b/modules/raycast/godot_update_embree.py new file mode 100644 index 000000000000..92649bbf7427 --- /dev/null +++ b/modules/raycast/godot_update_embree.py @@ -0,0 +1,259 @@ +import glob, os, shutil, subprocess, re + +include_dirs = [ + "common/tasking", + "kernels/bvh", + "kernels/builders", + "common/sys", + "kernels", + "kernels/common", + "common/math", + "common/algorithms", + "common/lexers", + "common/simd", + "include/embree3", + "kernels/subdiv", + "kernels/geometry", +] + +cpp_files = [ + "common/sys/sysinfo.cpp", + "common/sys/alloc.cpp", + "common/sys/filename.cpp", + "common/sys/library.cpp", + "common/sys/thread.cpp", + "common/sys/string.cpp", + "common/sys/regression.cpp", + "common/sys/mutex.cpp", + "common/sys/condition.cpp", + "common/sys/barrier.cpp", + "common/math/constants.cpp", + "common/simd/sse.cpp", + "common/lexers/stringstream.cpp", + "common/lexers/tokenstream.cpp", + "common/tasking/taskschedulerinternal.cpp", + "common/algorithms/parallel_for.cpp", + "common/algorithms/parallel_reduce.cpp", + "common/algorithms/parallel_prefix_sum.cpp", + "common/algorithms/parallel_for_for.cpp", + "common/algorithms/parallel_for_for_prefix_sum.cpp", + "common/algorithms/parallel_partition.cpp", + "common/algorithms/parallel_sort.cpp", + "common/algorithms/parallel_set.cpp", + "common/algorithms/parallel_map.cpp", + "common/algorithms/parallel_filter.cpp", + "kernels/common/device.cpp", + "kernels/common/stat.cpp", + "kernels/common/acceln.cpp", + "kernels/common/accelset.cpp", + "kernels/common/state.cpp", + "kernels/common/rtcore.cpp", + "kernels/common/rtcore_builder.cpp", + "kernels/common/scene.cpp", + "kernels/common/alloc.cpp", + "kernels/common/geometry.cpp", + "kernels/common/scene_triangle_mesh.cpp", + "kernels/geometry/primitive4.cpp", + "kernels/builders/primrefgen.cpp", + "kernels/bvh/bvh.cpp", + "kernels/bvh/bvh_statistics.cpp", + "kernels/bvh/bvh4_factory.cpp", + "kernels/bvh/bvh8_factory.cpp", + "kernels/bvh/bvh_collider.cpp", + "kernels/bvh/bvh_rotate.cpp", + "kernels/bvh/bvh_refit.cpp", + "kernels/bvh/bvh_builder.cpp", + "kernels/bvh/bvh_builder_morton.cpp", + "kernels/bvh/bvh_builder_sah.cpp", + "kernels/bvh/bvh_builder_sah_spatial.cpp", + "kernels/bvh/bvh_builder_sah_mb.cpp", + "kernels/bvh/bvh_builder_twolevel.cpp", + "kernels/bvh/bvh_intersector1.cpp", + "kernels/bvh/bvh_intersector1_bvh4.cpp", +] + +os.chdir("../../thirdparty") + +if os.path.exists("embree"): + shutil.rmtree("embree") + +subprocess.run(["git", "clone", "https://github.com/embree/embree.git", "embree-tmp"]) +os.chdir("embree-tmp") + +commit_hash = str(subprocess.check_output(["git", "rev-parse", "HEAD"], universal_newlines=True)).strip() + +dest_dir = "../embree" +all_files = set(cpp_files) + +for include_dir in include_dirs: + headers = glob.iglob(os.path.join(include_dir, "*.h")) + all_files.update(headers) + +for f in all_files: + d = os.path.join(dest_dir, os.path.dirname(f)) + if not os.path.exists(d): + os.makedirs(d) + shutil.copy2(f, d) + +with open(os.path.join(dest_dir, "kernels/hash.h"), "w") as hash_file: + hash_file.write( + f""" +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#define RTC_HASH "{commit_hash}" +""" + ) + +with open(os.path.join(dest_dir, "kernels/config.h"), "w") as config_file: + config_file.write( + """ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +/* #undef EMBREE_RAY_MASK */ +/* #undef EMBREE_STAT_COUNTERS */ +/* #undef EMBREE_BACKFACE_CULLING */ +/* #undef EMBREE_BACKFACE_CULLING_CURVES */ +#define EMBREE_FILTER_FUNCTION +/* #undef EMBREE_IGNORE_INVALID_RAYS */ +#define EMBREE_GEOMETRY_TRIANGLE +/* #undef EMBREE_GEOMETRY_QUAD */ +/* #undef EMBREE_GEOMETRY_CURVE */ +/* #undef EMBREE_GEOMETRY_SUBDIVISION */ +/* #undef EMBREE_GEOMETRY_USER */ +/* #undef EMBREE_GEOMETRY_INSTANCE */ +/* #undef EMBREE_GEOMETRY_GRID */ +/* #undef EMBREE_GEOMETRY_POINT */ +/* #undef EMBREE_RAY_PACKETS */ +/* #undef EMBREE_COMPACT_POLYS */ + +#define EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR 2.0 + +#if defined(EMBREE_GEOMETRY_TRIANGLE) + #define IF_ENABLED_TRIS(x) x +#else + #define IF_ENABLED_TRIS(x) +#endif + +#if defined(EMBREE_GEOMETRY_QUAD) + #define IF_ENABLED_QUADS(x) x +#else + #define IF_ENABLED_QUADS(x) +#endif + +#if defined(EMBREE_GEOMETRY_CURVE) || defined(EMBREE_GEOMETRY_POINT) + #define IF_ENABLED_CURVES_OR_POINTS(x) x +#else + #define IF_ENABLED_CURVES_OR_POINTS(x) +#endif + +#if defined(EMBREE_GEOMETRY_CURVE) + #define IF_ENABLED_CURVES(x) x +#else + #define IF_ENABLED_CURVES(x) +#endif + +#if defined(EMBREE_GEOMETRY_POINT) + #define IF_ENABLED_POINTS(x) x +#else + #define IF_ENABLED_POINTS(x) +#endif + +#if defined(EMBREE_GEOMETRY_SUBDIVISION) + #define IF_ENABLED_SUBDIV(x) x +#else + #define IF_ENABLED_SUBDIV(x) +#endif + +#if defined(EMBREE_GEOMETRY_USER) + #define IF_ENABLED_USER(x) x +#else + #define IF_ENABLED_USER(x) +#endif + +#if defined(EMBREE_GEOMETRY_INSTANCE) + #define IF_ENABLED_INSTANCE(x) x +#else + #define IF_ENABLED_INSTANCE(x) +#endif + +#if defined(EMBREE_GEOMETRY_GRID) + #define IF_ENABLED_GRIDS(x) x +#else + #define IF_ENABLED_GRIDS(x) +#endif +""" + ) + + +with open("CMakeLists.txt", "r") as cmake_file: + cmake_content = cmake_file.read() + major_version = int(re.compile(r"EMBREE_VERSION_MAJOR\s(\d+)").findall(cmake_content)[0]) + minor_version = int(re.compile(r"EMBREE_VERSION_MINOR\s(\d+)").findall(cmake_content)[0]) + patch_version = int(re.compile(r"EMBREE_VERSION_PATCH\s(\d+)").findall(cmake_content)[0]) + +with open(os.path.join(dest_dir, "include/embree3/rtcore_config.h"), "w") as config_file: + config_file.write( + f""" +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#define RTC_VERSION_MAJOR {major_version} +#define RTC_VERSION_MINOR {minor_version} +#define RTC_VERSION_PATCH {patch_version} +#define RTC_VERSION {major_version}{minor_version:02d}{patch_version:02d} +#define RTC_VERSION_STRING "{major_version}.{minor_version}.{patch_version}" + +#define RTC_MAX_INSTANCE_LEVEL_COUNT 1 + +#define EMBREE_MIN_WIDTH 0 +#define RTC_MIN_WIDTH EMBREE_MIN_WIDTH + +#define EMBREE_STATIC_LIB +/* #undef EMBREE_API_NAMESPACE */ + +#if defined(EMBREE_API_NAMESPACE) +# define RTC_NAMESPACE +# define RTC_NAMESPACE_BEGIN namespace {{ +# define RTC_NAMESPACE_END }} +# define RTC_NAMESPACE_USE using namespace ; +# define RTC_API_EXTERN_C +# undef EMBREE_API_NAMESPACE +#else +# define RTC_NAMESPACE_BEGIN +# define RTC_NAMESPACE_END +# define RTC_NAMESPACE_USE +# if defined(__cplusplus) +# define RTC_API_EXTERN_C extern "C" +# else +# define RTC_API_EXTERN_C +# endif +#endif + +#if defined(ISPC) +# define RTC_API_IMPORT extern "C" unmasked +# define RTC_API_EXPORT extern "C" unmasked +#elif defined(EMBREE_STATIC_LIB) +# define RTC_API_IMPORT RTC_API_EXTERN_C +# define RTC_API_EXPORT RTC_API_EXTERN_C +#elif defined(_WIN32) +# define RTC_API_IMPORT RTC_API_EXTERN_C __declspec(dllimport) +# define RTC_API_EXPORT RTC_API_EXTERN_C __declspec(dllexport) +#else +# define RTC_API_IMPORT RTC_API_EXTERN_C +# define RTC_API_EXPORT RTC_API_EXTERN_C __attribute__ ((visibility ("default"))) +#endif + +#if defined(RTC_EXPORT_API) +# define RTC_API RTC_API_EXPORT +#else +# define RTC_API RTC_API_IMPORT +#endif +""" + ) + +os.chdir("..") +shutil.rmtree("embree-tmp") diff --git a/modules/raycast/lightmap_raycaster.cpp b/modules/raycast/lightmap_raycaster.cpp new file mode 100644 index 000000000000..2a0a216f7694 --- /dev/null +++ b/modules/raycast/lightmap_raycaster.cpp @@ -0,0 +1,198 @@ +/*************************************************************************/ +/* lightmap_raycaster.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "lightmap_raycaster.h" + +// From Embree. +#include +#include + +using namespace embree; + +LightmapRaycaster *LightmapRaycasterEmbree::create_embree_raycaster() { + return memnew(LightmapRaycasterEmbree); +} + +void LightmapRaycasterEmbree::make_default_raycaster() { + create_function = create_embree_raycaster; +} + +void LightmapRaycasterEmbree::filter_function(const struct RTCFilterFunctionNArguments *p_args) { + + RTCHit *hit = (RTCHit *)p_args->hit; + + unsigned int geomID = hit->geomID; + float u = hit->u; + float v = hit->v; + + LightmapRaycasterEmbree *scene = (LightmapRaycasterEmbree *)p_args->geometryUserPtr; + RTCGeometry geom = rtcGetGeometry(scene->embree_scene, geomID); + + rtcInterpolate0(geom, hit->primID, hit->u, hit->v, RTC_BUFFER_TYPE_VERTEX_ATTRIBUTE, 0, &hit->u, 2); + + if (scene->alpha_textures.has(geomID)) { + const AlphaTextureData &alpha_texture = scene->alpha_textures[geomID]; + + if (alpha_texture.sample(hit->u, hit->v) < 128) { + p_args->valid[0] = 0; + return; + } + } + + rtcInterpolate0(geom, hit->primID, u, v, RTC_BUFFER_TYPE_VERTEX_ATTRIBUTE, 1, &hit->Ng_x, 3); +} + +bool LightmapRaycasterEmbree::intersect(Ray &r_ray) { + RTCIntersectContext context; + + rtcInitIntersectContext(&context); + + rtcIntersect1(embree_scene, &context, (RTCRayHit *)&r_ray); + return r_ray.geomID != RTC_INVALID_GEOMETRY_ID; +} + +void LightmapRaycasterEmbree::intersect(Vector &r_rays) { + Ray *rays = r_rays.ptrw(); + for (int i = 0; i < r_rays.size(); ++i) { + intersect(rays[i]); + } +} + +void LightmapRaycasterEmbree::set_mesh_alpha_texture(Ref p_alpha_texture, unsigned int p_id) { + if (p_alpha_texture.is_valid() && p_alpha_texture->get_size() != Vector2i()) { + AlphaTextureData tex; + tex.size = p_alpha_texture->get_size(); + tex.data.resize(tex.size.x * tex.size.y); + + { + PoolVector::Read r = p_alpha_texture->get_data().read(); + uint8_t *ptrw = tex.data.ptrw(); + for (int i = 0; i < tex.size.x * tex.size.y; ++i) { + ptrw[i] = r[i]; + } + } + + alpha_textures.insert(p_id, tex); + } +} + +float blerp(float c00, float c10, float c01, float c11, float tx, float ty) { + return Math::lerp(Math::lerp(c00, c10, tx), Math::lerp(c01, c11, tx), ty); +} + +uint8_t LightmapRaycasterEmbree::AlphaTextureData::sample(float u, float v) const { + float x = u * size.x; + float y = v * size.y; + int xi = (int)x; + int yi = (int)y; + + uint8_t texels[4]; + + for (int i = 0; i < 4; ++i) { + int sample_x = CLAMP(xi + i % 2, 0, size.x - 1); + int sample_y = CLAMP(yi + i / 2, 0, size.y - 1); + texels[i] = data[sample_y * size.x + sample_x]; + } + + return Math::round(blerp(texels[0], texels[1], texels[2], texels[3], x - xi, y - yi)); +} + +void LightmapRaycasterEmbree::add_mesh(const Vector &p_vertices, const Vector &p_normals, const Vector &p_uv2s, unsigned int p_id) { + + RTCGeometry embree_mesh = rtcNewGeometry(embree_device, RTC_GEOMETRY_TYPE_TRIANGLE); + + rtcSetGeometryVertexAttributeCount(embree_mesh, 2); + + int vertex_count = p_vertices.size(); + + ERR_FAIL_COND(vertex_count % 3 != 0); + ERR_FAIL_COND(vertex_count != p_uv2s.size()); + + Vec3fa *embree_vertices = (Vec3fa *)rtcSetNewGeometryBuffer(embree_mesh, RTC_BUFFER_TYPE_VERTEX, 0, RTC_FORMAT_FLOAT3, sizeof(Vec3fa), vertex_count); + Vec2fa *embree_light_uvs = (Vec2fa *)rtcSetNewGeometryBuffer(embree_mesh, RTC_BUFFER_TYPE_VERTEX_ATTRIBUTE, 0, RTC_FORMAT_FLOAT2, sizeof(Vec2fa), vertex_count); + uint32_t *embree_triangles = (uint32_t *)rtcSetNewGeometryBuffer(embree_mesh, RTC_BUFFER_TYPE_INDEX, 0, RTC_FORMAT_UINT3, sizeof(uint32_t) * 3, vertex_count / 3); + + Vec3fa *embree_normals = nullptr; + if (!p_normals.empty()) { + embree_normals = (Vec3fa *)rtcSetNewGeometryBuffer(embree_mesh, RTC_BUFFER_TYPE_VERTEX_ATTRIBUTE, 1, RTC_FORMAT_FLOAT3, sizeof(Vec3fa), vertex_count); + } + + for (uint32_t i = 0; i < vertex_count; i++) { + embree_vertices[i] = Vec3fa(p_vertices[i].x, p_vertices[i].y, p_vertices[i].z); + embree_light_uvs[i] = Vec2fa(p_uv2s[i].x, p_uv2s[i].y); + if (embree_normals != nullptr) { + embree_normals[i] = Vec3fa(p_normals[i].x, p_normals[i].y, p_normals[i].z); + } + embree_triangles[i] = i; + } + + rtcCommitGeometry(embree_mesh); + rtcSetGeometryIntersectFilterFunction(embree_mesh, filter_function); + rtcSetGeometryUserData(embree_mesh, this); + rtcAttachGeometryByID(embree_scene, embree_mesh, p_id); + rtcReleaseGeometry(embree_mesh); +} + +void LightmapRaycasterEmbree::commit() { + rtcCommitScene(embree_scene); +} + +void LightmapRaycasterEmbree::set_mesh_filter(const Set &p_mesh_ids) { + for (Set::Element *E = p_mesh_ids.front(); E; E = E->next()) { + rtcDisableGeometry(rtcGetGeometry(embree_scene, E->get())); + } + rtcCommitScene(embree_scene); + filter_meshes = p_mesh_ids; +} + +void LightmapRaycasterEmbree::clear_mesh_filter() { + for (Set::Element *E = filter_meshes.front(); E; E = E->next()) { + rtcEnableGeometry(rtcGetGeometry(embree_scene, E->get())); + } + rtcCommitScene(embree_scene); + filter_meshes.clear(); +} + +void embree_error_handler(void *p_user_data, RTCError p_code, const char *p_str) { + print_error("Embree error: " + String(p_str)); +} + +LightmapRaycasterEmbree::LightmapRaycasterEmbree() { + embree_device = rtcNewDevice(nullptr); + rtcSetDeviceErrorFunction(embree_device, &embree_error_handler, nullptr); + embree_scene = rtcNewScene(embree_device); +} + +LightmapRaycasterEmbree::~LightmapRaycasterEmbree() { + if (embree_scene != nullptr) + rtcReleaseScene(embree_scene); + if (embree_device != nullptr) + rtcReleaseDevice(embree_device); +} diff --git a/modules/raycast/lightmap_raycaster.h b/modules/raycast/lightmap_raycaster.h new file mode 100644 index 000000000000..b116cf6466bc --- /dev/null +++ b/modules/raycast/lightmap_raycaster.h @@ -0,0 +1,73 @@ +/*************************************************************************/ +/* lightmap_raycaster.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "core/object.h" +#include "scene/3d/lightmapper.h" +#include "scene/resources/mesh.h" + +#include + +class LightmapRaycasterEmbree : public LightmapRaycaster { + GDCLASS(LightmapRaycasterEmbree, LightmapRaycaster); + +private: + struct AlphaTextureData { + Vector data; + Vector2i size; + + uint8_t sample(float u, float v) const; + }; + + RTCDevice embree_device; + RTCScene embree_scene; + + static void filter_function(const struct RTCFilterFunctionNArguments *p_args); + + Map alpha_textures; + Set filter_meshes; + +public: + virtual bool intersect(Ray &p_ray); + + virtual void intersect(Vector &r_rays); + + virtual void add_mesh(const Vector &p_vertices, const Vector &p_normals, const Vector &p_uv2s, unsigned int p_id); + virtual void set_mesh_alpha_texture(Ref p_alpha_texture, unsigned int p_id); + virtual void commit(); + + virtual void set_mesh_filter(const Set &p_mesh_ids); + virtual void clear_mesh_filter(); + + static LightmapRaycaster *create_embree_raycaster(); + static void make_default_raycaster(); + + LightmapRaycasterEmbree(); + ~LightmapRaycasterEmbree(); +}; diff --git a/modules/raycast/register_types.cpp b/modules/raycast/register_types.cpp new file mode 100644 index 000000000000..b6799ce62969 --- /dev/null +++ b/modules/raycast/register_types.cpp @@ -0,0 +1,40 @@ +/*************************************************************************/ +/* register_types.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "register_types.h" + +#include "lightmap_raycaster.h" + +void register_raycast_types() { + LightmapRaycasterEmbree::make_default_raycaster(); +} + +void unregister_raycast_types() { +} diff --git a/modules/raycast/register_types.h b/modules/raycast/register_types.h new file mode 100644 index 000000000000..789604a4917d --- /dev/null +++ b/modules/raycast/register_types.h @@ -0,0 +1,32 @@ +/*************************************************************************/ +/* register_types.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +void register_raycast_types(); +void unregister_raycast_types(); diff --git a/modules/xatlas_unwrap/register_types.cpp b/modules/xatlas_unwrap/register_types.cpp index 95339277e2ab..72026a2a4118 100644 --- a/modules/xatlas_unwrap/register_types.cpp +++ b/modules/xatlas_unwrap/register_types.cpp @@ -41,7 +41,7 @@ extern bool (*array_mesh_lightmap_unwrap_callback)(float p_texel_size, const flo bool xatlas_mesh_lightmap_unwrap_callback(float p_texel_size, const float *p_vertices, const float *p_normals, int p_vertex_count, const int *p_indices, const int *p_face_materials, int p_index_count, float **r_uv, int **r_vertex, int *r_vertex_count, int **r_index, int *r_index_count, int *r_size_hint_x, int *r_size_hint_y) { - //set up input mesh + // set up input mesh xatlas::MeshDecl input_mesh; input_mesh.indexData = p_indices; input_mesh.indexCount = p_index_count; @@ -56,18 +56,19 @@ bool xatlas_mesh_lightmap_unwrap_callback(float p_texel_size, const float *p_ver input_mesh.vertexUvStride = 0; xatlas::ChartOptions chart_options; - xatlas::PackOptions pack_options; + chart_options.fixWinding = true; - pack_options.maxChartSize = 4096; + xatlas::PackOptions pack_options; + pack_options.padding = 1; + pack_options.maxChartSize = 4094; // Lightmap atlassing needs 2 for padding between meshes, so 4096-2 pack_options.blockAlign = true; pack_options.texelsPerUnit = 1.0 / p_texel_size; xatlas::Atlas *atlas = xatlas::Create(); - printf("Adding mesh..\n"); + xatlas::AddMeshError err = xatlas::AddMesh(atlas, input_mesh, 1); ERR_FAIL_COND_V_MSG(err != xatlas::AddMeshError::Success, false, xatlas::StringForEnum(err)); - printf("Generate..\n"); xatlas::Generate(atlas, chart_options, pack_options); *r_size_hint_x = atlas->width; @@ -96,7 +97,6 @@ bool xatlas_mesh_lightmap_unwrap_callback(float p_texel_size, const float *p_ver max_y = MAX(max_y, output.vertexArray[i].uv[1]); } - printf("Final texture size: %f,%f - max %f,%f\n", w, h, max_x, max_y); *r_vertex_count = output.vertexCount; for (uint32_t i = 0; i < output.indexCount; i++) { @@ -106,7 +106,7 @@ bool xatlas_mesh_lightmap_unwrap_callback(float p_texel_size, const float *p_ver *r_index_count = output.indexCount; xatlas::Destroy(atlas); - printf("Done\n"); + return true; } diff --git a/scene/3d/baked_lightmap.cpp b/scene/3d/baked_lightmap.cpp index 5fbec7b3b254..425857dde42e 100644 --- a/scene/3d/baked_lightmap.cpp +++ b/scene/3d/baked_lightmap.cpp @@ -31,6 +31,7 @@ #include "baked_lightmap.h" #include "core/io/config_file.h" #include "core/io/resource_saver.h" +#include "core/math/math_defs.h" #include "core/os/dir_access.h" #include "core/os/os.h" #include "voxel_light_baker.h" @@ -86,12 +87,21 @@ float BakedLightmapData::get_energy() const { return energy; } -void BakedLightmapData::add_user(const NodePath &p_path, const Ref &p_lightmap, int p_instance) { +void BakedLightmapData::add_user(const NodePath &p_path, const Ref &p_lightmap, int p_lightmap_slice, const Rect2 &p_lightmap_uv_rect, int p_instance) { ERR_FAIL_COND_MSG(p_lightmap.is_null(), "It's not a reference to a valid Texture object."); + ERR_FAIL_COND(p_lightmap_slice == -1 && !Object::cast_to(p_lightmap.ptr())); + ERR_FAIL_COND(p_lightmap_slice != -1 && !Object::cast_to(p_lightmap.ptr())); + User user; user.path = p_path; - user.lightmap = p_lightmap; + if (p_lightmap_slice == -1) { + user.lightmap.single = p_lightmap; + } else { + user.lightmap.layered = p_lightmap; + } + user.lightmap_slice = p_lightmap_slice; + user.lightmap_uv_rect = p_lightmap_uv_rect; user.instance_index = p_instance; users.push_back(user); } @@ -105,10 +115,26 @@ NodePath BakedLightmapData::get_user_path(int p_user) const { ERR_FAIL_INDEX_V(p_user, users.size(), NodePath()); return users[p_user].path; } -Ref BakedLightmapData::get_user_lightmap(int p_user) const { +Ref BakedLightmapData::get_user_lightmap(int p_user) const { + + ERR_FAIL_INDEX_V(p_user, users.size(), Ref()); + if (users[p_user].lightmap_slice == -1) { + return users[p_user].lightmap.single; + } else { + return users[p_user].lightmap.layered; + } +} + +int BakedLightmapData::get_user_lightmap_slice(int p_user) const { + + ERR_FAIL_INDEX_V(p_user, users.size(), -1); + return users[p_user].lightmap_slice; +} - ERR_FAIL_INDEX_V(p_user, users.size(), Ref()); - return users[p_user].lightmap; +Rect2 BakedLightmapData::get_user_lightmap_uv_rect(int p_user) const { + + ERR_FAIL_INDEX_V(p_user, users.size(), Rect2(0, 0, 1, 1)); + return users[p_user].lightmap_uv_rect; } int BakedLightmapData::get_user_instance(int p_user) const { @@ -123,10 +149,39 @@ void BakedLightmapData::clear_users() { void BakedLightmapData::_set_user_data(const Array &p_data) { - ERR_FAIL_COND((p_data.size() % 3) != 0); + // Detect old lightmapper format + if (p_data.size() % 3 == 0) { + bool is_old_format = true; + for (int i = 0; i < p_data.size(); i += 3) { + is_old_format = is_old_format && p_data[i + 0].get_type() == Variant::NODE_PATH; + is_old_format = is_old_format && p_data[i + 1].is_ref(); + is_old_format = is_old_format && p_data[i + 2].get_type() == Variant::INT; + if (!is_old_format) { + break; + } + } + if (is_old_format) { +#ifdef DEBUG_ENABLED + WARN_PRINTS("Geometry at path " + String(p_data[0]) + " is using old lightmapper data. Please re-bake."); +#endif + Array adapted_data; + adapted_data.resize((p_data.size() / 3) * 5); + for (int i = 0; i < p_data.size() / 3; i++) { + adapted_data[i * 5 + 0] = p_data[i * 3 + 0]; + adapted_data[i * 5 + 1] = p_data[i * 3 + 1]; + adapted_data[i * 5 + 2] = -1; + adapted_data[i * 5 + 3] = Rect2(0, 0, 1, 1); + adapted_data[i * 5 + 4] = p_data[i * 3 + 2]; + } + _set_user_data(adapted_data); + return; + } + } + + ERR_FAIL_COND((p_data.size() % 5) != 0); - for (int i = 0; i < p_data.size(); i += 3) { - add_user(p_data[i], p_data[i + 1], p_data[i + 2]); + for (int i = 0; i < p_data.size(); i += 5) { + add_user(p_data[i], p_data[i + 1], p_data[i + 2], p_data[i + 3], p_data[i + 4]); } } @@ -135,7 +190,9 @@ Array BakedLightmapData::_get_user_data() const { Array ret; for (int i = 0; i < users.size(); i++) { ret.push_back(users[i].path); - ret.push_back(users[i].lightmap); + ret.push_back(users[i].lightmap_slice == -1 ? Ref(users[i].lightmap.single) : Ref(users[i].lightmap.layered)); + ret.push_back(users[i].lightmap_slice); + ret.push_back(users[i].lightmap_uv_rect); ret.push_back(users[i].instance_index); } return ret; @@ -164,7 +221,7 @@ void BakedLightmapData::_bind_methods() { ClassDB::bind_method(D_METHOD("set_energy", "energy"), &BakedLightmapData::set_energy); ClassDB::bind_method(D_METHOD("get_energy"), &BakedLightmapData::get_energy); - ClassDB::bind_method(D_METHOD("add_user", "path", "lightmap", "instance"), &BakedLightmapData::add_user); + ClassDB::bind_method(D_METHOD("add_user", "path", "lightmap", "lightmap_slice", "lightmap_uv_rect", "instance"), &BakedLightmapData::add_user); ClassDB::bind_method(D_METHOD("get_user_count"), &BakedLightmapData::get_user_count); ClassDB::bind_method(D_METHOD("get_user_path", "user_idx"), &BakedLightmapData::get_user_path); ClassDB::bind_method(D_METHOD("get_user_lightmap", "user_idx"), &BakedLightmapData::get_user_lightmap); @@ -192,79 +249,117 @@ BakedLightmapData::~BakedLightmapData() { /////////////////////////// -BakedLightmap::BakeBeginFunc BakedLightmap::bake_begin_function = NULL; -BakedLightmap::BakeStepFunc BakedLightmap::bake_step_function = NULL; -BakedLightmap::BakeEndFunc BakedLightmap::bake_end_function = NULL; - -void BakedLightmap::set_bake_cell_size(float p_cell_size) { - bake_cell_size = p_cell_size; -} - -float BakedLightmap::get_bake_cell_size() const { - return bake_cell_size; -} +Lightmapper::BakeStepFunc BakedLightmap::bake_step_function; +Lightmapper::BakeStepFunc BakedLightmap::bake_substep_function; + +Size2i BakedLightmap::_compute_lightmap_size(const MeshesFound &p_mesh) { + double area = 0; + double uv_area = 0; + for (int i = 0; i < p_mesh.mesh->get_surface_count(); i++) { + Array arrays = p_mesh.mesh->surface_get_arrays(i); + PoolVector vertices = arrays[Mesh::ARRAY_VERTEX]; + PoolVector uv2 = arrays[Mesh::ARRAY_TEX_UV2]; + PoolVector indices = arrays[Mesh::ARRAY_INDEX]; + + ERR_FAIL_COND_V(vertices.size() == 0, Vector2()); + ERR_FAIL_COND_V(uv2.size() == 0, Vector2()); + + int vc = vertices.size(); + PoolVector::Read vr = vertices.read(); + PoolVector::Read u2r = uv2.read(); + PoolVector::Read ir; + int ic = 0; + + if (indices.size()) { + ic = indices.size(); + ir = indices.read(); + } -void BakedLightmap::set_capture_cell_size(float p_cell_size) { - capture_cell_size = p_cell_size; -} + int faces = ic ? ic / 3 : vc / 3; + for (int j = 0; j < faces; j++) { + Vector3 vertex[3]; + Vector2 uv[3]; -float BakedLightmap::get_capture_cell_size() const { - return capture_cell_size; -} - -void BakedLightmap::set_extents(const Vector3 &p_extents) { - extents = p_extents; - update_gizmo(); - _change_notify("bake_extents"); -} + for (int k = 0; k < 3; k++) { + int idx = ic ? ir[j * 3 + k] : j * 3 + k; + vertex[k] = p_mesh.xform.xform(vr[idx]); + uv[k] = u2r[idx]; + } -Vector3 BakedLightmap::get_extents() const { - return extents; -} + Vector3 p1 = vertex[0]; + Vector3 p2 = vertex[1]; + Vector3 p3 = vertex[2]; + double a = p1.distance_to(p2); + double b = p2.distance_to(p3); + double c = p3.distance_to(p1); + double halfPerimeter = (a + b + c) / 2.0; + area += sqrt(halfPerimeter * (halfPerimeter - a) * (halfPerimeter - b) * (halfPerimeter - c)); + + Vector2 uv_p1 = uv[0]; + Vector2 uv_p2 = uv[1]; + Vector2 uv_p3 = uv[2]; + double uv_a = uv_p1.distance_to(uv_p2); + double uv_b = uv_p2.distance_to(uv_p3); + double uv_c = uv_p3.distance_to(uv_p1); + double uv_halfPerimeter = (uv_a + uv_b + uv_c) / 2.0; + uv_area += sqrt( + uv_halfPerimeter * (uv_halfPerimeter - uv_a) * (uv_halfPerimeter - uv_b) * (uv_halfPerimeter - uv_c)); + } + } -void BakedLightmap::set_bake_default_texels_per_unit(const float &p_bake_texels_per_unit) { - bake_default_texels_per_unit = p_bake_texels_per_unit; - update_gizmo(); -} + if (uv_area < 0.0001f) { + uv_area = 1.0; + } -float BakedLightmap::get_bake_default_texels_per_unit() const { - return bake_default_texels_per_unit; + int pixels = Math::round(ceil((1.0 / sqrt(uv_area)) * sqrt(area * default_texels_per_unit))); + int size = CLAMP(pixels, 2, 4096); + return Vector2i(size, size); } -void BakedLightmap::_find_meshes_and_lights(Node *p_at_node, List &plot_meshes, List &plot_lights) { - +void BakedLightmap::_find_meshes_and_lights(Node *p_at_node, Vector &meshes, Vector &lights) { MeshInstance *mi = Object::cast_to(p_at_node); if (mi && mi->get_flag(GeometryInstance::FLAG_USE_BAKED_LIGHT) && mi->is_visible_in_tree()) { Ref mesh = mi->get_mesh(); if (mesh.is_valid()) { - - bool all_have_uv2 = true; + bool all_have_uv2_and_normal = true; + bool surfaces_found = false; for (int i = 0; i < mesh->get_surface_count(); i++) { + if (mesh->surface_get_primitive_type(i) != Mesh::PRIMITIVE_TRIANGLES) { + continue; + } if (!(mesh->surface_get_format(i) & Mesh::ARRAY_FORMAT_TEX_UV2)) { - all_have_uv2 = false; + all_have_uv2_and_normal = false; break; } + if (!(mesh->surface_get_format(i) & Mesh::ARRAY_FORMAT_NORMAL)) { + all_have_uv2_and_normal = false; + break; + } + surfaces_found = true; } - if (all_have_uv2) { - //READY TO BAKE! size hint could be computed if not found, actually.. - - AABB aabb = mesh->get_aabb(); - - Transform xf = get_global_transform().affine_inverse() * mi->get_global_transform(); - - if (AABB(-extents, extents * 2).intersects(xf.xform(aabb))) { - PlotMesh pm; - pm.local_xform = xf; - pm.mesh = mesh; - pm.path = get_path_to(mi); - pm.instance_idx = -1; - for (int i = 0; i < mesh->get_surface_count(); i++) { - pm.instance_materials.push_back(mi->get_surface_material(i)); + if (surfaces_found && all_have_uv2_and_normal) { + MeshesFound mf; + mf.cast_shadows = mi->get_cast_shadows_setting() != GeometryInstance::SHADOW_CASTING_SETTING_OFF; + mf.generate_lightmap = mi->get_generate_lightmap(); + mf.xform = get_global_transform().affine_inverse() * mi->get_global_transform(); + mf.node_path = get_path_to(mi); + mf.subindex = -1; + mf.mesh = mesh; + + static const int lightmap_scale[4] = { 1, 2, 4, 8 }; //GeometryInstance3D::LIGHTMAP_SCALE_MAX = { 1, 2, 4, 8 }; + mf.lightmap_scale = lightmap_scale[mi->get_lightmap_scale()]; + + Ref all_override = mi->get_material_override(); + for (int i = 0; i < mesh->get_surface_count(); i++) { + if (all_override.is_valid()) { + mf.overrides.push_back(all_override); + } else { + mf.overrides.push_back(mi->get_surface_material(i)); } - pm.override_material = mi->get_material_override(); - plot_meshes.push_back(pm); } + + meshes.push_back(mf); } } } @@ -272,19 +367,44 @@ void BakedLightmap::_find_meshes_and_lights(Node *p_at_node, List &plo Spatial *s = Object::cast_to(p_at_node); if (!mi && s) { - Array meshes = p_at_node->call("get_bake_meshes"); - if (meshes.size() && (meshes.size() & 1) == 0) { + Array bmeshes = p_at_node->call("get_bake_meshes"); + if (bmeshes.size() && (bmeshes.size() & 1) == 0) { Transform xf = get_global_transform().affine_inverse() * s->get_global_transform(); - for (int i = 0; i < meshes.size(); i += 2) { - PlotMesh pm; - Transform mesh_xf = meshes[i + 1]; - pm.local_xform = xf * mesh_xf; - pm.mesh = meshes[i]; - pm.instance_idx = i / 2; - if (!pm.mesh.is_valid()) + Ref all_override; + + GeometryInstance *gi = Object::cast_to(p_at_node); + if (gi) { + all_override = mi->get_material_override(); + } + + for (int i = 0; i < bmeshes.size(); i += 2) { + Ref mesh = bmeshes[i]; + if (!mesh.is_valid()) { continue; - pm.path = get_path_to(s); - plot_meshes.push_back(pm); + } + + MeshesFound mf; + + Transform mesh_xf = bmeshes[i + 1]; + mf.xform = xf * mesh_xf; + mf.node_path = get_path_to(s); + mf.subindex = i / 2; + mf.lightmap_scale = 1; + mf.mesh = mesh; + + if (gi) { + mf.cast_shadows = mi->get_cast_shadows_setting() != GeometryInstance::SHADOW_CASTING_SETTING_OFF; + mf.generate_lightmap = mi->get_generate_lightmap(); + } else { + mf.cast_shadows = true; + mf.generate_lightmap = true; + } + + for (int j = 0; j < mesh->get_surface_count(); j++) { + mf.overrides.push_back(all_override); + } + + meshes.push_back(mf); } } } @@ -292,290 +412,590 @@ void BakedLightmap::_find_meshes_and_lights(Node *p_at_node, List &plo Light *light = Object::cast_to(p_at_node); if (light && light->get_bake_mode() != Light::BAKE_DISABLED) { - PlotLight pl; - Transform xf = get_global_transform().affine_inverse() * light->get_global_transform(); - - pl.local_xform = xf; - pl.light = light; - plot_lights.push_back(pl); + LightsFound lf; + lf.xform = get_global_transform().affine_inverse() * light->get_global_transform(); + lf.light = light; + lights.push_back(lf); } - for (int i = 0; i < p_at_node->get_child_count(); i++) { + for (int i = 0; i < p_at_node->get_child_count(); i++) { Node *child = p_at_node->get_child(i); - if (!child->get_owner()) + if (!child->get_owner()) { continue; //maybe a helper + } - _find_meshes_and_lights(child, plot_meshes, plot_lights); + _find_meshes_and_lights(child, meshes, lights); } } -void BakedLightmap::set_hdr(bool p_enable) { - hdr = p_enable; -} +void BakedLightmap::_get_material_images(const MeshesFound &p_found_mesh, Lightmapper::MeshData &r_mesh_data, Vector > &r_albedo_textures, Vector > &r_emission_textures) { -bool BakedLightmap::is_hdr() const { - return hdr; -} + for (int i = 0; i < p_found_mesh.mesh->get_surface_count(); ++i) { + Ref mat = p_found_mesh.overrides[i]; + + if (mat.is_null()) { + mat = p_found_mesh.mesh->surface_get_material(i); + } + + Ref albedo_texture; + Color albedo_add = Color(0, 0, 0, 0); + Color albedo_mul = Color(1, 1, 1, 1); + + Ref emission_texture; + Color emission_add = Color(0, 0, 0, 0); + Color emission_mul = Color(1, 1, 1, 1); + + if (mat.is_valid()) { + + albedo_texture = mat->get_texture(SpatialMaterial::TEXTURE_ALBEDO); + + if (albedo_texture.is_valid()) { + albedo_mul = mat->get_albedo(); + } else { + albedo_add = mat->get_albedo(); + } + + emission_texture = mat->get_texture(SpatialMaterial::TEXTURE_EMISSION); + Color emission_color = mat->get_emission(); + float emission_energy = mat->get_emission_energy(); + + if (mat->get_emission_operator() == SpatialMaterial::EMISSION_OP_ADD) { + emission_mul = Color(1, 1, 1) * emission_energy; + emission_add = emission_color * emission_energy; + } else { + emission_mul = emission_color * emission_energy; + emission_add = Color(0, 0, 0); + } + } -bool BakedLightmap::_bake_time(void *ud, float p_secs, float p_progress) { + Lightmapper::MeshData::TextureDef albedo; + albedo.mul = albedo_mul; + albedo.add = albedo_add; - uint64_t time = OS::get_singleton()->get_ticks_usec(); - BakeTimeData *btd = (BakeTimeData *)ud; + if (albedo_texture.is_valid()) { + albedo.tex_rid = albedo_texture->get_rid(); + r_albedo_textures.push_back(albedo_texture); + } + + r_mesh_data.albedo.push_back(albedo); - if (time - btd->last_step > 1000000) { + Lightmapper::MeshData::TextureDef emission; + emission.mul = emission_mul; + emission.add = emission_add; - int mins_left = p_secs / 60; - int secs_left = Math::fmod(p_secs, 60.0f); - int percent = p_progress * 100; - bool abort = bake_step_function(btd->pass + percent, btd->text + " " + vformat(RTR("%d%%"), percent) + " " + vformat(RTR("(Time Left: %d:%02d s)"), mins_left, secs_left)); - btd->last_step = time; - if (abort) - return true; + if (emission_texture.is_valid()) { + emission.tex_rid = emission_texture->get_rid(); + r_emission_textures.push_back(emission_texture); + } + r_mesh_data.emission.push_back(emission); } +} - return false; +bool BakedLightmap::_lightmap_bake_step_function(float p_completion, const String &p_text, void *ud, bool p_refresh) { + BakeStepUD *bsud = (BakeStepUD *)ud; + bool ret = false; + if (bsud->func) { + ret = bsud->func(bsud->from_percent + p_completion * (bsud->to_percent - bsud->from_percent), p_text, bsud->ud, p_refresh); + } + return ret; } -BakedLightmap::BakeError BakedLightmap::bake(Node *p_from_node, bool p_create_visual_debug) { +BakedLightmap::BakeError BakedLightmap::bake(Node *p_from_node, String p_data_save_path) { - String save_path; + bool no_save_path = false; + if (p_data_save_path == "" && (get_light_data().is_null() || !get_light_data()->get_path().is_resource_file())) { + no_save_path = true; + } - if (image_path.begins_with("res://")) { - save_path = image_path; - } else { - if (get_filename() != "") { - save_path = get_filename().get_base_dir(); - } else if (get_owner() && get_owner()->get_filename() != "") { - save_path = get_owner()->get_filename().get_base_dir(); + if (p_data_save_path == "") { + if (get_light_data().is_null()) { + no_save_path = true; + } else { + p_data_save_path = get_light_data()->get_path(); + if (!p_data_save_path.is_resource_file()) { + no_save_path = true; + } } + } - if (save_path == "") { + if (no_save_path) { + + if (image_path == "") { return BAKE_ERROR_NO_SAVE_PATH; + } else { + p_data_save_path = image_path; } - if (image_path != "") { - save_path.plus_file(image_path); - } + + WARN_PRINT("Using the deprecated property \"image_path\" as a save path, consider providing a better save path via the \"data_save_path\" parameter"); + p_data_save_path = image_path.plus_file("BakedLightmap.lmbake"); } + { //check for valid save path - DirAccessRef d = DirAccess::open(save_path); + DirAccessRef d = DirAccess::open(p_data_save_path.get_base_dir()); if (!d) { - ERR_PRINTS("Invalid Save Path '" + save_path + "'."); - return BAKE_ERROR_NO_SAVE_PATH; + ERR_FAIL_V_MSG(BAKE_ERROR_NO_SAVE_PATH, "Invalid save path '" + p_data_save_path + "'."); } } - Ref new_light_data; - new_light_data.instance(); + if (bake_step_function) { + bool cancelled = bake_step_function(0.0, TTR("Finding meshes and lights"), nullptr, true); + if (cancelled) { + return BAKE_ERROR_USER_ABORTED; + } + } - VoxelLightBaker baker; + Ref lightmapper = Lightmapper::create(); + ERR_FAIL_COND_V(lightmapper.is_null(), BAKE_ERROR_NO_LIGHTMAPPER); - int bake_subdiv; - int capture_subdiv; - AABB bake_bounds; - { - bake_bounds = AABB(-extents, extents * 2.0); - int subdiv = nearest_power_of_2_templated(int(bake_bounds.get_longest_axis_size() / bake_cell_size)); - bake_bounds.size[bake_bounds.get_longest_axis_index()] = subdiv * bake_cell_size; - bake_subdiv = nearest_shift(subdiv) + 1; + Vector lights_found; + Vector meshes_found; - capture_subdiv = bake_subdiv; - float css = bake_cell_size; - while (css < capture_cell_size && capture_subdiv > 2) { - capture_subdiv--; - css *= 2.0; - } + _find_meshes_and_lights(p_from_node ? p_from_node : get_parent(), meshes_found, lights_found); + + if (meshes_found.size() == 0) { + return BAKE_ERROR_NO_MESHES; } - baker.begin_bake(bake_subdiv, bake_bounds); + for (int m_i = 0; m_i < meshes_found.size(); m_i++) { + if (bake_step_function) { + float p = (float)(m_i) / meshes_found.size(); + bool cancelled = bake_step_function(p * 0.05, vformat(TTR("Preparing geometry (%d/%d)"), m_i + 1, meshes_found.size()), nullptr, false); + if (cancelled) { + return BAKE_ERROR_USER_ABORTED; + } + } - List mesh_list; - List light_list; + MeshesFound &mf = meshes_found.write[m_i]; - _find_meshes_and_lights(p_from_node ? p_from_node : get_parent(), mesh_list, light_list); + Size2i lightmap_size = mf.mesh->get_lightmap_size_hint(); - if (bake_begin_function) { - bake_begin_function(mesh_list.size() + light_list.size() + 1 + mesh_list.size() * 100); - } + if (lightmap_size == Vector2i(0, 0)) { + lightmap_size = _compute_lightmap_size(mf); + } + lightmap_size *= mf.lightmap_scale; - int step = 0; + Lightmapper::MeshData md; - int pmc = 0; + { + Dictionary d; + d["path"] = mf.node_path; + if (mf.subindex >= 0) { + d["subindex"] = mf.subindex; + } + d["cast_shadows"] = mf.cast_shadows; + d["generate_lightmap"] = mf.generate_lightmap; + d["node_name"] = mf.node_path.get_name(mf.node_path.get_name_count() - 1); + md.userdata = d; + } - for (List::Element *E = mesh_list.front(); E; E = E->next()) { + Basis normal_xform = mf.xform.basis.inverse().transposed(); - if (bake_step_function) { - bake_step_function(step++, RTR("Plotting Meshes: ") + " (" + itos(pmc + 1) + "/" + itos(mesh_list.size()) + ")"); + for (int i = 0; i < mf.mesh->get_surface_count(); i++) { + if (mf.mesh->surface_get_primitive_type(i) != Mesh::PRIMITIVE_TRIANGLES) { + continue; + } + Array a = mf.mesh->surface_get_arrays(i); + + Vector vertices = a[Mesh::ARRAY_VERTEX]; + const Vector3 *vr = vertices.ptr(); + Vector uv2 = a[Mesh::ARRAY_TEX_UV2]; + const Vector2 *uv2r = nullptr; + Vector uv = a[Mesh::ARRAY_TEX_UV]; + const Vector2 *uvr = nullptr; + Vector normals = a[Mesh::ARRAY_NORMAL]; + const Vector3 *nr = nullptr; + Vector index = a[Mesh::ARRAY_INDEX]; + + ERR_CONTINUE(uv2.size() == 0); + ERR_CONTINUE(normals.size() == 0); + + if (!uv.empty()) { + uvr = uv.ptr(); + } + + uv2r = uv2.ptr(); + nr = normals.ptr(); + + int facecount; + const int *ir = nullptr; + + if (index.size()) { + facecount = index.size() / 3; + ir = index.ptr(); + } else { + facecount = vertices.size() / 3; + } + + md.surface_facecounts.push_back(facecount); + + for (int j = 0; j < facecount; j++) { + uint32_t vidx[3]; + + if (ir) { + for (int k = 0; k < 3; k++) { + vidx[k] = ir[j * 3 + k]; + } + } else { + for (int k = 0; k < 3; k++) { + vidx[k] = j * 3 + k; + } + } + + for (int k = 0; k < 3; k++) { + Vector3 v = mf.xform.xform(vr[vidx[k]]); + md.points.push_back(v); + + md.uv2.push_back(uv2r[vidx[k]]); + md.normal.push_back(normal_xform.xform(nr[vidx[k]]).normalized()); + + if (uvr != nullptr) { + md.uv.push_back(uvr[vidx[k]]); + } + } + } + } + + Vector > albedo_textures; + Vector > emission_textures; + + _get_material_images(mf, md, albedo_textures, emission_textures); + + for (int j = 0; j < albedo_textures.size(); j++) { + lightmapper->add_albedo_texture(albedo_textures[j]); } - pmc++; - baker.plot_mesh(E->get().local_xform, E->get().mesh, E->get().instance_materials, E->get().override_material); + for (int j = 0; j < emission_textures.size(); j++) { + lightmapper->add_emission_texture(emission_textures[j]); + } + + lightmapper->add_mesh(md, lightmap_size); } - pmc = 0; - baker.begin_bake_light(VoxelLightBaker::BakeQuality(bake_quality), VoxelLightBaker::BakeMode(bake_mode), propagation, energy); + for (int i = 0; i < lights_found.size(); i++) { + Light *light = lights_found[i].light; + Transform xf = lights_found[i].xform; + + if (Object::cast_to(light)) { + DirectionalLight *l = Object::cast_to(light); + lightmapper->add_directional_light(light->get_bake_mode() == Light::BAKE_ALL, -xf.basis.get_axis(Vector3::AXIS_Z).normalized(), l->get_color(), l->get_param(Light::PARAM_ENERGY), l->get_param(Light::PARAM_INDIRECT_ENERGY)); + } else if (Object::cast_to(light)) { + OmniLight *l = Object::cast_to(light); + lightmapper->add_omni_light(light->get_bake_mode() == Light::BAKE_ALL, xf.origin, l->get_color(), l->get_param(Light::PARAM_ENERGY), l->get_param(Light::PARAM_INDIRECT_ENERGY), l->get_param(Light::PARAM_RANGE), l->get_param(Light::PARAM_ATTENUATION)); + } else if (Object::cast_to(light)) { + SpotLight *l = Object::cast_to(light); + lightmapper->add_spot_light(light->get_bake_mode() == Light::BAKE_ALL, xf.origin, -xf.basis.get_axis(Vector3::AXIS_Z).normalized(), l->get_color(), l->get_param(Light::PARAM_ENERGY), l->get_param(Light::PARAM_INDIRECT_ENERGY), l->get_param(Light::PARAM_RANGE), l->get_param(Light::PARAM_ATTENUATION), l->get_param(Light::PARAM_SPOT_ANGLE), l->get_param(Light::PARAM_SPOT_ATTENUATION)); + } + } - for (List::Element *E = light_list.front(); E; E = E->next()) { + Ref environment_image; + Basis environment_xform; + if (environment_mode != ENVIRONMENT_MODE_DISABLED) { if (bake_step_function) { - bake_step_function(step++, RTR("Plotting Lights:") + " (" + itos(pmc + 1) + "/" + itos(light_list.size()) + ")"); + bake_step_function(0.1, TTR("Preparing environment"), nullptr, true); } - pmc++; - PlotLight pl = E->get(); - switch (pl.light->get_light_type()) { - case VS::LIGHT_DIRECTIONAL: { - baker.plot_light_directional(-pl.local_xform.basis.get_axis(2), pl.light->get_color(), pl.light->get_param(Light::PARAM_ENERGY), pl.light->get_param(Light::PARAM_INDIRECT_ENERGY), pl.light->get_bake_mode() == Light::BAKE_ALL); + switch (environment_mode) { + case ENVIRONMENT_MODE_DISABLED: { + //nothing } break; - case VS::LIGHT_OMNI: { - baker.plot_light_omni(pl.local_xform.origin, pl.light->get_color(), pl.light->get_param(Light::PARAM_ENERGY), pl.light->get_param(Light::PARAM_INDIRECT_ENERGY), pl.light->get_param(Light::PARAM_RANGE), pl.light->get_param(Light::PARAM_ATTENUATION), pl.light->get_bake_mode() == Light::BAKE_ALL); + case ENVIRONMENT_MODE_SCENE: { + Ref world = get_world(); + if (world.is_valid()) { + Ref env = world->get_environment(); + if (env.is_null()) { + env = world->get_fallback_environment(); + } + + if (env.is_valid()) { + environment_image = _get_irradiance_map(env, Vector2i(128, 64)); + environment_xform = get_global_transform().affine_inverse().basis * env->get_sky_orientation(); + } + } } break; - case VS::LIGHT_SPOT: { - baker.plot_light_spot(pl.local_xform.origin, pl.local_xform.basis.get_axis(2), pl.light->get_color(), pl.light->get_param(Light::PARAM_ENERGY), pl.light->get_param(Light::PARAM_INDIRECT_ENERGY), pl.light->get_param(Light::PARAM_RANGE), pl.light->get_param(Light::PARAM_ATTENUATION), pl.light->get_param(Light::PARAM_SPOT_ANGLE), pl.light->get_param(Light::PARAM_SPOT_ATTENUATION), pl.light->get_bake_mode() == Light::BAKE_ALL); + case ENVIRONMENT_MODE_CUSTOM_SKY: { + if (environment_custom_sky.is_valid()) { + environment_image = _get_irradiance_from_sky(environment_custom_sky, Vector2i(128, 64)); + print_line(vformat("env -> %s", environment_custom_sky_rotation_degrees)); + environment_xform.set_euler(environment_custom_sky_rotation_degrees * Math_PI / 180.0); + } + + } break; + case ENVIRONMENT_MODE_CUSTOM_COLOR: { + environment_image.instance(); + environment_image->create(128, 64, false, Image::FORMAT_RGBF); + Color c = environment_custom_color; + c.r *= environment_custom_energy; + c.g *= environment_custom_energy; + c.b *= environment_custom_energy; + for (int i = 0; i < 128; i++) { + for (int j = 0; j < 64; j++) { + environment_image->set_pixel(i, j, c); + } + } } break; } } - /*if (bake_step_function) { - bake_step_function(pmc++, RTR("Finishing Plot")); - }*/ - baker.end_bake(); + BakeStepUD bsud; + bsud.func = bake_step_function; + bsud.ud = nullptr; + bsud.from_percent = 0.1; + bsud.to_percent = 0.9; - Set used_mesh_names; + bool gen_atlas = OS::get_singleton()->get_current_video_driver() == OS::VIDEO_DRIVER_GLES2 ? false : generate_atlas; - pmc = 0; - for (List::Element *E = mesh_list.front(); E; E = E->next()) { + Lightmapper::BakeError bake_err = lightmapper->bake(Lightmapper::BakeQuality(bake_quality), use_denoiser, bounces, bias, gen_atlas, max_atlas_size, environment_image, environment_xform, _lightmap_bake_step_function, &bsud, bake_substep_function); - String mesh_name = E->get().mesh->get_name(); - if (mesh_name == "" || mesh_name.find(":") != -1 || mesh_name.find("/") != -1) { - mesh_name = "LightMap"; + if (bake_err != Lightmapper::BAKE_OK) { + switch (bake_err) { + case Lightmapper::BAKE_ERROR_USER_ABORTED: { + return BAKE_ERROR_USER_ABORTED; + } + case Lightmapper::BAKE_ERROR_LIGHTMAP_TOO_SMALL: { + return BAKE_ERROR_LIGHTMAP_SIZE; + } + case Lightmapper::BAKE_ERROR_NO_MESHES: { + return BAKE_ERROR_NO_MESHES; + } + default: { + } } + return BAKE_ERROR_NO_LIGHTMAPPER; + } - if (used_mesh_names.has(mesh_name)) { - int idx = 2; - String base = mesh_name; - while (true) { - mesh_name = base + itos(idx); - if (!used_mesh_names.has(mesh_name)) - break; - idx++; + Ref data; + if (get_light_data().is_valid()) { + data = get_light_data(); + set_light_data(Ref()); //clear + data->clear_users(); + } else { + data.instance(); + } + + if (capture_enabled) { + + if (bake_step_function) { + bool cancelled = bake_step_function(0.85, TTR("Generating capture"), nullptr, true); + if (cancelled) { + return BAKE_ERROR_USER_ABORTED; } } - used_mesh_names.insert(mesh_name); - pmc++; - VoxelLightBaker::LightMapData lm; + VoxelLightBaker voxel_baker; + + int bake_subdiv; + int capture_subdiv; + AABB bake_bounds; + { + bake_bounds = AABB(-extents, extents * 2.0); + int subdiv = nearest_power_of_2_templated(int(bake_bounds.get_longest_axis_size() / capture_cell_size)); + bake_bounds.size[bake_bounds.get_longest_axis_index()] = subdiv * capture_cell_size; + bake_subdiv = nearest_shift(subdiv) + 1; + + capture_subdiv = bake_subdiv; + float css = capture_cell_size; + while (css < capture_cell_size && capture_subdiv > 2) { + capture_subdiv--; + css *= 2.0; + } + } - Error err; - if (bake_step_function) { - BakeTimeData btd; - btd.text = RTR("Lighting Meshes: ") + mesh_name + " (" + itos(pmc) + "/" + itos(mesh_list.size()) + ")"; - btd.pass = step; - btd.last_step = 0; - err = baker.make_lightmap(E->get().local_xform, E->get().mesh, bake_default_texels_per_unit, lm, _bake_time, &btd); - if (err != OK) { - bake_end_function(); - if (err == ERR_SKIP) - return BAKE_ERROR_USER_ABORTED; - return BAKE_ERROR_CANT_CREATE_IMAGE; - } - step += 100; - } else { + voxel_baker.begin_bake(capture_subdiv + 1, bake_bounds); - err = baker.make_lightmap(E->get().local_xform, E->get().mesh, bake_default_texels_per_unit, lm); + for (int mesh_id = 0; mesh_id < meshes_found.size(); mesh_id++) { + MeshesFound &mf = meshes_found.write[mesh_id]; + voxel_baker.plot_mesh(mf.xform, mf.mesh, mf.overrides, Ref()); } - if (err == OK) { + VoxelLightBaker::BakeQuality capt_quality = VoxelLightBaker::BakeQuality::BAKE_QUALITY_HIGH; + if (capture_quality == BakedLightmap::BakeQuality::BAKE_QUALITY_LOW) { + capt_quality = VoxelLightBaker::BakeQuality::BAKE_QUALITY_LOW; + } else if (capture_quality == BakedLightmap::BakeQuality::BAKE_QUALITY_MEDIUM) { + capt_quality = VoxelLightBaker::BakeQuality::BAKE_QUALITY_MEDIUM; + } - Ref image; - image.instance(); + voxel_baker.begin_bake_light(capt_quality, capture_propagation); + + for (int i = 0; i < lights_found.size(); i++) { + LightsFound &lf = lights_found.write[i]; + switch (lf.light->get_light_type()) { + case VS::LIGHT_DIRECTIONAL: { + voxel_baker.plot_light_directional(-lf.xform.basis.get_axis(2), lf.light->get_color(), lf.light->get_param(Light::PARAM_ENERGY), lf.light->get_param(Light::PARAM_INDIRECT_ENERGY), lf.light->get_bake_mode() == Light::BAKE_ALL); + } break; + case VS::LIGHT_OMNI: { + voxel_baker.plot_light_omni(lf.xform.origin, lf.light->get_color(), lf.light->get_param(Light::PARAM_ENERGY), lf.light->get_param(Light::PARAM_INDIRECT_ENERGY), lf.light->get_param(Light::PARAM_RANGE), lf.light->get_param(Light::PARAM_ATTENUATION), lf.light->get_bake_mode() == Light::BAKE_ALL); + } break; + case VS::LIGHT_SPOT: { + voxel_baker.plot_light_spot(lf.xform.origin, lf.xform.basis.get_axis(2), lf.light->get_color(), lf.light->get_param(Light::PARAM_ENERGY), lf.light->get_param(Light::PARAM_INDIRECT_ENERGY), lf.light->get_param(Light::PARAM_RANGE), lf.light->get_param(Light::PARAM_ATTENUATION), lf.light->get_param(Light::PARAM_SPOT_ANGLE), lf.light->get_param(Light::PARAM_SPOT_ATTENUATION), lf.light->get_bake_mode() == Light::BAKE_ALL); + + } break; + } + } - uint32_t tex_flags = Texture::FLAGS_DEFAULT; - if (hdr) { + voxel_baker.end_bake(); - //just save a regular image - PoolVector data; - int s = lm.light.size(); - data.resize(lm.light.size() * 2); - { + AABB bounds = AABB(-extents, extents * 2); + data->set_cell_subdiv(capture_subdiv); + data->set_bounds(bounds); + data->set_octree(voxel_baker.create_capture_octree(capture_subdiv)); - PoolVector::Write w = data.write(); - PoolVector::Read r = lm.light.read(); - uint16_t *hfw = (uint16_t *)w.ptr(); - for (int i = 0; i < s; i++) { - hfw[i] = Math::make_half_float(r[i]); - } - } + { + float bake_bound_size = bake_bounds.get_longest_axis_size(); + Transform to_bounds; + to_bounds.basis.scale(Vector3(bake_bound_size, bake_bound_size, bake_bound_size)); + to_bounds.origin = bounds.position; + + Transform to_grid; + to_grid.basis.scale(Vector3(1 << (capture_subdiv - 1), 1 << (capture_subdiv - 1), 1 << (capture_subdiv - 1))); + + Transform to_cell_space = to_grid * to_bounds.affine_inverse(); + data->set_cell_space_transform(to_cell_space); + } + } + + if (bake_step_function) { + bool cancelled = bake_step_function(0.9, TTR("Saving lightmaps"), nullptr, true); + if (cancelled) { + return BAKE_ERROR_USER_ABORTED; + } + } + + Vector > images; + for (int i = 0; i < lightmapper->get_bake_texture_count(); i++) { + images.push_back(lightmapper->get_bake_texture(i)); + } + + if (gen_atlas) { - image->create(lm.width, lm.height, false, Image::FORMAT_RGBH, data); + Ref large_image; + large_image.instance(); + large_image->create(images[0]->get_width(), images[0]->get_height() * images.size(), false, images[0]->get_format()); + for (int i = 0; i < images.size(); i++) { + large_image->blit_rect(images[i], Rect2(0, 0, images[0]->get_width(), images[0]->get_height()), Point2(0, images[0]->get_height() * i)); + } + + Ref texture; + String base_path = p_data_save_path.get_basename(); + + if (ResourceLoader::import) { + base_path += ".exr"; + String relative_path = base_path; + if (relative_path.begins_with("res://")) { + relative_path = relative_path.substr(6, relative_path.length()); + } + large_image->save_exr(relative_path, false); + + Ref config; + config.instance(); + if (FileAccess::exists(base_path + ".import")) { + config->load(base_path + ".import"); } else { + // Set only if settings don't exist, to keep user choice + config->set_value("params", "compress/mode", 0); + } + config->set_value("remap", "importer", "texture_array"); + config->set_value("remap", "type", "TextureArray"); + config->set_value("params", "detect_3d", false); + config->set_value("params", "flags/repeat", false); + config->set_value("params", "flags/filter", true); + config->set_value("params", "flags/mipmaps", false); + config->set_value("params", "flags/srgb", false); + config->set_value("params", "slices/horizontal", 1); + config->set_value("params", "slices/vertical", images.size()); + config->save(base_path + ".import"); + + ResourceLoader::import(base_path); + texture = ResourceLoader::load(base_path); //if already loaded, it will be updated on refocus? + } else { - //just save a regular image - PoolVector data; - int s = lm.light.size(); - data.resize(lm.light.size()); - { - - PoolVector::Write w = data.write(); - PoolVector::Read r = lm.light.read(); - for (int i = 0; i < s; i += 3) { - Color c(r[i + 0], r[i + 1], r[i + 2]); - c = c.to_srgb(); - w[i + 0] = CLAMP(c.r * 255, 0, 255); - w[i + 1] = CLAMP(c.g * 255, 0, 255); - w[i + 2] = CLAMP(c.b * 255, 0, 255); - } + base_path += ".texarr"; + Ref tex; + bool set_path = true; + if (ResourceCache::has(base_path)) { + tex = Ref((Resource *)ResourceCache::get(base_path)); + set_path = false; + } + + if (!tex.is_valid()) { + tex.instance(); + } + + tex->create(images[0]->get_width(), images[0]->get_height(), images.size(), images[0]->get_format(), Texture::FLAGS_DEFAULT); + for (int i = 0; i < images.size(); i++) { + tex->set_layer_data(images[i], i); + } + + ResourceSaver::save(base_path, tex, ResourceSaver::FLAG_CHANGE_PATH); + if (set_path) { + tex->set_path(base_path); + } + texture = tex; + } + + for (int i = 0; i < lightmapper->get_bake_mesh_count(); i++) { + if (meshes_found[i].generate_lightmap) { + Dictionary d = lightmapper->get_bake_mesh_userdata(i); + NodePath np = d["path"]; + int32_t subindex = -1; + if (d.has("subindex")) { + subindex = d["subindex"]; } - image->create(lm.width, lm.height, false, Image::FORMAT_RGB8, data); + Rect2 uv_rect = lightmapper->get_bake_mesh_uv_scale(i); + int slice_index = lightmapper->get_bake_mesh_texture_slice(i); + data->add_user(np, texture, slice_index, uv_rect, subindex); + } + } + } else { + for (int i = 0; i < lightmapper->get_bake_mesh_count(); i++) { - //This texture is saved to SRGB for two reasons: - // 1) first is so it looks better when doing the LINEAR->SRGB conversion (more accurate) - // 2) So it can be used in the GLES2 backend, which does not support linkear workflow - tex_flags |= Texture::FLAG_CONVERT_TO_LINEAR; + if (!meshes_found[i].generate_lightmap) { + continue; } - String image_path = save_path.plus_file(mesh_name); Ref texture; + String base_path = p_data_save_path.get_base_dir().plus_file(images[i]->get_name()); if (ResourceLoader::import) { - bool srgb = false; - if (false && hdr) { - //save hdr - } else { - image_path += ".png"; - print_line("image path saving png: " + image_path); - image->save_png(image_path); - srgb = true; + base_path += ".exr"; + String relative_path = base_path; + if (relative_path.begins_with("res://")) { + relative_path = relative_path.substr(6, relative_path.length()); } + images[i]->save_exr(relative_path, false); - if (!FileAccess::exists(image_path + ".import")) { - Ref config; - config.instance(); - config->set_value("remap", "importer", "texture"); - config->set_value("remap", "type", "StreamTexture"); - config->set_value("params", "compress/mode", 2); - config->set_value("params", "detect_3d", false); - config->set_value("params", "flags/repeat", false); - config->set_value("params", "flags/filter", true); - config->set_value("params", "flags/mipmaps", false); - config->set_value("params", "flags/srgb", srgb); - - config->save(image_path + ".import"); + Ref config; + config.instance(); + if (FileAccess::exists(base_path + ".import")) { + config->load(base_path + ".import"); + } else { + // Set only if settings don't exist, to keep user choice + config->set_value("params", "compress/mode", 0); } - - ResourceLoader::import(image_path); - texture = ResourceLoader::load(image_path); //if already loaded, it will be updated on refocus? + config->set_value("remap", "importer", "texture"); + config->set_value("remap", "type", "StreamTexture"); + config->set_value("params", "detect_3d", false); + config->set_value("params", "flags/repeat", false); + config->set_value("params", "flags/filter", true); + config->set_value("params", "flags/mipmaps", false); + config->set_value("params", "flags/srgb", false); + + config->save(base_path + ".import"); + + ResourceLoader::import(base_path); + texture = ResourceLoader::load(base_path); //if already loaded, it will be updated on refocus? } else { - image_path += ".text"; + base_path += ".tex"; Ref tex; bool set_path = true; - if (ResourceCache::has(image_path)) { - tex = Ref((Resource *)ResourceCache::get(image_path)); + if (ResourceCache::has(base_path)) { + tex = Ref((Resource *)ResourceCache::get(base_path)); set_path = false; } @@ -583,67 +1003,80 @@ BakedLightmap::BakeError BakedLightmap::bake(Node *p_from_node, bool p_create_vi tex.instance(); } - tex->create_from_image(image, tex_flags); + tex->create_from_image(images[i], Texture::FLAGS_DEFAULT); - err = ResourceSaver::save(image_path, tex, ResourceSaver::FLAG_CHANGE_PATH); + ResourceSaver::save(base_path, tex, ResourceSaver::FLAG_CHANGE_PATH); if (set_path) { - tex->set_path(image_path); + tex->set_path(base_path); } texture = tex; } - if (err != OK) { - if (bake_end_function) { - bake_end_function(); - } - ERR_FAIL_COND_V(err != OK, BAKE_ERROR_CANT_CREATE_IMAGE); + + Dictionary d = lightmapper->get_bake_mesh_userdata(i); + NodePath np = d["path"]; + int32_t subindex = -1; + if (d.has("subindex")) { + subindex = d["subindex"]; } - new_light_data->add_user(E->get().path, texture, E->get().instance_idx); + Rect2 uv_rect = Rect2(0, 0, 1, 1); + int slice_index = -1; + data->add_user(np, texture, slice_index, uv_rect, subindex); } } - AABB bounds = AABB(-extents, extents * 2); - new_light_data->set_cell_subdiv(capture_subdiv); - new_light_data->set_bounds(bounds); - new_light_data->set_octree(baker.create_capture_octree(capture_subdiv)); - { - - float bake_bound_size = bake_bounds.get_longest_axis_size(); - Transform to_bounds; - to_bounds.basis.scale(Vector3(bake_bound_size, bake_bound_size, bake_bound_size)); - to_bounds.origin = bounds.position; + if (bake_step_function) { + bool cancelled = bake_step_function(1.0, TTR("Done"), nullptr, true); + if (cancelled) { + return BAKE_ERROR_USER_ABORTED; + } + } - Transform to_grid; - to_grid.basis.scale(Vector3(1 << (capture_subdiv - 1), 1 << (capture_subdiv - 1), 1 << (capture_subdiv - 1))); + Error err = ResourceSaver::save(p_data_save_path, data); + data->set_path(p_data_save_path); - Transform to_cell_space = to_grid * to_bounds.affine_inverse(); - new_light_data->set_cell_space_transform(to_cell_space); + if (err != OK) { + return BAKE_ERROR_CANT_CREATE_IMAGE; } - if (bake_end_function) { - bake_end_function(); - } + set_light_data(data); + + return BAKE_ERROR_OK; +} - //create the data for visual server +void BakedLightmap::set_capture_cell_size(float p_cell_size) { + capture_cell_size = MAX(0.1, p_cell_size); +} - if (p_create_visual_debug) { - MultiMeshInstance *mmi = memnew(MultiMeshInstance); - mmi->set_multimesh(baker.create_debug_multimesh(VoxelLightBaker::DEBUG_LIGHT)); - add_child(mmi); -#ifdef TOOLS_ENABLED - if (get_tree()->get_edited_scene_root() == this) { - mmi->set_owner(this); - } else { - mmi->set_owner(get_owner()); - } -#else - mmi->set_owner(get_owner()); -#endif - } +float BakedLightmap::get_capture_cell_size() const { + return capture_cell_size; +} + +void BakedLightmap::set_extents(const Vector3 &p_extents) { + extents = p_extents; + update_gizmo(); + _change_notify("bake_extents"); +} - set_light_data(new_light_data); +Vector3 BakedLightmap::get_extents() const { + return extents; +} - return BAKE_ERROR_OK; +void BakedLightmap::set_default_texels_per_unit(const float &p_bake_texels_per_unit) { + default_texels_per_unit = MAX(0.0, p_bake_texels_per_unit); +} + +float BakedLightmap::get_default_texels_per_unit() const { + return default_texels_per_unit; +} + +void BakedLightmap::set_capture_enabled(bool p_enable) { + capture_enabled = p_enable; + _change_notify(); +} + +bool BakedLightmap::get_capture_enabled() const { + return capture_enabled; } void BakedLightmap::_notification(int p_what) { @@ -667,23 +1100,34 @@ void BakedLightmap::_assign_lightmaps() { ERR_FAIL_COND(!light_data.is_valid()); + bool atlassed_on_gles2 = false; + for (int i = 0; i < light_data->get_user_count(); i++) { - Ref lightmap = light_data->get_user_lightmap(i); + Ref lightmap = light_data->get_user_lightmap(i); ERR_CONTINUE(!lightmap.is_valid()); + ERR_CONTINUE(!Object::cast_to(lightmap.ptr()) && !Object::cast_to(lightmap.ptr())); Node *node = get_node(light_data->get_user_path(i)); int instance_idx = light_data->get_user_instance(i); if (instance_idx >= 0) { RID instance = node->call("get_bake_mesh_instance", instance_idx); if (instance.is_valid()) { - VS::get_singleton()->instance_set_use_lightmap(instance, get_instance(), lightmap->get_rid()); + int slice = light_data->get_user_lightmap_slice(i); + atlassed_on_gles2 = atlassed_on_gles2 || (slice != -1 && OS::get_singleton()->get_current_video_driver() == OS::VIDEO_DRIVER_GLES2); + VS::get_singleton()->instance_set_use_lightmap(instance, get_instance(), lightmap->get_rid(), slice, light_data->get_user_lightmap_uv_rect(i)); } } else { VisualInstance *vi = Object::cast_to(node); ERR_CONTINUE(!vi); - VS::get_singleton()->instance_set_use_lightmap(vi->get_instance(), get_instance(), lightmap->get_rid()); + int slice = light_data->get_user_lightmap_slice(i); + atlassed_on_gles2 = atlassed_on_gles2 || (slice != -1 && OS::get_singleton()->get_current_video_driver() == OS::VIDEO_DRIVER_GLES2); + VS::get_singleton()->instance_set_use_lightmap(vi->get_instance(), get_instance(), lightmap->get_rid(), slice, light_data->get_user_lightmap_uv_rect(i)); } } + + if (atlassed_on_gles2) { + ERR_PRINT("GLES2 doesn't support layered textures, so lightmap atlassing is not supported. Please re-bake the lightmap or switch to GLES3."); + } } void BakedLightmap::_clear_lightmaps() { @@ -694,14 +1138,63 @@ void BakedLightmap::_clear_lightmaps() { if (instance_idx >= 0) { RID instance = node->call("get_bake_mesh_instance", instance_idx); if (instance.is_valid()) { - VS::get_singleton()->instance_set_use_lightmap(instance, get_instance(), RID()); + VS::get_singleton()->instance_set_use_lightmap(instance, get_instance(), RID(), -1, Rect2(0, 0, 1, 1)); } } else { VisualInstance *vi = Object::cast_to(node); ERR_CONTINUE(!vi); - VS::get_singleton()->instance_set_use_lightmap(vi->get_instance(), get_instance(), RID()); + VS::get_singleton()->instance_set_use_lightmap(vi->get_instance(), get_instance(), RID(), -1, Rect2(0, 0, 1, 1)); + } + } +} + +Ref BakedLightmap::_get_irradiance_from_sky(Ref p_sky, Vector2i p_size) { + if (p_sky.is_null()) { + return Ref(); + } + + Ref sky_image; + Ref panorama = p_sky; + if (panorama.is_valid()) { + sky_image = panorama->get_panorama()->get_data(); + } + Ref procedural = p_sky; + if (procedural.is_valid()) { + sky_image = procedural->get_panorama(); + } + + if (sky_image.is_null()) { + return Ref(); + } + + sky_image->convert(Image::FORMAT_RGBF); + sky_image->resize(p_size.x, p_size.y, Image::INTERPOLATE_CUBIC); + return sky_image; +} + +Ref BakedLightmap::_get_irradiance_map(Ref p_env, Vector2i p_size) { + Environment::BGMode bg_mode = p_env->get_background(); + switch (bg_mode) { + case Environment::BG_SKY: { + return _get_irradiance_from_sky(p_env->get_sky(), Vector2i(128, 64)); + } + case Environment::BG_CLEAR_COLOR: + case Environment::BG_COLOR: { + Color c = bg_mode == Environment::BG_CLEAR_COLOR ? Color(GLOBAL_GET("rendering/environment/default_clear_color")) : p_env->get_bg_color(); + c.r *= p_env->get_bg_energy(); + c.g *= p_env->get_bg_energy(); + c.b *= p_env->get_bg_energy(); + + Ref ret; + ret.instance(); + ret->create(p_size.x, p_size.y, false, Image::FORMAT_RGBF); + ret->fill(c); + return ret; + } + default: { } } + return Ref(); } void BakedLightmap::set_light_data(const Ref &p_data) { @@ -713,6 +1206,7 @@ void BakedLightmap::set_light_data(const Ref &p_data) { set_base(RID()); } light_data = p_data; + _change_notify(); if (light_data.is_valid()) { set_base(light_data->get_rid()); @@ -727,44 +1221,50 @@ Ref BakedLightmap::get_light_data() const { return light_data; } -void BakedLightmap::_debug_bake() { - bake(get_parent(), true); +void BakedLightmap::set_capture_propagation(float p_propagation) { + capture_propagation = p_propagation; } -void BakedLightmap::set_propagation(float p_propagation) { - propagation = p_propagation; +float BakedLightmap::get_capture_propagation() const { + + return capture_propagation; } -float BakedLightmap::get_propagation() const { +void BakedLightmap::set_capture_quality(BakeQuality p_quality) { + capture_quality = p_quality; +} - return propagation; +BakedLightmap::BakeQuality BakedLightmap::get_capture_quality() const { + return capture_quality; } -void BakedLightmap::set_energy(float p_energy) { - energy = p_energy; +void BakedLightmap::set_generate_atlas(bool p_enabled) { + generate_atlas = p_enabled; } -float BakedLightmap::get_energy() const { +bool BakedLightmap::is_generate_atlas_enabled() const { + return generate_atlas; +} - return energy; +void BakedLightmap::set_max_atlas_size(int p_size) { + ERR_FAIL_COND(p_size < 2048); + max_atlas_size = p_size; +} + +int BakedLightmap::get_max_atlas_size() const { + return max_atlas_size; } void BakedLightmap::set_bake_quality(BakeQuality p_quality) { bake_quality = p_quality; + _change_notify(); } BakedLightmap::BakeQuality BakedLightmap::get_bake_quality() const { return bake_quality; } -void BakedLightmap::set_bake_mode(BakeMode p_mode) { - bake_mode = p_mode; -} - -BakedLightmap::BakeMode BakedLightmap::get_bake_mode() const { - return bake_mode; -} - +#ifndef DISABLE_DEPRECATED void BakedLightmap::set_image_path(const String &p_path) { image_path = p_path; } @@ -772,6 +1272,74 @@ void BakedLightmap::set_image_path(const String &p_path) { String BakedLightmap::get_image_path() const { return image_path; } +#endif + +void BakedLightmap::set_use_denoiser(bool p_enable) { + + use_denoiser = p_enable; +} + +bool BakedLightmap::is_using_denoiser() const { + + return use_denoiser; +} + +void BakedLightmap::set_environment_mode(EnvironmentMode p_mode) { + environment_mode = p_mode; + _change_notify(); +} + +BakedLightmap::EnvironmentMode BakedLightmap::get_environment_mode() const { + return environment_mode; +} + +void BakedLightmap::set_environment_custom_sky(const Ref &p_sky) { + environment_custom_sky = p_sky; +} + +Ref BakedLightmap::get_environment_custom_sky() const { + return environment_custom_sky; +} + +void BakedLightmap::set_environment_custom_sky_rotation_degrees(const Vector3 &p_rotation) { + environment_custom_sky_rotation_degrees = p_rotation; +} + +Vector3 BakedLightmap::get_environment_custom_sky_rotation_degrees() const { + return environment_custom_sky_rotation_degrees; +} + +void BakedLightmap::set_environment_custom_color(const Color &p_color) { + environment_custom_color = p_color; +} +Color BakedLightmap::get_environment_custom_color() const { + return environment_custom_color; +} + +void BakedLightmap::set_environment_custom_energy(float p_energy) { + environment_custom_energy = p_energy; +} +float BakedLightmap::get_environment_custom_energy() const { + return environment_custom_energy; +} + +void BakedLightmap::set_bounces(int p_bounces) { + ERR_FAIL_COND(p_bounces < 0 || p_bounces > 16); + bounces = p_bounces; +} + +int BakedLightmap::get_bounces() const { + return bounces; +} + +void BakedLightmap::set_bias(float p_bias) { + ERR_FAIL_COND(p_bias < 0.00001); + bias = p_bias; +} + +float BakedLightmap::get_bias() const { + return bias; +} AABB BakedLightmap::get_aabb() const { return AABB(-extents, extents * 2); @@ -780,85 +1348,163 @@ PoolVector BakedLightmap::get_faces(uint32_t p_usage_flags) const { return PoolVector(); } +void BakedLightmap::_validate_property(PropertyInfo &property) const { + + if (property.name.begins_with("environment_custom_sky") && environment_mode != ENVIRONMENT_MODE_CUSTOM_SKY) { + property.usage = 0; + } + + if (property.name == "environment_custom_color" && environment_mode != ENVIRONMENT_MODE_CUSTOM_COLOR) { + property.usage = 0; + } + + if (property.name == "environment_custom_energy" && environment_mode != ENVIRONMENT_MODE_CUSTOM_COLOR && environment_mode != ENVIRONMENT_MODE_CUSTOM_SKY) { + property.usage = 0; + } + + if (property.name.begins_with("atlas") && OS::get_singleton()->get_current_video_driver() == OS::VIDEO_DRIVER_GLES2) { + property.usage = PROPERTY_USAGE_NOEDITOR; + } + + if (property.name.begins_with("capture") && property.name != "capture_enabled" && !capture_enabled) { + property.usage = 0; + } +} + void BakedLightmap::_bind_methods() { ClassDB::bind_method(D_METHOD("set_light_data", "data"), &BakedLightmap::set_light_data); ClassDB::bind_method(D_METHOD("get_light_data"), &BakedLightmap::get_light_data); - ClassDB::bind_method(D_METHOD("set_bake_cell_size", "bake_cell_size"), &BakedLightmap::set_bake_cell_size); - ClassDB::bind_method(D_METHOD("get_bake_cell_size"), &BakedLightmap::get_bake_cell_size); + ClassDB::bind_method(D_METHOD("set_bake_quality", "quality"), &BakedLightmap::set_bake_quality); + ClassDB::bind_method(D_METHOD("get_bake_quality"), &BakedLightmap::get_bake_quality); - ClassDB::bind_method(D_METHOD("set_capture_cell_size", "capture_cell_size"), &BakedLightmap::set_capture_cell_size); - ClassDB::bind_method(D_METHOD("get_capture_cell_size"), &BakedLightmap::get_capture_cell_size); + ClassDB::bind_method(D_METHOD("set_bounces", "bounces"), &BakedLightmap::set_bounces); + ClassDB::bind_method(D_METHOD("get_bounces"), &BakedLightmap::get_bounces); - ClassDB::bind_method(D_METHOD("set_bake_quality", "bake_quality"), &BakedLightmap::set_bake_quality); - ClassDB::bind_method(D_METHOD("get_bake_quality"), &BakedLightmap::get_bake_quality); + ClassDB::bind_method(D_METHOD("set_bias", "bias"), &BakedLightmap::set_bias); + ClassDB::bind_method(D_METHOD("get_bias"), &BakedLightmap::get_bias); + + ClassDB::bind_method(D_METHOD("set_environment_mode", "mode"), &BakedLightmap::set_environment_mode); + ClassDB::bind_method(D_METHOD("get_environment_mode"), &BakedLightmap::get_environment_mode); + + ClassDB::bind_method(D_METHOD("set_environment_custom_sky", "sky"), &BakedLightmap::set_environment_custom_sky); + ClassDB::bind_method(D_METHOD("get_environment_custom_sky"), &BakedLightmap::get_environment_custom_sky); + + ClassDB::bind_method(D_METHOD("set_environment_custom_sky_rotation_degrees", "rotation"), &BakedLightmap::set_environment_custom_sky_rotation_degrees); + ClassDB::bind_method(D_METHOD("get_environment_custom_sky_rotation_degrees"), &BakedLightmap::get_environment_custom_sky_rotation_degrees); + + ClassDB::bind_method(D_METHOD("set_environment_custom_color", "color"), &BakedLightmap::set_environment_custom_color); + ClassDB::bind_method(D_METHOD("get_environment_custom_color"), &BakedLightmap::get_environment_custom_color); - ClassDB::bind_method(D_METHOD("set_bake_mode", "bake_mode"), &BakedLightmap::set_bake_mode); - ClassDB::bind_method(D_METHOD("get_bake_mode"), &BakedLightmap::get_bake_mode); + ClassDB::bind_method(D_METHOD("set_environment_custom_energy", "energy"), &BakedLightmap::set_environment_custom_energy); + ClassDB::bind_method(D_METHOD("get_environment_custom_energy"), &BakedLightmap::get_environment_custom_energy); + + ClassDB::bind_method(D_METHOD("set_use_denoiser", "use_denoiser"), &BakedLightmap::set_use_denoiser); + ClassDB::bind_method(D_METHOD("is_using_denoiser"), &BakedLightmap::is_using_denoiser); + + ClassDB::bind_method(D_METHOD("set_generate_atlas", "enabled"), &BakedLightmap::set_generate_atlas); + ClassDB::bind_method(D_METHOD("is_generate_atlas_enabled"), &BakedLightmap::is_generate_atlas_enabled); + + ClassDB::bind_method(D_METHOD("set_max_atlas_size", "max_atlas_size"), &BakedLightmap::set_max_atlas_size); + ClassDB::bind_method(D_METHOD("get_max_atlas_size"), &BakedLightmap::get_max_atlas_size); + + ClassDB::bind_method(D_METHOD("set_capture_quality", "capture_quality"), &BakedLightmap::set_capture_quality); + ClassDB::bind_method(D_METHOD("get_capture_quality"), &BakedLightmap::get_capture_quality); ClassDB::bind_method(D_METHOD("set_extents", "extents"), &BakedLightmap::set_extents); ClassDB::bind_method(D_METHOD("get_extents"), &BakedLightmap::get_extents); - ClassDB::bind_method(D_METHOD("set_bake_default_texels_per_unit", "texels"), &BakedLightmap::set_bake_default_texels_per_unit); - ClassDB::bind_method(D_METHOD("get_bake_default_texels_per_unit"), &BakedLightmap::get_bake_default_texels_per_unit); - - ClassDB::bind_method(D_METHOD("set_propagation", "propagation"), &BakedLightmap::set_propagation); - ClassDB::bind_method(D_METHOD("get_propagation"), &BakedLightmap::get_propagation); + ClassDB::bind_method(D_METHOD("set_default_texels_per_unit", "texels"), &BakedLightmap::set_default_texels_per_unit); + ClassDB::bind_method(D_METHOD("get_default_texels_per_unit"), &BakedLightmap::get_default_texels_per_unit); - ClassDB::bind_method(D_METHOD("set_energy", "energy"), &BakedLightmap::set_energy); - ClassDB::bind_method(D_METHOD("get_energy"), &BakedLightmap::get_energy); + ClassDB::bind_method(D_METHOD("set_capture_propagation", "propagation"), &BakedLightmap::set_capture_propagation); + ClassDB::bind_method(D_METHOD("get_capture_propagation"), &BakedLightmap::get_capture_propagation); - ClassDB::bind_method(D_METHOD("set_hdr", "hdr"), &BakedLightmap::set_hdr); - ClassDB::bind_method(D_METHOD("is_hdr"), &BakedLightmap::is_hdr); + ClassDB::bind_method(D_METHOD("set_capture_enabled", "enabled"), &BakedLightmap::set_capture_enabled); + ClassDB::bind_method(D_METHOD("get_capture_enabled"), &BakedLightmap::get_capture_enabled); + ClassDB::bind_method(D_METHOD("set_capture_cell_size", "capture_cell_size"), &BakedLightmap::set_capture_cell_size); + ClassDB::bind_method(D_METHOD("get_capture_cell_size"), &BakedLightmap::get_capture_cell_size); +#ifndef DISABLE_DEPRECATED ClassDB::bind_method(D_METHOD("set_image_path", "image_path"), &BakedLightmap::set_image_path); ClassDB::bind_method(D_METHOD("get_image_path"), &BakedLightmap::get_image_path); +#endif + ClassDB::bind_method(D_METHOD("bake", "from_node", "data_save_path"), &BakedLightmap::bake, DEFVAL(Variant()), DEFVAL("")); + + ADD_PROPERTY(PropertyInfo(Variant::VECTOR3, "extents"), "set_extents", "get_extents"); + + ADD_GROUP("Tweaks", ""); + ADD_PROPERTY(PropertyInfo(Variant::INT, "quality", PROPERTY_HINT_ENUM, "Low,Medium,High,Ultra"), "set_bake_quality", "get_bake_quality"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "bounces", PROPERTY_HINT_RANGE, "0,16,1"), "set_bounces", "get_bounces"); + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "use_denoiser"), "set_use_denoiser", "is_using_denoiser"); + ADD_PROPERTY(PropertyInfo(Variant::REAL, "bias", PROPERTY_HINT_RANGE, "0.00001,0.1,0.00001,or_greater"), "set_bias", "get_bias"); + ADD_PROPERTY(PropertyInfo(Variant::REAL, "default_texels_per_unit", PROPERTY_HINT_RANGE, "0.0,64.0,0.01,or_greater"), "set_default_texels_per_unit", "get_default_texels_per_unit"); + + ADD_GROUP("Atlas", "atlas_"); + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "atlas_generate"), "set_generate_atlas", "is_generate_atlas_enabled"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "atlas_max_size"), "set_max_atlas_size", "get_max_atlas_size"); + + ADD_GROUP("Environment", "environment_"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "environment_mode", PROPERTY_HINT_ENUM, "Disabled,Scene,Custom Sky,Custom Color"), "set_environment_mode", "get_environment_mode"); + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "environment_custom_sky", PROPERTY_HINT_RESOURCE_TYPE, "Sky"), "set_environment_custom_sky", "get_environment_custom_sky"); + ADD_PROPERTY(PropertyInfo(Variant::VECTOR3, "environment_custom_sky_rotation_degrees", PROPERTY_HINT_NONE), "set_environment_custom_sky_rotation_degrees", "get_environment_custom_sky_rotation_degrees"); + ADD_PROPERTY(PropertyInfo(Variant::COLOR, "environment_custom_color", PROPERTY_HINT_COLOR_NO_ALPHA), "set_environment_custom_color", "get_environment_custom_color"); + ADD_PROPERTY(PropertyInfo(Variant::REAL, "environment_custom_energy", PROPERTY_HINT_RANGE, "0,64,0.01"), "set_environment_custom_energy", "get_environment_custom_energy"); - ClassDB::bind_method(D_METHOD("bake", "from_node", "create_visual_debug"), &BakedLightmap::bake, DEFVAL(Variant()), DEFVAL(false)); - ClassDB::bind_method(D_METHOD("debug_bake"), &BakedLightmap::_debug_bake); - ClassDB::set_method_flags(get_class_static(), _scs_create("debug_bake"), METHOD_FLAGS_DEFAULT | METHOD_FLAG_EDITOR); - - ADD_GROUP("Bake", "bake_"); - ADD_PROPERTY(PropertyInfo(Variant::REAL, "bake_cell_size", PROPERTY_HINT_RANGE, "0.01,64,0.01"), "set_bake_cell_size", "get_bake_cell_size"); - ADD_PROPERTY(PropertyInfo(Variant::INT, "bake_quality", PROPERTY_HINT_ENUM, "Low,Medium,High"), "set_bake_quality", "get_bake_quality"); - ADD_PROPERTY(PropertyInfo(Variant::INT, "bake_mode", PROPERTY_HINT_ENUM, "ConeTrace,RayTrace"), "set_bake_mode", "get_bake_mode"); - ADD_PROPERTY(PropertyInfo(Variant::REAL, "bake_propagation", PROPERTY_HINT_RANGE, "0,1,0.01"), "set_propagation", "get_propagation"); - ADD_PROPERTY(PropertyInfo(Variant::REAL, "bake_energy", PROPERTY_HINT_RANGE, "0,32,0.01"), "set_energy", "get_energy"); - ADD_PROPERTY(PropertyInfo(Variant::BOOL, "bake_hdr"), "set_hdr", "is_hdr"); - ADD_PROPERTY(PropertyInfo(Variant::VECTOR3, "bake_extents"), "set_extents", "get_extents"); - ADD_PROPERTY(PropertyInfo(Variant::REAL, "bake_default_texels_per_unit"), "set_bake_default_texels_per_unit", "get_bake_default_texels_per_unit"); ADD_GROUP("Capture", "capture_"); - ADD_PROPERTY(PropertyInfo(Variant::REAL, "capture_cell_size", PROPERTY_HINT_RANGE, "0.01,64,0.01"), "set_capture_cell_size", "get_capture_cell_size"); + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "capture_enabled"), "set_capture_enabled", "get_capture_enabled"); + ADD_PROPERTY(PropertyInfo(Variant::REAL, "capture_cell_size", PROPERTY_HINT_RANGE, "0.25,2.0,0.05,or_greater"), "set_capture_cell_size", "get_capture_cell_size"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "capture_quality", PROPERTY_HINT_ENUM, "Low,Medium,High"), "set_capture_quality", "get_capture_quality"); + ADD_PROPERTY(PropertyInfo(Variant::REAL, "capture_propagation", PROPERTY_HINT_RANGE, "0,1,0.01"), "set_capture_propagation", "get_capture_propagation"); + ADD_GROUP("Data", ""); - ADD_PROPERTY(PropertyInfo(Variant::STRING, "image_path", PROPERTY_HINT_DIR), "set_image_path", "get_image_path"); +#ifndef DISABLE_DEPRECATED + ADD_PROPERTY(PropertyInfo(Variant::STRING, "image_path", PROPERTY_HINT_DIR, "", 0), "set_image_path", "get_image_path"); +#endif ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "light_data", PROPERTY_HINT_RESOURCE_TYPE, "BakedLightmapData"), "set_light_data", "get_light_data"); BIND_ENUM_CONSTANT(BAKE_QUALITY_LOW); BIND_ENUM_CONSTANT(BAKE_QUALITY_MEDIUM); BIND_ENUM_CONSTANT(BAKE_QUALITY_HIGH); - BIND_ENUM_CONSTANT(BAKE_MODE_CONE_TRACE); - BIND_ENUM_CONSTANT(BAKE_MODE_RAY_TRACE); + BIND_ENUM_CONSTANT(BAKE_QUALITY_ULTRA); BIND_ENUM_CONSTANT(BAKE_ERROR_OK); BIND_ENUM_CONSTANT(BAKE_ERROR_NO_SAVE_PATH); BIND_ENUM_CONSTANT(BAKE_ERROR_NO_MESHES); BIND_ENUM_CONSTANT(BAKE_ERROR_CANT_CREATE_IMAGE); + BIND_ENUM_CONSTANT(BAKE_ERROR_LIGHTMAP_SIZE); + BIND_ENUM_CONSTANT(BAKE_ERROR_INVALID_MESH); BIND_ENUM_CONSTANT(BAKE_ERROR_USER_ABORTED); + BIND_ENUM_CONSTANT(BAKE_ERROR_NO_LIGHTMAPPER); + + BIND_ENUM_CONSTANT(ENVIRONMENT_MODE_DISABLED); + BIND_ENUM_CONSTANT(ENVIRONMENT_MODE_SCENE); + BIND_ENUM_CONSTANT(ENVIRONMENT_MODE_CUSTOM_SKY); + BIND_ENUM_CONSTANT(ENVIRONMENT_MODE_CUSTOM_COLOR); } BakedLightmap::BakedLightmap() { extents = Vector3(10, 10, 10); - bake_default_texels_per_unit = 20; - bake_cell_size = 0.25; - capture_cell_size = 0.5; + default_texels_per_unit = 16.0f; bake_quality = BAKE_QUALITY_MEDIUM; - bake_mode = BAKE_MODE_CONE_TRACE; - energy = 1; - propagation = 1; - hdr = false; - image_path = "."; + capture_quality = BAKE_QUALITY_MEDIUM; + capture_propagation = 1; + capture_enabled = true; + bounces = 3; + image_path = ""; set_disable_scale(true); + capture_cell_size = 0.5; + + environment_mode = ENVIRONMENT_MODE_DISABLED; + environment_custom_color = Color(0.2, 0.7, 1.0); + environment_custom_energy = 1.0; + + use_denoiser = true; + bias = 0.005; + + generate_atlas = true; + max_atlas_size = 4096; } diff --git a/scene/3d/baked_lightmap.h b/scene/3d/baked_lightmap.h index 28f6196dee75..33f9a62a57ca 100644 --- a/scene/3d/baked_lightmap.h +++ b/scene/3d/baked_lightmap.h @@ -31,12 +31,15 @@ #ifndef BAKED_INDIRECT_LIGHT_H #define BAKED_INDIRECT_LIGHT_H +#include "core/local_vector.h" #include "multimesh_instance.h" #include "scene/3d/light.h" +#include "scene/3d/lightmapper.h" #include "scene/3d/visual_instance.h" class BakedLightmapData : public Resource { GDCLASS(BakedLightmapData, Resource); + RES_BASE_EXTENSION("lmbake") RID baked_light; AABB bounds; @@ -47,7 +50,12 @@ class BakedLightmapData : public Resource { struct User { NodePath path; - Ref lightmap; + struct { + Ref single; + Ref layered; + } lightmap; + int lightmap_slice; + Rect2 lightmap_uv_rect; int instance_index; }; @@ -75,10 +83,12 @@ class BakedLightmapData : public Resource { void set_energy(float p_energy); float get_energy() const; - void add_user(const NodePath &p_path, const Ref &p_lightmap, int p_instance = -1); + void add_user(const NodePath &p_path, const Ref &p_lightmap, int p_lightmap_slice, const Rect2 &p_lightmap_uv_rect, int p_instance); int get_user_count() const; NodePath get_user_path(int p_user) const; - Ref get_user_lightmap(int p_user) const; + Ref get_user_lightmap(int p_user) const; + int get_user_lightmap_slice(int p_user) const; + Rect2 get_user_lightmap_uv_rect(int p_user) const; int get_user_instance(int p_user) const; void clear_users(); @@ -94,12 +104,8 @@ class BakedLightmap : public VisualInstance { enum BakeQuality { BAKE_QUALITY_LOW, BAKE_QUALITY_MEDIUM, - BAKE_QUALITY_HIGH - }; - - enum BakeMode { - BAKE_MODE_CONE_TRACE, - BAKE_MODE_RAY_TRACE, + BAKE_QUALITY_HIGH, + BAKE_QUALITY_ULTRA }; enum BakeError { @@ -107,108 +113,154 @@ class BakedLightmap : public VisualInstance { BAKE_ERROR_NO_SAVE_PATH, BAKE_ERROR_NO_MESHES, BAKE_ERROR_CANT_CREATE_IMAGE, - BAKE_ERROR_USER_ABORTED + BAKE_ERROR_LIGHTMAP_SIZE, + BAKE_ERROR_INVALID_MESH, + BAKE_ERROR_USER_ABORTED, + BAKE_ERROR_NO_LIGHTMAPPER }; - typedef void (*BakeBeginFunc)(int); - typedef bool (*BakeStepFunc)(int, const String &); - typedef void (*BakeEndFunc)(); + enum EnvironmentMode { + ENVIRONMENT_MODE_DISABLED, + ENVIRONMENT_MODE_SCENE, + ENVIRONMENT_MODE_CUSTOM_SKY, + ENVIRONMENT_MODE_CUSTOM_COLOR + }; -private: - float bake_cell_size; - float capture_cell_size; - Vector3 extents; - float bake_default_texels_per_unit; - float propagation; - float energy; - BakeQuality bake_quality; - BakeMode bake_mode; - bool hdr; - String image_path; + struct BakeStepUD { + Lightmapper::BakeStepFunc func; + void *ud; + float from_percent; + float to_percent; + }; - Ref light_data; + struct LightsFound { + Transform xform; + Light *light; + }; - struct PlotMesh { - Ref override_material; - Vector > instance_materials; + struct MeshesFound { + Transform xform; + NodePath node_path; + int32_t subindex; Ref mesh; - Transform local_xform; - NodePath path; - int instance_idx; + int32_t lightmap_scale; + Vector > overrides; + bool cast_shadows; + bool generate_lightmap; }; - struct PlotLight { - Light *light; - Transform local_xform; - }; +private: + Vector3 extents; + float default_texels_per_unit; + float bias; + BakeQuality bake_quality; + bool generate_atlas; + int max_atlas_size; + bool capture_enabled; + int bounces; + bool use_denoiser; + + EnvironmentMode environment_mode; + Ref environment_custom_sky; + Vector3 environment_custom_sky_rotation_degrees; + Color environment_custom_color; + float environment_custom_energy; + + BakeQuality capture_quality; + float capture_propagation; + float capture_cell_size; - void _find_meshes_and_lights(Node *p_at_node, List &plot_meshes, List &plot_lights); + String image_path; // (Deprecated property) - void _debug_bake(); + Ref light_data; void _assign_lightmaps(); void _clear_lightmaps(); - static bool _bake_time(void *ud, float p_secs, float p_progress); + void _get_material_images(const MeshesFound &p_found_mesh, Lightmapper::MeshData &r_mesh_data, Vector > &r_albedo_textures, Vector > &r_emission_textures); + Ref _get_irradiance_from_sky(Ref p_sky, Vector2i p_size); + Ref _get_irradiance_map(Ref p_env, Vector2i p_size); + void _find_meshes_and_lights(Node *p_at_node, Vector &meshes, Vector &lights); + Vector2i _compute_lightmap_size(const MeshesFound &p_mesh); - struct BakeTimeData { - String text; - int pass; - uint64_t last_step; - }; + static bool _lightmap_bake_step_function(float p_completion, const String &p_text, void *ud, bool p_refresh); protected: static void _bind_methods(); + void _validate_property(PropertyInfo &property) const; void _notification(int p_what); public: - static BakeBeginFunc bake_begin_function; - static BakeStepFunc bake_step_function; - static BakeEndFunc bake_end_function; + static Lightmapper::BakeStepFunc bake_step_function; + static Lightmapper::BakeStepFunc bake_substep_function; void set_light_data(const Ref &p_data); Ref get_light_data() const; - void set_bake_cell_size(float p_cell_size); - float get_bake_cell_size() const; - void set_capture_cell_size(float p_cell_size); float get_capture_cell_size() const; void set_extents(const Vector3 &p_extents); Vector3 get_extents() const; - void set_bake_default_texels_per_unit(const float &p_bake_texels_per_unit); - float get_bake_default_texels_per_unit() const; + void set_default_texels_per_unit(const float &p_extents); + float get_default_texels_per_unit() const; - void set_propagation(float p_propagation); - float get_propagation() const; + void set_capture_propagation(float p_propagation); + float get_capture_propagation() const; - void set_energy(float p_energy); - float get_energy() const; + void set_capture_quality(BakeQuality p_quality); + BakeQuality get_capture_quality() const; void set_bake_quality(BakeQuality p_quality); BakeQuality get_bake_quality() const; - void set_bake_mode(BakeMode p_mode); - BakeMode get_bake_mode() const; + void set_generate_atlas(bool p_enabled); + bool is_generate_atlas_enabled() const; + + void set_max_atlas_size(int p_size); + int get_max_atlas_size() const; - void set_hdr(bool p_enable); - bool is_hdr() const; + void set_capture_enabled(bool p_enable); + bool get_capture_enabled() const; void set_image_path(const String &p_path); String get_image_path() const; + void set_environment_mode(EnvironmentMode p_mode); + EnvironmentMode get_environment_mode() const; + + void set_environment_custom_sky(const Ref &p_sky); + Ref get_environment_custom_sky() const; + + void set_environment_custom_sky_rotation_degrees(const Vector3 &p_rotation); + Vector3 get_environment_custom_sky_rotation_degrees() const; + + void set_environment_custom_color(const Color &p_color); + Color get_environment_custom_color() const; + + void set_environment_custom_energy(float p_energy); + float get_environment_custom_energy() const; + + void set_use_denoiser(bool p_enable); + bool is_using_denoiser() const; + + void set_bounces(int p_bounces); + int get_bounces() const; + + void set_bias(float p_bias); + float get_bias() const; + AABB get_aabb() const; PoolVector get_faces(uint32_t p_usage_flags) const; - BakeError bake(Node *p_from_node, bool p_create_visual_debug = false); + BakeError bake(Node *p_from_node, String p_data_save_path = ""); BakedLightmap(); }; VARIANT_ENUM_CAST(BakedLightmap::BakeQuality); -VARIANT_ENUM_CAST(BakedLightmap::BakeMode); VARIANT_ENUM_CAST(BakedLightmap::BakeError); +VARIANT_ENUM_CAST(BakedLightmap::EnvironmentMode); #endif // BAKED_INDIRECT_LIGHT_H diff --git a/scene/3d/lightmapper.cpp b/scene/3d/lightmapper.cpp new file mode 100644 index 000000000000..9e5078ba95e0 --- /dev/null +++ b/scene/3d/lightmapper.cpp @@ -0,0 +1,76 @@ +/*************************************************************************/ +/* lightmapper.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "lightmapper.h" + +LightmapDenoiser *(*LightmapDenoiser::create_function)() = nullptr; + +Ref LightmapDenoiser::create() { + if (create_function) { + return Ref(create_function()); + } + return Ref(); +} + +LightmapRaycaster *(*LightmapRaycaster::create_function)() = nullptr; + +Ref LightmapRaycaster::create() { + if (create_function) { + return Ref(create_function()); + } + return Ref(); +} + +Lightmapper::CreateFunc Lightmapper::create_custom = nullptr; +Lightmapper::CreateFunc Lightmapper::create_gpu = nullptr; +Lightmapper::CreateFunc Lightmapper::create_cpu = nullptr; + +Ref Lightmapper::create() { + Lightmapper *lm = nullptr; + if (create_custom) { + lm = create_custom(); + } + + if (!lm && create_gpu) { + lm = create_gpu(); + } + + if (!lm && create_cpu) { + lm = create_cpu(); + } + if (!lm) { + return Ref(); + } else { + return Ref(lm); + } +} + +Lightmapper::Lightmapper() { +} diff --git a/scene/3d/lightmapper.h b/scene/3d/lightmapper.h new file mode 100644 index 000000000000..720f95a86460 --- /dev/null +++ b/scene/3d/lightmapper.h @@ -0,0 +1,196 @@ +/*************************************************************************/ +/* lightmapper.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2021 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2021 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef LIGHTMAPPER_H +#define LIGHTMAPPER_H + +#include "scene/resources/mesh.h" + +#if !defined(__aligned) + +#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__)) && !defined(__CYGWIN__) +#define __aligned(...) __declspec(align(__VA_ARGS__)) +#else +#define __aligned(...) __attribute__((aligned(__VA_ARGS__))) +#endif + +#endif + +class LightmapDenoiser : public Reference { + GDCLASS(LightmapDenoiser, Reference) +protected: + static LightmapDenoiser *(*create_function)(); + +public: + virtual Ref denoise_image(const Ref &p_image) = 0; + static Ref create(); +}; + +class LightmapRaycaster : public Reference { + GDCLASS(LightmapRaycaster, Reference) +protected: + static LightmapRaycaster *(*create_function)(); + +public: + // compatible with embree3 rays + struct __aligned(16) Ray { + const static unsigned int INVALID_GEOMETRY_ID = ((unsigned int)-1); // from rtcore_common.h + + /*! Default construction does nothing. */ + _FORCE_INLINE_ Ray() : + geomID(INVALID_GEOMETRY_ID) {} + + /*! Constructs a ray from origin, direction, and ray segment. Near + * has to be smaller than far. */ + _FORCE_INLINE_ Ray(const Vector3 &org, + const Vector3 &dir, + float tnear = 0.0f, + float tfar = INFINITY) : + org(org), + tnear(tnear), + dir(dir), + time(0.0f), + tfar(tfar), + mask(-1), + u(0.0), + v(0.0), + primID(INVALID_GEOMETRY_ID), + geomID(INVALID_GEOMETRY_ID), + instID(INVALID_GEOMETRY_ID) {} + + /*! Tests if we hit something. */ + _FORCE_INLINE_ explicit operator bool() const { return geomID != INVALID_GEOMETRY_ID; } + + public: + Vector3 org; //!< Ray origin + tnear + float tnear; //!< Start of ray segment + Vector3 dir; //!< Ray direction + tfar + float time; //!< Time of this ray for motion blur. + float tfar; //!< End of ray segment + unsigned int mask; //!< used to mask out objects during traversal + unsigned int id; //!< ray ID + unsigned int flags; //!< ray flags + + Vector3 normal; //!< Not normalized geometry normal + float u; //!< Barycentric u coordinate of hit + float v; //!< Barycentric v coordinate of hit + unsigned int primID; //!< primitive ID + unsigned int geomID; //!< geometry ID + unsigned int instID; //!< instance ID + }; + + virtual bool intersect(Ray &p_ray) = 0; + + virtual void intersect(Vector &r_rays) = 0; + + virtual void add_mesh(const Vector &p_vertices, const Vector &p_normals, const Vector &p_uv2s, unsigned int p_id) = 0; + virtual void set_mesh_alpha_texture(Ref p_alpha_texture, unsigned int p_id) = 0; + virtual void commit() = 0; + + virtual void set_mesh_filter(const Set &p_mesh_ids) = 0; + virtual void clear_mesh_filter() = 0; + + static Ref create(); +}; + +class Lightmapper : public Reference { + GDCLASS(Lightmapper, Reference) +public: + enum LightType { + LIGHT_TYPE_DIRECTIONAL, + LIGHT_TYPE_OMNI, + LIGHT_TYPE_SPOT + }; + + enum BakeError { + BAKE_ERROR_LIGHTMAP_TOO_SMALL, + BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES, + BAKE_ERROR_NO_MESHES, + BAKE_ERROR_USER_ABORTED, + BAKE_ERROR_NO_RAYCASTER, + BAKE_OK + }; + + enum BakeQuality { + BAKE_QUALITY_LOW, + BAKE_QUALITY_MEDIUM, + BAKE_QUALITY_HIGH, + BAKE_QUALITY_ULTRA, + }; + + typedef Lightmapper *(*CreateFunc)(); + + static CreateFunc create_custom; + static CreateFunc create_gpu; + static CreateFunc create_cpu; + +protected: +public: + typedef bool (*BakeStepFunc)(float, const String &, void *, bool); //progress, step description, userdata, force refresh + + struct MeshData { + struct TextureDef { + RID tex_rid; + Color mul; + Color add; + }; + + //triangle data + Vector points; + Vector uv; + Vector uv2; + Vector normal; + Vector albedo; + Vector emission; + Vector surface_facecounts; + Variant userdata; + }; + + virtual void add_albedo_texture(Ref p_texture) = 0; + virtual void add_emission_texture(Ref p_texture) = 0; + virtual void add_mesh(const MeshData &p_mesh, Vector2i p_size) = 0; + virtual void add_directional_light(bool p_bake_direct, const Vector3 &p_direction, const Color &p_color, float p_energy, float p_indirect_multiplier) = 0; + virtual void add_omni_light(bool p_bake_direct, const Vector3 &p_position, const Color &p_color, float p_energy, float p_indirect_multiplier, float p_range, float p_attenuation) = 0; + virtual void add_spot_light(bool p_bake_direct, const Vector3 &p_position, const Vector3 p_direction, const Color &p_color, float p_energy, float p_indirect_multiplier, float p_range, float p_attenuation, float p_spot_angle, float p_spot_attenuation) = 0; + virtual BakeError bake(BakeQuality p_quality, bool p_use_denoiser, int p_bounces, float p_bias, bool p_generate_atlas, int p_max_texture_size, const Ref &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function = nullptr, void *p_step_userdata = nullptr, BakeStepFunc p_substep_function = nullptr) = 0; + + virtual int get_bake_texture_count() const = 0; + virtual Ref get_bake_texture(int p_index) const = 0; + virtual int get_bake_mesh_count() const = 0; + virtual Variant get_bake_mesh_userdata(int p_index) const = 0; + virtual Rect2 get_bake_mesh_uv_scale(int p_index) const = 0; + virtual int get_bake_mesh_texture_slice(int p_index) const = 0; + + static Ref create(); + + Lightmapper(); +}; + +#endif // LIGHTMAPPER_H diff --git a/scene/3d/visual_instance.cpp b/scene/3d/visual_instance.cpp index a66b18e09867..7bea31af214a 100644 --- a/scene/3d/visual_instance.cpp +++ b/scene/3d/visual_instance.cpp @@ -170,6 +170,23 @@ Ref GeometryInstance::get_material_override() const { return material_override; } +void GeometryInstance::set_generate_lightmap(bool p_enabled) { + generate_lightmap = p_enabled; +} + +bool GeometryInstance::get_generate_lightmap() { + return generate_lightmap; +} + +void GeometryInstance::set_lightmap_scale(LightmapScale p_scale) { + ERR_FAIL_INDEX(p_scale, LIGHTMAP_SCALE_MAX); + lightmap_scale = p_scale; +} + +GeometryInstance::LightmapScale GeometryInstance::get_lightmap_scale() const { + return lightmap_scale; +} + void GeometryInstance::set_lod_min_distance(float p_dist) { lod_min_distance = p_dist; @@ -274,6 +291,12 @@ void GeometryInstance::_bind_methods() { ClassDB::bind_method(D_METHOD("set_cast_shadows_setting", "shadow_casting_setting"), &GeometryInstance::set_cast_shadows_setting); ClassDB::bind_method(D_METHOD("get_cast_shadows_setting"), &GeometryInstance::get_cast_shadows_setting); + ClassDB::bind_method(D_METHOD("set_generate_lightmap", "enabled"), &GeometryInstance::set_generate_lightmap); + ClassDB::bind_method(D_METHOD("get_generate_lightmap"), &GeometryInstance::get_generate_lightmap); + + ClassDB::bind_method(D_METHOD("set_lightmap_scale", "scale"), &GeometryInstance::set_lightmap_scale); + ClassDB::bind_method(D_METHOD("get_lightmap_scale"), &GeometryInstance::get_lightmap_scale); + ClassDB::bind_method(D_METHOD("set_lod_max_hysteresis", "mode"), &GeometryInstance::set_lod_max_hysteresis); ClassDB::bind_method(D_METHOD("get_lod_max_hysteresis"), &GeometryInstance::get_lod_max_hysteresis); @@ -297,7 +320,11 @@ void GeometryInstance::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "material_override", PROPERTY_HINT_RESOURCE_TYPE, "ShaderMaterial,SpatialMaterial"), "set_material_override", "get_material_override"); ADD_PROPERTY(PropertyInfo(Variant::INT, "cast_shadow", PROPERTY_HINT_ENUM, "Off,On,Double-Sided,Shadows Only"), "set_cast_shadows_setting", "get_cast_shadows_setting"); ADD_PROPERTY(PropertyInfo(Variant::REAL, "extra_cull_margin", PROPERTY_HINT_RANGE, "0,16384,0.01"), "set_extra_cull_margin", "get_extra_cull_margin"); + + ADD_GROUP("Baked Light", ""); ADD_PROPERTYI(PropertyInfo(Variant::BOOL, "use_in_baked_light"), "set_flag", "get_flag", FLAG_USE_BAKED_LIGHT); + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "generate_lightmap"), "set_generate_lightmap", "get_generate_lightmap"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "lightmap_scale", PROPERTY_HINT_ENUM, "1x,2x,4x,8x"), "set_lightmap_scale", "get_lightmap_scale"); ADD_GROUP("LOD", "lod_"); ADD_PROPERTY(PropertyInfo(Variant::INT, "lod_min_distance", PROPERTY_HINT_RANGE, "0,32768,0.01"), "set_lod_min_distance", "get_lod_min_distance"); @@ -307,6 +334,12 @@ void GeometryInstance::_bind_methods() { //ADD_SIGNAL( MethodInfo("visibility_changed")); + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_1X); + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_2X); + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_4X); + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_8X); + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_MAX); + BIND_ENUM_CONSTANT(SHADOW_CASTING_SETTING_OFF); BIND_ENUM_CONSTANT(SHADOW_CASTING_SETTING_ON); BIND_ENUM_CONSTANT(SHADOW_CASTING_SETTING_DOUBLE_SIDED); @@ -329,5 +362,7 @@ GeometryInstance::GeometryInstance() { shadow_casting_setting = SHADOW_CASTING_SETTING_ON; extra_cull_margin = 0; + generate_lightmap = true; + lightmap_scale = LightmapScale::LIGHTMAP_SCALE_1X; //VS::get_singleton()->instance_geometry_set_baked_light_texture_index(get_instance(),0); } diff --git a/scene/3d/visual_instance.h b/scene/3d/visual_instance.h index ce67b4dd6a3a..6c12ce5b2a23 100644 --- a/scene/3d/visual_instance.h +++ b/scene/3d/visual_instance.h @@ -91,6 +91,14 @@ class GeometryInstance : public VisualInstance { FLAG_MAX = VS::INSTANCE_FLAG_MAX, }; + enum LightmapScale { + LIGHTMAP_SCALE_1X, + LIGHTMAP_SCALE_2X, + LIGHTMAP_SCALE_4X, + LIGHTMAP_SCALE_8X, + LIGHTMAP_SCALE_MAX, + }; + enum ShadowCastingSetting { SHADOW_CASTING_SETTING_OFF = VS::SHADOW_CASTING_SETTING_OFF, SHADOW_CASTING_SETTING_ON = VS::SHADOW_CASTING_SETTING_ON, @@ -100,6 +108,8 @@ class GeometryInstance : public VisualInstance { private: bool flags[FLAG_MAX]; + bool generate_lightmap; + LightmapScale lightmap_scale; ShadowCastingSetting shadow_casting_setting; Ref material_override; float lod_min_distance; @@ -120,6 +130,15 @@ class GeometryInstance : public VisualInstance { void set_cast_shadows_setting(ShadowCastingSetting p_shadow_casting_setting); ShadowCastingSetting get_cast_shadows_setting() const; + void set_bake_cast_shadows(bool p_enabled); + bool get_bake_cast_shadows(); + + void set_generate_lightmap(bool p_enabled); + bool get_generate_lightmap(); + + void set_lightmap_scale(LightmapScale p_scale); + LightmapScale get_lightmap_scale() const; + void set_lod_min_distance(float p_dist); float get_lod_min_distance() const; @@ -144,6 +163,7 @@ class GeometryInstance : public VisualInstance { }; VARIANT_ENUM_CAST(GeometryInstance::Flags); +VARIANT_ENUM_CAST(GeometryInstance::LightmapScale); VARIANT_ENUM_CAST(GeometryInstance::ShadowCastingSetting); #endif diff --git a/scene/3d/voxel_light_baker.cpp b/scene/3d/voxel_light_baker.cpp index 48d722ff2ad7..0af8cbf20c18 100644 --- a/scene/3d/voxel_light_baker.cpp +++ b/scene/3d/voxel_light_baker.cpp @@ -718,12 +718,10 @@ void VoxelLightBaker::_init_light_plot(int p_idx, int p_level, int p_x, int p_y, } } -void VoxelLightBaker::begin_bake_light(BakeQuality p_quality, BakeMode p_bake_mode, float p_propagation, float p_energy) { +void VoxelLightBaker::begin_bake_light(BakeQuality p_quality, float p_propagation) { _check_init_light(); propagation = p_propagation; bake_quality = p_quality; - bake_mode = p_bake_mode; - energy = p_energy; } void VoxelLightBaker::_check_init_light() { @@ -733,7 +731,6 @@ void VoxelLightBaker::_check_init_light() { leaf_voxel_count = 0; _fixup_plot(0, 0); //pre fixup, so normal, albedo, emission, etc. work for lighting. bake_light.resize(bake_cells.size()); - print_line("bake light size: " + itos(bake_light.size())); //zeromem(bake_light.ptrw(), bake_light.size() * sizeof(Light)); first_leaf = -1; _init_light_plot(0, 0, 0, 0, 0, CHILD_EMPTY); @@ -1289,872 +1286,6 @@ void VoxelLightBaker::_fixup_plot(int p_idx, int p_level) { } } -//make sure any cell (save for the root) has an empty cell previous to it, so it can be interpolated into - -void VoxelLightBaker::_plot_triangle(Vector2 *vertices, Vector3 *positions, Vector3 *normals, LightMap *pixels, int width, int height) { - - int x[3]; - int y[3]; - - for (int j = 0; j < 3; j++) { - - x[j] = vertices[j].x * width; - y[j] = vertices[j].y * height; - //x[j] = CLAMP(x[j], 0, bt.width - 1); - //y[j] = CLAMP(y[j], 0, bt.height - 1); - } - - // sort the points vertically - if (y[1] > y[2]) { - SWAP(x[1], x[2]); - SWAP(y[1], y[2]); - SWAP(positions[1], positions[2]); - SWAP(normals[1], normals[2]); - } - if (y[0] > y[1]) { - SWAP(x[0], x[1]); - SWAP(y[0], y[1]); - SWAP(positions[0], positions[1]); - SWAP(normals[0], normals[1]); - } - if (y[1] > y[2]) { - SWAP(x[1], x[2]); - SWAP(y[1], y[2]); - SWAP(positions[1], positions[2]); - SWAP(normals[1], normals[2]); - } - - double dx_far = double(x[2] - x[0]) / (y[2] - y[0] + 1); - double dx_upper = double(x[1] - x[0]) / (y[1] - y[0] + 1); - double dx_low = double(x[2] - x[1]) / (y[2] - y[1] + 1); - double xf = x[0]; - double xt = x[0] + dx_upper; // if y[0] == y[1], special case - for (int yi = y[0]; yi <= (y[2] > height - 1 ? height - 1 : y[2]); yi++) { - if (yi >= 0) { - for (int xi = (xf > 0 ? int(xf) : 0); xi <= (xt < width ? xt : width - 1); xi++) { - //pixels[int(x + y * width)] = color; - - Vector2 v0 = Vector2(x[1] - x[0], y[1] - y[0]); - Vector2 v1 = Vector2(x[2] - x[0], y[2] - y[0]); - //vertices[2] - vertices[0]; - Vector2 v2 = Vector2(xi - x[0], yi - y[0]); - float d00 = v0.dot(v0); - float d01 = v0.dot(v1); - float d11 = v1.dot(v1); - float d20 = v2.dot(v0); - float d21 = v2.dot(v1); - float denom = (d00 * d11 - d01 * d01); - Vector3 pos; - Vector3 normal; - if (denom == 0) { - pos = positions[0]; - normal = normals[0]; - } else { - float v = (d11 * d20 - d01 * d21) / denom; - float w = (d00 * d21 - d01 * d20) / denom; - float u = 1.0f - v - w; - pos = positions[0] * u + positions[1] * v + positions[2] * w; - normal = normals[0] * u + normals[1] * v + normals[2] * w; - } - - int ofs = yi * width + xi; - pixels[ofs].normal = normal; - pixels[ofs].pos = pos; - } - - for (int xi = (xf < width ? int(xf) : width - 1); xi >= (xt > 0 ? xt : 0); xi--) { - //pixels[int(x + y * width)] = color; - Vector2 v0 = Vector2(x[1] - x[0], y[1] - y[0]); - Vector2 v1 = Vector2(x[2] - x[0], y[2] - y[0]); - //vertices[2] - vertices[0]; - Vector2 v2 = Vector2(xi - x[0], yi - y[0]); - float d00 = v0.dot(v0); - float d01 = v0.dot(v1); - float d11 = v1.dot(v1); - float d20 = v2.dot(v0); - float d21 = v2.dot(v1); - float denom = (d00 * d11 - d01 * d01); - Vector3 pos; - Vector3 normal; - if (denom == 0) { - pos = positions[0]; - normal = normals[0]; - } else { - float v = (d11 * d20 - d01 * d21) / denom; - float w = (d00 * d21 - d01 * d20) / denom; - float u = 1.0f - v - w; - pos = positions[0] * u + positions[1] * v + positions[2] * w; - normal = normals[0] * u + normals[1] * v + normals[2] * w; - } - - int ofs = yi * width + xi; - pixels[ofs].normal = normal; - pixels[ofs].pos = pos; - } - } - xf += dx_far; - if (yi < y[1]) - xt += dx_upper; - else - xt += dx_low; - } -} - -void VoxelLightBaker::_sample_baked_octree_filtered_and_anisotropic(const Vector3 &p_posf, const Vector3 &p_direction, float p_level, Vector3 &r_color, float &r_alpha) { - - int size = 1 << (cell_subdiv - 1); - - int clamp_v = size - 1; - //first of all, clamp - Vector3 pos; - pos.x = CLAMP(p_posf.x, 0, clamp_v); - pos.y = CLAMP(p_posf.y, 0, clamp_v); - pos.z = CLAMP(p_posf.z, 0, clamp_v); - - float level = (cell_subdiv - 1) - p_level; - - int target_level; - float level_filter; - if (level <= 0.0) { - level_filter = 0; - target_level = 0; - } else { - target_level = Math::ceil(level); - level_filter = target_level - level; - } - - const Cell *cells = bake_cells.ptr(); - const Light *light = bake_light.ptr(); - - Vector3 color[2][8]; - float alpha[2][8]; - zeromem(alpha, sizeof(float) * 2 * 8); - - //find cell at given level first - - for (int c = 0; c < 2; c++) { - - int current_level = MAX(0, target_level - c); - int level_cell_size = (1 << (cell_subdiv - 1)) >> current_level; - - for (int n = 0; n < 8; n++) { - - int x = int(pos.x); - int y = int(pos.y); - int z = int(pos.z); - - if (n & 1) - x += level_cell_size; - if (n & 2) - y += level_cell_size; - if (n & 4) - z += level_cell_size; - - int ofs_x = 0; - int ofs_y = 0; - int ofs_z = 0; - - x = CLAMP(x, 0, clamp_v); - y = CLAMP(y, 0, clamp_v); - z = CLAMP(z, 0, clamp_v); - - int half = size / 2; - uint32_t cell = 0; - for (int i = 0; i < current_level; i++) { - - const Cell *bc = &cells[cell]; - - int child = 0; - if (x >= ofs_x + half) { - child |= 1; - ofs_x += half; - } - if (y >= ofs_y + half) { - child |= 2; - ofs_y += half; - } - if (z >= ofs_z + half) { - child |= 4; - ofs_z += half; - } - - cell = bc->children[child]; - if (cell == CHILD_EMPTY) - break; - - half >>= 1; - } - - if (cell == CHILD_EMPTY) { - alpha[c][n] = 0; - } else { - alpha[c][n] = cells[cell].alpha; - - for (int i = 0; i < 6; i++) { - //anisotropic read light - float amount = p_direction.dot(aniso_normal[i]); - if (amount < 0) - amount = 0; - color[c][n].x += light[cell].accum[i][0] * amount; - color[c][n].y += light[cell].accum[i][1] * amount; - color[c][n].z += light[cell].accum[i][2] * amount; - } - - color[c][n].x += cells[cell].emission[0]; - color[c][n].y += cells[cell].emission[1]; - color[c][n].z += cells[cell].emission[2]; - } - } - } - - float target_level_size = size >> target_level; - Vector3 pos_fract[2]; - - pos_fract[0].x = Math::fmod(pos.x, target_level_size) / target_level_size; - pos_fract[0].y = Math::fmod(pos.y, target_level_size) / target_level_size; - pos_fract[0].z = Math::fmod(pos.z, target_level_size) / target_level_size; - - target_level_size = size >> MAX(0, target_level - 1); - - pos_fract[1].x = Math::fmod(pos.x, target_level_size) / target_level_size; - pos_fract[1].y = Math::fmod(pos.y, target_level_size) / target_level_size; - pos_fract[1].z = Math::fmod(pos.z, target_level_size) / target_level_size; - - float alpha_interp[2]; - Vector3 color_interp[2]; - - for (int i = 0; i < 2; i++) { - - Vector3 color_x00 = color[i][0].linear_interpolate(color[i][1], pos_fract[i].x); - Vector3 color_xy0 = color[i][2].linear_interpolate(color[i][3], pos_fract[i].x); - Vector3 blend_z0 = color_x00.linear_interpolate(color_xy0, pos_fract[i].y); - - Vector3 color_x0z = color[i][4].linear_interpolate(color[i][5], pos_fract[i].x); - Vector3 color_xyz = color[i][6].linear_interpolate(color[i][7], pos_fract[i].x); - Vector3 blend_z1 = color_x0z.linear_interpolate(color_xyz, pos_fract[i].y); - - color_interp[i] = blend_z0.linear_interpolate(blend_z1, pos_fract[i].z); - - float alpha_x00 = Math::lerp(alpha[i][0], alpha[i][1], pos_fract[i].x); - float alpha_xy0 = Math::lerp(alpha[i][2], alpha[i][3], pos_fract[i].x); - float alpha_z0 = Math::lerp(alpha_x00, alpha_xy0, pos_fract[i].y); - - float alpha_x0z = Math::lerp(alpha[i][4], alpha[i][5], pos_fract[i].x); - float alpha_xyz = Math::lerp(alpha[i][6], alpha[i][7], pos_fract[i].x); - float alpha_z1 = Math::lerp(alpha_x0z, alpha_xyz, pos_fract[i].y); - - alpha_interp[i] = Math::lerp(alpha_z0, alpha_z1, pos_fract[i].z); - } - - r_color = color_interp[0].linear_interpolate(color_interp[1], level_filter); - r_alpha = Math::lerp(alpha_interp[0], alpha_interp[1], level_filter); -} - -Vector3 VoxelLightBaker::_voxel_cone_trace(const Vector3 &p_pos, const Vector3 &p_normal, float p_aperture) { - - float bias = 2.5; - float max_distance = (Vector3(1, 1, 1) * (1 << (cell_subdiv - 1))).length(); - - float dist = bias; - float alpha = 0.0; - Vector3 color; - - Vector3 scolor; - float salpha; - - while (dist < max_distance && alpha < 0.95) { - float diameter = MAX(1.0, 2.0 * p_aperture * dist); - _sample_baked_octree_filtered_and_anisotropic(p_pos + dist * p_normal, p_normal, log2(diameter), scolor, salpha); - float a = (1.0 - alpha); - color += scolor * a; - alpha += a * salpha; - dist += diameter * 0.5; - } - - /*if (blend_ambient) { - color.rgb = mix(ambient,color.rgb,min(1.0,alpha/0.95)); - }*/ - - return color; -} - -Vector3 VoxelLightBaker::_compute_pixel_light_at_pos(const Vector3 &p_pos, const Vector3 &p_normal) { - - //find arbitrary tangent and bitangent, then build a matrix - Vector3 v0 = Math::abs(p_normal.z) < 0.999 ? Vector3(0, 0, 1) : Vector3(0, 1, 0); - Vector3 tangent = v0.cross(p_normal).normalized(); - Vector3 bitangent = tangent.cross(p_normal).normalized(); - Basis normal_xform = Basis(tangent, bitangent, p_normal).transposed(); - - const Vector3 *cone_dirs = NULL; - const float *cone_weights = NULL; - int cone_dir_count = 0; - float cone_aperture = 0; - - switch (bake_quality) { - case BAKE_QUALITY_LOW: { - //default quality - static const Vector3 dirs[4] = { - Vector3(Math_SQRT12, 0, Math_SQRT12), - Vector3(0, Math_SQRT12, Math_SQRT12), - Vector3(-Math_SQRT12, 0, Math_SQRT12), - Vector3(0, -Math_SQRT12, Math_SQRT12) - }; - - static const float weights[4] = { 0.25, 0.25, 0.25, 0.25 }; - - cone_dirs = dirs; - cone_dir_count = 4; - cone_aperture = 1.0; // tan(angle) 90 degrees - cone_weights = weights; - } break; - case BAKE_QUALITY_MEDIUM: { - //default quality - static const Vector3 dirs[6] = { - Vector3(0, 0, 1), - Vector3(0.866025, 0, 0.5), - Vector3(0.267617, 0.823639, 0.5), - Vector3(-0.700629, 0.509037, 0.5), - Vector3(-0.700629, -0.509037, 0.5), - Vector3(0.267617, -0.823639, 0.5) - }; - static const float weights[6] = { 0.25f, 0.15f, 0.15f, 0.15f, 0.15f, 0.15f }; - // - cone_dirs = dirs; - cone_dir_count = 6; - cone_aperture = 0.577; // tan(angle) 60 degrees - cone_weights = weights; - } break; - case BAKE_QUALITY_HIGH: { - - //high qualily - static const Vector3 dirs[10] = { - Vector3(0.8781648411741658, 0.0, 0.478358141694643), - Vector3(0.5369754325592234, 0.6794204427701518, 0.5000452447267606), - Vector3(-0.19849436573466497, 0.8429904390140635, 0.49996710542041645), - Vector3(-0.7856196499811189, 0.3639120321329737, 0.5003696617825604), - Vector3(-0.7856196499811189, -0.3639120321329737, 0.5003696617825604), - Vector3(-0.19849436573466497, -0.8429904390140635, 0.49996710542041645), - Vector3(0.5369754325592234, -0.6794204427701518, 0.5000452447267606), - Vector3(-0.4451656858129485, 0.0, 0.8954482185892644), - Vector3(0.19124006749743122, 0.39355745585016605, 0.8991883926788214), - Vector3(0.19124006749743122, -0.39355745585016605, 0.8991883926788214), - }; - static const float weights[10] = { 0.08571f, 0.08571f, 0.08571f, 0.08571f, 0.08571f, 0.08571f, 0.08571f, 0.133333f, 0.133333f, 0.13333f }; - cone_dirs = dirs; - cone_dir_count = 10; - cone_aperture = 0.404; // tan(angle) 45 degrees - cone_weights = weights; - } break; - } - - Vector3 accum; - - for (int i = 0; i < cone_dir_count; i++) { - Vector3 dir = normal_xform.xform(cone_dirs[i]).normalized(); //normal may not completely correct when transformed to cell - accum += _voxel_cone_trace(p_pos, dir, cone_aperture) * cone_weights[i]; - } - - return accum; -} - -_ALWAYS_INLINE_ uint32_t xorshift32(uint32_t *state) { - /* Algorithm "xor" from p. 4 of Marsaglia, "Xorshift RNGs" */ - uint32_t x = *state; - x ^= x << 13; - x ^= x >> 17; - x ^= x << 5; - *state = x; - return x; -} - -Vector3 VoxelLightBaker::_compute_ray_trace_at_pos(const Vector3 &p_pos, const Vector3 &p_normal) { - - int samples_per_quality[3] = { 48, 128, 512 }; - - int samples = samples_per_quality[bake_quality]; - - //create a basis in Z - Vector3 v0 = Math::abs(p_normal.z) < 0.999 ? Vector3(0, 0, 1) : Vector3(0, 1, 0); - Vector3 tangent = v0.cross(p_normal).normalized(); - Vector3 bitangent = tangent.cross(p_normal).normalized(); - Basis normal_xform = Basis(tangent, bitangent, p_normal).transposed(); - - float bias = 1.5; - int max_level = cell_subdiv - 1; - int size = 1 << max_level; - - Vector3 accum; - float spread = Math::deg2rad(80.0); - - const Light *light = bake_light.ptr(); - const Cell *cells = bake_cells.ptr(); - - uint32_t local_rng_state = rand(); //needs to be fixed again - - for (int i = 0; i < samples; i++) { - - float random_angle1 = (((xorshift32(&local_rng_state) % 65535) / 65535.0) * 2.0 - 1.0) * spread; - Vector3 axis(0, sin(random_angle1), cos(random_angle1)); - float random_angle2 = ((xorshift32(&local_rng_state) % 65535) / 65535.0) * Math_PI * 2.0; - Basis rot(Vector3(0, 0, 1), random_angle2); - axis = rot.xform(axis); - - Vector3 direction = normal_xform.xform(axis).normalized(); - - Vector3 advance = direction * _get_normal_advance(direction); - - Vector3 pos = p_pos /*+ Vector3(0.5, 0.5, 0.5)*/ + advance * bias; - - uint32_t cell = CHILD_EMPTY; - - while (cell == CHILD_EMPTY) { - - int x = int(pos.x); - int y = int(pos.y); - int z = int(pos.z); - - int ofs_x = 0; - int ofs_y = 0; - int ofs_z = 0; - int half = size / 2; - - if (x < 0 || x >= size) - break; - if (y < 0 || y >= size) - break; - if (z < 0 || z >= size) - break; - - //int level_limit = max_level; - - cell = 0; //start from root - for (int j = 0; j < max_level; j++) { - - const Cell *bc = &cells[cell]; - - int child = 0; - if (x >= ofs_x + half) { - child |= 1; - ofs_x += half; - } - if (y >= ofs_y + half) { - child |= 2; - ofs_y += half; - } - if (z >= ofs_z + half) { - child |= 4; - ofs_z += half; - } - - cell = bc->children[child]; - if (unlikely(cell == CHILD_EMPTY)) - break; - - half >>= 1; - } - - pos += advance; - } - - if (unlikely(cell != CHILD_EMPTY)) { - for (int j = 0; j < 6; j++) { - //anisotropic read light - float amount = direction.dot(aniso_normal[j]); - if (amount <= 0) - continue; - accum.x += light[cell].accum[j][0] * amount; - accum.y += light[cell].accum[j][1] * amount; - accum.z += light[cell].accum[j][2] * amount; - } - accum.x += cells[cell].emission[0]; - accum.y += cells[cell].emission[1]; - accum.z += cells[cell].emission[2]; - } - } - - // Make sure we don't reset this thread's RNG state - - return accum / samples; -} - -void VoxelLightBaker::_lightmap_bake_point(uint32_t p_x, LightMap *p_line) { - - LightMap *pixel = &p_line[p_x]; - if (pixel->pos == Vector3()) - return; - switch (bake_mode) { - case BAKE_MODE_CONE_TRACE: { - pixel->light = _compute_pixel_light_at_pos(pixel->pos, pixel->normal) * energy; - } break; - case BAKE_MODE_RAY_TRACE: { - pixel->light = _compute_ray_trace_at_pos(pixel->pos, pixel->normal) * energy; - } break; - } -} - -Error VoxelLightBaker::make_lightmap(const Transform &p_xform, Ref &p_mesh, float default_texels_per_unit, LightMapData &r_lightmap, bool (*p_bake_time_func)(void *, float, float), void *p_bake_time_ud) { - - //transfer light information to a lightmap - Ref mesh = p_mesh; - - //step 1 - create lightmap - int width; - int height; - Vector lightmap; - Transform xform = to_cell_space * p_xform; - if (mesh->get_lightmap_size_hint() == Size2()) { - double area = 0; - double uv_area = 0; - for (int i = 0; i < mesh->get_surface_count(); i++) { - Array arrays = mesh->surface_get_arrays(i); - PoolVector vertices = arrays[Mesh::ARRAY_VERTEX]; - PoolVector uv2 = arrays[Mesh::ARRAY_TEX_UV2]; - PoolVector indices = arrays[Mesh::ARRAY_INDEX]; - - ERR_FAIL_COND_V(vertices.size() == 0, ERR_INVALID_PARAMETER); - ERR_FAIL_COND_V(uv2.size() == 0, ERR_INVALID_PARAMETER); - - int vc = vertices.size(); - PoolVector::Read vr = vertices.read(); - PoolVector::Read u2r = uv2.read(); - PoolVector::Read ir; - int ic = 0; - - if (indices.size()) { - ic = indices.size(); - ir = indices.read(); - } - - int faces = ic ? ic / 3 : vc / 3; - for (int j = 0; j < faces; j++) { - Vector3 vertex[3]; - Vector2 uv[3]; - - for (int k = 0; k < 3; k++) { - int idx = ic ? ir[j * 3 + k] : j * 3 + k; - vertex[k] = xform.xform(vr[idx]); - uv[k] = u2r[idx]; - } - - Vector3 p1 = vertex[0]; - Vector3 p2 = vertex[1]; - Vector3 p3 = vertex[2]; - double a = p1.distance_to(p2); - double b = p2.distance_to(p3); - double c = p3.distance_to(p1); - double halfPerimeter = (a + b + c) / 2.0; - area += sqrt(halfPerimeter * (halfPerimeter - a) * (halfPerimeter - b) * (halfPerimeter - c)); - - Vector2 uv_p1 = uv[0]; - Vector2 uv_p2 = uv[1]; - Vector2 uv_p3 = uv[2]; - double uv_a = uv_p1.distance_to(uv_p2); - double uv_b = uv_p2.distance_to(uv_p3); - double uv_c = uv_p3.distance_to(uv_p1); - double uv_halfPerimeter = (uv_a + uv_b + uv_c) / 2.0; - uv_area += sqrt(uv_halfPerimeter * (uv_halfPerimeter - uv_a) * (uv_halfPerimeter - uv_b) * (uv_halfPerimeter - uv_c)); - } - } - - if (uv_area < 0.0001f) { - uv_area = 1.0; - } - - int pixels = (ceil((1.0 / sqrt(uv_area)) * sqrt(area * default_texels_per_unit))); - width = height = CLAMP(pixels, 2, 4096); - } else { - width = mesh->get_lightmap_size_hint().x; - height = mesh->get_lightmap_size_hint().y; - } - - lightmap.resize(width * height); - - //step 2 plot faces to lightmap - for (int i = 0; i < mesh->get_surface_count(); i++) { - Array arrays = mesh->surface_get_arrays(i); - PoolVector vertices = arrays[Mesh::ARRAY_VERTEX]; - PoolVector normals = arrays[Mesh::ARRAY_NORMAL]; - PoolVector uv2 = arrays[Mesh::ARRAY_TEX_UV2]; - PoolVector indices = arrays[Mesh::ARRAY_INDEX]; - - ERR_FAIL_COND_V(vertices.size() == 0, ERR_INVALID_PARAMETER); - ERR_FAIL_COND_V(normals.size() == 0, ERR_INVALID_PARAMETER); - ERR_FAIL_COND_V(uv2.size() == 0, ERR_INVALID_PARAMETER); - - int vc = vertices.size(); - PoolVector::Read vr = vertices.read(); - PoolVector::Read nr = normals.read(); - PoolVector::Read u2r = uv2.read(); - PoolVector::Read ir; - int ic = 0; - - if (indices.size()) { - ic = indices.size(); - ir = indices.read(); - } - - int faces = ic ? ic / 3 : vc / 3; - for (int j = 0; j < faces; j++) { - Vector3 vertex[3]; - Vector3 normal[3]; - Vector2 uv[3]; - - for (int k = 0; k < 3; k++) { - int idx = ic ? ir[j * 3 + k] : j * 3 + k; - vertex[k] = xform.xform(vr[idx]); - normal[k] = xform.basis.xform(nr[idx]).normalized(); - uv[k] = u2r[idx]; - } - - _plot_triangle(uv, vertex, normal, lightmap.ptrw(), width, height); - } - } - - //step 3 perform voxel cone trace on lightmap pixels - { - LightMap *lightmap_ptr = lightmap.ptrw(); - uint64_t begin_time = OS::get_singleton()->get_ticks_usec(); - volatile int lines = 0; - - // make sure our OS-level rng is seeded - - for (int i = 0; i < height; i++) { - - thread_process_array(width, this, &VoxelLightBaker::_lightmap_bake_point, &lightmap_ptr[i * width]); - - lines = MAX(lines, i); //for multithread - if (p_bake_time_func) { - uint64_t elapsed = OS::get_singleton()->get_ticks_usec() - begin_time; - float elapsed_sec = double(elapsed) / 1000000.0; - float remaining = lines < 1 ? 0 : (elapsed_sec / lines) * (height - lines - 1); - if (p_bake_time_func(p_bake_time_ud, remaining, lines / float(height))) { - return ERR_SKIP; - } - } - } - - if (bake_mode == BAKE_MODE_RAY_TRACE) { - //blur - //gauss kernel, 7 step sigma 2 - static const float gauss_kernel[4] = { 0.214607f, 0.189879f, 0.131514f, 0.071303f }; - //horizontal pass - for (int i = 0; i < height; i++) { - for (int j = 0; j < width; j++) { - if (lightmap_ptr[i * width + j].normal == Vector3()) { - continue; //empty - } - float gauss_sum = gauss_kernel[0]; - Vector3 accum = lightmap_ptr[i * width + j].light * gauss_kernel[0]; - for (int k = 1; k < 4; k++) { - int new_x = j + k; - if (new_x >= width || lightmap_ptr[i * width + new_x].normal == Vector3()) - break; - gauss_sum += gauss_kernel[k]; - accum += lightmap_ptr[i * width + new_x].light * gauss_kernel[k]; - } - for (int k = 1; k < 4; k++) { - int new_x = j - k; - if (new_x < 0 || lightmap_ptr[i * width + new_x].normal == Vector3()) - break; - gauss_sum += gauss_kernel[k]; - accum += lightmap_ptr[i * width + new_x].light * gauss_kernel[k]; - } - - lightmap_ptr[i * width + j].pos = accum /= gauss_sum; - } - } - //vertical pass - for (int i = 0; i < height; i++) { - for (int j = 0; j < width; j++) { - if (lightmap_ptr[i * width + j].normal == Vector3()) - continue; //empty, don't write over it anyway - float gauss_sum = gauss_kernel[0]; - Vector3 accum = lightmap_ptr[i * width + j].pos * gauss_kernel[0]; - for (int k = 1; k < 4; k++) { - int new_y = i + k; - if (new_y >= height || lightmap_ptr[new_y * width + j].normal == Vector3()) - break; - gauss_sum += gauss_kernel[k]; - accum += lightmap_ptr[new_y * width + j].pos * gauss_kernel[k]; - } - for (int k = 1; k < 4; k++) { - int new_y = i - k; - if (new_y < 0 || lightmap_ptr[new_y * width + j].normal == Vector3()) - break; - gauss_sum += gauss_kernel[k]; - accum += lightmap_ptr[new_y * width + j].pos * gauss_kernel[k]; - } - - lightmap_ptr[i * width + j].light = accum /= gauss_sum; - } - } - } - - //add directional light (do this after blur) - { - const Cell *cells = bake_cells.ptr(); - const Light *light = bake_light.ptr(); -#ifdef _OPENMP -#pragma omp parallel -#endif - for (int i = 0; i < height; i++) { -#ifdef _OPENMP -#pragma omp parallel for schedule(dynamic, 1) -#endif - for (int j = 0; j < width; j++) { - - //if (i == 125 && j == 280) { - - LightMap *pixel = &lightmap_ptr[i * width + j]; - if (pixel->pos == Vector3()) - continue; //unused, skipe - - int x = int(pixel->pos.x) - 1; - int y = int(pixel->pos.y) - 1; - int z = int(pixel->pos.z) - 1; - Color accum; - int size = 1 << (cell_subdiv - 1); - - int found = 0; - - for (int k = 0; k < 8; k++) { - - int ofs_x = x; - int ofs_y = y; - int ofs_z = z; - - if (k & 1) - ofs_x++; - if (k & 2) - ofs_y++; - if (k & 4) - ofs_z++; - - if (x < 0 || x >= size) - continue; - if (y < 0 || y >= size) - continue; - if (z < 0 || z >= size) - continue; - - uint32_t cell = _find_cell_at_pos(cells, ofs_x, ofs_y, ofs_z); - - if (cell == CHILD_EMPTY) - continue; - for (int l = 0; l < 6; l++) { - float s = pixel->normal.dot(aniso_normal[l]); - if (s < 0) - s = 0; - accum.r += light[cell].direct_accum[l][0] * s; - accum.g += light[cell].direct_accum[l][1] * s; - accum.b += light[cell].direct_accum[l][2] * s; - } - found++; - } - if (found) { - accum /= found; - pixel->light.x += accum.r; - pixel->light.y += accum.g; - pixel->light.z += accum.b; - } - } - } - } - - { - //fill gaps with neighbour vertices to avoid filter fades to black on edges - - for (int i = 0; i < height; i++) { - for (int j = 0; j < width; j++) { - if (lightmap_ptr[i * width + j].normal != Vector3()) { - continue; //filled, skip - } - - //this can't be made separatable.. - - int closest_i = -1, closest_j = 1; - float closest_dist = 1e20; - - const int margin = 3; - for (int y = i - margin; y <= i + margin; y++) { - for (int x = j - margin; x <= j + margin; x++) { - - if (x == j && y == i) - continue; - if (x < 0 || x >= width) - continue; - if (y < 0 || y >= height) - continue; - if (lightmap_ptr[y * width + x].normal == Vector3()) - continue; //also ensures that blitted stuff is not reused - - float dist = Vector2(i - y, j - x).length(); - if (dist > closest_dist) - continue; - - closest_dist = dist; - closest_i = y; - closest_j = x; - } - } - - if (closest_i != -1) { - lightmap_ptr[i * width + j].light = lightmap_ptr[closest_i * width + closest_j].light; - } - } - } - } - - { - //fill the lightmap data - r_lightmap.width = width; - r_lightmap.height = height; - r_lightmap.light.resize(lightmap.size() * 3); - PoolVector::Write w = r_lightmap.light.write(); - for (int i = 0; i < lightmap.size(); i++) { - w[i * 3 + 0] = lightmap[i].light.x; - w[i * 3 + 1] = lightmap[i].light.y; - w[i * 3 + 2] = lightmap[i].light.z; - } - } - -#if 0 // Enable for debugging. - { - PoolVector img; - int ls = lightmap.size(); - img.resize(ls * 3); - { - PoolVector::Write w = img.write(); - for (int i = 0; i < ls; i++) { - w[i * 3 + 0] = CLAMP(lightmap_ptr[i].light.x * 255, 0, 255); - w[i * 3 + 1] = CLAMP(lightmap_ptr[i].light.y * 255, 0, 255); - w[i * 3 + 2] = CLAMP(lightmap_ptr[i].light.z * 255, 0, 255); - //w[i * 3 + 0] = CLAMP(lightmap_ptr[i].normal.x * 255, 0, 255); - //w[i * 3 + 1] = CLAMP(lightmap_ptr[i].normal.y * 255, 0, 255); - //w[i * 3 + 2] = CLAMP(lightmap_ptr[i].normal.z * 255, 0, 255); - //w[i * 3 + 0] = CLAMP(lightmap_ptr[i].pos.x / (1 << (cell_subdiv - 1)) * 255, 0, 255); - //w[i * 3 + 1] = CLAMP(lightmap_ptr[i].pos.y / (1 << (cell_subdiv - 1)) * 255, 0, 255); - //w[i * 3 + 2] = CLAMP(lightmap_ptr[i].pos.z / (1 << (cell_subdiv - 1)) * 255, 0, 255); - } - } - - Ref image; - image.instance(); - image->create(width, height, false, Image::FORMAT_RGB8, img); - - String name = p_mesh->get_name(); - if (name == "") { - name = "Mesh" + itos(p_mesh->get_instance_id()); - } - image->save_png(name + ".png"); - } -#endif - } - - return OK; -} - void VoxelLightBaker::begin_bake(int p_subdiv, const AABB &p_bounds) { original_bounds = p_bounds; @@ -2482,5 +1613,4 @@ VoxelLightBaker::VoxelLightBaker() { color_scan_cell_width = 4; bake_texture_size = 128; propagation = 0.85; - energy = 1.0; } diff --git a/scene/3d/voxel_light_baker.h b/scene/3d/voxel_light_baker.h index 135fae2ac819..9f7258caeb31 100644 --- a/scene/3d/voxel_light_baker.h +++ b/scene/3d/voxel_light_baker.h @@ -128,10 +128,8 @@ class VoxelLightBaker { int bake_texture_size; float cell_size; float propagation; - float energy; BakeQuality bake_quality; - BakeMode bake_mode; int max_original_cells; @@ -147,25 +145,10 @@ class VoxelLightBaker { uint32_t _find_cell_at_pos(const Cell *cells, int x, int y, int z); - struct LightMap { - Vector3 light; - Vector3 pos; - Vector3 normal; - }; - - void _plot_triangle(Vector2 *vertices, Vector3 *positions, Vector3 *normals, LightMap *pixels, int width, int height); - - _FORCE_INLINE_ void _sample_baked_octree_filtered_and_anisotropic(const Vector3 &p_posf, const Vector3 &p_direction, float p_level, Vector3 &r_color, float &r_alpha); - _FORCE_INLINE_ Vector3 _voxel_cone_trace(const Vector3 &p_pos, const Vector3 &p_normal, float p_aperture); - _FORCE_INLINE_ Vector3 _compute_pixel_light_at_pos(const Vector3 &p_pos, const Vector3 &p_normal); - _FORCE_INLINE_ Vector3 _compute_ray_trace_at_pos(const Vector3 &p_pos, const Vector3 &p_normal); - - void _lightmap_bake_point(uint32_t p_x, LightMap *p_line); - public: void begin_bake(int p_subdiv, const AABB &p_bounds); void plot_mesh(const Transform &p_xform, Ref &p_mesh, const Vector > &p_materials, const Ref &p_override_material); - void begin_bake_light(BakeQuality p_quality = BAKE_QUALITY_MEDIUM, BakeMode p_bake_mode = BAKE_MODE_CONE_TRACE, float p_propagation = 0.85, float p_energy = 1); + void begin_bake_light(BakeQuality p_quality = BAKE_QUALITY_MEDIUM, float p_propagation = 0.85); void plot_light_directional(const Vector3 &p_direction, const Color &p_color, float p_energy, float p_indirect_energy, bool p_direct); void plot_light_omni(const Vector3 &p_pos, const Color &p_color, float p_energy, float p_indirect_energy, float p_radius, float p_attenutation, bool p_direct); void plot_light_spot(const Vector3 &p_pos, const Vector3 &p_axis, const Color &p_color, float p_energy, float p_indirect_energy, float p_radius, float p_attenutation, float p_spot_angle, float p_spot_attenuation, bool p_direct); @@ -177,8 +160,6 @@ class VoxelLightBaker { PoolVector light; }; - Error make_lightmap(const Transform &p_xform, Ref &p_mesh, float default_texels_per_unit, LightMapData &r_lightmap, bool (*p_bake_time_func)(void *, float, float) = NULL, void *p_bake_time_ud = NULL); - PoolVector create_gi_probe_data(); Ref create_debug_multimesh(DebugMode p_mode = DEBUG_ALBEDO); PoolVector create_capture_octree(int p_subdiv); diff --git a/scene/resources/mesh.cpp b/scene/resources/mesh.cpp index fb0595e58d26..f997929ed3f5 100644 --- a/scene/resources/mesh.cpp +++ b/scene/resources/mesh.cpp @@ -30,6 +30,8 @@ #include "mesh.h" +#include "core/crypto/crypto_core.h" +#include "core/local_vector.h" #include "core/pair.h" #include "scene/resources/concave_polygon_shape.h" #include "scene/resources/convex_polygon_shape.h" @@ -1108,18 +1110,35 @@ struct ArrayMeshLightmapSurface { }; Error ArrayMesh::lightmap_unwrap(const Transform &p_base_transform, float p_texel_size) { + int *cache_data = nullptr; + unsigned int cache_size = 0; + bool use_cache = false; // Don't use cache + return lightmap_unwrap_cached(cache_data, cache_size, use_cache, p_base_transform, p_texel_size); +} + +Error ArrayMesh::lightmap_unwrap_cached(int *&r_cache_data, unsigned int &r_cache_size, bool &r_used_cache, const Transform &p_base_transform, float p_texel_size) { ERR_FAIL_COND_V(!array_mesh_lightmap_unwrap_callback, ERR_UNCONFIGURED); ERR_FAIL_COND_V_MSG(blend_shapes.size() != 0, ERR_UNAVAILABLE, "Can't unwrap mesh with blend shapes."); - Vector vertices; - Vector normals; - Vector indices; - Vector face_materials; - Vector uv; - Vector > uv_index; + LocalVector vertices; + LocalVector normals; + LocalVector indices; + LocalVector face_materials; + LocalVector uv; + LocalVector > uv_indices; + + Vector lightmap_surfaces; + + // Keep only the scale + Basis basis = p_base_transform.get_basis(); + Vector3 scale = Vector3(basis.get_axis(0).length(), basis.get_axis(1).length(), basis.get_axis(2).length()); + + Transform transform; + transform.scale(scale); + + Basis normal_basis = transform.basis.inverse().transposed(); - Vector surfaces; for (int i = 0; i < get_surface_count(); i++) { ArrayMeshLightmapSurface s; s.primitive = surface_get_primitive_type(i); @@ -1143,30 +1162,36 @@ Error ArrayMesh::lightmap_unwrap(const Transform &p_base_transform, float p_texe vertices.resize((vertex_ofs + vc) * 3); normals.resize((vertex_ofs + vc) * 3); - uv_index.resize(vertex_ofs + vc); + uv_indices.resize(vertex_ofs + vc); for (int j = 0; j < vc; j++) { - Vector3 v = p_base_transform.xform(r[j]); - Vector3 n = p_base_transform.basis.xform(rn[j]).normalized(); + Vector3 v = transform.xform(r[j]); + Vector3 n = normal_basis.xform(rn[j]).normalized(); - vertices.write[(j + vertex_ofs) * 3 + 0] = v.x; - vertices.write[(j + vertex_ofs) * 3 + 1] = v.y; - vertices.write[(j + vertex_ofs) * 3 + 2] = v.z; - normals.write[(j + vertex_ofs) * 3 + 0] = n.x; - normals.write[(j + vertex_ofs) * 3 + 1] = n.y; - normals.write[(j + vertex_ofs) * 3 + 2] = n.z; - uv_index.write[j + vertex_ofs] = Pair(i, j); + vertices[(j + vertex_ofs) * 3 + 0] = v.x; + vertices[(j + vertex_ofs) * 3 + 1] = v.y; + vertices[(j + vertex_ofs) * 3 + 2] = v.z; + normals[(j + vertex_ofs) * 3 + 0] = n.x; + normals[(j + vertex_ofs) * 3 + 1] = n.y; + normals[(j + vertex_ofs) * 3 + 2] = n.z; + uv_indices[j + vertex_ofs] = Pair(i, j); } PoolVector rindices = arrays[Mesh::ARRAY_INDEX]; int ic = rindices.size(); + float eps = 1.19209290e-7F; // Taken from xatlas.h if (ic == 0) { for (int j = 0; j < vc / 3; j++) { - if (Face3(r[j * 3 + 0], r[j * 3 + 1], r[j * 3 + 2]).is_degenerate()) + Vector3 p0 = transform.xform(r[j * 3 + 0]); + Vector3 p1 = transform.xform(r[j * 3 + 1]); + Vector3 p2 = transform.xform(r[j * 3 + 2]); + + if ((p0 - p1).length_squared() < eps || (p1 - p2).length_squared() < eps || (p2 - p0).length_squared() < eps) { continue; + } indices.push_back(vertex_ofs + j * 3 + 0); indices.push_back(vertex_ofs + j * 3 + 1); @@ -1178,8 +1203,14 @@ Error ArrayMesh::lightmap_unwrap(const Transform &p_base_transform, float p_texe PoolVector::Read ri = rindices.read(); for (int j = 0; j < ic / 3; j++) { - if (Face3(r[ri[j * 3 + 0]], r[ri[j * 3 + 1]], r[ri[j * 3 + 2]]).is_degenerate()) + Vector3 p0 = transform.xform(r[ri[j * 3 + 0]]); + Vector3 p1 = transform.xform(r[ri[j * 3 + 1]]); + Vector3 p2 = transform.xform(r[ri[j * 3 + 2]]); + + if ((p0 - p1).length_squared() < eps || (p1 - p2).length_squared() < eps || (p2 - p0).length_squared() < eps) { continue; + } + indices.push_back(vertex_ofs + ri[j * 3 + 0]); indices.push_back(vertex_ofs + ri[j * 3 + 1]); indices.push_back(vertex_ofs + ri[j * 3 + 2]); @@ -1187,7 +1218,49 @@ Error ArrayMesh::lightmap_unwrap(const Transform &p_base_transform, float p_texe } } - surfaces.push_back(s); + lightmap_surfaces.push_back(s); + } + + CryptoCore::MD5Context ctx; + ctx.start(); + + ctx.update((unsigned char *)&p_texel_size, sizeof(float)); + ctx.update((unsigned char *)indices.ptr(), sizeof(int) * indices.size()); + ctx.update((unsigned char *)face_materials.ptr(), sizeof(int) * face_materials.size()); + ctx.update((unsigned char *)vertices.ptr(), sizeof(float) * vertices.size()); + ctx.update((unsigned char *)normals.ptr(), sizeof(float) * normals.size()); + + unsigned char hash[16]; + ctx.finish(hash); + + bool cached = false; + unsigned int cache_idx = 0; + + if (r_used_cache && r_cache_data) { + //Check if hash is in cache data + + int *cache_data = r_cache_data; + int n_entries = cache_data[0]; + unsigned int r_idx = 1; + for (int i = 0; i < n_entries; ++i) { + if (memcmp(&cache_data[r_idx], hash, 16) == 0) { + cached = true; + cache_idx = r_idx; + break; + } + + r_idx += 4; // hash + r_idx += 2; // size hint + + int vertex_count = cache_data[r_idx]; + r_idx += 1; // vertex count + r_idx += vertex_count; // vertex + r_idx += vertex_count * 2; // uvs + + int index_count = cache_data[r_idx]; + r_idx += 1; // index count + r_idx += index_count; // indices + } } //unwrap @@ -1200,10 +1273,86 @@ Error ArrayMesh::lightmap_unwrap(const Transform &p_base_transform, float p_texe int size_x; int size_y; - bool ok = array_mesh_lightmap_unwrap_callback(p_texel_size, vertices.ptr(), normals.ptr(), vertices.size() / 3, indices.ptr(), face_materials.ptr(), indices.size(), &gen_uvs, &gen_vertices, &gen_vertex_count, &gen_indices, &gen_index_count, &size_x, &size_y); + if (r_used_cache && cached) { + int *cache_data = r_cache_data; + + // Return cache data pointer to the caller + r_cache_data = &cache_data[cache_idx]; + + cache_idx += 4; + + // Load size + size_x = ((int *)cache_data)[cache_idx]; + size_y = ((int *)cache_data)[cache_idx + 1]; + cache_idx += 2; + + // Load vertices + gen_vertex_count = cache_data[cache_idx]; + cache_idx++; + gen_vertices = &cache_data[cache_idx]; + cache_idx += gen_vertex_count; + + // Load UVs + gen_uvs = (float *)&cache_data[cache_idx]; + cache_idx += gen_vertex_count * 2; + + // Load indices + gen_index_count = cache_data[cache_idx]; + cache_idx++; + gen_indices = &cache_data[cache_idx]; + + // Return cache data size to the caller + r_cache_size = sizeof(int) * (4 + 2 + 1 + gen_vertex_count + (gen_vertex_count * 2) + 1 + gen_index_count); // hash + size hint + vertex_count + vertices + uvs + index_count + indices + r_used_cache = true; + } + + if (!cached) { + bool ok = array_mesh_lightmap_unwrap_callback(p_texel_size, vertices.ptr(), normals.ptr(), vertices.size() / 3, indices.ptr(), face_materials.ptr(), indices.size(), &gen_uvs, &gen_vertices, &gen_vertex_count, &gen_indices, &gen_index_count, &size_x, &size_y); - if (!ok) { - return ERR_CANT_CREATE; + if (!ok) { + return ERR_CANT_CREATE; + } + + if (r_used_cache) { + unsigned int new_cache_size = 4 + 2 + 1 + gen_vertex_count + (gen_vertex_count * 2) + 1 + gen_index_count; // hash + size hint + vertex_count + vertices + uvs + index_count + indices + new_cache_size *= sizeof(int); + int *new_cache_data = (int *)memalloc(new_cache_size); + unsigned int new_cache_idx = 0; + + // hash + memcpy(&new_cache_data[new_cache_idx], hash, 16); + new_cache_idx += 4; + + // size hint + new_cache_data[new_cache_idx] = size_x; + new_cache_data[new_cache_idx + 1] = size_y; + new_cache_idx += 2; + + // vertex count + new_cache_data[new_cache_idx] = gen_vertex_count; + new_cache_idx++; + + // vertices + memcpy(&new_cache_data[new_cache_idx], gen_vertices, sizeof(int) * gen_vertex_count); + new_cache_idx += gen_vertex_count; + + // uvs + memcpy(&new_cache_data[new_cache_idx], gen_uvs, sizeof(float) * gen_vertex_count * 2); + new_cache_idx += gen_vertex_count * 2; + + // index count + new_cache_data[new_cache_idx] = gen_index_count; + new_cache_idx++; + + // indices + memcpy(&new_cache_data[new_cache_idx], gen_indices, sizeof(int) * gen_index_count); + new_cache_idx += gen_index_count; + + // Return cache data to the caller + r_cache_data = new_cache_data; + r_cache_size = new_cache_size; + r_used_cache = false; + } } //remove surfaces @@ -1212,13 +1361,13 @@ Error ArrayMesh::lightmap_unwrap(const Transform &p_base_transform, float p_texe } //create surfacetools for each surface.. - Vector > surfaces_tools; + LocalVector > surfaces_tools; - for (int i = 0; i < surfaces.size(); i++) { + for (int i = 0; i < lightmap_surfaces.size(); i++) { Ref st; st.instance(); st->begin(Mesh::PRIMITIVE_TRIANGLES); - st->set_material(surfaces[i].material); + st->set_material(lightmap_surfaces[i].material); surfaces_tools.push_back(st); //stay there } @@ -1226,61 +1375,62 @@ Error ArrayMesh::lightmap_unwrap(const Transform &p_base_transform, float p_texe //go through all indices for (int i = 0; i < gen_index_count; i += 3) { - ERR_FAIL_INDEX_V(gen_vertices[gen_indices[i + 0]], uv_index.size(), ERR_BUG); - ERR_FAIL_INDEX_V(gen_vertices[gen_indices[i + 1]], uv_index.size(), ERR_BUG); - ERR_FAIL_INDEX_V(gen_vertices[gen_indices[i + 2]], uv_index.size(), ERR_BUG); + ERR_FAIL_INDEX_V(gen_vertices[gen_indices[i + 0]], (int)uv_indices.size(), ERR_BUG); + ERR_FAIL_INDEX_V(gen_vertices[gen_indices[i + 1]], (int)uv_indices.size(), ERR_BUG); + ERR_FAIL_INDEX_V(gen_vertices[gen_indices[i + 2]], (int)uv_indices.size(), ERR_BUG); - ERR_FAIL_COND_V(uv_index[gen_vertices[gen_indices[i + 0]]].first != uv_index[gen_vertices[gen_indices[i + 1]]].first || uv_index[gen_vertices[gen_indices[i + 0]]].first != uv_index[gen_vertices[gen_indices[i + 2]]].first, ERR_BUG); + ERR_FAIL_COND_V(uv_indices[gen_vertices[gen_indices[i + 0]]].first != uv_indices[gen_vertices[gen_indices[i + 1]]].first || uv_indices[gen_vertices[gen_indices[i + 0]]].first != uv_indices[gen_vertices[gen_indices[i + 2]]].first, ERR_BUG); - int surface = uv_index[gen_vertices[gen_indices[i + 0]]].first; + int surface = uv_indices[gen_vertices[gen_indices[i + 0]]].first; for (int j = 0; j < 3; j++) { - SurfaceTool::Vertex v = surfaces[surface].vertices[uv_index[gen_vertices[gen_indices[i + j]]].second]; + SurfaceTool::Vertex v = lightmap_surfaces[surface].vertices[uv_indices[gen_vertices[gen_indices[i + j]]].second]; - if (surfaces[surface].format & ARRAY_FORMAT_COLOR) { - surfaces_tools.write[surface]->add_color(v.color); + if (lightmap_surfaces[surface].format & ARRAY_FORMAT_COLOR) { + surfaces_tools[surface]->add_color(v.color); } - if (surfaces[surface].format & ARRAY_FORMAT_TEX_UV) { - surfaces_tools.write[surface]->add_uv(v.uv); + if (lightmap_surfaces[surface].format & ARRAY_FORMAT_TEX_UV) { + surfaces_tools[surface]->add_uv(v.uv); } - if (surfaces[surface].format & ARRAY_FORMAT_NORMAL) { - surfaces_tools.write[surface]->add_normal(v.normal); + if (lightmap_surfaces[surface].format & ARRAY_FORMAT_NORMAL) { + surfaces_tools[surface]->add_normal(v.normal); } - if (surfaces[surface].format & ARRAY_FORMAT_TANGENT) { + if (lightmap_surfaces[surface].format & ARRAY_FORMAT_TANGENT) { Plane t; t.normal = v.tangent; t.d = v.binormal.dot(v.normal.cross(v.tangent)) < 0 ? -1 : 1; - surfaces_tools.write[surface]->add_tangent(t); + surfaces_tools[surface]->add_tangent(t); } - if (surfaces[surface].format & ARRAY_FORMAT_BONES) { - surfaces_tools.write[surface]->add_bones(v.bones); + if (lightmap_surfaces[surface].format & ARRAY_FORMAT_BONES) { + surfaces_tools[surface]->add_bones(v.bones); } - if (surfaces[surface].format & ARRAY_FORMAT_WEIGHTS) { - surfaces_tools.write[surface]->add_weights(v.weights); + if (lightmap_surfaces[surface].format & ARRAY_FORMAT_WEIGHTS) { + surfaces_tools[surface]->add_weights(v.weights); } Vector2 uv2(gen_uvs[gen_indices[i + j] * 2 + 0], gen_uvs[gen_indices[i + j] * 2 + 1]); - surfaces_tools.write[surface]->add_uv2(uv2); + surfaces_tools[surface]->add_uv2(uv2); - surfaces_tools.write[surface]->add_vertex(v.vertex); + surfaces_tools[surface]->add_vertex(v.vertex); } } - //free stuff - ::free(gen_vertices); - ::free(gen_indices); - ::free(gen_uvs); - //generate surfaces - - for (int i = 0; i < surfaces_tools.size(); i++) { - surfaces_tools.write[i]->index(); - surfaces_tools.write[i]->commit(Ref((ArrayMesh *)this), surfaces[i].format); + for (unsigned int i = 0; i < surfaces_tools.size(); i++) { + surfaces_tools[i]->index(); + surfaces_tools[i]->commit(Ref((ArrayMesh *)this), lightmap_surfaces[i].format); } set_lightmap_size_hint(Size2(size_x, size_y)); + if (!cached) { + //free stuff + ::free(gen_vertices); + ::free(gen_indices); + ::free(gen_uvs); + } + return OK; } diff --git a/scene/resources/mesh.h b/scene/resources/mesh.h index f4f86ff4598b..3c4cbb570e51 100644 --- a/scene/resources/mesh.h +++ b/scene/resources/mesh.h @@ -231,6 +231,7 @@ class ArrayMesh : public Mesh { void regen_normalmaps(); Error lightmap_unwrap(const Transform &p_base_transform = Transform(), float p_texel_size = 0.05); + Error lightmap_unwrap_cached(int *&r_cache_data, unsigned int &r_cache_size, bool &r_used_cache, const Transform &p_base_transform = Transform(), float p_texel_size = 0.05); virtual void reload_from_file(); diff --git a/scene/resources/sky.cpp b/scene/resources/sky.cpp index a724b435d6c6..5bd96bac0e60 100644 --- a/scene/resources/sky.cpp +++ b/scene/resources/sky.cpp @@ -390,6 +390,10 @@ ProceduralSky::TextureSize ProceduralSky::get_texture_size() const { return texture_size; } +Ref ProceduralSky::get_panorama() const { + return panorama; +} + RID ProceduralSky::get_rid() const { return sky; } @@ -414,9 +418,9 @@ void ProceduralSky::_update_sky() { } } else { - Ref image = _generate_sky(); - VS::get_singleton()->texture_allocate(texture, image->get_width(), image->get_height(), 0, Image::FORMAT_RGBE9995, VS::TEXTURE_TYPE_2D, VS::TEXTURE_FLAG_FILTER | VS::TEXTURE_FLAG_REPEAT); - VS::get_singleton()->texture_set_data(texture, image); + panorama = _generate_sky(); + VS::get_singleton()->texture_allocate(texture, panorama->get_width(), panorama->get_height(), 0, Image::FORMAT_RGBE9995, VS::TEXTURE_TYPE_2D, VS::TEXTURE_FLAG_FILTER | VS::TEXTURE_FLAG_REPEAT); + VS::get_singleton()->texture_set_data(texture, panorama); _radiance_changed(); } } @@ -432,8 +436,9 @@ void ProceduralSky::_queue_update() { void ProceduralSky::_thread_done(const Ref &p_image) { - VS::get_singleton()->texture_allocate(texture, p_image->get_width(), p_image->get_height(), 0, Image::FORMAT_RGBE9995, VS::TEXTURE_TYPE_2D, VS::TEXTURE_FLAG_FILTER | VS::TEXTURE_FLAG_REPEAT); - VS::get_singleton()->texture_set_data(texture, p_image); + panorama = p_image; + VS::get_singleton()->texture_allocate(texture, panorama->get_width(), panorama->get_height(), 0, Image::FORMAT_RGBE9995, VS::TEXTURE_TYPE_2D, VS::TEXTURE_FLAG_FILTER | VS::TEXTURE_FLAG_REPEAT); + VS::get_singleton()->texture_set_data(texture, panorama); _radiance_changed(); Thread::wait_to_finish(sky_thread); memdelete(sky_thread); diff --git a/scene/resources/sky.h b/scene/resources/sky.h index 5d763cbdbb91..503b23976f12 100644 --- a/scene/resources/sky.h +++ b/scene/resources/sky.h @@ -122,6 +122,7 @@ class ProceduralSky : public Sky { RID sky; RID texture; + Ref panorama; bool update_queued; bool regen_queued; @@ -189,6 +190,8 @@ class ProceduralSky : public Sky { void set_texture_size(TextureSize p_size); TextureSize get_texture_size() const; + Ref get_panorama() const; + virtual RID get_rid() const; ProceduralSky(bool p_desaturate = false); diff --git a/scene/resources/texture.cpp b/scene/resources/texture.cpp index 0e94864b8f76..03e5a676d936 100644 --- a/scene/resources/texture.cpp +++ b/scene/resources/texture.cpp @@ -2250,6 +2250,143 @@ Image::Format TextureLayered::get_format() const { return format; } +Error TextureLayered::load(const String &p_path) { + + Error error; + FileAccess *f = FileAccess::open(p_path, FileAccess::READ, &error); + ERR_FAIL_COND_V(error, error); + + uint8_t header[5] = { 0, 0, 0, 0, 0 }; + f->get_buffer(header, 4); + + if (header[0] == 'G' && header[1] == 'D' && header[2] == '3' && header[3] == 'T') { + if (!Object::cast_to(this)) { + f->close(); + memdelete(f); + ERR_FAIL_V(ERR_INVALID_DATA); + } + } else if (header[0] == 'G' && header[1] == 'D' && header[2] == 'A' && header[3] == 'T') { + if (!Object::cast_to(this)) { + f->close(); + memdelete(f); + ERR_FAIL_V(ERR_INVALID_DATA); + } + } else { + + f->close(); + memdelete(f); + ERR_FAIL_V_MSG(ERR_INVALID_DATA, "Unrecognized layered texture file format: " + String((const char *)header)); + } + + int tw = f->get_32(); + int th = f->get_32(); + int td = f->get_32(); + int flags = f->get_32(); //texture flags! + Image::Format format = Image::Format(f->get_32()); + uint32_t compression = f->get_32(); // 0 - lossless (PNG), 1 - vram, 2 - uncompressed + + create(tw, th, td, format, flags); + + for (int layer = 0; layer < td; layer++) { + + Ref image; + image.instance(); + + if (compression == COMPRESS_LOSSLESS) { + //look for a PNG file inside + + int mipmaps = f->get_32(); + Vector > mipmap_images; + + for (int i = 0; i < mipmaps; i++) { + uint32_t size = f->get_32(); + + PoolVector pv; + pv.resize(size); + { + PoolVector::Write w = pv.write(); + f->get_buffer(w.ptr(), size); + } + + Ref img = Image::lossless_unpacker(pv); + + if (img.is_null() || img->empty() || format != img->get_format()) { + f->close(); + memdelete(f); + ERR_FAIL_V(ERR_FILE_CORRUPT); + } + + mipmap_images.push_back(img); + } + + if (mipmap_images.size() == 1) { + + image = mipmap_images[0]; + + } else { + int total_size = Image::get_image_data_size(tw, th, format, true); + PoolVector img_data; + img_data.resize(total_size); + + { + PoolVector::Write w = img_data.write(); + + int ofs = 0; + for (int i = 0; i < mipmap_images.size(); i++) { + + PoolVector id = mipmap_images[i]->get_data(); + int len = id.size(); + PoolVector::Read r = id.read(); + copymem(&w[ofs], r.ptr(), len); + ofs += len; + } + } + + image->create(tw, th, true, format, img_data); + if (image->empty()) { + f->close(); + memdelete(f); + ERR_FAIL_V(ERR_FILE_CORRUPT); + } + } + + } else { + + //look for regular format + bool mipmaps = (flags & Texture::FLAG_MIPMAPS); + int total_size = Image::get_image_data_size(tw, th, format, mipmaps); + + PoolVector img_data; + img_data.resize(total_size); + + { + PoolVector::Write w = img_data.write(); + int bytes = f->get_buffer(w.ptr(), total_size); + if (bytes != total_size) { + f->close(); + memdelete(f); + ERR_FAIL_V(ERR_FILE_CORRUPT); + } + } + + image->create(tw, th, mipmaps, format, img_data); + } + + set_layer_data(image, layer); + } + + memdelete(f); + + path_to_file = p_path; + _change_notify(); + return OK; +} + +String TextureLayered::get_load_path() const { + + return path_to_file; +} + uint32_t TextureLayered::get_width() const { return width; } @@ -2262,6 +2399,20 @@ uint32_t TextureLayered::get_depth() const { return depth; } +void TextureLayered::reload_from_file() { + + String path = get_path(); + if (!path.is_resource_file()) + return; + + path = ResourceLoader::path_remap(path); //remap for translation + path = ResourceLoader::import_remap(path); //remap for import + if (!path.is_resource_file()) + return; + + load(path); +} + void TextureLayered::_set_data(const Dictionary &p_data) { ERR_FAIL_COND(!p_data.has("width")); ERR_FAIL_COND(!p_data.has("height")); @@ -2410,139 +2561,11 @@ RES ResourceFormatLoaderTextureLayered::load(const String &p_path, const String ERR_FAIL_V_MSG(RES(), "Unrecognized layered texture extension."); } - FileAccess *f = FileAccess::open(p_path, FileAccess::READ); - ERR_FAIL_COND_V_MSG(!f, RES(), "Cannot open file '" + p_path + "'."); - - uint8_t header[5] = { 0, 0, 0, 0, 0 }; - f->get_buffer(header, 4); - - if (header[0] == 'G' && header[1] == 'D' && header[2] == '3' && header[3] == 'T') { - if (tex3d.is_null()) { - f->close(); - memdelete(f); - ERR_FAIL_COND_V(tex3d.is_null(), RES()) - } - } else if (header[0] == 'G' && header[1] == 'D' && header[2] == 'A' && header[3] == 'T') { - if (texarr.is_null()) { - f->close(); - memdelete(f); - ERR_FAIL_COND_V(texarr.is_null(), RES()) - } - } else { - - f->close(); - memdelete(f); - ERR_FAIL_V_MSG(RES(), "Unrecognized layered texture file format '" + String((const char *)header) + "'."); - } - - int tw = f->get_32(); - int th = f->get_32(); - int td = f->get_32(); - int flags = f->get_32(); //texture flags! - Image::Format format = Image::Format(f->get_32()); - uint32_t compression = f->get_32(); // 0 - lossless (PNG), 1 - vram, 2 - uncompressed - - lt->create(tw, th, td, format, flags); - - for (int layer = 0; layer < td; layer++) { - - Ref image; - image.instance(); - - if (compression == COMPRESSION_LOSSLESS) { - //look for a PNG file inside - - int mipmaps = f->get_32(); - Vector > mipmap_images; - - for (int i = 0; i < mipmaps; i++) { - uint32_t size = f->get_32(); - - PoolVector pv; - pv.resize(size); - { - PoolVector::Write w = pv.write(); - f->get_buffer(w.ptr(), size); - } - - Ref img = Image::lossless_unpacker(pv); - - if (img.is_null() || img->empty() || format != img->get_format()) { - if (r_error) { - *r_error = ERR_FILE_CORRUPT; - } - f->close(); - memdelete(f); - ERR_FAIL_V(RES()); - } - - mipmap_images.push_back(img); - } - - if (mipmap_images.size() == 1) { - - image = mipmap_images[0]; - - } else { - int total_size = Image::get_image_data_size(tw, th, format, true); - PoolVector img_data; - img_data.resize(total_size); - - { - PoolVector::Write w = img_data.write(); - - int ofs = 0; - for (int i = 0; i < mipmap_images.size(); i++) { - - PoolVector id = mipmap_images[i]->get_data(); - int len = id.size(); - PoolVector::Read r = id.read(); - copymem(&w[ofs], r.ptr(), len); - ofs += len; - } - } - - image->create(tw, th, true, format, img_data); - if (image->empty()) { - if (r_error) { - *r_error = ERR_FILE_CORRUPT; - } - f->close(); - memdelete(f); - ERR_FAIL_V(RES()); - } - } - - } else { - - //look for regular format - bool mipmaps = (flags & Texture::FLAG_MIPMAPS); - int total_size = Image::get_image_data_size(tw, th, format, mipmaps); - - PoolVector img_data; - img_data.resize(total_size); - - { - PoolVector::Write w = img_data.write(); - int bytes = f->get_buffer(w.ptr(), total_size); - if (bytes != total_size) { - if (r_error) { - *r_error = ERR_FILE_CORRUPT; - } - f->close(); - memdelete(f); - ERR_FAIL_V(RES()); - } - } - - image->create(tw, th, mipmaps, format, img_data); - } - - lt->set_layer_data(image, layer); - } - + Error err = lt->load(p_path); if (r_error) *r_error = OK; + if (err != OK) + return RES(); return lt; } diff --git a/scene/resources/texture.h b/scene/resources/texture.h index add9572dffda..9ee35ee13eb8 100644 --- a/scene/resources/texture.h +++ b/scene/resources/texture.h @@ -477,7 +477,14 @@ class TextureLayered : public Resource { FLAGS_DEFAULT = FLAG_FILTER, }; + enum CompressMode { + COMPRESS_LOSSLESS, + COMPRESS_VIDEO_RAM, + COMPRESS_UNCOMPRESSED + }; + private: + String path_to_file; bool is_3d; RID texture; Image::Format format; @@ -487,6 +494,8 @@ class TextureLayered : public Resource { int height; int depth; + virtual void reload_from_file(); + void _set_data(const Dictionary &p_data); Dictionary _get_data() const; @@ -498,6 +507,9 @@ class TextureLayered : public Resource { uint32_t get_flags() const; Image::Format get_format() const; + Error load(const String &p_path); + String get_load_path() const; + uint32_t get_width() const; uint32_t get_height() const; uint32_t get_depth() const; @@ -536,12 +548,6 @@ class TextureArray : public TextureLayered { class ResourceFormatLoaderTextureLayered : public ResourceFormatLoader { public: - enum Compression { - COMPRESSION_LOSSLESS, - COMPRESSION_VRAM, - COMPRESSION_UNCOMPRESSED - }; - virtual RES load(const String &p_path, const String &p_original_path = "", Error *r_error = NULL); virtual void get_recognized_extensions(List *p_extensions) const; virtual bool handles_type(const String &p_type) const; diff --git a/servers/visual/rasterizer.h b/servers/visual/rasterizer.h index bf0bd1b4a4c5..894b782f5b38 100644 --- a/servers/visual/rasterizer.h +++ b/servers/visual/rasterizer.h @@ -120,6 +120,8 @@ class RasterizerScene { InstanceBase *lightmap_capture; RID lightmap; Vector lightmap_capture_data; //in a array (12 values) to avoid wasting space if unused. Alpha is unused, but needed to send to shader + int lightmap_slice; + Rect2 lightmap_uv_rect; virtual void base_removed() = 0; virtual void base_changed(bool p_aabb, bool p_materials) = 0; @@ -135,6 +137,8 @@ class RasterizerScene { baked_light = false; redraw_if_visible = false; lightmap_capture = NULL; + lightmap_slice = -1; + lightmap_uv_rect = Rect2(0, 0, 1, 1); } }; diff --git a/servers/visual/visual_server_raster.h b/servers/visual/visual_server_raster.h index daae705e1427..cf6d6f7d7298 100644 --- a/servers/visual/visual_server_raster.h +++ b/servers/visual/visual_server_raster.h @@ -548,7 +548,7 @@ class VisualServerRaster : public VisualServer { BIND3(instance_set_blend_shape_weight, RID, int, float) BIND3(instance_set_surface_material, RID, int, RID) BIND2(instance_set_visible, RID, bool) - BIND3(instance_set_use_lightmap, RID, RID, RID) + BIND5(instance_set_use_lightmap, RID, RID, RID, int, const Rect2 &) BIND2(instance_set_custom_aabb, RID, AABB) diff --git a/servers/visual/visual_server_scene.cpp b/servers/visual/visual_server_scene.cpp index d62ee8791e9a..7c4b201400ab 100644 --- a/servers/visual/visual_server_scene.cpp +++ b/servers/visual/visual_server_scene.cpp @@ -484,7 +484,7 @@ void VisualServerScene::instance_set_base(RID p_instance, RID p_base) { InstanceLightmapCaptureData *lightmap_capture = static_cast(instance->base_data); //erase dependencies, since no longer a lightmap while (lightmap_capture->users.front()) { - instance_set_use_lightmap(lightmap_capture->users.front()->get()->self, RID(), RID()); + instance_set_use_lightmap(lightmap_capture->users.front()->get()->self, RID(), RID(), -1, Rect2(0, 0, 1, 1)); } } break; case VS::INSTANCE_GI_PROBE: { @@ -805,7 +805,7 @@ inline bool is_geometry_instance(VisualServer::InstanceType p_type) { return p_type == VS::INSTANCE_MESH || p_type == VS::INSTANCE_MULTIMESH || p_type == VS::INSTANCE_PARTICLES || p_type == VS::INSTANCE_IMMEDIATE; } -void VisualServerScene::instance_set_use_lightmap(RID p_instance, RID p_lightmap_instance, RID p_lightmap) { +void VisualServerScene::instance_set_use_lightmap(RID p_instance, RID p_lightmap_instance, RID p_lightmap, int p_lightmap_slice, const Rect2 &p_lightmap_uv_rect) { Instance *instance = instance_owner.get(p_instance); ERR_FAIL_COND(!instance); @@ -814,6 +814,8 @@ void VisualServerScene::instance_set_use_lightmap(RID p_instance, RID p_lightmap InstanceLightmapCaptureData *lightmap_capture = static_cast(((Instance *)instance->lightmap_capture)->base_data); lightmap_capture->users.erase(instance); instance->lightmap = RID(); + instance->lightmap_slice = -1; + instance->lightmap_uv_rect = Rect2(0, 0, 1, 1); instance->lightmap_capture = NULL; } @@ -826,6 +828,8 @@ void VisualServerScene::instance_set_use_lightmap(RID p_instance, RID p_lightmap InstanceLightmapCaptureData *lightmap_capture = static_cast(((Instance *)instance->lightmap_capture)->base_data); lightmap_capture->users.insert(instance); instance->lightmap = p_lightmap; + instance->lightmap_slice = p_lightmap_slice; + instance->lightmap_uv_rect = p_lightmap_uv_rect; } } @@ -3618,7 +3622,7 @@ bool VisualServerScene::free(RID p_rid) { Instance *instance = instance_owner.get(p_rid); - instance_set_use_lightmap(p_rid, RID(), RID()); + instance_set_use_lightmap(p_rid, RID(), RID(), -1, Rect2(0, 0, 1, 1)); instance_set_scenario(p_rid, RID()); instance_set_base(p_rid, RID()); instance_geometry_set_material_override(p_rid, RID()); diff --git a/servers/visual/visual_server_scene.h b/servers/visual/visual_server_scene.h index c228e2731a3c..a33e9a4e97da 100644 --- a/servers/visual/visual_server_scene.h +++ b/servers/visual/visual_server_scene.h @@ -517,7 +517,7 @@ class VisualServerScene { virtual void instance_set_blend_shape_weight(RID p_instance, int p_shape, float p_weight); virtual void instance_set_surface_material(RID p_instance, int p_surface, RID p_material); virtual void instance_set_visible(RID p_instance, bool p_visible); - virtual void instance_set_use_lightmap(RID p_instance, RID p_lightmap_instance, RID p_lightmap); + virtual void instance_set_use_lightmap(RID p_instance, RID p_lightmap_instance, RID p_lightmap, int p_lightmap_slice, const Rect2 &p_lightmap_uv_rect); virtual void instance_set_custom_aabb(RID p_instance, AABB p_aabb); diff --git a/servers/visual/visual_server_wrap_mt.h b/servers/visual/visual_server_wrap_mt.h index 70052f2bd0ac..0f24d7908f70 100644 --- a/servers/visual/visual_server_wrap_mt.h +++ b/servers/visual/visual_server_wrap_mt.h @@ -470,7 +470,7 @@ class VisualServerWrapMT : public VisualServer { FUNC3(instance_set_blend_shape_weight, RID, int, float) FUNC3(instance_set_surface_material, RID, int, RID) FUNC2(instance_set_visible, RID, bool) - FUNC3(instance_set_use_lightmap, RID, RID, RID) + FUNC5(instance_set_use_lightmap, RID, RID, RID, int, const Rect2 &) FUNC2(instance_set_custom_aabb, RID, AABB) diff --git a/servers/visual_server.cpp b/servers/visual_server.cpp index c14a2d9abc7f..7d6dda10bf02 100644 --- a/servers/visual_server.cpp +++ b/servers/visual_server.cpp @@ -1938,7 +1938,7 @@ void VisualServer::_bind_methods() { ClassDB::bind_method(D_METHOD("instance_set_blend_shape_weight", "instance", "shape", "weight"), &VisualServer::instance_set_blend_shape_weight); ClassDB::bind_method(D_METHOD("instance_set_surface_material", "instance", "surface", "material"), &VisualServer::instance_set_surface_material); ClassDB::bind_method(D_METHOD("instance_set_visible", "instance", "visible"), &VisualServer::instance_set_visible); - ClassDB::bind_method(D_METHOD("instance_set_use_lightmap", "instance", "lightmap_instance", "lightmap"), &VisualServer::instance_set_use_lightmap); + ClassDB::bind_method(D_METHOD("instance_set_use_lightmap", "instance", "lightmap_instance", "lightmap", "lightmap_slice", "lightmap_uv_rect"), &VisualServer::instance_set_use_lightmap, DEFVAL(-1), DEFVAL(Rect2(0, 0, 1, 1))); ClassDB::bind_method(D_METHOD("instance_set_custom_aabb", "instance", "aabb"), &VisualServer::instance_set_custom_aabb); ClassDB::bind_method(D_METHOD("instance_attach_skeleton", "instance", "skeleton"), &VisualServer::instance_attach_skeleton); ClassDB::bind_method(D_METHOD("instance_set_exterior", "instance", "enabled"), &VisualServer::instance_set_exterior); diff --git a/servers/visual_server.h b/servers/visual_server.h index 48e73e043809..dee3b8749ef5 100644 --- a/servers/visual_server.h +++ b/servers/visual_server.h @@ -843,7 +843,7 @@ class VisualServer : public Object { virtual void instance_set_surface_material(RID p_instance, int p_surface, RID p_material) = 0; virtual void instance_set_visible(RID p_instance, bool p_visible) = 0; - virtual void instance_set_use_lightmap(RID p_instance, RID p_lightmap_instance, RID p_lightmap) = 0; + virtual void instance_set_use_lightmap(RID p_instance, RID p_lightmap_instance, RID p_lightmap, int p_lightmap_slice, const Rect2 &p_lightmap_uv_rect) = 0; virtual void instance_set_custom_aabb(RID p_instance, AABB aabb) = 0; diff --git a/thirdparty/README.md b/thirdparty/README.md index 418c5cbeb5b7..9ff0bf6e6f9d 100644 --- a/thirdparty/README.md +++ b/thirdparty/README.md @@ -39,6 +39,25 @@ Files extracted from upstream source: - all .cpp, .h, and .txt files in ConvectionKernels/ +## embree + +- Upstream: https://github.com/embree/embree +- Version: 3.12.1 (69bd4c272f1ed608494f233ecfff3feec516880b, 2020) +- License: Apache 2.0 + +Files extracted from upstream: + +- All cpp files listed in `modules/raytrace/godot_update_embree.py` +- All header files in the directories listed in `modules/raytrace/godot_update_embree.py` + +The `modules/raytrace/godot_update_embree.py`script can be used to pull the +relevant files from the latest Embree release and apply some automatic changes. + +Some minor changes have been made in order to fix build errors. +They are marked with `// -- GODOT start --` and `// -- GODOT end --` +comments. Apply the patches in the `patches/` folder when syncing on newer upstream +commits. + ## enet @@ -334,6 +353,10 @@ Collection of single-file libraries used in Godot components. * Version: git (2f625846a775501fb69456567409a8b12f10ea25, 2012) * License: BSD-3-Clause * Modifications: use `const char*` instead of `char*` for input string +- `stb_rect_pack.h` + * Upstream: https://github.com/nothings/stb + * Version: 1.00 + * License: Public Domain (Unlicense) or MIT - `stb_vorbis.c` * Upstream: https://github.com/nothings/stb * Version: 1.20 (314d0a6f9af5af27e585336eecea333e95c5a2d8, 2020) @@ -360,6 +383,37 @@ Files extracted from the upstream source: - LICENSE.txt +## oidn + +- Upstream: https://github.com/OpenImageDenoise/oidn +- Version: 1.1.0 (c58c5216db05ceef4cde5a096862f2eeffd14c06, 2019) +- License: Apache 2.0 + +Files extracted from upstream source: + +common/* (except tasking.* and CMakeLists.txt) +core/* +include/OpenImageDenoise/* (except version.h.in) +LICENSE.txt +mkl-dnn/include/* +mkl-dnn/src/* (except CMakeLists.txt) +weights/rtlightmap_hdr.tza +scripts/resource_to_cpp.py + +Modified files: +Modifications are marked with `// -- GODOT start --` and `// -- GODOT end --`. +Patch files are provided in `oidn/patches/`. + +core/autoencoder.cpp +core/autoencoder.h +core/common.h +core/device.cpp +core/device.h +core/transfer_function.cpp + +scripts/resource_to_cpp.py (used in modules/denoise/resource_to_cpp.py) + + ## opus - Upstream: https://opus-codec.org diff --git a/thirdparty/embree/common/algorithms/parallel_any_of.h b/thirdparty/embree/common/algorithms/parallel_any_of.h new file mode 100644 index 000000000000..01f1f80f6c5d --- /dev/null +++ b/thirdparty/embree/common/algorithms/parallel_any_of.h @@ -0,0 +1,55 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "parallel_reduce.h" + +namespace embree +{ + + template + __forceinline bool parallel_any_of (Index first, Index last, UnaryPredicate pred) + { + bool ret = false; + +#if defined(TASKING_TBB) +#if TBB_INTERFACE_VERSION >= 12002 + tbb::task_group_context context; + tbb::parallel_for(tbb::blocked_range{first, last}, [&ret,pred,&context](const tbb::blocked_range& r) { + if (context.is_group_execution_cancelled()) return; + for (size_t i = r.begin(); i != r.end(); ++i) { + if (pred(i)) { + ret = true; + context.cancel_group_execution(); + } + } + }); +#else + tbb::parallel_for(tbb::blocked_range{first, last}, [&ret,pred](const tbb::blocked_range& r) { + if (tbb::task::self().is_cancelled()) return; + for (size_t i = r.begin(); i != r.end(); ++i) { + if (pred(i)) { + ret = true; + tbb::task::self().cancel_group_execution(); + } + } + }); +#endif +#else + ret = parallel_reduce (first, last, false, [pred](const range& r)->bool { + bool localret = false; + for (auto i=r.begin(); i() + ); +#endif + + return ret; + } + +} // end namespace diff --git a/thirdparty/embree/common/algorithms/parallel_filter.cpp b/thirdparty/embree/common/algorithms/parallel_filter.cpp new file mode 100644 index 000000000000..acddc0ff8168 --- /dev/null +++ b/thirdparty/embree/common/algorithms/parallel_filter.cpp @@ -0,0 +1,56 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "parallel_filter.h" +#include "../sys/regression.h" +#include + +namespace embree +{ + struct parallel_filter_regression_test : public RegressionTest + { + parallel_filter_regression_test(const char* name) : RegressionTest(name) { + registerRegressionTest(this); + } + + bool run () + { + bool passed = true; + auto pred = [&]( uint32_t v ) { return (v & 0x3) == 0; }; + + for (size_t N=10; N<1000000; N=size_t(2.1*N)) + { + size_t N0 = rand() % N; + + /* initialize array with random numbers */ + std::vector src(N); + std::map m; + for (size_t i=0; i + inline Index sequential_filter( Ty* data, const Index first, const Index last, const Predicate& predicate) + { + Index j = first; + for (Index i=first; i + inline Index parallel_filter( Ty* data, const Index begin, const Index end, const Index minStepSize, const Predicate& predicate) + { + /* sequential fallback */ + if (end-begin <= minStepSize) + return sequential_filter(data,begin,end,predicate); + + /* calculate number of tasks to use */ + enum { MAX_TASKS = 64 }; + const Index numThreads = TaskScheduler::threadCount(); + const Index numBlocks = (end-begin+minStepSize-1)/minStepSize; + const Index taskCount = min(numThreads,numBlocks,(Index)MAX_TASKS); + + /* filter blocks */ + Index nused[MAX_TASKS]; + Index nfree[MAX_TASKS]; + parallel_for(taskCount, [&](const Index taskIndex) + { + const Index i0 = begin+(taskIndex+0)*(end-begin)/taskCount; + const Index i1 = begin+(taskIndex+1)*(end-begin)/taskCount; + const Index i2 = sequential_filter(data,i0,i1,predicate); + nused[taskIndex] = i2-i0; + nfree[taskIndex] = i1-i2; + }); + + /* calculate offsets */ + Index sused=0; + Index sfree=0; + Index pfree[MAX_TASKS]; + for (Index i=0; i0; i--) + { + if (k0 > r1) break; + Index k1 = k0+nused[i]; + Index src = begin+(i+0)*(end-begin)/taskCount+nused[i]; + for (Index i=max(r0,k0); i= begin && dst < end); + assert(isrc >= begin && isrc < end); + data[dst++] = data[isrc]; + } + k0 = k1; + } + }); + + return begin+sused; + } +} diff --git a/thirdparty/embree/common/algorithms/parallel_for.cpp b/thirdparty/embree/common/algorithms/parallel_for.cpp new file mode 100644 index 000000000000..ef070ebc4d07 --- /dev/null +++ b/thirdparty/embree/common/algorithms/parallel_for.cpp @@ -0,0 +1,48 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "parallel_for.h" +#include "../sys/regression.h" + +namespace embree +{ + struct parallel_for_regression_test : public RegressionTest + { + parallel_for_regression_test(const char* name) : RegressionTest(name) { + registerRegressionTest(this); + } + + bool run () + { + bool passed = true; + + const size_t M = 10; + for (size_t N=10; N<10000000; N=size_t(2.1*N)) + { + /* sequentially calculate sum of squares */ + size_t sum0 = 0; + for (size_t i=0; i sum1(0); + parallel_for( size_t(0), size_t(N), size_t(1024), [&](const range& r) + { + size_t s = 0; + for (size_t i=r.begin(); i + __forceinline void parallel_for( const Index N, const Func& func) + { +#if defined(TASKING_INTERNAL) + if (N) { + TaskScheduler::spawn(Index(0),N,Index(1),[&] (const range& r) { + assert(r.size() == 1); + func(r.begin()); + }); + if (!TaskScheduler::wait()) + throw std::runtime_error("task cancelled"); + } + +#elif defined(TASKING_TBB) + #if TBB_INTERFACE_VERSION >= 12002 + tbb::task_group_context context; + tbb::parallel_for(Index(0),N,Index(1),[&](Index i) { + func(i); + },context); + if (context.is_group_execution_cancelled()) + throw std::runtime_error("task cancelled"); + #else + tbb::parallel_for(Index(0),N,Index(1),[&](Index i) { + func(i); + }); + if (tbb::task::self().is_cancelled()) + throw std::runtime_error("task cancelled"); + #endif + +#elif defined(TASKING_PPL) + concurrency::parallel_for(Index(0),N,Index(1),[&](Index i) { + func(i); + }); +#else +# error "no tasking system enabled" +#endif + } + + /* parallel for with range and granulatity */ + template + __forceinline void parallel_for( const Index first, const Index last, const Index minStepSize, const Func& func) + { + assert(first <= last); +#if defined(TASKING_INTERNAL) + TaskScheduler::spawn(first,last,minStepSize,func); + if (!TaskScheduler::wait()) + throw std::runtime_error("task cancelled"); + +#elif defined(TASKING_TBB) + #if TBB_INTERFACE_VERSION >= 12002 + tbb::task_group_context context; + tbb::parallel_for(tbb::blocked_range(first,last,minStepSize),[&](const tbb::blocked_range& r) { + func(range(r.begin(),r.end())); + },context); + if (context.is_group_execution_cancelled()) + throw std::runtime_error("task cancelled"); + #else + tbb::parallel_for(tbb::blocked_range(first,last,minStepSize),[&](const tbb::blocked_range& r) { + func(range(r.begin(),r.end())); + }); + if (tbb::task::self().is_cancelled()) + throw std::runtime_error("task cancelled"); + #endif + +#elif defined(TASKING_PPL) + concurrency::parallel_for(first, last, Index(1) /*minStepSize*/, [&](Index i) { + func(range(i,i+1)); + }); + +#else +# error "no tasking system enabled" +#endif + } + + /* parallel for with range */ + template + __forceinline void parallel_for( const Index first, const Index last, const Func& func) + { + assert(first <= last); + parallel_for(first,last,(Index)1,func); + } + +#if defined(TASKING_TBB) && (TBB_INTERFACE_VERSION > 4001) + + template + __forceinline void parallel_for_static( const Index N, const Func& func) + { + #if TBB_INTERFACE_VERSION >= 12002 + tbb::task_group_context context; + tbb::parallel_for(Index(0),N,Index(1),[&](Index i) { + func(i); + },tbb::simple_partitioner(),context); + if (context.is_group_execution_cancelled()) + throw std::runtime_error("task cancelled"); + #else + tbb::parallel_for(Index(0),N,Index(1),[&](Index i) { + func(i); + },tbb::simple_partitioner()); + if (tbb::task::self().is_cancelled()) + throw std::runtime_error("task cancelled"); + #endif + } + + typedef tbb::affinity_partitioner affinity_partitioner; + + template + __forceinline void parallel_for_affinity( const Index N, const Func& func, tbb::affinity_partitioner& ap) + { + #if TBB_INTERFACE_VERSION >= 12002 + tbb::task_group_context context; + tbb::parallel_for(Index(0),N,Index(1),[&](Index i) { + func(i); + },ap,context); + if (context.is_group_execution_cancelled()) + throw std::runtime_error("task cancelled"); + #else + tbb::parallel_for(Index(0),N,Index(1),[&](Index i) { + func(i); + },ap); + if (tbb::task::self().is_cancelled()) + throw std::runtime_error("task cancelled"); + #endif + } + +#else + + template + __forceinline void parallel_for_static( const Index N, const Func& func) + { + parallel_for(N,func); + } + + struct affinity_partitioner { + }; + + template + __forceinline void parallel_for_affinity( const Index N, const Func& func, affinity_partitioner& ap) + { + parallel_for(N,func); + } + +#endif +} diff --git a/thirdparty/embree/common/algorithms/parallel_for_for.cpp b/thirdparty/embree/common/algorithms/parallel_for_for.cpp new file mode 100644 index 000000000000..0337611b350a --- /dev/null +++ b/thirdparty/embree/common/algorithms/parallel_for_for.cpp @@ -0,0 +1,63 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "parallel_for_for.h" +#include "../sys/regression.h" + +namespace embree +{ + struct parallel_for_for_regression_test : public RegressionTest + { + parallel_for_for_regression_test(const char* name) : RegressionTest(name) { + registerRegressionTest(this); + } + + bool run () + { + bool passed = true; + + /* create vector with random numbers */ + size_t sum0 = 0; + size_t K = 0; + const size_t M = 1000; + std::vector* > array2(M); + for (size_t i=0; i(N); + for (size_t j=0; j> verify_k(K); + for (size_t i=0; i sum1(0); + parallel_for_for( array2, size_t(1), [&](std::vector* v, const range& r, size_t k) -> size_t + { + size_t s = 0; + for (size_t i=r.begin(); i + __forceinline void sequential_for_for( ArrayArray& array2, const size_t minStepSize, const Func& func ) + { + size_t k=0; + for (size_t i=0; i!=array2.size(); ++i) { + const size_t N = array2[i]->size(); + if (N) func(array2[i],range(0,N),k); + k+=N; + } + } + + class ParallelForForState + { + public: + + enum { MAX_TASKS = 64 }; + + __forceinline ParallelForForState () + : taskCount(0) {} + + template + __forceinline ParallelForForState (ArrayArray& array2, const size_t minStepSize) { + init(array2,minStepSize); + } + + template + __forceinline void init ( ArrayArray& array2, const size_t minStepSize ) + { + /* first calculate total number of elements */ + size_t N = 0; + for (size_t i=0; isize() : 0; + } + this->N = N; + + /* calculate number of tasks to use */ + const size_t numThreads = TaskScheduler::threadCount(); + const size_t numBlocks = (N+minStepSize-1)/minStepSize; + taskCount = max(size_t(1),min(numThreads,numBlocks,size_t(ParallelForForState::MAX_TASKS))); + + /* calculate start (i,j) for each task */ + size_t taskIndex = 0; + i0[taskIndex] = 0; + j0[taskIndex] = 0; + size_t k0 = (++taskIndex)*N/taskCount; + for (size_t i=0, k=0; taskIndex < taskCount; i++) + { + assert(isize() : 0; + while (j= k0 && taskIndex < taskCount) { + assert(taskIndex + __forceinline void parallel_for_for( ArrayArray& array2, const size_t minStepSize, const Func& func ) + { + ParallelForForState state(array2,minStepSize); + + parallel_for(state.taskCount, [&](const size_t taskIndex) + { + /* calculate range */ + const size_t k0 = (taskIndex+0)*state.size()/state.taskCount; + const size_t k1 = (taskIndex+1)*state.size()/state.taskCount; + size_t i0 = state.i0[taskIndex]; + size_t j0 = state.j0[taskIndex]; + + /* iterate over arrays */ + size_t k=k0; + for (size_t i=i0; ksize() : 0; + const size_t r0 = j0, r1 = min(N,r0+k1-k); + if (r1 > r0) func(array2[i],range(r0,r1),k); + k+=r1-r0; j0 = 0; + } + }); + } + + template + __forceinline void parallel_for_for( ArrayArray& array2, const Func& func ) + { + parallel_for_for(array2,1,func); + } + + template + __forceinline Value parallel_for_for_reduce( ArrayArray& array2, const size_t minStepSize, const Value& identity, const Func& func, const Reduction& reduction ) + { + ParallelForForState state(array2,minStepSize); + Value temp[ParallelForForState::MAX_TASKS]; + + for (size_t i=0; isize() : 0; + const size_t r0 = j0, r1 = min(N,r0+k1-k); + if (r1 > r0) temp[taskIndex] = reduction(temp[taskIndex],func(array2[i],range(r0,r1),k)); + k+=r1-r0; j0 = 0; + } + }); + + Value ret = identity; + for (size_t i=0; i + __forceinline Value parallel_for_for_reduce( ArrayArray& array2, const Value& identity, const Func& func, const Reduction& reduction) + { + return parallel_for_for_reduce(array2,1,identity,func,reduction); + } +} diff --git a/thirdparty/embree/common/algorithms/parallel_for_for_prefix_sum.cpp b/thirdparty/embree/common/algorithms/parallel_for_for_prefix_sum.cpp new file mode 100644 index 000000000000..0169d8e48189 --- /dev/null +++ b/thirdparty/embree/common/algorithms/parallel_for_for_prefix_sum.cpp @@ -0,0 +1,85 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "parallel_for_for_prefix_sum.h" +#include "../sys/regression.h" + +namespace embree +{ + struct parallel_for_for_prefix_sum_regression_test : public RegressionTest + { + parallel_for_for_prefix_sum_regression_test(const char* name) : RegressionTest(name) { + registerRegressionTest(this); + } + + bool run () + { + bool passed = true; + + /* create vector with random numbers */ + const size_t M = 10; + std::vector> flattened; + typedef std::vector* > ArrayArray; + ArrayArray array2(M); + size_t K = 0; + for (size_t i=0; i(N); + for (size_t j=0; j> verify_k(K); + for (size_t i=0; i state(array2,size_t(1)); + + /* dry run only counts */ + size_t S = parallel_for_for_prefix_sum0( state, array2, size_t(0), [&](std::vector* v, const range& r, size_t k, size_t i) -> size_t + { + size_t s = 0; + for (size_t i=r.begin(); i* v, const range& r, size_t k, size_t i, const size_t base) -> size_t + { + size_t s = 0; + for (size_t i=r.begin(); i + struct ParallelForForPrefixSumState : public ParallelForForState + { + __forceinline ParallelForForPrefixSumState () {} + + template + __forceinline ParallelForForPrefixSumState (ArrayArray& array2, const size_t minStepSize) + : ParallelForForState(array2,minStepSize) {} + + ParallelPrefixSumState prefix_state; + }; + + template + __forceinline Value parallel_for_for_prefix_sum0( ParallelForForPrefixSumState& state, ArrayArray& array2, Index minStepSize, + const Value& identity, const Func& func, const Reduction& reduction) + { + /* calculate number of tasks to use */ + const size_t taskCount = state.taskCount; + /* perform parallel prefix sum */ + parallel_for(taskCount, [&](const size_t taskIndex) + { + const size_t k0 = (taskIndex+0)*state.size()/taskCount; + const size_t k1 = (taskIndex+1)*state.size()/taskCount; + size_t i0 = state.i0[taskIndex]; + size_t j0 = state.j0[taskIndex]; + + /* iterate over arrays */ + size_t k=k0; + Value N=identity; + for (size_t i=i0; ksize() : 0; + const size_t r0 = j0, r1 = min(size,r0+k1-k); + if (r1 > r0) N = reduction(N, func(array2[i],range((Index)r0,(Index)r1),(Index)k,(Index)i)); + k+=r1-r0; j0 = 0; + } + state.prefix_state.counts[taskIndex] = N; + }); + + /* calculate prefix sum */ + Value sum=identity; + for (size_t i=0; i + __forceinline Value parallel_for_for_prefix_sum1( ParallelForForPrefixSumState& state, ArrayArray& array2, Index minStepSize, + const Value& identity, const Func& func, const Reduction& reduction) + { + /* calculate number of tasks to use */ + const size_t taskCount = state.taskCount; + /* perform parallel prefix sum */ + parallel_for(taskCount, [&](const size_t taskIndex) + { + const size_t k0 = (taskIndex+0)*state.size()/taskCount; + const size_t k1 = (taskIndex+1)*state.size()/taskCount; + size_t i0 = state.i0[taskIndex]; + size_t j0 = state.j0[taskIndex]; + + /* iterate over arrays */ + size_t k=k0; + Value N=identity; + for (size_t i=i0; ksize() : 0; + const size_t r0 = j0, r1 = min(size,r0+k1-k); + if (r1 > r0) N = reduction(N, func(array2[i],range((Index)r0,(Index)r1),(Index)k,(Index)i,reduction(state.prefix_state.sums[taskIndex],N))); + k+=r1-r0; j0 = 0; + } + state.prefix_state.counts[taskIndex] = N; + }); + + /* calculate prefix sum */ + Value sum=identity; + for (size_t i=0; i + __forceinline Value parallel_for_for_prefix_sum0( ParallelForForPrefixSumState& state, ArrayArray& array2, + const Value& identity, const Func& func, const Reduction& reduction) + { + return parallel_for_for_prefix_sum0(state,array2,size_t(1),identity,func,reduction); + } + + template + __forceinline Value parallel_for_for_prefix_sum1( ParallelForForPrefixSumState& state, ArrayArray& array2, + const Value& identity, const Func& func, const Reduction& reduction) + { + return parallel_for_for_prefix_sum1(state,array2,size_t(1),identity,func,reduction); + } +} diff --git a/thirdparty/embree/common/algorithms/parallel_map.cpp b/thirdparty/embree/common/algorithms/parallel_map.cpp new file mode 100644 index 000000000000..09dc303f8146 --- /dev/null +++ b/thirdparty/embree/common/algorithms/parallel_map.cpp @@ -0,0 +1,47 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "parallel_map.h" +#include "../sys/regression.h" + +namespace embree +{ + struct parallel_map_regression_test : public RegressionTest + { + parallel_map_regression_test(const char* name) : RegressionTest(name) { + registerRegressionTest(this); + } + + bool run () + { + bool passed = true; + + /* create key/value vectors with random numbers */ + const size_t N = 10000; + std::vector keys(N); + std::vector vals(N); + for (size_t i=0; i map; + map.init(keys,vals); + + /* check that all keys are properly mapped */ + for (size_t i=0; i + class parallel_map + { + /* key/value pair to build the map */ + struct KeyValue + { + __forceinline KeyValue () {} + + __forceinline KeyValue (const Key key, const Val val) + : key(key), val(val) {} + + __forceinline operator Key() const { + return key; + } + + public: + Key key; + Val val; + }; + + public: + + /*! parallel map constructors */ + parallel_map () {} + + /*! construction from pair of vectors */ + template + parallel_map (const KeyVector& keys, const ValVector& values) { init(keys,values); } + + /*! initialized the parallel map from a vector with keys and values */ + template + void init(const KeyVector& keys, const ValVector& values) + { + /* reserve sufficient space for all data */ + assert(keys.size() == values.size()); + vec.resize(keys.size()); + + /* generate key/value pairs */ + parallel_for( size_t(0), keys.size(), size_t(4*4096), [&](const range& r) { + for (size_t i=r.begin(); i temp(keys.size()); + radix_sort(vec.data(),temp.data(),keys.size()); + } + + /*! Returns a pointer to the value associated with the specified key. The pointer will be nullptr of the key is not contained in the map. */ + __forceinline const Val* lookup(const Key& key) const + { + typename std::vector::const_iterator i = std::lower_bound(vec.begin(), vec.end(), key); + if (i == vec.end()) return nullptr; + if (i->key != key) return nullptr; + return &i->val; + } + + /*! If the key is in the map, the function returns the value associated with the key, otherwise it returns the default value. */ + __forceinline Val lookup(const Key& key, const Val& def) const + { + typename std::vector::const_iterator i = std::lower_bound(vec.begin(), vec.end(), key); + if (i == vec.end()) return def; + if (i->key != key) return def; + return i->val; + } + + /*! clears all state */ + void clear() { + vec.clear(); + } + + private: + std::vector vec; //!< vector containing sorted elements + }; +} diff --git a/thirdparty/embree/common/algorithms/parallel_partition.cpp b/thirdparty/embree/common/algorithms/parallel_partition.cpp new file mode 100644 index 000000000000..eb20c4465d2c --- /dev/null +++ b/thirdparty/embree/common/algorithms/parallel_partition.cpp @@ -0,0 +1,53 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "parallel_partition.h" +#include "../sys/regression.h" + +namespace embree +{ + struct parallel_partition_regression_test : public RegressionTest + { + parallel_partition_regression_test(const char* name) : RegressionTest(name) { + registerRegressionTest(this); + } + + bool run () + { + bool passed = true; + + for (size_t i=0; i<100; i++) + { + /* create random permutation */ + size_t N = std::rand() % 1000000; + std::vector array(N); + for (unsigned i=0; i= split; + } + + return passed; + } + }; + + parallel_partition_regression_test parallel_partition_regression("parallel_partition_regression_test"); +} diff --git a/thirdparty/embree/common/algorithms/parallel_partition.h b/thirdparty/embree/common/algorithms/parallel_partition.h new file mode 100644 index 000000000000..3b3ad7c8541f --- /dev/null +++ b/thirdparty/embree/common/algorithms/parallel_partition.h @@ -0,0 +1,283 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "parallel_for.h" +#include "../math/range.h" + +namespace embree +{ + /* serial partitioning */ + template + __forceinline size_t serial_partitioning(T* array, + const size_t begin, + const size_t end, + V& leftReduction, + V& rightReduction, + const IsLeft& is_left, + const Reduction_T& reduction_t) + { + T* l = array + begin; + T* r = array + end - 1; + + while(1) + { + /* *l < pivot */ + while (likely(l <= r && is_left(*l) )) + { + //prefetchw(l+4); // FIXME: enable? + reduction_t(leftReduction,*l); + ++l; + } + /* *r >= pivot) */ + while (likely(l <= r && !is_left(*r))) + { + //prefetchw(r-4); FIXME: enable? + reduction_t(rightReduction,*r); + --r; + } + if (r + class __aligned(64) parallel_partition_task + { + ALIGNED_CLASS_(64); + private: + + static const size_t MAX_TASKS = 64; + + T* array; + size_t N; + const IsLeft& is_left; + const Reduction_T& reduction_t; + const Reduction_V& reduction_v; + const Vi& identity; + + size_t numTasks; + __aligned(64) size_t counter_start[MAX_TASKS+1]; + __aligned(64) size_t counter_left[MAX_TASKS+1]; + __aligned(64) range leftMisplacedRanges[MAX_TASKS]; + __aligned(64) range rightMisplacedRanges[MAX_TASKS]; + __aligned(64) V leftReductions[MAX_TASKS]; + __aligned(64) V rightReductions[MAX_TASKS]; + + public: + + __forceinline parallel_partition_task(T* array, + const size_t N, + const Vi& identity, + const IsLeft& is_left, + const Reduction_T& reduction_t, + const Reduction_V& reduction_v, + const size_t BLOCK_SIZE) + + : array(array), N(N), is_left(is_left), reduction_t(reduction_t), reduction_v(reduction_v), identity(identity), + numTasks(min((N+BLOCK_SIZE-1)/BLOCK_SIZE,min(TaskScheduler::threadCount(),MAX_TASKS))) {} + + __forceinline const range* findStartRange(size_t& index, const range* const r, const size_t numRanges) + { + size_t i = 0; + while(index >= (size_t)r[i].size()) + { + assert(i < numRanges); + index -= (size_t)r[i].size(); + i++; + } + return &r[i]; + } + + __forceinline void swapItemsInMisplacedRanges(const size_t numLeftMisplacedRanges, + const size_t numRightMisplacedRanges, + const size_t startID, + const size_t endID) + { + size_t leftLocalIndex = startID; + size_t rightLocalIndex = startID; + const range* l_range = findStartRange(leftLocalIndex,leftMisplacedRanges,numLeftMisplacedRanges); + const range* r_range = findStartRange(rightLocalIndex,rightMisplacedRanges,numRightMisplacedRanges); + + size_t l_left = l_range->size() - leftLocalIndex; + size_t r_left = r_range->size() - rightLocalIndex; + T *__restrict__ l = &array[l_range->begin() + leftLocalIndex]; + T *__restrict__ r = &array[r_range->begin() + rightLocalIndex]; + size_t size = endID - startID; + size_t items = min(size,min(l_left,r_left)); + + while (size) + { + if (unlikely(l_left == 0)) + { + l_range++; + l_left = l_range->size(); + l = &array[l_range->begin()]; + items = min(size,min(l_left,r_left)); + } + + if (unlikely(r_left == 0)) + { + r_range++; + r_left = r_range->size(); + r = &array[r_range->begin()]; + items = min(size,min(l_left,r_left)); + } + + size -= items; + l_left -= items; + r_left -= items; + + while(items) { + items--; + xchg(*l++,*r++); + } + } + } + + __forceinline size_t partition(V& leftReduction, V& rightReduction) + { + /* partition the individual ranges for each task */ + parallel_for(numTasks,[&] (const size_t taskID) { + const size_t startID = (taskID+0)*N/numTasks; + const size_t endID = (taskID+1)*N/numTasks; + V local_left(identity); + V local_right(identity); + const size_t mid = serial_partitioning(array,startID,endID,local_left,local_right,is_left,reduction_t); + counter_start[taskID] = startID; + counter_left [taskID] = mid-startID; + leftReductions[taskID] = local_left; + rightReductions[taskID] = local_right; + }); + counter_start[numTasks] = N; + counter_left[numTasks] = 0; + + /* finalize the reductions */ + for (size_t i=0; i globalLeft (0,mid); + const range globalRight(mid,N); + + /* calculate all left and right ranges that are on the wrong global side */ + size_t numMisplacedRangesLeft = 0; + size_t numMisplacedRangesRight = 0; + size_t numMisplacedItemsLeft = 0; + size_t numMisplacedItemsRight = 0; + + for (size_t i=0; i left_range (counter_start[i], counter_start[i] + counter_left[i]); + const range right_range(counter_start[i] + counter_left[i], counter_start[i+1]); + const range left_misplaced = globalLeft. intersect(right_range); + const range right_misplaced = globalRight.intersect(left_range); + + if (!left_misplaced.empty()) + { + numMisplacedItemsLeft += left_misplaced.size(); + leftMisplacedRanges[numMisplacedRangesLeft++] = left_misplaced; + } + + if (!right_misplaced.empty()) + { + numMisplacedItemsRight += right_misplaced.size(); + rightMisplacedRanges[numMisplacedRangesRight++] = right_misplaced; + } + } + assert( numMisplacedItemsLeft == numMisplacedItemsRight ); + + /* if no items are misplaced we are done */ + if (numMisplacedItemsLeft == 0) + return mid; + + /* otherwise we copy the items to the right place in parallel */ + parallel_for(numTasks,[&] (const size_t taskID) { + const size_t startID = (taskID+0)*numMisplacedItemsLeft/numTasks; + const size_t endID = (taskID+1)*numMisplacedItemsLeft/numTasks; + swapItemsInMisplacedRanges(numMisplacedRangesLeft,numMisplacedRangesRight,startID,endID); + }); + + return mid; + } + }; + + template + __noinline size_t parallel_partitioning(T* array, + const size_t begin, + const size_t end, + const Vi &identity, + V &leftReduction, + V &rightReduction, + const IsLeft& is_left, + const Reduction_T& reduction_t, + const Reduction_V& reduction_v, + size_t BLOCK_SIZE = 128) + { + /* fall back to single threaded partitioning for small N */ + if (unlikely(end-begin < BLOCK_SIZE)) + return serial_partitioning(array,begin,end,leftReduction,rightReduction,is_left,reduction_t); + + /* otherwise use parallel code */ + else { + typedef parallel_partition_task partition_task; + std::unique_ptr p(new partition_task(&array[begin],end-begin,identity,is_left,reduction_t,reduction_v,BLOCK_SIZE)); + return begin+p->partition(leftReduction,rightReduction); + } + } + + template + __noinline size_t parallel_partitioning(T* array, + const size_t begin, + const size_t end, + const Vi &identity, + V &leftReduction, + V &rightReduction, + const IsLeft& is_left, + const Reduction_T& reduction_t, + const Reduction_V& reduction_v, + size_t BLOCK_SIZE, + size_t PARALLEL_THRESHOLD) + { + /* fall back to single threaded partitioning for small N */ + if (unlikely(end-begin < PARALLEL_THRESHOLD)) + return serial_partitioning(array,begin,end,leftReduction,rightReduction,is_left,reduction_t); + + /* otherwise use parallel code */ + else { + typedef parallel_partition_task partition_task; + std::unique_ptr p(new partition_task(&array[begin],end-begin,identity,is_left,reduction_t,reduction_v,BLOCK_SIZE)); + return begin+p->partition(leftReduction,rightReduction); + } + } + + + template + inline size_t parallel_partitioning(T* array, + const size_t begin, + const size_t end, + const IsLeft& is_left, + size_t BLOCK_SIZE = 128) + { + size_t leftReduction = 0; + size_t rightReduction = 0; + return parallel_partitioning( + array,begin,end,0,leftReduction,rightReduction,is_left, + [] (size_t& t,const T& ref) { }, + [] (size_t& t0,size_t& t1) { }, + BLOCK_SIZE); + } + +} diff --git a/thirdparty/embree/common/algorithms/parallel_prefix_sum.cpp b/thirdparty/embree/common/algorithms/parallel_prefix_sum.cpp new file mode 100644 index 000000000000..685952c3dce5 --- /dev/null +++ b/thirdparty/embree/common/algorithms/parallel_prefix_sum.cpp @@ -0,0 +1,48 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "parallel_prefix_sum.h" +#include "../sys/regression.h" + +namespace embree +{ + struct parallel_prefix_sum_regression_test : public RegressionTest + { + parallel_prefix_sum_regression_test(const char* name) : RegressionTest(name) { + registerRegressionTest(this); + } + + bool run () + { + bool passed = true; + const size_t M = 10; + + for (size_t N=10; N<10000000; N=size_t(2.1*N)) + { + /* initialize array with random numbers */ + uint32_t sum0 = 0; + std::vector src(N); + for (size_t i=0; i dst(N); + for (auto& v : dst) v = 0; + + for (size_t i=0; i()); + passed &= (sum0 == sum1); + } + + /* check if prefix sum is correct */ + for (size_t i=0, sum=0; i + struct ParallelPrefixSumState + { + enum { MAX_TASKS = 64 }; + Value counts[MAX_TASKS]; + Value sums [MAX_TASKS]; + }; + + template + __forceinline Value parallel_prefix_sum( ParallelPrefixSumState& state, Index first, Index last, Index minStepSize, const Value& identity, const Func& func, const Reduction& reduction) + { + /* calculate number of tasks to use */ + const size_t numThreads = TaskScheduler::threadCount(); + const size_t numBlocks = (last-first+minStepSize-1)/minStepSize; + const size_t taskCount = min(numThreads,numBlocks,size_t(ParallelPrefixSumState::MAX_TASKS)); + + /* perform parallel prefix sum */ + parallel_for(taskCount, [&](const size_t taskIndex) + { + const size_t i0 = first+(taskIndex+0)*(last-first)/taskCount; + const size_t i1 = first+(taskIndex+1)*(last-first)/taskCount; + state.counts[taskIndex] = func(range(i0,i1),state.sums[taskIndex]); + }); + + /* calculate prefix sum */ + Value sum=identity; + for (size_t i=0; i + __forceinline Value parallel_prefix_sum(const SrcArray& src, DstArray& dst, size_t N, const Value& identity, const Add& add, const size_t SINGLE_THREAD_THRESHOLD = 4096) + { + /* perform single threaded prefix operation for small N */ + if (N < SINGLE_THREAD_THRESHOLD) + { + Value sum=identity; + for (size_t i=0; i state; + + /* initial run just sets up start values for subtasks */ + parallel_prefix_sum( state, size_t(0), size_t(N), size_t(1024), identity, [&](const range& r, const Value& sum) -> Value { + + Value s = identity; + for (size_t i=r.begin(); i& r, const Value& sum) -> Value { + + Value s = identity; + for (size_t i=r.begin(); i& r) -> size_t + { + size_t s = 0; + for (size_t i=r.begin(); i + __forceinline Value sequential_reduce( const Index first, const Index last, const Value& identity, const Func& func, const Reduction& reduction ) + { + return func(range(first,last)); + } + + template + __forceinline Value sequential_reduce( const Index first, const Index last, const Index minStepSize, const Value& identity, const Func& func, const Reduction& reduction ) + { + return func(range(first,last)); + } + + template + __noinline Value parallel_reduce_internal( Index taskCount, const Index first, const Index last, const Index minStepSize, const Value& identity, const Func& func, const Reduction& reduction ) + { + const Index maxTasks = 512; + const Index threadCount = (Index) TaskScheduler::threadCount(); + taskCount = min(taskCount,threadCount,maxTasks); + + /* parallel invokation of all tasks */ + dynamic_large_stack_array(Value,values,taskCount,8192); // consumes at most 8192 bytes on the stack + parallel_for(taskCount, [&](const Index taskIndex) { + const Index k0 = first+(taskIndex+0)*(last-first)/taskCount; + const Index k1 = first+(taskIndex+1)*(last-first)/taskCount; + values[taskIndex] = func(range(k0,k1)); + }); + + /* perform reduction over all tasks */ + Value v = identity; + for (Index i=0; i + __forceinline Value parallel_reduce( const Index first, const Index last, const Index minStepSize, const Value& identity, const Func& func, const Reduction& reduction ) + { +#if defined(TASKING_INTERNAL) + + /* fast path for small number of iterations */ + Index taskCount = (last-first+minStepSize-1)/minStepSize; + if (likely(taskCount == 1)) { + return func(range(first,last)); + } + return parallel_reduce_internal(taskCount,first,last,minStepSize,identity,func,reduction); + +#elif defined(TASKING_TBB) + #if TBB_INTERFACE_VERSION >= 12002 + tbb::task_group_context context; + const Value v = tbb::parallel_reduce(tbb::blocked_range(first,last,minStepSize),identity, + [&](const tbb::blocked_range& r, const Value& start) { return reduction(start,func(range(r.begin(),r.end()))); }, + reduction,context); + if (context.is_group_execution_cancelled()) + throw std::runtime_error("task cancelled"); + return v; + #else + const Value v = tbb::parallel_reduce(tbb::blocked_range(first,last,minStepSize),identity, + [&](const tbb::blocked_range& r, const Value& start) { return reduction(start,func(range(r.begin(),r.end()))); }, + reduction); + if (tbb::task::self().is_cancelled()) + throw std::runtime_error("task cancelled"); + return v; + #endif +#else // TASKING_PPL + struct AlignedValue + { + char storage[__alignof(Value)+sizeof(Value)]; + static uintptr_t alignUp(uintptr_t p, size_t a) { return p + (~(p - 1) % a); }; + Value* getValuePtr() { return reinterpret_cast(alignUp(uintptr_t(storage), __alignof(Value))); } + const Value* getValuePtr() const { return reinterpret_cast(alignUp(uintptr_t(storage), __alignof(Value))); } + AlignedValue(const Value& v) { new(getValuePtr()) Value(v); } + AlignedValue(const AlignedValue& v) { new(getValuePtr()) Value(*v.getValuePtr()); } + AlignedValue(const AlignedValue&& v) { new(getValuePtr()) Value(*v.getValuePtr()); }; + AlignedValue& operator = (const AlignedValue& v) { *getValuePtr() = *v.getValuePtr(); return *this; }; + AlignedValue& operator = (const AlignedValue&& v) { *getValuePtr() = *v.getValuePtr(); return *this; }; + operator Value() const { return *getValuePtr(); } + }; + + struct Iterator_Index + { + Index v; + typedef std::forward_iterator_tag iterator_category; + typedef AlignedValue value_type; + typedef Index difference_type; + typedef Index distance_type; + typedef AlignedValue* pointer; + typedef AlignedValue& reference; + __forceinline Iterator_Index() {} + __forceinline Iterator_Index(Index v) : v(v) {} + __forceinline bool operator== (Iterator_Index other) { return v == other.v; } + __forceinline bool operator!= (Iterator_Index other) { return v != other.v; } + __forceinline Iterator_Index operator++() { return Iterator_Index(++v); } + __forceinline Iterator_Index operator++(int) { return Iterator_Index(v++); } + }; + + auto range_reduction = [&](Iterator_Index begin, Iterator_Index end, const AlignedValue& start) { + assert(begin.v < end.v); + return reduction(start, func(range(begin.v, end.v))); + }; + const Value v = concurrency::parallel_reduce(Iterator_Index(first), Iterator_Index(last), AlignedValue(identity), range_reduction, reduction); + return v; +#endif + } + + template + __forceinline Value parallel_reduce( const Index first, const Index last, const Index minStepSize, const Index parallel_threshold, const Value& identity, const Func& func, const Reduction& reduction ) + { + if (likely(last-first < parallel_threshold)) { + return func(range(first,last)); + } else { + return parallel_reduce(first,last,minStepSize,identity,func,reduction); + } + } + + template + __forceinline Value parallel_reduce( const range range, const Index minStepSize, const Index parallel_threshold, const Value& identity, const Func& func, const Reduction& reduction ) + { + return parallel_reduce(range.begin(),range.end(),minStepSize,parallel_threshold,identity,func,reduction); + } + + template + __forceinline Value parallel_reduce( const Index first, const Index last, const Value& identity, const Func& func, const Reduction& reduction ) + { + auto funcr = [&] ( const range r ) { + Value v = identity; + for (Index i=r.begin(); i + __forceinline Value parallel_reduce( const range range, const Value& identity, const Func& func, const Reduction& reduction ) + { + return parallel_reduce(range.begin(),range.end(),Index(1),identity,func,reduction); + } +} diff --git a/thirdparty/embree/common/algorithms/parallel_set.cpp b/thirdparty/embree/common/algorithms/parallel_set.cpp new file mode 100644 index 000000000000..20b639c1c9b5 --- /dev/null +++ b/thirdparty/embree/common/algorithms/parallel_set.cpp @@ -0,0 +1,43 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "parallel_set.h" +#include "../sys/regression.h" + +namespace embree +{ + struct parallel_set_regression_test : public RegressionTest + { + parallel_set_regression_test(const char* name) : RegressionTest(name) { + registerRegressionTest(this); + } + + bool run () + { + bool passed = true; + + /* create vector with random numbers */ + const size_t N = 10000; + std::vector unsorted(N); + for (size_t i=0; i sorted; + sorted.init(unsorted); + + /* check that all elements are in the set */ + for (size_t i=0; i + class parallel_set + { + public: + + /*! default constructor for the parallel set */ + parallel_set () {} + + /*! construction from vector */ + template + parallel_set (const Vector& in) { init(in); } + + /*! initialized the parallel set from a vector */ + template + void init(const Vector& in) + { + /* copy data to internal vector */ + vec.resize(in.size()); + parallel_for( size_t(0), in.size(), size_t(4*4096), [&](const range& r) { + for (size_t i=r.begin(); i temp(in.size()); + radix_sort(vec.data(),temp.data(),vec.size()); + } + + /*! tests if some element is in the set */ + __forceinline bool lookup(const T& elt) const { + return std::binary_search(vec.begin(), vec.end(), elt); + } + + /*! clears all state */ + void clear() { + vec.clear(); + } + + private: + std::vector vec; //!< vector containing sorted elements + }; +} diff --git a/thirdparty/embree/common/algorithms/parallel_sort.cpp b/thirdparty/embree/common/algorithms/parallel_sort.cpp new file mode 100644 index 000000000000..5e7ec79ac191 --- /dev/null +++ b/thirdparty/embree/common/algorithms/parallel_sort.cpp @@ -0,0 +1,50 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "parallel_sort.h" +#include "../sys/regression.h" + +namespace embree +{ + template + struct RadixSortRegressionTest : public RegressionTest + { + RadixSortRegressionTest(const char* name) : RegressionTest(name) { + registerRegressionTest(this); + } + + bool run () + { + bool passed = true; + const size_t M = 10; + + for (size_t N=10; N<1000000; N=size_t(2.1*N)) + { + std::vector src(N); memset(src.data(),0,N*sizeof(Key)); + std::vector tmp(N); memset(tmp.data(),0,N*sizeof(Key)); + for (size_t i=0; i(src.data(),tmp.data(),N); + } + + /* calculate checksum */ + Key sum1 = 0; for (size_t i=0; i test_u32("RadixSortRegressionTestU32"); + RadixSortRegressionTest test_u64("RadixSortRegressionTestU64"); +} diff --git a/thirdparty/embree/common/algorithms/parallel_sort.h b/thirdparty/embree/common/algorithms/parallel_sort.h new file mode 100644 index 000000000000..5a3382079375 --- /dev/null +++ b/thirdparty/embree/common/algorithms/parallel_sort.h @@ -0,0 +1,454 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../simd/simd.h" +#include "parallel_for.h" +#include + +namespace embree +{ + template + __forceinline void insertionsort_ascending(T *__restrict__ array, const size_t length) + { + for(size_t i = 1;i 0 && v < array[j-1]) + { + array[j] = array[j-1]; + --j; + } + array[j] = v; + } + } + + template + __forceinline void insertionsort_decending(T *__restrict__ array, const size_t length) + { + for(size_t i = 1;i 0 && v > array[j-1]) + { + array[j] = array[j-1]; + --j; + } + array[j] = v; + } + } + + template + void quicksort_ascending(T *__restrict__ t, + const ssize_t begin, + const ssize_t end) + { + if (likely(begin < end)) + { + const T pivotvalue = t[begin]; + ssize_t left = begin - 1; + ssize_t right = end + 1; + + while(1) + { + while (t[--right] > pivotvalue); + while (t[++left] < pivotvalue); + + if (left >= right) break; + + const T temp = t[right]; + t[right] = t[left]; + t[left] = temp; + } + + const int pivot = right; + quicksort_ascending(t, begin, pivot); + quicksort_ascending(t, pivot + 1, end); + } + } + + template + void quicksort_decending(T *__restrict__ t, + const ssize_t begin, + const ssize_t end) + { + if (likely(begin < end)) + { + const T pivotvalue = t[begin]; + ssize_t left = begin - 1; + ssize_t right = end + 1; + + while(1) + { + while (t[--right] < pivotvalue); + while (t[++left] > pivotvalue); + + if (left >= right) break; + + const T temp = t[right]; + t[right] = t[left]; + t[left] = temp; + } + + const int pivot = right; + quicksort_decending(t, begin, pivot); + quicksort_decending(t, pivot + 1, end); + } + } + + + template + void quicksort_insertionsort_ascending(T *__restrict__ t, + const ssize_t begin, + const ssize_t end) + { + if (likely(begin < end)) + { + const ssize_t size = end-begin+1; + if (likely(size <= THRESHOLD)) + { + insertionsort_ascending(&t[begin],size); + } + else + { + const T pivotvalue = t[begin]; + ssize_t left = begin - 1; + ssize_t right = end + 1; + + while(1) + { + while (t[--right] > pivotvalue); + while (t[++left] < pivotvalue); + + if (left >= right) break; + + const T temp = t[right]; + t[right] = t[left]; + t[left] = temp; + } + + const ssize_t pivot = right; + quicksort_insertionsort_ascending(t, begin, pivot); + quicksort_insertionsort_ascending(t, pivot + 1, end); + } + } + } + + + template + void quicksort_insertionsort_decending(T *__restrict__ t, + const ssize_t begin, + const ssize_t end) + { + if (likely(begin < end)) + { + const ssize_t size = end-begin+1; + if (likely(size <= THRESHOLD)) + { + insertionsort_decending(&t[begin],size); + } + else + { + + const T pivotvalue = t[begin]; + ssize_t left = begin - 1; + ssize_t right = end + 1; + + while(1) + { + while (t[--right] < pivotvalue); + while (t[++left] > pivotvalue); + + if (left >= right) break; + + const T temp = t[right]; + t[right] = t[left]; + t[left] = temp; + } + + const ssize_t pivot = right; + quicksort_insertionsort_decending(t, begin, pivot); + quicksort_insertionsort_decending(t, pivot + 1, end); + } + } + } + + template + static void radixsort32(T* const morton, const size_t num, const unsigned int shift = 3*8) + { + static const unsigned int BITS = 8; + static const unsigned int BUCKETS = (1 << BITS); + static const unsigned int CMP_SORT_THRESHOLD = 16; + + __aligned(64) unsigned int count[BUCKETS]; + + /* clear buckets */ + for (size_t i=0;i> shift) & (BUCKETS-1)]++; + + /* prefix sums */ + __aligned(64) unsigned int head[BUCKETS]; + __aligned(64) unsigned int tail[BUCKETS]; + + head[0] = 0; + for (size_t i=1; i> shift) & (BUCKETS-1); + if (b == i) break; + std::swap(v,morton[head[b]++]); + } + assert((unsigned(v) >> shift & (BUCKETS-1)) == i); + morton[head[i]++] = v; + } + } + if (shift == 0) return; + + size_t offset = 0; + for (size_t i=0;i> shift) & (BUCKETS-1)) == i); + + if (unlikely(count[i] < CMP_SORT_THRESHOLD)) + insertionsort_ascending(morton + offset, count[i]); + else + radixsort32(morton + offset, count[i], shift-BITS); + + for (size_t j=offset;j + class ParallelRadixSort + { + static const size_t MAX_TASKS = 64; + static const size_t BITS = 8; + static const size_t BUCKETS = (1 << BITS); + typedef unsigned int TyRadixCount[BUCKETS]; + + template + static bool compare(const T& v0, const T& v1) { + return (Key)v0 < (Key)v1; + } + + private: + ParallelRadixSort (const ParallelRadixSort& other) DELETED; // do not implement + ParallelRadixSort& operator= (const ParallelRadixSort& other) DELETED; // do not implement + + + public: + ParallelRadixSort (Ty* const src, Ty* const tmp, const size_t N) + : radixCount(nullptr), src(src), tmp(tmp), N(N) {} + + void sort(const size_t blockSize) + { + assert(blockSize > 0); + + /* perform single threaded sort for small N */ + if (N<=blockSize) // handles also special case of 0! + { + /* do inplace sort inside destination array */ + std::sort(src,src+N,compare); + } + + /* perform parallel sort for large N */ + else + { + const size_t numThreads = min((N+blockSize-1)/blockSize,TaskScheduler::threadCount(),size_t(MAX_TASKS)); + tbbRadixSort(numThreads); + } + } + + ~ParallelRadixSort() + { + alignedFree(radixCount); + radixCount = nullptr; + } + + private: + + void tbbRadixIteration0(const Key shift, + const Ty* __restrict const src, + Ty* __restrict const dst, + const size_t threadIndex, const size_t threadCount) + { + const size_t startID = (threadIndex+0)*N/threadCount; + const size_t endID = (threadIndex+1)*N/threadCount; + + /* mask to extract some number of bits */ + const Key mask = BUCKETS-1; + + /* count how many items go into the buckets */ + for (size_t i=0; i> (size_t)shift) & (size_t)mask; +#else + const Key index = ((Key)src[i] >> shift) & mask; +#endif + count[index]++; + } + } + + void tbbRadixIteration1(const Key shift, + const Ty* __restrict const src, + Ty* __restrict const dst, + const size_t threadIndex, const size_t threadCount) + { + const size_t startID = (threadIndex+0)*N/threadCount; + const size_t endID = (threadIndex+1)*N/threadCount; + + /* mask to extract some number of bits */ + const Key mask = BUCKETS-1; + + /* calculate total number of items for each bucket */ + __aligned(64) unsigned int total[BUCKETS]; + /* + for (size_t i=0; i> (size_t)shift) & (size_t)mask; +#else + const size_t index = ((Key)src[i] >> shift) & mask; +#endif + dst[offset[index]++] = elt; + } + } + + void tbbRadixIteration(const Key shift, const bool last, + const Ty* __restrict src, Ty* __restrict dst, + const size_t numTasks) + { + affinity_partitioner ap; + parallel_for_affinity(numTasks,[&] (size_t taskIndex) { tbbRadixIteration0(shift,src,dst,taskIndex,numTasks); },ap); + parallel_for_affinity(numTasks,[&] (size_t taskIndex) { tbbRadixIteration1(shift,src,dst,taskIndex,numTasks); },ap); + } + + void tbbRadixSort(const size_t numTasks) + { + radixCount = (TyRadixCount*) alignedMalloc(MAX_TASKS*sizeof(TyRadixCount),64); + + if (sizeof(Key) == sizeof(uint32_t)) { + tbbRadixIteration(0*BITS,0,src,tmp,numTasks); + tbbRadixIteration(1*BITS,0,tmp,src,numTasks); + tbbRadixIteration(2*BITS,0,src,tmp,numTasks); + tbbRadixIteration(3*BITS,1,tmp,src,numTasks); + } + else if (sizeof(Key) == sizeof(uint64_t)) + { + tbbRadixIteration(0*BITS,0,src,tmp,numTasks); + tbbRadixIteration(1*BITS,0,tmp,src,numTasks); + tbbRadixIteration(2*BITS,0,src,tmp,numTasks); + tbbRadixIteration(3*BITS,0,tmp,src,numTasks); + tbbRadixIteration(4*BITS,0,src,tmp,numTasks); + tbbRadixIteration(5*BITS,0,tmp,src,numTasks); + tbbRadixIteration(6*BITS,0,src,tmp,numTasks); + tbbRadixIteration(7*BITS,1,tmp,src,numTasks); + } + } + + private: + TyRadixCount* radixCount; + Ty* const src; + Ty* const tmp; + const size_t N; + }; + + template + void radix_sort(Ty* const src, Ty* const tmp, const size_t N, const size_t blockSize = 8192) + { + ParallelRadixSort(src,tmp,N).sort(blockSize); + } + + template + void radix_sort(Ty* const src, Ty* const tmp, const size_t N, const size_t blockSize = 8192) + { + ParallelRadixSort(src,tmp,N).sort(blockSize); + } + + template + void radix_sort_u32(Ty* const src, Ty* const tmp, const size_t N, const size_t blockSize = 8192) { + radix_sort(src,tmp,N,blockSize); + } + + template + void radix_sort_u64(Ty* const src, Ty* const tmp, const size_t N, const size_t blockSize = 8192) { + radix_sort(src,tmp,N,blockSize); + } +} diff --git a/thirdparty/embree/common/lexers/parsestream.h b/thirdparty/embree/common/lexers/parsestream.h new file mode 100644 index 000000000000..db46dc114fb2 --- /dev/null +++ b/thirdparty/embree/common/lexers/parsestream.h @@ -0,0 +1,101 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "stringstream.h" +#include "../sys/filename.h" +#include "../math/vec2.h" +#include "../math/vec3.h" +#include "../math/col3.h" +#include "../math/color.h" + +namespace embree +{ + /*! helper class for simple command line parsing */ + class ParseStream : public Stream + { + public: + ParseStream (const Ref >& cin) : cin(cin) {} + + ParseStream (const Ref >& cin, const std::string& seps = "\n\t\r ", + const std::string& endl = "", bool multiLine = false) + : cin(new StringStream(cin,seps,endl,multiLine)) {} + + public: + ParseLocation location() { return cin->loc(); } + std::string next() { return cin->get(); } + + void force(const std::string& next) { + std::string token = getString(); + if (token != next) + THROW_RUNTIME_ERROR("token \""+next+"\" expected but token \""+token+"\" found"); + } + + std::string getString() { + return get(); + } + + FileName getFileName() { + return FileName(get()); + } + + int getInt () { + return atoi(get().c_str()); + } + + Vec2i getVec2i() { + int x = atoi(get().c_str()); + int y = atoi(get().c_str()); + return Vec2i(x,y); + } + + Vec3ia getVec3ia() { + int x = atoi(get().c_str()); + int y = atoi(get().c_str()); + int z = atoi(get().c_str()); + return Vec3ia(x,y,z); + } + + float getFloat() { + return (float)atof(get().c_str()); + } + + Vec2f getVec2f() { + float x = (float)atof(get().c_str()); + float y = (float)atof(get().c_str()); + return Vec2f(x,y); + } + + Vec3f getVec3f() { + float x = (float)atof(get().c_str()); + float y = (float)atof(get().c_str()); + float z = (float)atof(get().c_str()); + return Vec3f(x,y,z); + } + + Vec3fa getVec3fa() { + float x = (float)atof(get().c_str()); + float y = (float)atof(get().c_str()); + float z = (float)atof(get().c_str()); + return Vec3fa(x,y,z); + } + + Col3f getCol3f() { + float x = (float)atof(get().c_str()); + float y = (float)atof(get().c_str()); + float z = (float)atof(get().c_str()); + return Col3f(x,y,z); + } + + Color getColor() { + float r = (float)atof(get().c_str()); + float g = (float)atof(get().c_str()); + float b = (float)atof(get().c_str()); + return Color(r,g,b); + } + + private: + Ref > cin; + }; +} diff --git a/thirdparty/embree/common/lexers/stream.h b/thirdparty/embree/common/lexers/stream.h new file mode 100644 index 000000000000..3f75677e6874 --- /dev/null +++ b/thirdparty/embree/common/lexers/stream.h @@ -0,0 +1,215 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/platform.h" +#include "../sys/ref.h" +#include "../sys/filename.h" +#include "../sys/string.h" + +#include +#include +#include +#include + +namespace embree +{ + /*! stores the location of a stream element in the source */ + class ParseLocation + { + public: + ParseLocation () : lineNumber(-1), colNumber(-1) {} + ParseLocation (std::shared_ptr fileName, ssize_t lineNumber, ssize_t colNumber, ssize_t /*charNumber*/) + : fileName(fileName), lineNumber(lineNumber), colNumber(colNumber) {} + + std::string str() const + { + std::string str = "unknown"; + if (fileName) str = *fileName; + if (lineNumber >= 0) str += " line " + toString(lineNumber); + if (lineNumber >= 0 && colNumber >= 0) str += " character " + toString(colNumber); + return str; + } + + private: + std::shared_ptr fileName; /// name of the file (or stream) the token is from + ssize_t lineNumber; /// the line number the token is from + ssize_t colNumber; /// the character number in the current line + }; + + /*! a stream class templated over the stream elements */ + template class Stream : public RefCount + { + enum { BUF_SIZE = 1024 }; + + private: + virtual T next() = 0; + virtual ParseLocation location() = 0; + __forceinline std::pair nextHelper() { + ParseLocation l = location(); + T v = next(); + return std::pair(v,l); + } + __forceinline void push_back(const std::pair& v) { + if (past+future == BUF_SIZE) pop_front(); + size_t end = (start+past+future++)%BUF_SIZE; + buffer[end] = v; + } + __forceinline void pop_front() { + if (past == 0) THROW_RUNTIME_ERROR("stream buffer empty"); + start = (start+1)%BUF_SIZE; past--; + } + public: + Stream () : start(0), past(0), future(0), buffer(BUF_SIZE) {} + virtual ~Stream() {} + + public: + + const ParseLocation& loc() { + if (future == 0) push_back(nextHelper()); + return buffer[(start+past)%BUF_SIZE].second; + } + T get() { + if (future == 0) push_back(nextHelper()); + T t = buffer[(start+past)%BUF_SIZE].first; + past++; future--; + return t; + } + const T& peek() { + if (future == 0) push_back(nextHelper()); + return buffer[(start+past)%BUF_SIZE].first; + } + const T& unget(size_t n = 1) { + if (past < n) THROW_RUNTIME_ERROR ("cannot unget that many items"); + past -= n; future += n; + return peek(); + } + void drop() { + if (future == 0) push_back(nextHelper()); + past++; future--; + } + private: + size_t start,past,future; + std::vector > buffer; + }; + + /*! warps an iostream stream */ + class StdStream : public Stream + { + public: + StdStream (std::istream& cin, const std::string& name = "std::stream") + : cin(cin), lineNumber(1), colNumber(0), charNumber(0), name(std::shared_ptr(new std::string(name))) {} + ~StdStream() {} + ParseLocation location() { + return ParseLocation(name,lineNumber,colNumber,charNumber); + } + int next() { + int c = cin.get(); + if (c == '\n') { lineNumber++; colNumber = 0; } else if (c != '\r') colNumber++; + charNumber++; + return c; + } + private: + std::istream& cin; + ssize_t lineNumber; /// the line number the token is from + ssize_t colNumber; /// the character number in the current line + ssize_t charNumber; /// the character in the file + std::shared_ptr name; /// name of buffer + }; + + /*! creates a stream from a file */ + class FileStream : public Stream + { + public: + + FileStream (FILE* file, const std::string& name = "file") + : file(file), lineNumber(1), colNumber(0), charNumber(0), name(std::shared_ptr(new std::string(name))) {} + + FileStream (const FileName& fileName) + : lineNumber(1), colNumber(0), charNumber(0), name(std::shared_ptr(new std::string(fileName.str()))) + { + file = fopen(fileName.c_str(),"r"); + if (file == nullptr) THROW_RUNTIME_ERROR("cannot open file " + fileName.str()); + } + ~FileStream() { if (file) fclose(file); } + + public: + ParseLocation location() { + return ParseLocation(name,lineNumber,colNumber,charNumber); + } + + int next() { + int c = fgetc(file); + if (c == '\n') { lineNumber++; colNumber = 0; } else if (c != '\r') colNumber++; + charNumber++; + return c; + } + + private: + FILE* file; + ssize_t lineNumber; /// the line number the token is from + ssize_t colNumber; /// the character number in the current line + ssize_t charNumber; /// the character in the file + std::shared_ptr name; /// name of buffer + }; + + /*! creates a stream from a string */ + class StrStream : public Stream + { + public: + + StrStream (const char* str) + : str(str), lineNumber(1), colNumber(0), charNumber(0) {} + + public: + ParseLocation location() { + return ParseLocation(std::shared_ptr(),lineNumber,colNumber,charNumber); + } + + int next() { + int c = str[charNumber]; + if (c == 0) return EOF; + if (c == '\n') { lineNumber++; colNumber = 0; } else if (c != '\r') colNumber++; + charNumber++; + return c; + } + + private: + const char* str; + ssize_t lineNumber; /// the line number the token is from + ssize_t colNumber; /// the character number in the current line + ssize_t charNumber; /// the character in the file + }; + + /*! creates a character stream from a command line */ + class CommandLineStream : public Stream + { + public: + CommandLineStream (int argc, char** argv, const std::string& name = "command line") + : i(0), j(0), charNumber(0), name(std::shared_ptr(new std::string(name))) + { + if (argc > 0) { + for (size_t i=0; argv[0][i] && i<1024; i++) charNumber++; + charNumber++; + } + for (ssize_t k=1; k args; + ssize_t charNumber; /// the character in the file + std::shared_ptr name; /// name of buffer + }; +} diff --git a/thirdparty/embree/common/lexers/streamfilters.h b/thirdparty/embree/common/lexers/streamfilters.h new file mode 100644 index 000000000000..25580a77b8ed --- /dev/null +++ b/thirdparty/embree/common/lexers/streamfilters.h @@ -0,0 +1,39 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "stream.h" + +namespace embree +{ + /* removes all line comments from a stream */ + class LineCommentFilter : public Stream + { + public: + LineCommentFilter (const FileName& fileName, const std::string& lineComment) + : cin(new FileStream(fileName)), lineComment(lineComment) {} + LineCommentFilter (Ref > cin, const std::string& lineComment) + : cin(cin), lineComment(lineComment) {} + + ParseLocation location() { return cin->loc(); } + + int next() + { + /* look if the line comment starts here */ + for (size_t j=0; jpeek() != lineComment[j]) { cin->unget(j); goto not_found; } + cin->get(); + } + /* eat all characters until the end of the line (or file) */ + while (cin->peek() != '\n' && cin->peek() != EOF) cin->get(); + + not_found: + return cin->get(); + } + + private: + Ref > cin; + std::string lineComment; + }; +} diff --git a/thirdparty/embree/common/lexers/stringstream.cpp b/thirdparty/embree/common/lexers/stringstream.cpp new file mode 100644 index 000000000000..7e7b9faef800 --- /dev/null +++ b/thirdparty/embree/common/lexers/stringstream.cpp @@ -0,0 +1,48 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "stringstream.h" + +namespace embree +{ + static const std::string stringChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 _.,+-=:/*\\"; + + /* creates map for fast categorization of characters */ + static void createCharMap(bool map[256], const std::string& chrs) { + for (size_t i=0; i<256; i++) map[i] = false; + for (size_t i=0; i >& cin, const std::string& seps, const std::string& endl, bool multiLine) + : cin(cin), endl(endl), multiLine(multiLine) + { + createCharMap(isSepMap,seps); + createCharMap(isValidCharMap,stringChars); + } + + std::string StringStream::next() + { + /* skip separators */ + while (cin->peek() != EOF) { + if (endl != "" && cin->peek() == '\n') { cin->drop(); return endl; } + if (multiLine && cin->peek() == '\\') { + cin->drop(); + if (cin->peek() == '\n') { cin->drop(); continue; } + cin->unget(); + } + if (!isSeparator(cin->peek())) break; + cin->drop(); + } + + /* parse everything until the next separator */ + std::vector str; str.reserve(64); + while (cin->peek() != EOF && !isSeparator(cin->peek())) { + int c = cin->get(); + if (!isValidChar(c)) throw std::runtime_error("invalid character "+std::string(1,c)+" in input"); + str.push_back((char)c); + } + str.push_back(0); + return std::string(str.data()); + } +} diff --git a/thirdparty/embree/common/lexers/stringstream.h b/thirdparty/embree/common/lexers/stringstream.h new file mode 100644 index 000000000000..e6dbd4aecc24 --- /dev/null +++ b/thirdparty/embree/common/lexers/stringstream.h @@ -0,0 +1,29 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "stream.h" + +namespace embree +{ + /*! simple tokenizer that produces a string stream */ + class StringStream : public Stream + { + public: + StringStream(const Ref >& cin, const std::string& seps = "\n\t\r ", + const std::string& endl = "", bool multiLine = false); + public: + ParseLocation location() { return cin->loc(); } + std::string next(); + private: + __forceinline bool isSeparator(unsigned int c) const { return c<256 && isSepMap[c]; } + __forceinline bool isValidChar(unsigned int c) const { return c<256 && isValidCharMap[c]; } + private: + Ref > cin; /*! source character stream */ + bool isSepMap[256]; /*! map for fast classification of separators */ + bool isValidCharMap[256]; /*! map for valid characters */ + std::string endl; /*! the token of the end of line */ + bool multiLine; /*! whether to parse lines wrapped with \ */ + }; +} diff --git a/thirdparty/embree/common/lexers/tokenstream.cpp b/thirdparty/embree/common/lexers/tokenstream.cpp new file mode 100644 index 000000000000..d05be65862f8 --- /dev/null +++ b/thirdparty/embree/common/lexers/tokenstream.cpp @@ -0,0 +1,181 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "tokenstream.h" +#include "../math/math.h" + +namespace embree +{ + /* shorthands for common sets of characters */ + const std::string TokenStream::alpha = "abcdefghijklmnopqrstuvwxyz"; + const std::string TokenStream::ALPHA = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + const std::string TokenStream::numbers = "0123456789"; + const std::string TokenStream::separators = "\n\t\r "; + const std::string TokenStream::stringChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 _.,+-=:/*\\"; + + /* creates map for fast categorization of characters */ + static void createCharMap(bool map[256], const std::string& chrs) { + for (size_t i=0; i<256; i++) map[i] = false; + for (size_t i=0; i >& cin, //< stream to read from + const std::string& alpha, //< valid characters for identifiers + const std::string& seps, //< characters that act as separators + const std::vector& symbols) //< symbols + : cin(cin), symbols(symbols) + { + createCharMap(isAlphaMap,alpha); + createCharMap(isSepMap,seps); + createCharMap(isStringCharMap,stringChars); + } + + bool TokenStream::decDigits(std::string& str_o) + { + bool ok = false; + std::string str; + if (cin->peek() == '+' || cin->peek() == '-') str += (char)cin->get(); + while (isDigit(cin->peek())) { ok = true; str += (char)cin->get(); } + if (ok) str_o += str; + else cin->unget(str.size()); + return ok; + } + + bool TokenStream::decDigits1(std::string& str_o) + { + bool ok = false; + std::string str; + while (isDigit(cin->peek())) { ok = true; str += (char)cin->get(); } + if (ok) str_o += str; else cin->unget(str.size()); + return ok; + } + + bool TokenStream::trySymbol(const std::string& symbol) + { + size_t pos = 0; + while (pos < symbol.size()) { + if (symbol[pos] != cin->peek()) { cin->unget(pos); return false; } + cin->drop(); pos++; + } + return true; + } + + bool TokenStream::trySymbols(Token& token, const ParseLocation& loc) + { + for (size_t i=0; ipeek() == '.') { + str += (char)cin->get(); + decDigits(str); + if (cin->peek() == 'e' || cin->peek() == 'E') { + str += (char)cin->get(); + if (decDigits(str)) ok = true; // 1.[2]E2 + } + else ok = true; // 1.[2] + } + else if (cin->peek() == 'e' || cin->peek() == 'E') { + str += (char)cin->get(); + if (decDigits(str)) ok = true; // 1E2 + } + } + else + { + if (cin->peek() == '.') { + str += (char)cin->get(); + if (decDigits(str)) { + if (cin->peek() == 'e' || cin->peek() == 'E') { + str += (char)cin->get(); + if (decDigits(str)) ok = true; // .3E2 + } + else ok = true; // .3 + } + } + } + if (ok) { + token = Token((float)atof(str.c_str()),loc); + } + else cin->unget(str.size()); + return ok; + } + + bool TokenStream::tryInt(Token& token, const ParseLocation& loc) { + std::string str; + if (decDigits(str)) { + token = Token(atoi(str.c_str()),loc); + return true; + } + return false; + } + + bool TokenStream::tryString(Token& token, const ParseLocation& loc) + { + std::string str; + if (cin->peek() != '\"') return false; + cin->drop(); + while (cin->peek() != '\"') { + const int c = cin->get(); + if (!isStringChar(c)) THROW_RUNTIME_ERROR("invalid string character "+std::string(1,c)+" at "+loc.str()); + str += (char)c; + } + cin->drop(); + token = Token(str,Token::TY_STRING,loc); + return true; + } + + bool TokenStream::tryIdentifier(Token& token, const ParseLocation& loc) + { + std::string str; + if (!isAlpha(cin->peek())) return false; + str += (char)cin->get(); + while (isAlphaNum(cin->peek())) str += (char)cin->get(); + token = Token(str,Token::TY_IDENTIFIER,loc); + return true; + } + + void TokenStream::skipSeparators() + { + /* skip separators */ + while (cin->peek() != EOF && isSeparator(cin->peek())) + cin->drop(); + } + + Token TokenStream::next() + { + Token token; + skipSeparators(); + ParseLocation loc = cin->loc(); + if (trySymbols (token,loc)) return token; /**< try to parse a symbol */ + if (tryFloat (token,loc)) return token; /**< try to parse float */ + if (tryInt (token,loc)) return token; /**< try to parse integer */ + if (tryString (token,loc)) return token; /**< try to parse string */ + if (tryIdentifier(token,loc)) return token; /**< try to parse identifier */ + if (cin->peek() == EOF ) return Token(loc); /**< return EOF token */ + return Token((char)cin->get(),loc); /**< return invalid character token */ + } +} diff --git a/thirdparty/embree/common/lexers/tokenstream.h b/thirdparty/embree/common/lexers/tokenstream.h new file mode 100644 index 000000000000..72a7b4f2f335 --- /dev/null +++ b/thirdparty/embree/common/lexers/tokenstream.h @@ -0,0 +1,164 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "stream.h" +#include +#include + +namespace embree +{ + /*! token class */ + class Token + { + public: + + enum Type { TY_EOF, TY_CHAR, TY_INT, TY_FLOAT, TY_IDENTIFIER, TY_STRING, TY_SYMBOL }; + + Token ( const ParseLocation& loc = ParseLocation()) : ty(TY_EOF ), loc(loc) {} + Token (char c, const ParseLocation& loc = ParseLocation()) : ty(TY_CHAR ), c(c), loc(loc) {} + Token (int i, const ParseLocation& loc = ParseLocation()) : ty(TY_INT ), i(i), loc(loc) {} + Token (float f,const ParseLocation& loc = ParseLocation()) : ty(TY_FLOAT), f(f), loc(loc) {} + Token (std::string str, Type ty, const ParseLocation& loc = ParseLocation()) : ty(ty), str(str), loc(loc) {} + + static Token Eof() { return Token(); } + static Token Sym(std::string str) { return Token(str,TY_SYMBOL); } + static Token Str(std::string str) { return Token(str,TY_STRING); } + static Token Id (std::string str) { return Token(str,TY_IDENTIFIER); } + + char Char() const { + if (ty == TY_CHAR) return c; + THROW_RUNTIME_ERROR(loc.str()+": character expected"); + } + + int Int() const { + if (ty == TY_INT) return i; + THROW_RUNTIME_ERROR(loc.str()+": integer expected"); + } + + float Float(bool cast = true) const { + if (ty == TY_FLOAT) return f; + if (ty == TY_INT && cast) return (float)i; + THROW_RUNTIME_ERROR(loc.str()+": float expected"); + } + + std::string Identifier() const { + if (ty == TY_IDENTIFIER) return str; + THROW_RUNTIME_ERROR(loc.str()+": identifier expected"); + } + + std::string String() const { + if (ty == TY_STRING) return str; + THROW_RUNTIME_ERROR(loc.str()+": string expected"); + } + + std::string Symbol() const { + if (ty == TY_SYMBOL) return str; + THROW_RUNTIME_ERROR(loc.str()+": symbol expected"); + } + + const ParseLocation& Location() const { return loc; } + + friend bool operator==(const Token& a, const Token& b) + { + if (a.ty != b.ty) return false; + if (a.ty == TY_CHAR) return a.c == b.c; + if (a.ty == TY_INT) return a.i == b.i; + if (a.ty == TY_FLOAT) return a.f == b.f; + if (a.ty == TY_IDENTIFIER) return a.str == b.str; + if (a.ty == TY_STRING) return a.str == b.str; + if (a.ty == TY_SYMBOL) return a.str == b.str; + return true; + } + + friend bool operator!=(const Token& a, const Token& b) { + return !(a == b); + } + + friend bool operator <( const Token& a, const Token& b ) { + if (a.ty != b.ty) return (int)a.ty < (int)b.ty; + if (a.ty == TY_CHAR) return a.c < b.c; + if (a.ty == TY_INT) return a.i < b.i; + if (a.ty == TY_FLOAT) return a.f < b.f; + if (a.ty == TY_IDENTIFIER) return a.str < b.str; + if (a.ty == TY_STRING) return a.str < b.str; + if (a.ty == TY_SYMBOL) return a.str < b.str; + return false; + } + + friend std::ostream& operator<<(std::ostream& cout, const Token& t) + { + if (t.ty == TY_EOF) return cout << "eof"; + if (t.ty == TY_CHAR) return cout << "Char(" << t.c << ")"; + if (t.ty == TY_INT) return cout << "Int(" << t.i << ")"; + if (t.ty == TY_FLOAT) return cout << "Float(" << t.f << ")"; + if (t.ty == TY_IDENTIFIER) return cout << "Id(" << t.str << ")"; + if (t.ty == TY_STRING) return cout << "String(" << t.str << ")"; + if (t.ty == TY_SYMBOL) return cout << "Symbol(" << t.str << ")"; + return cout << "unknown"; + } + + private: + Type ty; //< the type of the token + union { + char c; //< data for char tokens + int i; //< data for int tokens + float f; //< data for float tokens + }; + std::string str; //< data for string and identifier tokens + ParseLocation loc; //< the location the token is from + }; + + /*! build full tokenizer that takes list of valid characters and keywords */ + class TokenStream : public Stream + { + public: + + /*! shorthands for common sets of characters */ + static const std::string alpha; + static const std::string ALPHA; + static const std::string numbers; + static const std::string separators; + static const std::string stringChars; + + public: + TokenStream(const Ref >& cin, + const std::string& alpha, //< valid characters for identifiers + const std::string& seps, //< characters that act as separators + const std::vector& symbols = std::vector()); //< symbols + public: + ParseLocation location() { return cin->loc(); } + Token next(); + bool trySymbol(const std::string& symbol); + + private: + void skipSeparators(); + bool decDigits(std::string& str); + bool decDigits1(std::string& str); + bool trySymbols(Token& token, const ParseLocation& loc); + bool tryFloat(Token& token, const ParseLocation& loc); + bool tryInt(Token& token, const ParseLocation& loc); + bool tryString(Token& token, const ParseLocation& loc); + bool tryIdentifier(Token& token, const ParseLocation& loc); + + Ref > cin; + bool isSepMap[256]; + bool isAlphaMap[256]; + bool isStringCharMap[256]; + std::vector symbols; + + /*! checks if a character is a separator */ + __forceinline bool isSeparator(unsigned int c) const { return c<256 && isSepMap[c]; } + + /*! checks if a character is a number */ + __forceinline bool isDigit(unsigned int c) const { return c >= '0' && c <= '9'; } + + /*! checks if a character is valid inside a string */ + __forceinline bool isStringChar(unsigned int c) const { return c<256 && isStringCharMap[c]; } + + /*! checks if a character is legal for an identifier */ + __forceinline bool isAlpha(unsigned int c) const { return c<256 && isAlphaMap[c]; } + __forceinline bool isAlphaNum(unsigned int c) const { return isAlpha(c) || isDigit(c); } + }; +} diff --git a/thirdparty/embree/common/math/affinespace.h b/thirdparty/embree/common/math/affinespace.h new file mode 100644 index 000000000000..32452fbe72b5 --- /dev/null +++ b/thirdparty/embree/common/math/affinespace.h @@ -0,0 +1,361 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "linearspace2.h" +#include "linearspace3.h" +#include "quaternion.h" +#include "bbox.h" +#include "vec4.h" + +namespace embree +{ + #define VectorT typename L::Vector + #define ScalarT typename L::Vector::Scalar + + //////////////////////////////////////////////////////////////////////////////// + // Affine Space + //////////////////////////////////////////////////////////////////////////////// + + template + struct AffineSpaceT + { + L l; /*< linear part of affine space */ + VectorT p; /*< affine part of affine space */ + + //////////////////////////////////////////////////////////////////////////////// + // Constructors, Assignment, Cast, Copy Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline AffineSpaceT ( ) { } + __forceinline AffineSpaceT ( const AffineSpaceT& other ) { l = other.l; p = other.p; } + __forceinline AffineSpaceT ( const L & other ) { l = other ; p = VectorT(zero); } + __forceinline AffineSpaceT& operator=( const AffineSpaceT& other ) { l = other.l; p = other.p; return *this; } + + __forceinline AffineSpaceT( const VectorT& vx, const VectorT& vy, const VectorT& vz, const VectorT& p ) : l(vx,vy,vz), p(p) {} + __forceinline AffineSpaceT( const L& l, const VectorT& p ) : l(l), p(p) {} + + template __forceinline AffineSpaceT( const AffineSpaceT& s ) : l(s.l), p(s.p) {} + + //////////////////////////////////////////////////////////////////////////////// + // Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline AffineSpaceT( ZeroTy ) : l(zero), p(zero) {} + __forceinline AffineSpaceT( OneTy ) : l(one), p(zero) {} + + /*! return matrix for scaling */ + static __forceinline AffineSpaceT scale(const VectorT& s) { return L::scale(s); } + + /*! return matrix for translation */ + static __forceinline AffineSpaceT translate(const VectorT& p) { return AffineSpaceT(one,p); } + + /*! return matrix for rotation, only in 2D */ + static __forceinline AffineSpaceT rotate(const ScalarT& r) { return L::rotate(r); } + + /*! return matrix for rotation around arbitrary point (2D) or axis (3D) */ + static __forceinline AffineSpaceT rotate(const VectorT& u, const ScalarT& r) { return L::rotate(u,r); } + + /*! return matrix for rotation around arbitrary axis and point, only in 3D */ + static __forceinline AffineSpaceT rotate(const VectorT& p, const VectorT& u, const ScalarT& r) { return translate(+p) * rotate(u,r) * translate(-p); } + + /*! return matrix for looking at given point, only in 3D */ + static __forceinline AffineSpaceT lookat(const VectorT& eye, const VectorT& point, const VectorT& up) { + VectorT Z = normalize(point-eye); + VectorT U = normalize(cross(up,Z)); + VectorT V = normalize(cross(Z,U)); + return AffineSpaceT(L(U,V,Z),eye); + } + + }; + + // template specialization to get correct identity matrix for type AffineSpace3fa + template<> + __forceinline AffineSpaceT::AffineSpaceT( OneTy ) : l(one), p(0.f, 0.f, 0.f, 1.f) {} + + //////////////////////////////////////////////////////////////////////////////// + // Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline AffineSpaceT operator -( const AffineSpaceT& a ) { return AffineSpaceT(-a.l,-a.p); } + template __forceinline AffineSpaceT operator +( const AffineSpaceT& a ) { return AffineSpaceT(+a.l,+a.p); } + template __forceinline AffineSpaceT rcp( const AffineSpaceT& a ) { L il = rcp(a.l); return AffineSpaceT(il,-(il*a.p)); } + + //////////////////////////////////////////////////////////////////////////////// + // Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline const AffineSpaceT operator +( const AffineSpaceT& a, const AffineSpaceT& b ) { return AffineSpaceT(a.l+b.l,a.p+b.p); } + template __forceinline const AffineSpaceT operator -( const AffineSpaceT& a, const AffineSpaceT& b ) { return AffineSpaceT(a.l-b.l,a.p-b.p); } + + template __forceinline const AffineSpaceT operator *( const ScalarT & a, const AffineSpaceT& b ) { return AffineSpaceT(a*b.l,a*b.p); } + template __forceinline const AffineSpaceT operator *( const AffineSpaceT& a, const AffineSpaceT& b ) { return AffineSpaceT(a.l*b.l,a.l*b.p+a.p); } + template __forceinline const AffineSpaceT operator /( const AffineSpaceT& a, const AffineSpaceT& b ) { return a * rcp(b); } + template __forceinline const AffineSpaceT operator /( const AffineSpaceT& a, const ScalarT & b ) { return a * rcp(b); } + + template __forceinline AffineSpaceT& operator *=( AffineSpaceT& a, const AffineSpaceT& b ) { return a = a * b; } + template __forceinline AffineSpaceT& operator *=( AffineSpaceT& a, const ScalarT & b ) { return a = a * b; } + template __forceinline AffineSpaceT& operator /=( AffineSpaceT& a, const AffineSpaceT& b ) { return a = a / b; } + template __forceinline AffineSpaceT& operator /=( AffineSpaceT& a, const ScalarT & b ) { return a = a / b; } + + template __forceinline VectorT xfmPoint (const AffineSpaceT& m, const VectorT& p) { return madd(VectorT(p.x),m.l.vx,madd(VectorT(p.y),m.l.vy,madd(VectorT(p.z),m.l.vz,m.p))); } + template __forceinline VectorT xfmVector(const AffineSpaceT& m, const VectorT& v) { return xfmVector(m.l,v); } + template __forceinline VectorT xfmNormal(const AffineSpaceT& m, const VectorT& n) { return xfmNormal(m.l,n); } + + __forceinline const BBox xfmBounds(const AffineSpaceT >& m, const BBox& b) + { + BBox3fa dst = empty; + const Vec3fa p0(b.lower.x,b.lower.y,b.lower.z); dst.extend(xfmPoint(m,p0)); + const Vec3fa p1(b.lower.x,b.lower.y,b.upper.z); dst.extend(xfmPoint(m,p1)); + const Vec3fa p2(b.lower.x,b.upper.y,b.lower.z); dst.extend(xfmPoint(m,p2)); + const Vec3fa p3(b.lower.x,b.upper.y,b.upper.z); dst.extend(xfmPoint(m,p3)); + const Vec3fa p4(b.upper.x,b.lower.y,b.lower.z); dst.extend(xfmPoint(m,p4)); + const Vec3fa p5(b.upper.x,b.lower.y,b.upper.z); dst.extend(xfmPoint(m,p5)); + const Vec3fa p6(b.upper.x,b.upper.y,b.lower.z); dst.extend(xfmPoint(m,p6)); + const Vec3fa p7(b.upper.x,b.upper.y,b.upper.z); dst.extend(xfmPoint(m,p7)); + return dst; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline bool operator ==( const AffineSpaceT& a, const AffineSpaceT& b ) { return a.l == b.l && a.p == b.p; } + template __forceinline bool operator !=( const AffineSpaceT& a, const AffineSpaceT& b ) { return a.l != b.l || a.p != b.p; } + + //////////////////////////////////////////////////////////////////////////////// + /// Select + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline AffineSpaceT select ( const typename L::Vector::Scalar::Bool& s, const AffineSpaceT& t, const AffineSpaceT& f ) { + return AffineSpaceT(select(s,t.l,f.l),select(s,t.p,f.p)); + } + + //////////////////////////////////////////////////////////////////////////////// + // Output Operators + //////////////////////////////////////////////////////////////////////////////// + + template static embree_ostream operator<<(embree_ostream cout, const AffineSpaceT& m) { + return cout << "{ l = " << m.l << ", p = " << m.p << " }"; + } + + //////////////////////////////////////////////////////////////////////////////// + // Template Instantiations + //////////////////////////////////////////////////////////////////////////////// + + typedef AffineSpaceT AffineSpace2f; + typedef AffineSpaceT AffineSpace3f; + typedef AffineSpaceT AffineSpace3fa; + typedef AffineSpaceT AffineSpace3fx; + typedef AffineSpaceT AffineSpace3ff; + typedef AffineSpaceT OrthonormalSpace3f; + + template using AffineSpace3vf = AffineSpaceT>>>; + typedef AffineSpaceT>>> AffineSpace3vf4; + typedef AffineSpaceT>>> AffineSpace3vf8; + typedef AffineSpaceT>>> AffineSpace3vf16; + + template using AffineSpace3vff = AffineSpaceT>>>; + typedef AffineSpaceT>>> AffineSpace3vfa4; + typedef AffineSpaceT>>> AffineSpace3vfa8; + typedef AffineSpaceT>>> AffineSpace3vfa16; + + ////////////////////////////////////////////////////////////////////////////// + /// Interpolation + ////////////////////////////////////////////////////////////////////////////// + template + __forceinline AffineSpaceT lerp(const AffineSpaceT& M0, + const AffineSpaceT& M1, + const R& t) + { + return AffineSpaceT(lerp(M0.l,M1.l,t),lerp(M0.p,M1.p,t)); + } + + // slerp interprets the 16 floats of the matrix M = D * R * S as components of + // three matrizes (D, R, S) that are interpolated individually. + template __forceinline AffineSpaceT>> + slerp(const AffineSpaceT>>& M0, + const AffineSpaceT>>& M1, + const T& t) + { + QuaternionT q0(M0.p.w, M0.l.vx.w, M0.l.vy.w, M0.l.vz.w); + QuaternionT q1(M1.p.w, M1.l.vx.w, M1.l.vy.w, M1.l.vz.w); + QuaternionT q = slerp(q0, q1, t); + + AffineSpaceT>> S = lerp(M0, M1, t); + AffineSpaceT>> D(one); + D.p.x = S.l.vx.y; + D.p.y = S.l.vx.z; + D.p.z = S.l.vy.z; + S.l.vx.y = 0; + S.l.vx.z = 0; + S.l.vy.z = 0; + + AffineSpaceT>> R = LinearSpace3>(q); + return D * R * S; + } + + // this is a specialized version for Vec3fa because that does + // not play along nicely with the other templated Vec3/Vec4 types + __forceinline AffineSpace3fa slerp(const AffineSpace3ff& M0, + const AffineSpace3ff& M1, + const float& t) + { + Quaternion3f q0(M0.p.w, M0.l.vx.w, M0.l.vy.w, M0.l.vz.w); + Quaternion3f q1(M1.p.w, M1.l.vx.w, M1.l.vy.w, M1.l.vz.w); + Quaternion3f q = slerp(q0, q1, t); + + AffineSpace3fa S = lerp(M0, M1, t); + AffineSpace3fa D(one); + D.p.x = S.l.vx.y; + D.p.y = S.l.vx.z; + D.p.z = S.l.vy.z; + S.l.vx.y = 0; + S.l.vx.z = 0; + S.l.vy.z = 0; + + AffineSpace3fa R = LinearSpace3fa(q); + return D * R * S; + } + + __forceinline AffineSpace3fa quaternionDecompositionToAffineSpace(const AffineSpace3ff& qd) + { + // compute affine transform from quaternion decomposition + Quaternion3f q(qd.p.w, qd.l.vx.w, qd.l.vy.w, qd.l.vz.w); + AffineSpace3fa M = qd; + AffineSpace3fa D(one); + D.p.x = M.l.vx.y; + D.p.y = M.l.vx.z; + D.p.z = M.l.vy.z; + M.l.vx.y = 0; + M.l.vx.z = 0; + M.l.vy.z = 0; + AffineSpace3fa R = LinearSpace3fa(q); + return D * R * M; + } + + __forceinline void quaternionDecomposition(const AffineSpace3ff& qd, Vec3fa& T, Quaternion3f& q, AffineSpace3fa& S) + { + q = Quaternion3f(qd.p.w, qd.l.vx.w, qd.l.vy.w, qd.l.vz.w); + S = qd; + T.x = qd.l.vx.y; + T.y = qd.l.vx.z; + T.z = qd.l.vy.z; + S.l.vx.y = 0; + S.l.vx.z = 0; + S.l.vy.z = 0; + } + + __forceinline AffineSpace3fx quaternionDecomposition(Vec3fa const& T, Quaternion3f const& q, AffineSpace3fa const& S) + { + AffineSpace3ff M = S; + M.l.vx.w = q.i; + M.l.vy.w = q.j; + M.l.vz.w = q.k; + M.p.w = q.r; + M.l.vx.y = T.x; + M.l.vx.z = T.y; + M.l.vy.z = T.z; + return M; + } + + struct __aligned(16) QuaternionDecomposition + { + float scale_x = 1.f; + float scale_y = 1.f; + float scale_z = 1.f; + float skew_xy = 0.f; + float skew_xz = 0.f; + float skew_yz = 0.f; + float shift_x = 0.f; + float shift_y = 0.f; + float shift_z = 0.f; + float quaternion_r = 1.f; + float quaternion_i = 0.f; + float quaternion_j = 0.f; + float quaternion_k = 0.f; + float translation_x = 0.f; + float translation_y = 0.f; + float translation_z = 0.f; + }; + + __forceinline QuaternionDecomposition quaternionDecomposition(AffineSpace3ff const& M) + { + QuaternionDecomposition qd; + qd.scale_x = M.l.vx.x; + qd.scale_y = M.l.vy.y; + qd.scale_z = M.l.vz.z; + qd.shift_x = M.p.x; + qd.shift_y = M.p.y; + qd.shift_z = M.p.z; + qd.translation_x = M.l.vx.y; + qd.translation_y = M.l.vx.z; + qd.translation_z = M.l.vy.z; + qd.skew_xy = M.l.vy.x; + qd.skew_xz = M.l.vz.x; + qd.skew_yz = M.l.vz.y; + qd.quaternion_r = M.p.w; + qd.quaternion_i = M.l.vx.w; + qd.quaternion_j = M.l.vy.w; + qd.quaternion_k = M.l.vz.w; + return qd; + } + + //////////////////////////////////////////////////////////////////////////////// + /* + * ! Template Specialization for 2D: return matrix for rotation around point + * (rotation around arbitrarty vector is not meaningful in 2D) + */ + template<> __forceinline + AffineSpace2f AffineSpace2f::rotate(const Vec2f& p, const float& r) { + return translate(+p)*AffineSpace2f(LinearSpace2f::rotate(r))*translate(-p); + } + + //////////////////////////////////////////////////////////////////////////////// + // Similarity Transform + // + // checks, if M is a similarity transformation, i.e if there exists a factor D + // such that for all x,y: distance(Mx, My) = D * distance(x, y) + //////////////////////////////////////////////////////////////////////////////// + __forceinline bool similarityTransform(const AffineSpace3fa& M, float* D) + { + if (D) *D = 0.f; + if (abs(dot(M.l.vx, M.l.vy)) > 1e-5f) return false; + if (abs(dot(M.l.vx, M.l.vz)) > 1e-5f) return false; + if (abs(dot(M.l.vy, M.l.vz)) > 1e-5f) return false; + + const float D_x = dot(M.l.vx, M.l.vx); + const float D_y = dot(M.l.vy, M.l.vy); + const float D_z = dot(M.l.vz, M.l.vz); + + if (abs(D_x - D_y) > 1e-5f || + abs(D_x - D_z) > 1e-5f || + abs(D_y - D_z) > 1e-5f) + return false; + + if (D) *D = sqrtf(D_x); + return true; + } + + __forceinline void AffineSpace3fa_store_unaligned(const AffineSpace3fa &source, AffineSpace3fa* ptr) + { + Vec3fa::storeu(&ptr->l.vx, source.l.vx); + Vec3fa::storeu(&ptr->l.vy, source.l.vy); + Vec3fa::storeu(&ptr->l.vz, source.l.vz); + Vec3fa::storeu(&ptr->p, source.p); + } + + __forceinline AffineSpace3fa AffineSpace3fa_load_unaligned(AffineSpace3fa* ptr) + { + AffineSpace3fa space; + space.l.vx = Vec3fa::loadu(&ptr->l.vx); + space.l.vy = Vec3fa::loadu(&ptr->l.vy); + space.l.vz = Vec3fa::loadu(&ptr->l.vz); + space.p = Vec3fa::loadu(&ptr->p); + return space; + } + + #undef VectorT + #undef ScalarT +} diff --git a/thirdparty/embree/common/math/bbox.h b/thirdparty/embree/common/math/bbox.h new file mode 100644 index 000000000000..24d5b87223a2 --- /dev/null +++ b/thirdparty/embree/common/math/bbox.h @@ -0,0 +1,331 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "vec2.h" +#include "vec3.h" + +namespace embree +{ + namespace internal { + + template __forceinline T divideByTwo(const T& v) { return v / T(2); } + template <> __forceinline float divideByTwo(const float& v) { return v * 0.5f; } + template <> __forceinline double divideByTwo(const double& v) { return v * 0.5; } + + } // namespace internal + template + struct BBox + { + T lower, upper; + + //////////////////////////////////////////////////////////////////////////////// + /// Construction + //////////////////////////////////////////////////////////////////////////////// + + __forceinline BBox ( ) { } + template + __forceinline BBox ( const BBox& other ) : lower(other.lower), upper(other.upper) {} + __forceinline BBox& operator=( const BBox& other ) { lower = other.lower; upper = other.upper; return *this; } + + __forceinline BBox ( const T& v ) : lower(v), upper(v) {} + __forceinline BBox ( const T& lower, const T& upper ) : lower(lower), upper(upper) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Extending Bounds + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const BBox& extend(const BBox& other) { lower = min(lower,other.lower); upper = max(upper,other.upper); return *this; } + __forceinline const BBox& extend(const T & other) { lower = min(lower,other ); upper = max(upper,other ); return *this; } + + /*! tests if box is empty */ + __forceinline bool empty() const { for (int i=0; i upper[i]) return true; return false; } + + /*! computes the size of the box */ + __forceinline T size() const { return upper - lower; } + + /*! computes the center of the box */ + __forceinline T center() const { return internal::divideByTwo(lower+upper); } + + /*! computes twice the center of the box */ + __forceinline T center2() const { return lower+upper; } + + /*! merges two boxes */ + __forceinline static const BBox merge (const BBox& a, const BBox& b) { + return BBox(min(a.lower, b.lower), max(a.upper, b.upper)); + } + + /*! enlarge box by some scaling factor */ + __forceinline BBox enlarge_by(const float a) const { + return BBox(lower - T(a)*abs(lower), upper + T(a)*abs(upper)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline BBox( EmptyTy ) : lower(pos_inf), upper(neg_inf) {} + __forceinline BBox( FullTy ) : lower(neg_inf), upper(pos_inf) {} + __forceinline BBox( FalseTy ) : lower(pos_inf), upper(neg_inf) {} + __forceinline BBox( TrueTy ) : lower(neg_inf), upper(pos_inf) {} + __forceinline BBox( NegInfTy ): lower(pos_inf), upper(neg_inf) {} + __forceinline BBox( PosInfTy ): lower(neg_inf), upper(pos_inf) {} + }; + + template<> __forceinline bool BBox::empty() const { + return lower > upper; + } + +#if defined(__SSE__) + template<> __forceinline bool BBox::empty() const { + return !all(le_mask(lower,upper)); + } + template<> __forceinline bool BBox::empty() const { + return !all(le_mask(lower,upper)); + } +#endif + + /*! tests if box is finite */ + __forceinline bool isvalid( const BBox& v ) { + return all(gt_mask(v.lower,Vec3fa_t(-FLT_LARGE)) & lt_mask(v.upper,Vec3fa_t(+FLT_LARGE))); + } + + /*! tests if box is finite and non-empty*/ + __forceinline bool isvalid_non_empty( const BBox& v ) { + return all(gt_mask(v.lower,Vec3fa_t(-FLT_LARGE)) & lt_mask(v.upper,Vec3fa_t(+FLT_LARGE)) & le_mask(v.lower,v.upper)); + } + + /*! tests if box has finite entries */ + __forceinline bool is_finite( const BBox& b) { + return is_finite(b.lower) && is_finite(b.upper); + } + + /*! test if point contained in box */ + __forceinline bool inside ( const BBox& b, const Vec3fa& p ) { return all(ge_mask(p,b.lower) & le_mask(p,b.upper)); } + + /*! computes the center of the box */ + template __forceinline const T center2(const BBox& box) { return box.lower + box.upper; } + template __forceinline const T center (const BBox& box) { return internal::divideByTwo(center2(box)); } + + /*! computes the volume of a bounding box */ + __forceinline float volume ( const BBox& b ) { return reduce_mul(b.size()); } + __forceinline float safeVolume( const BBox& b ) { if (b.empty()) return 0.0f; else return volume(b); } + + /*! computes the volume of a bounding box */ + __forceinline float volume( const BBox& b ) { return reduce_mul(b.size()); } + + /*! computes the surface area of a bounding box */ + template __forceinline const T area( const BBox >& b ) { const Vec2 d = b.size(); return d.x*d.y; } + + template __forceinline const T halfArea( const BBox >& b ) { return halfArea(b.size()); } + template __forceinline const T area( const BBox >& b ) { return T(2)*halfArea(b); } + + __forceinline float halfArea( const BBox& b ) { return halfArea(b.size()); } + __forceinline float area( const BBox& b ) { return 2.0f*halfArea(b); } + + __forceinline float halfArea( const BBox& b ) { return halfArea(b.size()); } + __forceinline float area( const BBox& b ) { return 2.0f*halfArea(b); } + + template __forceinline float safeArea( const BBox& b ) { if (b.empty()) return 0.0f; else return area(b); } + + template __forceinline float expectedApproxHalfArea(const BBox& box) { + return halfArea(box); + } + + /*! merges bounding boxes and points */ + template __forceinline const BBox merge( const BBox& a, const T& b ) { return BBox(min(a.lower, b ), max(a.upper, b )); } + template __forceinline const BBox merge( const T& a, const BBox& b ) { return BBox(min(a , b.lower), max(a , b.upper)); } + template __forceinline const BBox merge( const BBox& a, const BBox& b ) { return BBox(min(a.lower, b.lower), max(a.upper, b.upper)); } + + /*! Merges three boxes. */ + template __forceinline const BBox merge( const BBox& a, const BBox& b, const BBox& c ) { return merge(a,merge(b,c)); } + + /*! Merges four boxes. */ + template __forceinline BBox merge(const BBox& a, const BBox& b, const BBox& c, const BBox& d) { + return merge(merge(a,b),merge(c,d)); + } + + /*! Comparison Operators */ + template __forceinline bool operator==( const BBox& a, const BBox& b ) { return a.lower == b.lower && a.upper == b.upper; } + template __forceinline bool operator!=( const BBox& a, const BBox& b ) { return a.lower != b.lower || a.upper != b.upper; } + + /*! scaling */ + template __forceinline BBox operator *( const float& a, const BBox& b ) { return BBox(a*b.lower,a*b.upper); } + template __forceinline BBox operator *( const T& a, const BBox& b ) { return BBox(a*b.lower,a*b.upper); } + + /*! translations */ + template __forceinline BBox operator +( const BBox& a, const BBox& b ) { return BBox(a.lower+b.lower,a.upper+b.upper); } + template __forceinline BBox operator -( const BBox& a, const BBox& b ) { return BBox(a.lower-b.lower,a.upper-b.upper); } + template __forceinline BBox operator +( const BBox& a, const T & b ) { return BBox(a.lower+b ,a.upper+b ); } + template __forceinline BBox operator -( const BBox& a, const T & b ) { return BBox(a.lower-b ,a.upper-b ); } + + /*! extension */ + template __forceinline BBox enlarge(const BBox& a, const T& b) { return BBox(a.lower-b, a.upper+b); } + + /*! intersect bounding boxes */ + template __forceinline const BBox intersect( const BBox& a, const BBox& b ) { return BBox(max(a.lower, b.lower), min(a.upper, b.upper)); } + template __forceinline const BBox intersect( const BBox& a, const BBox& b, const BBox& c ) { return intersect(a,intersect(b,c)); } + template __forceinline const BBox intersect( const BBox& a, const BBox& b, const BBox& c, const BBox& d ) { return intersect(intersect(a,b),intersect(c,d)); } + + /*! subtract bounds from each other */ + template __forceinline void subtract(const BBox& a, const BBox& b, BBox& c, BBox& d) + { + c.lower = a.lower; + c.upper = min(a.upper,b.lower); + d.lower = max(a.lower,b.upper); + d.upper = a.upper; + } + + /*! tests if bounding boxes (and points) are disjoint (empty intersection) */ + template __inline bool disjoint( const BBox& a, const BBox& b ) { return intersect(a,b).empty(); } + template __inline bool disjoint( const BBox& a, const T& b ) { return disjoint(a,BBox(b)); } + template __inline bool disjoint( const T& a, const BBox& b ) { return disjoint(BBox(a),b); } + + /*! tests if bounding boxes (and points) are conjoint (non-empty intersection) */ + template __inline bool conjoint( const BBox& a, const BBox& b ) { return !intersect(a,b).empty(); } + template __inline bool conjoint( const BBox& a, const T& b ) { return conjoint(a,BBox(b)); } + template __inline bool conjoint( const T& a, const BBox& b ) { return conjoint(BBox(a),b); } + + /*! subset relation */ + template __inline bool subset( const BBox& a, const BBox& b ) + { + for ( size_t i = 0; i < T::N; i++ ) if ( a.lower[i] < b.lower[i] ) return false; + for ( size_t i = 0; i < T::N; i++ ) if ( a.upper[i] > b.upper[i] ) return false; + return true; + } + + template<> __inline bool subset( const BBox& a, const BBox& b ) { + return all(ge_mask(a.lower,b.lower)) & all(le_mask(a.upper,b.upper)); + } + + template<> __inline bool subset( const BBox& a, const BBox& b ) { + return all(ge_mask(a.lower,b.lower)) & all(le_mask(a.upper,b.upper)); + } + + /*! blending */ + template + __forceinline BBox lerp(const BBox& b0, const BBox& b1, const float t) { + return BBox(lerp(b0.lower,b1.lower,t),lerp(b0.upper,b1.upper,t)); + } + + /*! output operator */ + template __forceinline embree_ostream operator<<(embree_ostream cout, const BBox& box) { + return cout << "[" << box.lower << "; " << box.upper << "]"; + } + + /*! default template instantiations */ + typedef BBox BBox1f; + typedef BBox BBox2f; + typedef BBox BBox2fa; + typedef BBox BBox3f; + typedef BBox BBox3fa; + typedef BBox BBox3fx; + typedef BBox BBox3ff; +} + +//////////////////////////////////////////////////////////////////////////////// +/// SSE / AVX / MIC specializations +//////////////////////////////////////////////////////////////////////////////// + +#if defined __SSE__ +#include "../simd/sse.h" +#endif + +#if defined __AVX__ +#include "../simd/avx.h" +#endif + +#if defined(__AVX512F__) +#include "../simd/avx512.h" +#endif + +namespace embree +{ + template + __forceinline BBox>> transpose(const BBox3fa* bounds); + + template<> + __forceinline BBox> transpose<4>(const BBox3fa* bounds) + { + BBox> dest; + + transpose((vfloat4&)bounds[0].lower, + (vfloat4&)bounds[1].lower, + (vfloat4&)bounds[2].lower, + (vfloat4&)bounds[3].lower, + dest.lower.x, + dest.lower.y, + dest.lower.z); + + transpose((vfloat4&)bounds[0].upper, + (vfloat4&)bounds[1].upper, + (vfloat4&)bounds[2].upper, + (vfloat4&)bounds[3].upper, + dest.upper.x, + dest.upper.y, + dest.upper.z); + + return dest; + } + +#if defined(__AVX__) + template<> + __forceinline BBox> transpose<8>(const BBox3fa* bounds) + { + BBox> dest; + + transpose((vfloat4&)bounds[0].lower, + (vfloat4&)bounds[1].lower, + (vfloat4&)bounds[2].lower, + (vfloat4&)bounds[3].lower, + (vfloat4&)bounds[4].lower, + (vfloat4&)bounds[5].lower, + (vfloat4&)bounds[6].lower, + (vfloat4&)bounds[7].lower, + dest.lower.x, + dest.lower.y, + dest.lower.z); + + transpose((vfloat4&)bounds[0].upper, + (vfloat4&)bounds[1].upper, + (vfloat4&)bounds[2].upper, + (vfloat4&)bounds[3].upper, + (vfloat4&)bounds[4].upper, + (vfloat4&)bounds[5].upper, + (vfloat4&)bounds[6].upper, + (vfloat4&)bounds[7].upper, + dest.upper.x, + dest.upper.y, + dest.upper.z); + + return dest; + } +#endif + + template + __forceinline BBox3fa merge(const BBox3fa* bounds); + + template<> + __forceinline BBox3fa merge<4>(const BBox3fa* bounds) + { + const Vec3fa lower = min(min(bounds[0].lower,bounds[1].lower), + min(bounds[2].lower,bounds[3].lower)); + const Vec3fa upper = max(max(bounds[0].upper,bounds[1].upper), + max(bounds[2].upper,bounds[3].upper)); + return BBox3fa(lower,upper); + } + +#if defined(__AVX__) + template<> + __forceinline BBox3fa merge<8>(const BBox3fa* bounds) + { + const Vec3fa lower = min(min(min(bounds[0].lower,bounds[1].lower),min(bounds[2].lower,bounds[3].lower)), + min(min(bounds[4].lower,bounds[5].lower),min(bounds[6].lower,bounds[7].lower))); + const Vec3fa upper = max(max(max(bounds[0].upper,bounds[1].upper),max(bounds[2].upper,bounds[3].upper)), + max(max(bounds[4].upper,bounds[5].upper),max(bounds[6].upper,bounds[7].upper))); + return BBox3fa(lower,upper); + } +#endif +} + diff --git a/thirdparty/embree/common/math/col3.h b/thirdparty/embree/common/math/col3.h new file mode 100644 index 000000000000..2a477ec1319b --- /dev/null +++ b/thirdparty/embree/common/math/col3.h @@ -0,0 +1,47 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "math.h" + +namespace embree +{ + //////////////////////////////////////////////////////////////////////////////// + /// RGB Color Class + //////////////////////////////////////////////////////////////////////////////// + + template struct Col3 + { + T r, g, b; + + //////////////////////////////////////////////////////////////////////////////// + /// Construction + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Col3 ( ) { } + __forceinline Col3 ( const Col3& other ) { r = other.r; g = other.g; b = other.b; } + __forceinline Col3& operator=( const Col3& other ) { r = other.r; g = other.g; b = other.b; return *this; } + + __forceinline explicit Col3 (const T& v) : r(v), g(v), b(v) {} + __forceinline Col3 (const T& r, const T& g, const T& b) : r(r), g(g), b(b) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Col3 (ZeroTy) : r(zero) , g(zero) , b(zero) {} + __forceinline Col3 (OneTy) : r(one) , g(one) , b(one) {} + __forceinline Col3 (PosInfTy) : r(pos_inf), g(pos_inf), b(pos_inf) {} + __forceinline Col3 (NegInfTy) : r(neg_inf), g(neg_inf), b(neg_inf) {} + }; + + /*! output operator */ + template __forceinline embree_ostream operator<<(embree_ostream cout, const Col3& a) { + return cout << "(" << a.r << ", " << a.g << ", " << a.b << ")"; + } + + /*! default template instantiations */ + typedef Col3 Col3uc; + typedef Col3 Col3f; +} diff --git a/thirdparty/embree/common/math/col4.h b/thirdparty/embree/common/math/col4.h new file mode 100644 index 000000000000..27849840ec5b --- /dev/null +++ b/thirdparty/embree/common/math/col4.h @@ -0,0 +1,47 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "math.h" + +namespace embree +{ + //////////////////////////////////////////////////////////////////////////////// + /// RGBA Color Class + //////////////////////////////////////////////////////////////////////////////// + + template struct Col4 + { + T r, g, b, a; + + //////////////////////////////////////////////////////////////////////////////// + /// Construction + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Col4 ( ) { } + __forceinline Col4 ( const Col4& other ) { r = other.r; g = other.g; b = other.b; a = other.a; } + __forceinline Col4& operator=( const Col4& other ) { r = other.r; g = other.g; b = other.b; a = other.a; return *this; } + + __forceinline explicit Col4 (const T& v) : r(v), g(v), b(v), a(v) {} + __forceinline Col4 (const T& r, const T& g, const T& b, const T& a) : r(r), g(g), b(b), a(a) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Col4 (ZeroTy) : r(zero) , g(zero) , b(zero) , a(zero) {} + __forceinline Col4 (OneTy) : r(one) , g(one) , b(one) , a(one) {} + __forceinline Col4 (PosInfTy) : r(pos_inf), g(pos_inf), b(pos_inf), a(pos_inf) {} + __forceinline Col4 (NegInfTy) : r(neg_inf), g(neg_inf), b(neg_inf), a(neg_inf) {} + }; + + /*! output operator */ + template __forceinline embree_ostream operator<<(embree_ostream cout, const Col4& a) { + return cout << "(" << a.r << ", " << a.g << ", " << a.b << ", " << a.a << ")"; + } + + /*! default template instantiations */ + typedef Col4 Col4uc; + typedef Col4 Col4f; +} diff --git a/thirdparty/embree/common/math/color.h b/thirdparty/embree/common/math/color.h new file mode 100644 index 000000000000..eae7b72ecfe0 --- /dev/null +++ b/thirdparty/embree/common/math/color.h @@ -0,0 +1,241 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "constants.h" +#include "col3.h" +#include "col4.h" + +#include "../simd/sse.h" + +namespace embree +{ + //////////////////////////////////////////////////////////////////////////////// + /// SSE RGBA Color Class + //////////////////////////////////////////////////////////////////////////////// + + struct Color4 + { + union { + __m128 m128; + struct { float r,g,b,a; }; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Construction + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Color4 () {} + __forceinline Color4 ( const __m128 a ) : m128(a) {} + + __forceinline explicit Color4 (const float v) : m128(_mm_set1_ps(v)) {} + __forceinline Color4 (const float r, const float g, const float b, const float a) : m128(_mm_set_ps(a,b,g,r)) {} + + __forceinline explicit Color4 ( const Col3uc& other ) { m128 = _mm_mul_ps(_mm_set_ps(255.0f,other.b,other.g,other.r),_mm_set1_ps(one_over_255)); } + __forceinline explicit Color4 ( const Col3f& other ) { m128 = _mm_set_ps(1.0f,other.b,other.g,other.r); } + __forceinline explicit Color4 ( const Col4uc& other ) { m128 = _mm_mul_ps(_mm_set_ps(other.a,other.b,other.g,other.r),_mm_set1_ps(one_over_255)); } + __forceinline explicit Color4 ( const Col4f& other ) { m128 = _mm_set_ps(other.a,other.b,other.g,other.r); } + + __forceinline Color4 ( const Color4& other ) : m128(other.m128) {} + __forceinline Color4& operator=( const Color4& other ) { m128 = other.m128; return *this; } + + __forceinline operator const __m128&() const { return m128; } + __forceinline operator __m128&() { return m128; } + + //////////////////////////////////////////////////////////////////////////////// + /// Set + //////////////////////////////////////////////////////////////////////////////// + + __forceinline void set(Col3f& d) const { d.r = r; d.g = g; d.b = b; } + __forceinline void set(Col4f& d) const { d.r = r; d.g = g; d.b = b; d.a = a; } + __forceinline void set(Col3uc& d) const + { + vfloat4 s = clamp(vfloat4(m128))*255.0f; + d.r = (unsigned char)(s[0]); + d.g = (unsigned char)(s[1]); + d.b = (unsigned char)(s[2]); + } + __forceinline void set(Col4uc& d) const + { + vfloat4 s = clamp(vfloat4(m128))*255.0f; + d.r = (unsigned char)(s[0]); + d.g = (unsigned char)(s[1]); + d.b = (unsigned char)(s[2]); + d.a = (unsigned char)(s[3]); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Color4( ZeroTy ) : m128(_mm_set1_ps(0.0f)) {} + __forceinline Color4( OneTy ) : m128(_mm_set1_ps(1.0f)) {} + __forceinline Color4( PosInfTy ) : m128(_mm_set1_ps(pos_inf)) {} + __forceinline Color4( NegInfTy ) : m128(_mm_set1_ps(neg_inf)) {} + }; + + //////////////////////////////////////////////////////////////////////////////// + /// SSE RGB Color Class + //////////////////////////////////////////////////////////////////////////////// + + struct Color + { + union { + __m128 m128; + struct { float r,g,b; }; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Construction + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Color () {} + __forceinline Color ( const __m128 a ) : m128(a) {} + + __forceinline explicit Color (const float v) : m128(_mm_set1_ps(v)) {} + __forceinline Color (const float r, const float g, const float b) : m128(_mm_set_ps(0.0f,b,g,r)) {} + + __forceinline Color ( const Color& other ) : m128(other.m128) {} + __forceinline Color& operator=( const Color& other ) { m128 = other.m128; return *this; } + + __forceinline Color ( const Color4& other ) : m128(other.m128) {} + __forceinline Color& operator=( const Color4& other ) { m128 = other.m128; return *this; } + + __forceinline operator const __m128&() const { return m128; } + __forceinline operator __m128&() { return m128; } + + //////////////////////////////////////////////////////////////////////////////// + /// Set + //////////////////////////////////////////////////////////////////////////////// + + __forceinline void set(Col3f& d) const { d.r = r; d.g = g; d.b = b; } + __forceinline void set(Col4f& d) const { d.r = r; d.g = g; d.b = b; d.a = 1.0f; } + __forceinline void set(Col3uc& d) const + { + vfloat4 s = clamp(vfloat4(m128))*255.0f; + d.r = (unsigned char)(s[0]); + d.g = (unsigned char)(s[1]); + d.b = (unsigned char)(s[2]); + } + __forceinline void set(Col4uc& d) const + { + vfloat4 s = clamp(vfloat4(m128))*255.0f; + d.r = (unsigned char)(s[0]); + d.g = (unsigned char)(s[1]); + d.b = (unsigned char)(s[2]); + d.a = 255; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Color( ZeroTy ) : m128(_mm_set1_ps(0.0f)) {} + __forceinline Color( OneTy ) : m128(_mm_set1_ps(1.0f)) {} + __forceinline Color( PosInfTy ) : m128(_mm_set1_ps(pos_inf)) {} + __forceinline Color( NegInfTy ) : m128(_mm_set1_ps(neg_inf)) {} + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const Color operator +( const Color& a ) { return a; } + __forceinline const Color operator -( const Color& a ) { + const __m128 mask = _mm_castsi128_ps(_mm_set1_epi32(0x80000000)); + return _mm_xor_ps(a.m128, mask); + } + __forceinline const Color abs ( const Color& a ) { + const __m128 mask = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffff)); + return _mm_and_ps(a.m128, mask); + } + __forceinline const Color rcp ( const Color& a ) + { +#if defined(__AVX512VL__) + const Color r = _mm_rcp14_ps(a.m128); +#else + const Color r = _mm_rcp_ps(a.m128); +#endif + return _mm_sub_ps(_mm_add_ps(r, r), _mm_mul_ps(_mm_mul_ps(r, r), a)); + } + __forceinline const Color rsqrt( const Color& a ) + { +#if defined(__AVX512VL__) + __m128 r = _mm_rsqrt14_ps(a.m128); +#else + __m128 r = _mm_rsqrt_ps(a.m128); +#endif + return _mm_add_ps(_mm_mul_ps(_mm_set1_ps(1.5f),r), _mm_mul_ps(_mm_mul_ps(_mm_mul_ps(a, _mm_set1_ps(-0.5f)), r), _mm_mul_ps(r, r))); + } + __forceinline const Color sqrt ( const Color& a ) { return _mm_sqrt_ps(a.m128); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const Color operator +( const Color& a, const Color& b ) { return _mm_add_ps(a.m128, b.m128); } + __forceinline const Color operator -( const Color& a, const Color& b ) { return _mm_sub_ps(a.m128, b.m128); } + __forceinline const Color operator *( const Color& a, const Color& b ) { return _mm_mul_ps(a.m128, b.m128); } + __forceinline const Color operator *( const Color& a, const float b ) { return a * Color(b); } + __forceinline const Color operator *( const float a, const Color& b ) { return Color(a) * b; } + __forceinline const Color operator /( const Color& a, const Color& b ) { return a * rcp(b); } + __forceinline const Color operator /( const Color& a, const float b ) { return a * rcp(b); } + + __forceinline const Color min( const Color& a, const Color& b ) { return _mm_min_ps(a.m128,b.m128); } + __forceinline const Color max( const Color& a, const Color& b ) { return _mm_max_ps(a.m128,b.m128); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const Color operator+=(Color& a, const Color& b) { return a = a + b; } + __forceinline const Color operator-=(Color& a, const Color& b) { return a = a - b; } + __forceinline const Color operator*=(Color& a, const Color& b) { return a = a * b; } + __forceinline const Color operator/=(Color& a, const Color& b) { return a = a / b; } + __forceinline const Color operator*=(Color& a, const float b ) { return a = a * b; } + __forceinline const Color operator/=(Color& a, const float b ) { return a = a / b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline float reduce_add(const Color& v) { return v.r+v.g+v.b; } + __forceinline float reduce_mul(const Color& v) { return v.r*v.g*v.b; } + __forceinline float reduce_min(const Color& v) { return min(v.r,v.g,v.b); } + __forceinline float reduce_max(const Color& v) { return max(v.r,v.g,v.b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator ==( const Color& a, const Color& b ) { return (_mm_movemask_ps(_mm_cmpeq_ps (a.m128, b.m128)) & 7) == 7; } + __forceinline bool operator !=( const Color& a, const Color& b ) { return (_mm_movemask_ps(_mm_cmpneq_ps(a.m128, b.m128)) & 7) != 0; } + __forceinline bool operator < ( const Color& a, const Color& b ) { + if (a.r != b.r) return a.r < b.r; + if (a.g != b.g) return a.g < b.g; + if (a.b != b.b) return a.b < b.b; + return false; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const Color select( bool s, const Color& t, const Color& f ) { + __m128 mask = s ? _mm_castsi128_ps(_mm_cmpeq_epi32(_mm_setzero_si128(), _mm_setzero_si128())) : _mm_setzero_ps(); + return blendv_ps(f, t, mask); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Special Operators + //////////////////////////////////////////////////////////////////////////////// + + /*! computes luminance of a color */ + __forceinline float luminance (const Color& a) { return madd(0.212671f,a.r,madd(0.715160f,a.g,0.072169f*a.b)); } + + /*! output operator */ + __forceinline embree_ostream operator<<(embree_ostream cout, const Color& a) { + return cout << "(" << a.r << ", " << a.g << ", " << a.b << ")"; + } +} diff --git a/thirdparty/embree/common/math/constants.cpp b/thirdparty/embree/common/math/constants.cpp new file mode 100644 index 000000000000..26968297d951 --- /dev/null +++ b/thirdparty/embree/common/math/constants.cpp @@ -0,0 +1,27 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "constants.h" + +namespace embree +{ + TrueTy True; + FalseTy False; + ZeroTy zero; + OneTy one; + NegInfTy neg_inf; + PosInfTy inf; + PosInfTy pos_inf; + NaNTy nan; + UlpTy ulp; + PiTy pi; + OneOverPiTy one_over_pi; + TwoPiTy two_pi; + OneOverTwoPiTy one_over_two_pi; + FourPiTy four_pi; + OneOverFourPiTy one_over_four_pi; + StepTy step; + ReverseStepTy reverse_step; + EmptyTy empty; + UndefinedTy undefined; +} diff --git a/thirdparty/embree/common/math/constants.h b/thirdparty/embree/common/math/constants.h new file mode 100644 index 000000000000..77c2b7aec2bf --- /dev/null +++ b/thirdparty/embree/common/math/constants.h @@ -0,0 +1,197 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/platform.h" + +#include + +#define _USE_MATH_DEFINES +#include // using cmath causes issues under Windows +#include +#include + +namespace embree +{ + static MAYBE_UNUSED const float one_over_255 = 1.0f/255.0f; + static MAYBE_UNUSED const float min_rcp_input = 1E-18f; // for abs(x) >= min_rcp_input the newton raphson rcp calculation does not fail + + /* we consider floating point numbers in that range as valid input numbers */ + static MAYBE_UNUSED float FLT_LARGE = 1.844E18f; + + struct TrueTy { + __forceinline operator bool( ) const { return true; } + }; + + extern MAYBE_UNUSED TrueTy True; + + struct FalseTy { + __forceinline operator bool( ) const { return false; } + }; + + extern MAYBE_UNUSED FalseTy False; + + struct ZeroTy + { + __forceinline operator double ( ) const { return 0; } + __forceinline operator float ( ) const { return 0; } + __forceinline operator long long( ) const { return 0; } + __forceinline operator unsigned long long( ) const { return 0; } + __forceinline operator long ( ) const { return 0; } + __forceinline operator unsigned long ( ) const { return 0; } + __forceinline operator int ( ) const { return 0; } + __forceinline operator unsigned int ( ) const { return 0; } + __forceinline operator short ( ) const { return 0; } + __forceinline operator unsigned short ( ) const { return 0; } + __forceinline operator char ( ) const { return 0; } + __forceinline operator unsigned char ( ) const { return 0; } + }; + + extern MAYBE_UNUSED ZeroTy zero; + + struct OneTy + { + __forceinline operator double ( ) const { return 1; } + __forceinline operator float ( ) const { return 1; } + __forceinline operator long long( ) const { return 1; } + __forceinline operator unsigned long long( ) const { return 1; } + __forceinline operator long ( ) const { return 1; } + __forceinline operator unsigned long ( ) const { return 1; } + __forceinline operator int ( ) const { return 1; } + __forceinline operator unsigned int ( ) const { return 1; } + __forceinline operator short ( ) const { return 1; } + __forceinline operator unsigned short ( ) const { return 1; } + __forceinline operator char ( ) const { return 1; } + __forceinline operator unsigned char ( ) const { return 1; } + }; + + extern MAYBE_UNUSED OneTy one; + + struct NegInfTy + { + __forceinline operator double ( ) const { return -std::numeric_limits::infinity(); } + __forceinline operator float ( ) const { return -std::numeric_limits::infinity(); } + __forceinline operator long long( ) const { return std::numeric_limits::min(); } + __forceinline operator unsigned long long( ) const { return std::numeric_limits::min(); } + __forceinline operator long ( ) const { return std::numeric_limits::min(); } + __forceinline operator unsigned long ( ) const { return std::numeric_limits::min(); } + __forceinline operator int ( ) const { return std::numeric_limits::min(); } + __forceinline operator unsigned int ( ) const { return std::numeric_limits::min(); } + __forceinline operator short ( ) const { return std::numeric_limits::min(); } + __forceinline operator unsigned short ( ) const { return std::numeric_limits::min(); } + __forceinline operator char ( ) const { return std::numeric_limits::min(); } + __forceinline operator unsigned char ( ) const { return std::numeric_limits::min(); } + + }; + + extern MAYBE_UNUSED NegInfTy neg_inf; + + struct PosInfTy + { + __forceinline operator double ( ) const { return std::numeric_limits::infinity(); } + __forceinline operator float ( ) const { return std::numeric_limits::infinity(); } + __forceinline operator long long( ) const { return std::numeric_limits::max(); } + __forceinline operator unsigned long long( ) const { return std::numeric_limits::max(); } + __forceinline operator long ( ) const { return std::numeric_limits::max(); } + __forceinline operator unsigned long ( ) const { return std::numeric_limits::max(); } + __forceinline operator int ( ) const { return std::numeric_limits::max(); } + __forceinline operator unsigned int ( ) const { return std::numeric_limits::max(); } + __forceinline operator short ( ) const { return std::numeric_limits::max(); } + __forceinline operator unsigned short ( ) const { return std::numeric_limits::max(); } + __forceinline operator char ( ) const { return std::numeric_limits::max(); } + __forceinline operator unsigned char ( ) const { return std::numeric_limits::max(); } + }; + + extern MAYBE_UNUSED PosInfTy inf; + extern MAYBE_UNUSED PosInfTy pos_inf; + + struct NaNTy + { + __forceinline operator double( ) const { return std::numeric_limits::quiet_NaN(); } + __forceinline operator float ( ) const { return std::numeric_limits::quiet_NaN(); } + }; + + extern MAYBE_UNUSED NaNTy nan; + + struct UlpTy + { + __forceinline operator double( ) const { return std::numeric_limits::epsilon(); } + __forceinline operator float ( ) const { return std::numeric_limits::epsilon(); } + }; + + extern MAYBE_UNUSED UlpTy ulp; + + struct PiTy + { + __forceinline operator double( ) const { return double(M_PI); } + __forceinline operator float ( ) const { return float(M_PI); } + }; + + extern MAYBE_UNUSED PiTy pi; + + struct OneOverPiTy + { + __forceinline operator double( ) const { return double(M_1_PI); } + __forceinline operator float ( ) const { return float(M_1_PI); } + }; + + extern MAYBE_UNUSED OneOverPiTy one_over_pi; + + struct TwoPiTy + { + __forceinline operator double( ) const { return double(2.0*M_PI); } + __forceinline operator float ( ) const { return float(2.0*M_PI); } + }; + + extern MAYBE_UNUSED TwoPiTy two_pi; + + struct OneOverTwoPiTy + { + __forceinline operator double( ) const { return double(0.5*M_1_PI); } + __forceinline operator float ( ) const { return float(0.5*M_1_PI); } + }; + + extern MAYBE_UNUSED OneOverTwoPiTy one_over_two_pi; + + struct FourPiTy + { + __forceinline operator double( ) const { return double(4.0*M_PI); } + __forceinline operator float ( ) const { return float(4.0*M_PI); } + }; + + extern MAYBE_UNUSED FourPiTy four_pi; + + struct OneOverFourPiTy + { + __forceinline operator double( ) const { return double(0.25*M_1_PI); } + __forceinline operator float ( ) const { return float(0.25*M_1_PI); } + }; + + extern MAYBE_UNUSED OneOverFourPiTy one_over_four_pi; + + struct StepTy { + }; + + extern MAYBE_UNUSED StepTy step; + + struct ReverseStepTy { + }; + + extern MAYBE_UNUSED ReverseStepTy reverse_step; + + struct EmptyTy { + }; + + extern MAYBE_UNUSED EmptyTy empty; + + struct FullTy { + }; + + extern MAYBE_UNUSED FullTy full; + + struct UndefinedTy { + }; + + extern MAYBE_UNUSED UndefinedTy undefined; +} diff --git a/thirdparty/embree/common/math/interval.h b/thirdparty/embree/common/math/interval.h new file mode 100644 index 000000000000..f06478e881fc --- /dev/null +++ b/thirdparty/embree/common/math/interval.h @@ -0,0 +1,161 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "vec2.h" +#include "vec3.h" +#include "bbox.h" + +namespace embree +{ + template + struct Interval + { + V lower, upper; + + __forceinline Interval() {} + __forceinline Interval ( const Interval& other ) { lower = other.lower; upper = other.upper; } + __forceinline Interval& operator=( const Interval& other ) { lower = other.lower; upper = other.upper; return *this; } + + __forceinline Interval(const V& a) : lower(a), upper(a) {} + __forceinline Interval(const V& lower, const V& upper) : lower(lower), upper(upper) {} + __forceinline Interval(const BBox& a) : lower(a.lower), upper(a.upper) {} + + /*! tests if box is empty */ + //__forceinline bool empty() const { return lower > upper; } + + /*! computes the size of the interval */ + __forceinline V size() const { return upper - lower; } + + __forceinline V center() const { return 0.5f*(lower+upper); } + + __forceinline const Interval& extend(const Interval& other) { lower = min(lower,other.lower); upper = max(upper,other.upper); return *this; } + __forceinline const Interval& extend(const V & other) { lower = min(lower,other ); upper = max(upper,other ); return *this; } + + __forceinline friend Interval operator +( const Interval& a, const Interval& b ) { + return Interval(a.lower+b.lower,a.upper+b.upper); + } + + __forceinline friend Interval operator -( const Interval& a, const Interval& b ) { + return Interval(a.lower-b.upper,a.upper-b.lower); + } + + __forceinline friend Interval operator -( const Interval& a, const V& b ) { + return Interval(a.lower-b,a.upper-b); + } + + __forceinline friend Interval operator *( const Interval& a, const Interval& b ) + { + const V ll = a.lower*b.lower; + const V lu = a.lower*b.upper; + const V ul = a.upper*b.lower; + const V uu = a.upper*b.upper; + return Interval(min(ll,lu,ul,uu),max(ll,lu,ul,uu)); + } + + __forceinline friend Interval merge( const Interval& a, const Interval& b) { + return Interval(min(a.lower,b.lower),max(a.upper,b.upper)); + } + + __forceinline friend Interval merge( const Interval& a, const Interval& b, const Interval& c) { + return merge(merge(a,b),c); + } + + __forceinline friend Interval merge( const Interval& a, const Interval& b, const Interval& c, const Interval& d) { + return merge(merge(a,b),merge(c,d)); + } + + /*! intersect bounding boxes */ + __forceinline friend const Interval intersect( const Interval& a, const Interval& b ) { return Interval(max(a.lower, b.lower), min(a.upper, b.upper)); } + __forceinline friend const Interval intersect( const Interval& a, const Interval& b, const Interval& c ) { return intersect(a,intersect(b,c)); } + __forceinline friend const Interval intersect( const Interval& a, const Interval& b, const Interval& c, const Interval& d ) { return intersect(intersect(a,b),intersect(c,d)); } + + friend embree_ostream operator<<(embree_ostream cout, const Interval& a) { + return cout << "[" << a.lower << ", " << a.upper << "]"; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Interval( EmptyTy ) : lower(pos_inf), upper(neg_inf) {} + __forceinline Interval( FullTy ) : lower(neg_inf), upper(pos_inf) {} + }; + + __forceinline bool isEmpty(const Interval& v) { + return v.lower > v.upper; + } + + __forceinline vboolx isEmpty(const Interval& v) { + return v.lower > v.upper; + } + + /*! subset relation */ + template __forceinline bool subset( const Interval& a, const Interval& b ) { + return (a.lower > b.lower) && (a.upper < b.upper); + } + + template __forceinline bool subset( const Vec2>& a, const Vec2>& b ) { + return subset(a.x,b.x) && subset(a.y,b.y); + } + + template __forceinline const Vec2> intersect( const Vec2>& a, const Vec2>& b ) { + return Vec2>(intersect(a.x,b.x),intersect(a.y,b.y)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Select + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Interval select ( bool s, const Interval& t, const Interval& f ) { + return Interval(select(s,t.lower,f.lower),select(s,t.upper,f.upper)); + } + + template __forceinline Interval select ( const typename T::Bool& s, const Interval& t, const Interval& f ) { + return Interval(select(s,t.lower,f.lower),select(s,t.upper,f.upper)); + } + + __forceinline int numRoots(const Interval& p0, const Interval& p1) + { + float eps = 1E-4f; + bool neg0 = p0.lower < eps; bool pos0 = p0.upper > -eps; + bool neg1 = p1.lower < eps; bool pos1 = p1.upper > -eps; + return (neg0 && pos1) || (pos0 && neg1) || (neg0 && pos0) || (neg1 && pos1); + } + + typedef Interval Interval1f; + typedef Vec2> Interval2f; + typedef Vec3> Interval3f; + +inline void swap(float& a, float& b) { float tmp = a; a = b; b = tmp; } + +inline Interval1f shift(const Interval1f& v, float shift) { return Interval1f(v.lower + shift, v.upper + shift); } + +#define TWO_PI (2.0*M_PI) +inline Interval1f sin(Interval1f interval) +{ + if (interval.upper-interval.lower >= M_PI) { return Interval1f(-1.0, 1.0); } + if (interval.upper > TWO_PI) { interval = shift(interval, -TWO_PI*floor(interval.upper/TWO_PI)); } + if (interval.lower < 0) { interval = shift(interval, -TWO_PI*floor(interval.lower/TWO_PI)); } + float sinLower = sin(interval.lower); + float sinUpper = sin(interval.upper); + if (sinLower > sinUpper) swap(sinLower, sinUpper); + if (interval.lower < M_PI / 2.0 && interval.upper > M_PI / 2.0) sinUpper = 1.0; + if (interval.lower < 3.0 * M_PI / 2.0 && interval.upper > 3.0 * M_PI / 2.0) sinLower = -1.0; + return Interval1f(sinLower, sinUpper); +} + +inline Interval1f cos(Interval1f interval) +{ + if (interval.upper-interval.lower >= M_PI) { return Interval1f(-1.0, 1.0); } + if (interval.upper > TWO_PI) { interval = shift(interval, -TWO_PI*floor(interval.upper/TWO_PI)); } + if (interval.lower < 0) { interval = shift(interval, -TWO_PI*floor(interval.lower/TWO_PI)); } + float cosLower = cos(interval.lower); + float cosUpper = cos(interval.upper); + if (cosLower > cosUpper) swap(cosLower, cosUpper); + if (interval.lower < M_PI && interval.upper > M_PI) cosLower = -1.0; + return Interval1f(cosLower, cosUpper); +} +#undef TWO_PI +} diff --git a/thirdparty/embree/common/math/lbbox.h b/thirdparty/embree/common/math/lbbox.h new file mode 100644 index 000000000000..95df4a918d7b --- /dev/null +++ b/thirdparty/embree/common/math/lbbox.h @@ -0,0 +1,289 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bbox.h" +#include "range.h" + +namespace embree +{ + template + __forceinline std::pair globalLinear(const std::pair& v, const BBox1f& dt) + { + const float rcp_dt_size = float(1.0f)/dt.size(); + const T g0 = lerp(v.first,v.second,-dt.lower*rcp_dt_size); + const T g1 = lerp(v.first,v.second,(1.0f-dt.lower)*rcp_dt_size); + return std::make_pair(g0,g1); + } + + template + struct LBBox + { + public: + __forceinline LBBox () {} + + template + __forceinline LBBox ( const LBBox& other ) + : bounds0(other.bounds0), bounds1(other.bounds1) {} + + __forceinline LBBox& operator= ( const LBBox& other ) { + bounds0 = other.bounds0; bounds1 = other.bounds1; return *this; + } + + __forceinline LBBox (EmptyTy) + : bounds0(EmptyTy()), bounds1(EmptyTy()) {} + + __forceinline explicit LBBox ( const BBox& bounds) + : bounds0(bounds), bounds1(bounds) { } + + __forceinline LBBox ( const BBox& bounds0, const BBox& bounds1) + : bounds0(bounds0), bounds1(bounds1) { } + + LBBox ( const avector>& bounds ) + { + assert(bounds.size()); + BBox b0 = bounds.front(); + BBox b1 = bounds.back(); + for (size_t i=1; i bt = lerp(b0,b1,f); + const T dlower = min(bounds[i].lower-bt.lower,T(zero)); + const T dupper = max(bounds[i].upper-bt.upper,T(zero)); + b0.lower += dlower; b1.lower += dlower; + b0.upper += dupper; b1.upper += dupper; + } + bounds0 = b0; + bounds1 = b1; + } + + /*! calculates the linear bounds of a primitive for the specified time range */ + template + __forceinline LBBox(const BoundsFunc& bounds, const BBox1f& time_range, float numTimeSegments) + { + const float lower = time_range.lower*numTimeSegments; + const float upper = time_range.upper*numTimeSegments; + const float ilowerf = floor(lower); + const float iupperf = ceil(upper); + const int ilower = (int)ilowerf; + const int iupper = (int)iupperf; + + const BBox blower0 = bounds(ilower); + const BBox bupper1 = bounds(iupper); + + if (iupper-ilower == 1) { + bounds0 = lerp(blower0, bupper1, lower-ilowerf); + bounds1 = lerp(bupper1, blower0, iupperf-upper); + return; + } + + const BBox blower1 = bounds(ilower+1); + const BBox bupper0 = bounds(iupper-1); + BBox b0 = lerp(blower0, blower1, lower-ilowerf); + BBox b1 = lerp(bupper1, bupper0, iupperf-upper); + + for (int i = ilower+1; i < iupper; i++) + { + const float f = (float(i)/numTimeSegments - time_range.lower) / time_range.size(); + const BBox bt = lerp(b0, b1, f); + const BBox bi = bounds(i); + const T dlower = min(bi.lower-bt.lower, T(zero)); + const T dupper = max(bi.upper-bt.upper, T(zero)); + b0.lower += dlower; b1.lower += dlower; + b0.upper += dupper; b1.upper += dupper; + } + + bounds0 = b0; + bounds1 = b1; + } + + /*! calculates the linear bounds of a primitive for the specified time range */ + template + __forceinline LBBox(const BoundsFunc& bounds, const BBox1f& time_range_in, const BBox1f& geom_time_range, float geom_time_segments) + { + /* normalize global time_range_in to local geom_time_range */ + const BBox1f time_range((time_range_in.lower-geom_time_range.lower)/geom_time_range.size(), + (time_range_in.upper-geom_time_range.lower)/geom_time_range.size()); + + const float lower = time_range.lower*geom_time_segments; + const float upper = time_range.upper*geom_time_segments; + const float ilowerf = floor(lower); + const float iupperf = ceil(upper); + const float ilowerfc = max(0.0f,ilowerf); + const float iupperfc = min(iupperf,geom_time_segments); + const int ilowerc = (int)ilowerfc; + const int iupperc = (int)iupperfc; + assert(iupperc-ilowerc > 0); + + /* this larger iteration range guarantees that we process borders of geom_time_range is (partially) inside time_range_in */ + const int ilower_iter = max(-1,(int)ilowerf); + const int iupper_iter = min((int)iupperf,(int)geom_time_segments+1); + + const BBox blower0 = bounds(ilowerc); + const BBox bupper1 = bounds(iupperc); + if (iupper_iter-ilower_iter == 1) { + bounds0 = lerp(blower0, bupper1, max(0.0f,lower-ilowerfc)); + bounds1 = lerp(bupper1, blower0, max(0.0f,iupperfc-upper)); + return; + } + + const BBox blower1 = bounds(ilowerc+1); + const BBox bupper0 = bounds(iupperc-1); + BBox b0 = lerp(blower0, blower1, max(0.0f,lower-ilowerfc)); + BBox b1 = lerp(bupper1, bupper0, max(0.0f,iupperfc-upper)); + + for (int i = ilower_iter+1; i < iupper_iter; i++) + { + const float f = (float(i)/geom_time_segments - time_range.lower) / time_range.size(); + const BBox bt = lerp(b0, b1, f); + const BBox bi = bounds(i); + const T dlower = min(bi.lower-bt.lower, T(zero)); + const T dupper = max(bi.upper-bt.upper, T(zero)); + b0.lower += dlower; b1.lower += dlower; + b0.upper += dupper; b1.upper += dupper; + } + + bounds0 = b0; + bounds1 = b1; + } + + /*! calculates the linear bounds of a primitive for the specified time range */ + template + __forceinline LBBox(const BoundsFunc& bounds, const range& time_range, int numTimeSegments) + { + const int ilower = time_range.begin(); + const int iupper = time_range.end(); + + BBox b0 = bounds(ilower); + BBox b1 = bounds(iupper); + + if (iupper-ilower == 1) + { + bounds0 = b0; + bounds1 = b1; + return; + } + + for (int i = ilower+1; i bt = lerp(b0, b1, f); + const BBox bi = bounds(i); + const T dlower = min(bi.lower-bt.lower, T(zero)); + const T dupper = max(bi.upper-bt.upper, T(zero)); + b0.lower += dlower; b1.lower += dlower; + b0.upper += dupper; b1.upper += dupper; + } + + bounds0 = b0; + bounds1 = b1; + } + + public: + + __forceinline bool empty() const { + return bounds().empty(); + } + + __forceinline BBox bounds () const { + return merge(bounds0,bounds1); + } + + __forceinline BBox interpolate( const float t ) const { + return lerp(bounds0,bounds1,t); + } + + __forceinline LBBox interpolate( const BBox1f& dt ) const { + return LBBox(interpolate(dt.lower),interpolate(dt.upper)); + } + + __forceinline void extend( const LBBox& other ) { + bounds0.extend(other.bounds0); + bounds1.extend(other.bounds1); + } + + __forceinline float expectedHalfArea() const; + + __forceinline float expectedHalfArea(const BBox1f& dt) const { + return interpolate(dt).expectedHalfArea(); + } + + __forceinline float expectedApproxHalfArea() const { + return 0.5f*(halfArea(bounds0) + halfArea(bounds1)); + } + + /* calculates bounds for [0,1] time range from bounds in dt time range */ + __forceinline LBBox global(const BBox1f& dt) const + { + const float rcp_dt_size = 1.0f/dt.size(); + const BBox b0 = interpolate(-dt.lower*rcp_dt_size); + const BBox b1 = interpolate((1.0f-dt.lower)*rcp_dt_size); + return LBBox(b0,b1); + } + + /*! Comparison Operators */ + //template friend __forceinline bool operator==( const LBBox& a, const LBBox& b ) { return a.bounds0 == b.bounds0 && a.bounds1 == b.bounds1; } + //template friend __forceinline bool operator!=( const LBBox& a, const LBBox& b ) { return a.bounds0 != b.bounds0 || a.bounds1 != b.bounds1; } + friend __forceinline bool operator==( const LBBox& a, const LBBox& b ) { return a.bounds0 == b.bounds0 && a.bounds1 == b.bounds1; } + friend __forceinline bool operator!=( const LBBox& a, const LBBox& b ) { return a.bounds0 != b.bounds0 || a.bounds1 != b.bounds1; } + + /*! output operator */ + friend __forceinline embree_ostream operator<<(embree_ostream cout, const LBBox& box) { + return cout << "LBBox { " << box.bounds0 << "; " << box.bounds1 << " }"; + } + + public: + BBox bounds0, bounds1; + }; + + /*! tests if box is finite */ + template + __forceinline bool isvalid( const LBBox& v ) { + return isvalid(v.bounds0) && isvalid(v.bounds1); + } + + template + __forceinline bool isvalid_non_empty( const LBBox& v ) { + return isvalid_non_empty(v.bounds0) && isvalid_non_empty(v.bounds1); + } + + template + __forceinline T expectedArea(const T& a0, const T& a1, const T& b0, const T& b1) + { + const T da = a1-a0; + const T db = b1-b0; + return a0*b0+(a0*db+da*b0)*T(0.5f) + da*db*T(1.0f/3.0f); + } + + template<> __forceinline float LBBox::expectedHalfArea() const + { + const Vec3fa d0 = bounds0.size(); + const Vec3fa d1 = bounds1.size(); + return reduce_add(expectedArea(Vec3fa(d0.x,d0.y,d0.z), + Vec3fa(d1.x,d1.y,d1.z), + Vec3fa(d0.y,d0.z,d0.x), + Vec3fa(d1.y,d1.z,d1.x))); + } + + template + __forceinline float expectedApproxHalfArea(const LBBox& box) { + return box.expectedApproxHalfArea(); + } + + template + __forceinline LBBox merge(const LBBox& a, const LBBox& b) { + return LBBox(merge(a.bounds0, b.bounds0), merge(a.bounds1, b.bounds1)); + } + + /*! subset relation */ + template __inline bool subset( const LBBox& a, const LBBox& b ) { + return subset(a.bounds0,b.bounds0) && subset(a.bounds1,b.bounds1); + } + + /*! default template instantiations */ + typedef LBBox LBBox1f; + typedef LBBox LBBox2f; + typedef LBBox LBBox3f; + typedef LBBox LBBox3fa; + typedef LBBox LBBox3fx; +} diff --git a/thirdparty/embree/common/math/linearspace2.h b/thirdparty/embree/common/math/linearspace2.h new file mode 100644 index 000000000000..b9a382962cf6 --- /dev/null +++ b/thirdparty/embree/common/math/linearspace2.h @@ -0,0 +1,148 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "vec2.h" + +namespace embree +{ + //////////////////////////////////////////////////////////////////////////////// + /// 2D Linear Transform (2x2 Matrix) + //////////////////////////////////////////////////////////////////////////////// + + template struct LinearSpace2 + { + typedef T Vector; + typedef typename T::Scalar Scalar; + + /*! default matrix constructor */ + __forceinline LinearSpace2 ( ) {} + __forceinline LinearSpace2 ( const LinearSpace2& other ) { vx = other.vx; vy = other.vy; } + __forceinline LinearSpace2& operator=( const LinearSpace2& other ) { vx = other.vx; vy = other.vy; return *this; } + + template __forceinline LinearSpace2( const LinearSpace2& s ) : vx(s.vx), vy(s.vy) {} + + /*! matrix construction from column vectors */ + __forceinline LinearSpace2(const Vector& vx, const Vector& vy) + : vx(vx), vy(vy) {} + + /*! matrix construction from row mayor data */ + __forceinline LinearSpace2(const Scalar& m00, const Scalar& m01, + const Scalar& m10, const Scalar& m11) + : vx(m00,m10), vy(m01,m11) {} + + /*! compute the determinant of the matrix */ + __forceinline const Scalar det() const { return vx.x*vy.y - vx.y*vy.x; } + + /*! compute adjoint matrix */ + __forceinline const LinearSpace2 adjoint() const { return LinearSpace2(vy.y,-vy.x,-vx.y,vx.x); } + + /*! compute inverse matrix */ + __forceinline const LinearSpace2 inverse() const { return adjoint()/det(); } + + /*! compute transposed matrix */ + __forceinline const LinearSpace2 transposed() const { return LinearSpace2(vx.x,vx.y,vy.x,vy.y); } + + /*! returns first row of matrix */ + __forceinline Vector row0() const { return Vector(vx.x,vy.x); } + + /*! returns second row of matrix */ + __forceinline Vector row1() const { return Vector(vx.y,vy.y); } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline LinearSpace2( ZeroTy ) : vx(zero), vy(zero) {} + __forceinline LinearSpace2( OneTy ) : vx(one, zero), vy(zero, one) {} + + /*! return matrix for scaling */ + static __forceinline LinearSpace2 scale(const Vector& s) { + return LinearSpace2(s.x, 0, + 0 , s.y); + } + + /*! return matrix for rotation */ + static __forceinline LinearSpace2 rotate(const Scalar& r) { + Scalar s = sin(r), c = cos(r); + return LinearSpace2(c, -s, + s, c); + } + + /*! return closest orthogonal matrix (i.e. a general rotation including reflection) */ + LinearSpace2 orthogonal() const + { + LinearSpace2 m = *this; + + // mirrored? + Scalar mirror(one); + if (m.det() < Scalar(zero)) { + m.vx = -m.vx; + mirror = -mirror; + } + + // rotation + for (int i = 0; i < 99; i++) { + const LinearSpace2 m_next = 0.5 * (m + m.transposed().inverse()); + const LinearSpace2 d = m_next - m; + m = m_next; + // norm^2 of difference small enough? + if (max(dot(d.vx, d.vx), dot(d.vy, d.vy)) < 1e-8) + break; + } + + // rotation * mirror_x + return LinearSpace2(mirror*m.vx, m.vy); + } + + public: + + /*! the column vectors of the matrix */ + Vector vx,vy; + }; + + //////////////////////////////////////////////////////////////////////////////// + // Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline LinearSpace2 operator -( const LinearSpace2& a ) { return LinearSpace2(-a.vx,-a.vy); } + template __forceinline LinearSpace2 operator +( const LinearSpace2& a ) { return LinearSpace2(+a.vx,+a.vy); } + template __forceinline LinearSpace2 rcp ( const LinearSpace2& a ) { return a.inverse(); } + + //////////////////////////////////////////////////////////////////////////////// + // Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline LinearSpace2 operator +( const LinearSpace2& a, const LinearSpace2& b ) { return LinearSpace2(a.vx+b.vx,a.vy+b.vy); } + template __forceinline LinearSpace2 operator -( const LinearSpace2& a, const LinearSpace2& b ) { return LinearSpace2(a.vx-b.vx,a.vy-b.vy); } + + template __forceinline LinearSpace2 operator*(const typename T::Scalar & a, const LinearSpace2& b) { return LinearSpace2(a*b.vx, a*b.vy); } + template __forceinline T operator*(const LinearSpace2& a, const T & b) { return b.x*a.vx + b.y*a.vy; } + template __forceinline LinearSpace2 operator*(const LinearSpace2& a, const LinearSpace2& b) { return LinearSpace2(a*b.vx, a*b.vy); } + + template __forceinline LinearSpace2 operator/(const LinearSpace2& a, const typename T::Scalar & b) { return LinearSpace2(a.vx/b, a.vy/b); } + template __forceinline LinearSpace2 operator/(const LinearSpace2& a, const LinearSpace2& b) { return a * rcp(b); } + + template __forceinline LinearSpace2& operator *=( LinearSpace2& a, const LinearSpace2& b ) { return a = a * b; } + template __forceinline LinearSpace2& operator /=( LinearSpace2& a, const LinearSpace2& b ) { return a = a / b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline bool operator ==( const LinearSpace2& a, const LinearSpace2& b ) { return a.vx == b.vx && a.vy == b.vy; } + template __forceinline bool operator !=( const LinearSpace2& a, const LinearSpace2& b ) { return a.vx != b.vx || a.vy != b.vy; } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + template static embree_ostream operator<<(embree_ostream cout, const LinearSpace2& m) { + return cout << "{ vx = " << m.vx << ", vy = " << m.vy << "}"; + } + + /*! Shortcuts for common linear spaces. */ + typedef LinearSpace2 LinearSpace2f; + typedef LinearSpace2 LinearSpace2fa; +} diff --git a/thirdparty/embree/common/math/linearspace3.h b/thirdparty/embree/common/math/linearspace3.h new file mode 100644 index 000000000000..12b5bb776bfb --- /dev/null +++ b/thirdparty/embree/common/math/linearspace3.h @@ -0,0 +1,213 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "vec3.h" +#include "quaternion.h" + +namespace embree +{ + //////////////////////////////////////////////////////////////////////////////// + /// 3D Linear Transform (3x3 Matrix) + //////////////////////////////////////////////////////////////////////////////// + + template struct LinearSpace3 + { + typedef T Vector; + typedef typename T::Scalar Scalar; + + /*! default matrix constructor */ + __forceinline LinearSpace3 ( ) {} + __forceinline LinearSpace3 ( const LinearSpace3& other ) { vx = other.vx; vy = other.vy; vz = other.vz; } + __forceinline LinearSpace3& operator=( const LinearSpace3& other ) { vx = other.vx; vy = other.vy; vz = other.vz; return *this; } + + template __forceinline LinearSpace3( const LinearSpace3& s ) : vx(s.vx), vy(s.vy), vz(s.vz) {} + + /*! matrix construction from column vectors */ + __forceinline LinearSpace3(const Vector& vx, const Vector& vy, const Vector& vz) + : vx(vx), vy(vy), vz(vz) {} + + /*! construction from quaternion */ + __forceinline LinearSpace3( const QuaternionT& q ) + : vx((q.r*q.r + q.i*q.i - q.j*q.j - q.k*q.k), 2.0f*(q.i*q.j + q.r*q.k), 2.0f*(q.i*q.k - q.r*q.j)) + , vy(2.0f*(q.i*q.j - q.r*q.k), (q.r*q.r - q.i*q.i + q.j*q.j - q.k*q.k), 2.0f*(q.j*q.k + q.r*q.i)) + , vz(2.0f*(q.i*q.k + q.r*q.j), 2.0f*(q.j*q.k - q.r*q.i), (q.r*q.r - q.i*q.i - q.j*q.j + q.k*q.k)) {} + + /*! matrix construction from row mayor data */ + __forceinline LinearSpace3(const Scalar& m00, const Scalar& m01, const Scalar& m02, + const Scalar& m10, const Scalar& m11, const Scalar& m12, + const Scalar& m20, const Scalar& m21, const Scalar& m22) + : vx(m00,m10,m20), vy(m01,m11,m21), vz(m02,m12,m22) {} + + /*! compute the determinant of the matrix */ + __forceinline const Scalar det() const { return dot(vx,cross(vy,vz)); } + + /*! compute adjoint matrix */ + __forceinline const LinearSpace3 adjoint() const { return LinearSpace3(cross(vy,vz),cross(vz,vx),cross(vx,vy)).transposed(); } + + /*! compute inverse matrix */ + __forceinline const LinearSpace3 inverse() const { return adjoint()/det(); } + + /*! compute transposed matrix */ + __forceinline const LinearSpace3 transposed() const { return LinearSpace3(vx.x,vx.y,vx.z,vy.x,vy.y,vy.z,vz.x,vz.y,vz.z); } + + /*! returns first row of matrix */ + __forceinline Vector row0() const { return Vector(vx.x,vy.x,vz.x); } + + /*! returns second row of matrix */ + __forceinline Vector row1() const { return Vector(vx.y,vy.y,vz.y); } + + /*! returns third row of matrix */ + __forceinline Vector row2() const { return Vector(vx.z,vy.z,vz.z); } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline LinearSpace3( ZeroTy ) : vx(zero), vy(zero), vz(zero) {} + __forceinline LinearSpace3( OneTy ) : vx(one, zero, zero), vy(zero, one, zero), vz(zero, zero, one) {} + + /*! return matrix for scaling */ + static __forceinline LinearSpace3 scale(const Vector& s) { + return LinearSpace3(s.x, 0, 0, + 0 , s.y, 0, + 0 , 0, s.z); + } + + /*! return matrix for rotation around arbitrary axis */ + static __forceinline LinearSpace3 rotate(const Vector& _u, const Scalar& r) { + Vector u = normalize(_u); + Scalar s = sin(r), c = cos(r); + return LinearSpace3(u.x*u.x+(1-u.x*u.x)*c, u.x*u.y*(1-c)-u.z*s, u.x*u.z*(1-c)+u.y*s, + u.x*u.y*(1-c)+u.z*s, u.y*u.y+(1-u.y*u.y)*c, u.y*u.z*(1-c)-u.x*s, + u.x*u.z*(1-c)-u.y*s, u.y*u.z*(1-c)+u.x*s, u.z*u.z+(1-u.z*u.z)*c); + } + + public: + + /*! the column vectors of the matrix */ + Vector vx,vy,vz; + }; + + /*! compute transposed matrix */ + template<> __forceinline const LinearSpace3 LinearSpace3::transposed() const { + vfloat4 rx,ry,rz; transpose((vfloat4&)vx,(vfloat4&)vy,(vfloat4&)vz,vfloat4(zero),rx,ry,rz); + return LinearSpace3(Vec3fa(rx),Vec3fa(ry),Vec3fa(rz)); + } + + template + __forceinline const LinearSpace3 transposed(const LinearSpace3& xfm) { + return xfm.transposed(); + } + + //////////////////////////////////////////////////////////////////////////////// + // Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline LinearSpace3 operator -( const LinearSpace3& a ) { return LinearSpace3(-a.vx,-a.vy,-a.vz); } + template __forceinline LinearSpace3 operator +( const LinearSpace3& a ) { return LinearSpace3(+a.vx,+a.vy,+a.vz); } + template __forceinline LinearSpace3 rcp ( const LinearSpace3& a ) { return a.inverse(); } + + /* constructs a coordinate frame form a normalized normal */ + template __forceinline LinearSpace3 frame(const T& N) + { + const T dx0(0,N.z,-N.y); + const T dx1(-N.z,0,N.x); + const T dx = normalize(select(dot(dx0,dx0) > dot(dx1,dx1),dx0,dx1)); + const T dy = normalize(cross(N,dx)); + return LinearSpace3(dx,dy,N); + } + + /* constructs a coordinate frame from a normal and approximate x-direction */ + template __forceinline LinearSpace3 frame(const T& N, const T& dxi) + { + if (abs(dot(dxi,N)) > 0.99f) return frame(N); // fallback in case N and dxi are very parallel + const T dx = normalize(cross(dxi,N)); + const T dy = normalize(cross(N,dx)); + return LinearSpace3(dx,dy,N); + } + + /* clamps linear space to range -1 to +1 */ + template __forceinline LinearSpace3 clamp(const LinearSpace3& space) { + return LinearSpace3(clamp(space.vx,T(-1.0f),T(1.0f)), + clamp(space.vy,T(-1.0f),T(1.0f)), + clamp(space.vz,T(-1.0f),T(1.0f))); + } + + //////////////////////////////////////////////////////////////////////////////// + // Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline LinearSpace3 operator +( const LinearSpace3& a, const LinearSpace3& b ) { return LinearSpace3(a.vx+b.vx,a.vy+b.vy,a.vz+b.vz); } + template __forceinline LinearSpace3 operator -( const LinearSpace3& a, const LinearSpace3& b ) { return LinearSpace3(a.vx-b.vx,a.vy-b.vy,a.vz-b.vz); } + + template __forceinline LinearSpace3 operator*(const typename T::Scalar & a, const LinearSpace3& b) { return LinearSpace3(a*b.vx, a*b.vy, a*b.vz); } + template __forceinline T operator*(const LinearSpace3& a, const T & b) { return madd(T(b.x),a.vx,madd(T(b.y),a.vy,T(b.z)*a.vz)); } + template __forceinline LinearSpace3 operator*(const LinearSpace3& a, const LinearSpace3& b) { return LinearSpace3(a*b.vx, a*b.vy, a*b.vz); } + + template __forceinline LinearSpace3 operator/(const LinearSpace3& a, const typename T::Scalar & b) { return LinearSpace3(a.vx/b, a.vy/b, a.vz/b); } + template __forceinline LinearSpace3 operator/(const LinearSpace3& a, const LinearSpace3& b) { return a * rcp(b); } + + template __forceinline LinearSpace3& operator *=( LinearSpace3& a, const LinearSpace3& b ) { return a = a * b; } + template __forceinline LinearSpace3& operator /=( LinearSpace3& a, const LinearSpace3& b ) { return a = a / b; } + + template __forceinline T xfmPoint (const LinearSpace3& s, const T & a) { return madd(T(a.x),s.vx,madd(T(a.y),s.vy,T(a.z)*s.vz)); } + template __forceinline T xfmVector(const LinearSpace3& s, const T & a) { return madd(T(a.x),s.vx,madd(T(a.y),s.vy,T(a.z)*s.vz)); } + template __forceinline T xfmNormal(const LinearSpace3& s, const T & a) { return xfmVector(s.inverse().transposed(),a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline bool operator ==( const LinearSpace3& a, const LinearSpace3& b ) { return a.vx == b.vx && a.vy == b.vy && a.vz == b.vz; } + template __forceinline bool operator !=( const LinearSpace3& a, const LinearSpace3& b ) { return a.vx != b.vx || a.vy != b.vy || a.vz != b.vz; } + + //////////////////////////////////////////////////////////////////////////////// + /// Select + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline LinearSpace3 select ( const typename T::Scalar::Bool& s, const LinearSpace3& t, const LinearSpace3& f ) { + return LinearSpace3(select(s,t.vx,f.vx),select(s,t.vy,f.vy),select(s,t.vz,f.vz)); + } + + /*! blending */ + template + __forceinline LinearSpace3 lerp(const LinearSpace3& l0, const LinearSpace3& l1, const float t) + { + return LinearSpace3(lerp(l0.vx,l1.vx,t), + lerp(l0.vy,l1.vy,t), + lerp(l0.vz,l1.vz,t)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + template static embree_ostream operator<<(embree_ostream cout, const LinearSpace3& m) { + return cout << "{ vx = " << m.vx << ", vy = " << m.vy << ", vz = " << m.vz << "}"; + } + + /*! Shortcuts for common linear spaces. */ + typedef LinearSpace3 LinearSpace3f; + typedef LinearSpace3 LinearSpace3fa; + typedef LinearSpace3 LinearSpace3fx; + typedef LinearSpace3 LinearSpace3ff; + + template using LinearSpace3vf = LinearSpace3>>; + typedef LinearSpace3>> LinearSpace3vf4; + typedef LinearSpace3>> LinearSpace3vf8; + typedef LinearSpace3>> LinearSpace3vf16; + + /*! blending */ + template + __forceinline LinearSpace3 lerp(const LinearSpace3& l0, + const LinearSpace3& l1, + const S& t) + { + return LinearSpace3(lerp(l0.vx,l1.vx,t), + lerp(l0.vy,l1.vy,t), + lerp(l0.vz,l1.vz,t)); + } + +} diff --git a/thirdparty/embree/common/math/math.h b/thirdparty/embree/common/math/math.h new file mode 100644 index 000000000000..1982c27c15e4 --- /dev/null +++ b/thirdparty/embree/common/math/math.h @@ -0,0 +1,352 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/platform.h" +#include "../sys/intrinsics.h" +#include "constants.h" +#include + +#include +#include +#include + +#if defined(__WIN32__) && !defined(__MINGW32__) +#if (__MSV_VER <= 1700) +namespace std +{ + __forceinline bool isinf ( const float x ) { return _finite(x) == 0; } + __forceinline bool isnan ( const float x ) { return _isnan(x) != 0; } + __forceinline bool isfinite (const float x) { return _finite(x) != 0; } +} +#endif +#endif + +namespace embree +{ + __forceinline bool isvalid ( const float& v ) { + return (v > -FLT_LARGE) & (v < +FLT_LARGE); + } + + __forceinline int cast_f2i(float f) { + union { float f; int i; } v; v.f = f; return v.i; + } + + __forceinline float cast_i2f(int i) { + union { float f; int i; } v; v.i = i; return v.f; + } + + __forceinline int toInt (const float& a) { return int(a); } + __forceinline float toFloat(const int& a) { return float(a); } + +#if defined(__WIN32__) + __forceinline bool finite ( const float x ) { return _finite(x) != 0; } +#endif + + __forceinline float sign ( const float x ) { return x<0?-1.0f:1.0f; } + __forceinline float sqr ( const float x ) { return x*x; } + + __forceinline float rcp ( const float x ) + { + const __m128 a = _mm_set_ss(x); + +#if defined(__AVX512VL__) + const __m128 r = _mm_rcp14_ss(_mm_set_ss(0.0f),a); +#else + const __m128 r = _mm_rcp_ss(a); +#endif + +#if defined(__AVX2__) + return _mm_cvtss_f32(_mm_mul_ss(r,_mm_fnmadd_ss(r, a, _mm_set_ss(2.0f)))); +#else + return _mm_cvtss_f32(_mm_mul_ss(r,_mm_sub_ss(_mm_set_ss(2.0f), _mm_mul_ss(r, a)))); +#endif + } + + __forceinline float signmsk ( const float x ) { + return _mm_cvtss_f32(_mm_and_ps(_mm_set_ss(x),_mm_castsi128_ps(_mm_set1_epi32(0x80000000)))); + } + __forceinline float xorf( const float x, const float y ) { + return _mm_cvtss_f32(_mm_xor_ps(_mm_set_ss(x),_mm_set_ss(y))); + } + __forceinline float andf( const float x, const unsigned y ) { + return _mm_cvtss_f32(_mm_and_ps(_mm_set_ss(x),_mm_castsi128_ps(_mm_set1_epi32(y)))); + } + __forceinline float rsqrt( const float x ) + { + const __m128 a = _mm_set_ss(x); +#if defined(__AVX512VL__) + const __m128 r = _mm_rsqrt14_ss(_mm_set_ss(0.0f),a); +#else + const __m128 r = _mm_rsqrt_ss(a); +#endif + const __m128 c = _mm_add_ss(_mm_mul_ss(_mm_set_ss(1.5f), r), + _mm_mul_ss(_mm_mul_ss(_mm_mul_ss(a, _mm_set_ss(-0.5f)), r), _mm_mul_ss(r, r))); + return _mm_cvtss_f32(c); + } + +#if defined(__WIN32__) && (__MSC_VER <= 1700) + __forceinline float nextafter(float x, float y) { if ((x0)) return x*(1.1f+float(ulp)); else return x*(0.9f-float(ulp)); } + __forceinline double nextafter(double x, double y) { return _nextafter(x, y); } + __forceinline int roundf(float f) { return (int)(f + 0.5f); } +#else + __forceinline float nextafter(float x, float y) { return ::nextafterf(x, y); } + __forceinline double nextafter(double x, double y) { return ::nextafter(x, y); } +#endif + + __forceinline float abs ( const float x ) { return ::fabsf(x); } + __forceinline float acos ( const float x ) { return ::acosf (x); } + __forceinline float asin ( const float x ) { return ::asinf (x); } + __forceinline float atan ( const float x ) { return ::atanf (x); } + __forceinline float atan2( const float y, const float x ) { return ::atan2f(y, x); } + __forceinline float cos ( const float x ) { return ::cosf (x); } + __forceinline float cosh ( const float x ) { return ::coshf (x); } + __forceinline float exp ( const float x ) { return ::expf (x); } + __forceinline float fmod ( const float x, const float y ) { return ::fmodf (x, y); } + __forceinline float log ( const float x ) { return ::logf (x); } + __forceinline float log10( const float x ) { return ::log10f(x); } + __forceinline float pow ( const float x, const float y ) { return ::powf (x, y); } + __forceinline float sin ( const float x ) { return ::sinf (x); } + __forceinline float sinh ( const float x ) { return ::sinhf (x); } + __forceinline float sqrt ( const float x ) { return ::sqrtf (x); } + __forceinline float tan ( const float x ) { return ::tanf (x); } + __forceinline float tanh ( const float x ) { return ::tanhf (x); } + __forceinline float floor( const float x ) { return ::floorf (x); } + __forceinline float ceil ( const float x ) { return ::ceilf (x); } + __forceinline float frac ( const float x ) { return x-floor(x); } + + __forceinline double abs ( const double x ) { return ::fabs(x); } + __forceinline double sign ( const double x ) { return x<0?-1.0:1.0; } + __forceinline double acos ( const double x ) { return ::acos (x); } + __forceinline double asin ( const double x ) { return ::asin (x); } + __forceinline double atan ( const double x ) { return ::atan (x); } + __forceinline double atan2( const double y, const double x ) { return ::atan2(y, x); } + __forceinline double cos ( const double x ) { return ::cos (x); } + __forceinline double cosh ( const double x ) { return ::cosh (x); } + __forceinline double exp ( const double x ) { return ::exp (x); } + __forceinline double fmod ( const double x, const double y ) { return ::fmod (x, y); } + __forceinline double log ( const double x ) { return ::log (x); } + __forceinline double log10( const double x ) { return ::log10(x); } + __forceinline double pow ( const double x, const double y ) { return ::pow (x, y); } + __forceinline double rcp ( const double x ) { return 1.0/x; } + __forceinline double rsqrt( const double x ) { return 1.0/::sqrt(x); } + __forceinline double sin ( const double x ) { return ::sin (x); } + __forceinline double sinh ( const double x ) { return ::sinh (x); } + __forceinline double sqr ( const double x ) { return x*x; } + __forceinline double sqrt ( const double x ) { return ::sqrt (x); } + __forceinline double tan ( const double x ) { return ::tan (x); } + __forceinline double tanh ( const double x ) { return ::tanh (x); } + __forceinline double floor( const double x ) { return ::floor (x); } + __forceinline double ceil ( const double x ) { return ::ceil (x); } + +#if defined(__SSE4_1__) + __forceinline float mini(float a, float b) { + const __m128i ai = _mm_castps_si128(_mm_set_ss(a)); + const __m128i bi = _mm_castps_si128(_mm_set_ss(b)); + const __m128i ci = _mm_min_epi32(ai,bi); + return _mm_cvtss_f32(_mm_castsi128_ps(ci)); + } +#endif + +#if defined(__SSE4_1__) + __forceinline float maxi(float a, float b) { + const __m128i ai = _mm_castps_si128(_mm_set_ss(a)); + const __m128i bi = _mm_castps_si128(_mm_set_ss(b)); + const __m128i ci = _mm_max_epi32(ai,bi); + return _mm_cvtss_f32(_mm_castsi128_ps(ci)); + } +#endif + + template + __forceinline T twice(const T& a) { return a+a; } + + __forceinline int min(int a, int b) { return a __forceinline T min(const T& a, const T& b, const T& c) { return min(min(a,b),c); } + template __forceinline T min(const T& a, const T& b, const T& c, const T& d) { return min(min(a,b),min(c,d)); } + template __forceinline T min(const T& a, const T& b, const T& c, const T& d, const T& e) { return min(min(min(a,b),min(c,d)),e); } + + template __forceinline T mini(const T& a, const T& b, const T& c) { return mini(mini(a,b),c); } + template __forceinline T mini(const T& a, const T& b, const T& c, const T& d) { return mini(mini(a,b),mini(c,d)); } + template __forceinline T mini(const T& a, const T& b, const T& c, const T& d, const T& e) { return mini(mini(mini(a,b),mini(c,d)),e); } + + __forceinline int max(int a, int b) { return a __forceinline T max(const T& a, const T& b, const T& c) { return max(max(a,b),c); } + template __forceinline T max(const T& a, const T& b, const T& c, const T& d) { return max(max(a,b),max(c,d)); } + template __forceinline T max(const T& a, const T& b, const T& c, const T& d, const T& e) { return max(max(max(a,b),max(c,d)),e); } + + template __forceinline T maxi(const T& a, const T& b, const T& c) { return maxi(maxi(a,b),c); } + template __forceinline T maxi(const T& a, const T& b, const T& c, const T& d) { return maxi(maxi(a,b),maxi(c,d)); } + template __forceinline T maxi(const T& a, const T& b, const T& c, const T& d, const T& e) { return maxi(maxi(maxi(a,b),maxi(c,d)),e); } + +#if defined(__MACOSX__) + __forceinline ssize_t min(ssize_t a, ssize_t b) { return a __forceinline T clamp(const T& x, const T& lower = T(zero), const T& upper = T(one)) { return max(min(x,upper),lower); } + template __forceinline T clampz(const T& x, const T& upper) { return max(T(zero), min(x,upper)); } + + template __forceinline T deg2rad ( const T& x ) { return x * T(1.74532925199432957692e-2f); } + template __forceinline T rad2deg ( const T& x ) { return x * T(5.72957795130823208768e1f); } + template __forceinline T sin2cos ( const T& x ) { return sqrt(max(T(zero),T(one)-x*x)); } + template __forceinline T cos2sin ( const T& x ) { return sin2cos(x); } + +#if defined(__AVX2__) + __forceinline float madd ( const float a, const float b, const float c) { return _mm_cvtss_f32(_mm_fmadd_ss(_mm_set_ss(a),_mm_set_ss(b),_mm_set_ss(c))); } + __forceinline float msub ( const float a, const float b, const float c) { return _mm_cvtss_f32(_mm_fmsub_ss(_mm_set_ss(a),_mm_set_ss(b),_mm_set_ss(c))); } + __forceinline float nmadd ( const float a, const float b, const float c) { return _mm_cvtss_f32(_mm_fnmadd_ss(_mm_set_ss(a),_mm_set_ss(b),_mm_set_ss(c))); } + __forceinline float nmsub ( const float a, const float b, const float c) { return _mm_cvtss_f32(_mm_fnmsub_ss(_mm_set_ss(a),_mm_set_ss(b),_mm_set_ss(c))); } +#else + __forceinline float madd ( const float a, const float b, const float c) { return a*b+c; } + __forceinline float msub ( const float a, const float b, const float c) { return a*b-c; } + __forceinline float nmadd ( const float a, const float b, const float c) { return -a*b+c;} + __forceinline float nmsub ( const float a, const float b, const float c) { return -a*b-c; } +#endif + + /*! random functions */ + template T random() { return T(0); } +#if defined(_WIN32) + template<> __forceinline int random() { return int(rand()) ^ (int(rand()) << 8) ^ (int(rand()) << 16); } + template<> __forceinline uint32_t random() { return uint32_t(rand()) ^ (uint32_t(rand()) << 8) ^ (uint32_t(rand()) << 16); } +#else + template<> __forceinline int random() { return int(rand()); } + template<> __forceinline uint32_t random() { return uint32_t(rand()) ^ (uint32_t(rand()) << 16); } +#endif + template<> __forceinline float random() { return rand()/float(RAND_MAX); } + template<> __forceinline double random() { return rand()/double(RAND_MAX); } + +#if _WIN32 + __forceinline double drand48() { + return double(rand())/double(RAND_MAX); + } + + __forceinline void srand48(long seed) { + return srand(seed); + } +#endif + + /*! selects */ + __forceinline bool select(bool s, bool t , bool f) { return s ? t : f; } + __forceinline int select(bool s, int t, int f) { return s ? t : f; } + __forceinline float select(bool s, float t, float f) { return s ? t : f; } + + __forceinline bool all(bool s) { return s; } + + __forceinline float lerp(const float v0, const float v1, const float t) { + return madd(1.0f-t,v0,t*v1); + } + + template + __forceinline T lerp2(const float x0, const float x1, const float x2, const float x3, const T& u, const T& v) { + return madd((1.0f-u),madd((1.0f-v),T(x0),v*T(x2)),u*madd((1.0f-v),T(x1),v*T(x3))); + } + + /*! exchange */ + template __forceinline void xchg ( T& a, T& b ) { const T tmp = a; a = b; b = tmp; } + + /*! bit reverse operation */ + template + __forceinline T bitReverse(const T& vin) + { + T v = vin; + v = ((v >> 1) & 0x55555555) | ((v & 0x55555555) << 1); + v = ((v >> 2) & 0x33333333) | ((v & 0x33333333) << 2); + v = ((v >> 4) & 0x0F0F0F0F) | ((v & 0x0F0F0F0F) << 4); + v = ((v >> 8) & 0x00FF00FF) | ((v & 0x00FF00FF) << 8); + v = ( v >> 16 ) | ( v << 16); + return v; + } + + /*! bit interleave operation */ + template + __forceinline T bitInterleave(const T& xin, const T& yin, const T& zin) + { + T x = xin, y = yin, z = zin; + x = (x | (x << 16)) & 0x030000FF; + x = (x | (x << 8)) & 0x0300F00F; + x = (x | (x << 4)) & 0x030C30C3; + x = (x | (x << 2)) & 0x09249249; + + y = (y | (y << 16)) & 0x030000FF; + y = (y | (y << 8)) & 0x0300F00F; + y = (y | (y << 4)) & 0x030C30C3; + y = (y | (y << 2)) & 0x09249249; + + z = (z | (z << 16)) & 0x030000FF; + z = (z | (z << 8)) & 0x0300F00F; + z = (z | (z << 4)) & 0x030C30C3; + z = (z | (z << 2)) & 0x09249249; + + return x | (y << 1) | (z << 2); + } + +#if defined(__AVX2__) + + template<> + __forceinline unsigned int bitInterleave(const unsigned int &xi, const unsigned int& yi, const unsigned int& zi) + { + const unsigned int xx = pdep(xi,0x49249249 /* 0b01001001001001001001001001001001 */ ); + const unsigned int yy = pdep(yi,0x92492492 /* 0b10010010010010010010010010010010 */); + const unsigned int zz = pdep(zi,0x24924924 /* 0b00100100100100100100100100100100 */); + return xx | yy | zz; + } + +#endif + + /*! bit interleave operation for 64bit data types*/ + template + __forceinline T bitInterleave64(const T& xin, const T& yin, const T& zin){ + T x = xin & 0x1fffff; + T y = yin & 0x1fffff; + T z = zin & 0x1fffff; + + x = (x | x << 32) & 0x1f00000000ffff; + x = (x | x << 16) & 0x1f0000ff0000ff; + x = (x | x << 8) & 0x100f00f00f00f00f; + x = (x | x << 4) & 0x10c30c30c30c30c3; + x = (x | x << 2) & 0x1249249249249249; + + y = (y | y << 32) & 0x1f00000000ffff; + y = (y | y << 16) & 0x1f0000ff0000ff; + y = (y | y << 8) & 0x100f00f00f00f00f; + y = (y | y << 4) & 0x10c30c30c30c30c3; + y = (y | y << 2) & 0x1249249249249249; + + z = (z | z << 32) & 0x1f00000000ffff; + z = (z | z << 16) & 0x1f0000ff0000ff; + z = (z | z << 8) & 0x100f00f00f00f00f; + z = (z | z << 4) & 0x10c30c30c30c30c3; + z = (z | z << 2) & 0x1249249249249249; + + return x | (y << 1) | (z << 2); + } +} diff --git a/thirdparty/embree/common/math/obbox.h b/thirdparty/embree/common/math/obbox.h new file mode 100644 index 000000000000..032b56904eef --- /dev/null +++ b/thirdparty/embree/common/math/obbox.h @@ -0,0 +1,39 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bbox.h" +#include "linearspace3.h" + +namespace embree +{ + /*! Oriented bounding box */ + template + struct OBBox + { + public: + + __forceinline OBBox () {} + + __forceinline OBBox (EmptyTy) + : space(one), bounds(empty) {} + + __forceinline OBBox (const BBox& bounds) + : space(one), bounds(bounds) {} + + __forceinline OBBox (const LinearSpace3& space, const BBox& bounds) + : space(space), bounds(bounds) {} + + friend embree_ostream operator<<(embree_ostream cout, const OBBox& p) { + return cout << "{ space = " << p.space << ", bounds = " << p.bounds << "}"; + } + + public: + LinearSpace3 space; //!< orthonormal transformation + BBox bounds; //!< bounds in transformed space + }; + + typedef OBBox OBBox3f; + typedef OBBox OBBox3fa; +} diff --git a/thirdparty/embree/common/math/quaternion.h b/thirdparty/embree/common/math/quaternion.h new file mode 100644 index 000000000000..20c69bc62f45 --- /dev/null +++ b/thirdparty/embree/common/math/quaternion.h @@ -0,0 +1,254 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "vec3.h" +#include "vec4.h" + +#include "transcendental.h" + +namespace embree +{ + //////////////////////////////////////////////////////////////// + // Quaternion Struct + //////////////////////////////////////////////////////////////// + + template + struct QuaternionT + { + typedef Vec3 Vector; + + //////////////////////////////////////////////////////////////////////////////// + /// Construction + //////////////////////////////////////////////////////////////////////////////// + + __forceinline QuaternionT () { } + __forceinline QuaternionT ( const QuaternionT& other ) { r = other.r; i = other.i; j = other.j; k = other.k; } + __forceinline QuaternionT& operator=( const QuaternionT& other ) { r = other.r; i = other.i; j = other.j; k = other.k; return *this; } + + __forceinline QuaternionT( const T& r ) : r(r), i(zero), j(zero), k(zero) {} + __forceinline explicit QuaternionT( const Vec3& v ) : r(zero), i(v.x), j(v.y), k(v.z) {} + __forceinline explicit QuaternionT( const Vec4& v ) : r(v.x), i(v.y), j(v.z), k(v.w) {} + __forceinline QuaternionT( const T& r, const T& i, const T& j, const T& k ) : r(r), i(i), j(j), k(k) {} + __forceinline QuaternionT( const T& r, const Vec3& v ) : r(r), i(v.x), j(v.y), k(v.z) {} + + __inline QuaternionT( const Vec3& vx, const Vec3& vy, const Vec3& vz ); + __inline QuaternionT( const T& yaw, const T& pitch, const T& roll ); + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline QuaternionT( ZeroTy ) : r(zero), i(zero), j(zero), k(zero) {} + __forceinline QuaternionT( OneTy ) : r( one), i(zero), j(zero), k(zero) {} + + /*! return quaternion for rotation around arbitrary axis */ + static __forceinline QuaternionT rotate(const Vec3& u, const T& r) { + return QuaternionT(cos(T(0.5)*r),sin(T(0.5)*r)*normalize(u)); + } + + /*! returns the rotation axis of the quaternion as a vector */ + __forceinline Vec3 v( ) const { return Vec3(i, j, k); } + + public: + T r, i, j, k; + }; + + template __forceinline QuaternionT operator *( const T & a, const QuaternionT& b ) { return QuaternionT(a * b.r, a * b.i, a * b.j, a * b.k); } + template __forceinline QuaternionT operator *( const QuaternionT& a, const T & b ) { return QuaternionT(a.r * b, a.i * b, a.j * b, a.k * b); } + + //////////////////////////////////////////////////////////////// + // Unary Operators + //////////////////////////////////////////////////////////////// + + template __forceinline QuaternionT operator +( const QuaternionT& a ) { return QuaternionT(+a.r, +a.i, +a.j, +a.k); } + template __forceinline QuaternionT operator -( const QuaternionT& a ) { return QuaternionT(-a.r, -a.i, -a.j, -a.k); } + template __forceinline QuaternionT conj ( const QuaternionT& a ) { return QuaternionT(a.r, -a.i, -a.j, -a.k); } + template __forceinline T abs ( const QuaternionT& a ) { return sqrt(a.r*a.r + a.i*a.i + a.j*a.j + a.k*a.k); } + template __forceinline QuaternionT rcp ( const QuaternionT& a ) { return conj(a)*rcp(a.r*a.r + a.i*a.i + a.j*a.j + a.k*a.k); } + template __forceinline QuaternionT normalize ( const QuaternionT& a ) { return a*rsqrt(a.r*a.r + a.i*a.i + a.j*a.j + a.k*a.k); } + + // evaluates a*q-r + template __forceinline QuaternionT + msub(const T& a, const QuaternionT& q, const QuaternionT& p) + { + return QuaternionT(msub(a, q.r, p.r), + msub(a, q.i, p.i), + msub(a, q.j, p.j), + msub(a, q.k, p.k)); + } + // evaluates a*q-r + template __forceinline QuaternionT + madd (const T& a, const QuaternionT& q, const QuaternionT& p) + { + return QuaternionT(madd(a, q.r, p.r), + madd(a, q.i, p.i), + madd(a, q.j, p.j), + madd(a, q.k, p.k)); + } + + //////////////////////////////////////////////////////////////// + // Binary Operators + //////////////////////////////////////////////////////////////// + + template __forceinline QuaternionT operator +( const T & a, const QuaternionT& b ) { return QuaternionT(a + b.r, b.i, b.j, b.k); } + template __forceinline QuaternionT operator +( const QuaternionT& a, const T & b ) { return QuaternionT(a.r + b, a.i, a.j, a.k); } + template __forceinline QuaternionT operator +( const QuaternionT& a, const QuaternionT& b ) { return QuaternionT(a.r + b.r, a.i + b.i, a.j + b.j, a.k + b.k); } + template __forceinline QuaternionT operator -( const T & a, const QuaternionT& b ) { return QuaternionT(a - b.r, -b.i, -b.j, -b.k); } + template __forceinline QuaternionT operator -( const QuaternionT& a, const T & b ) { return QuaternionT(a.r - b, a.i, a.j, a.k); } + template __forceinline QuaternionT operator -( const QuaternionT& a, const QuaternionT& b ) { return QuaternionT(a.r - b.r, a.i - b.i, a.j - b.j, a.k - b.k); } + + template __forceinline Vec3 operator *( const QuaternionT& a, const Vec3 & b ) { return (a*QuaternionT(b)*conj(a)).v(); } + template __forceinline QuaternionT operator *( const QuaternionT& a, const QuaternionT& b ) { + return QuaternionT(a.r*b.r - a.i*b.i - a.j*b.j - a.k*b.k, + a.r*b.i + a.i*b.r + a.j*b.k - a.k*b.j, + a.r*b.j - a.i*b.k + a.j*b.r + a.k*b.i, + a.r*b.k + a.i*b.j - a.j*b.i + a.k*b.r); + } + template __forceinline QuaternionT operator /( const T & a, const QuaternionT& b ) { return a*rcp(b); } + template __forceinline QuaternionT operator /( const QuaternionT& a, const T & b ) { return a*rcp(b); } + template __forceinline QuaternionT operator /( const QuaternionT& a, const QuaternionT& b ) { return a*rcp(b); } + + template __forceinline QuaternionT& operator +=( QuaternionT& a, const T & b ) { return a = a+b; } + template __forceinline QuaternionT& operator +=( QuaternionT& a, const QuaternionT& b ) { return a = a+b; } + template __forceinline QuaternionT& operator -=( QuaternionT& a, const T & b ) { return a = a-b; } + template __forceinline QuaternionT& operator -=( QuaternionT& a, const QuaternionT& b ) { return a = a-b; } + template __forceinline QuaternionT& operator *=( QuaternionT& a, const T & b ) { return a = a*b; } + template __forceinline QuaternionT& operator *=( QuaternionT& a, const QuaternionT& b ) { return a = a*b; } + template __forceinline QuaternionT& operator /=( QuaternionT& a, const T & b ) { return a = a*rcp(b); } + template __forceinline QuaternionT& operator /=( QuaternionT& a, const QuaternionT& b ) { return a = a*rcp(b); } + + template __forceinline QuaternionT + select(const M& m, const QuaternionT& q, const QuaternionT& p) + { + return QuaternionT(select(m, q.r, p.r), + select(m, q.i, p.i), + select(m, q.j, p.j), + select(m, q.k, p.k)); + } + + + template __forceinline Vec3 xfmPoint ( const QuaternionT& a, const Vec3& b ) { return (a*QuaternionT(b)*conj(a)).v(); } + template __forceinline Vec3 xfmVector( const QuaternionT& a, const Vec3& b ) { return (a*QuaternionT(b)*conj(a)).v(); } + template __forceinline Vec3 xfmNormal( const QuaternionT& a, const Vec3& b ) { return (a*QuaternionT(b)*conj(a)).v(); } + + template __forceinline T dot(const QuaternionT& a, const QuaternionT& b) { return a.r*b.r + a.i*b.i + a.j*b.j + a.k*b.k; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline bool operator ==( const QuaternionT& a, const QuaternionT& b ) { return a.r == b.r && a.i == b.i && a.j == b.j && a.k == b.k; } + template __forceinline bool operator !=( const QuaternionT& a, const QuaternionT& b ) { return a.r != b.r || a.i != b.i || a.j != b.j || a.k != b.k; } + + + //////////////////////////////////////////////////////////////////////////////// + /// Orientation Functions + //////////////////////////////////////////////////////////////////////////////// + + template QuaternionT::QuaternionT( const Vec3& vx, const Vec3& vy, const Vec3& vz ) + { + if ( vx.x + vy.y + vz.z >= T(zero) ) + { + const T t = T(one) + (vx.x + vy.y + vz.z); + const T s = rsqrt(t)*T(0.5f); + r = t*s; + i = (vy.z - vz.y)*s; + j = (vz.x - vx.z)*s; + k = (vx.y - vy.x)*s; + } + else if ( vx.x >= max(vy.y, vz.z) ) + { + const T t = (T(one) + vx.x) - (vy.y + vz.z); + const T s = rsqrt(t)*T(0.5f); + r = (vy.z - vz.y)*s; + i = t*s; + j = (vx.y + vy.x)*s; + k = (vz.x + vx.z)*s; + } + else if ( vy.y >= vz.z ) // if ( vy.y >= max(vz.z, vx.x) ) + { + const T t = (T(one) + vy.y) - (vz.z + vx.x); + const T s = rsqrt(t)*T(0.5f); + r = (vz.x - vx.z)*s; + i = (vx.y + vy.x)*s; + j = t*s; + k = (vy.z + vz.y)*s; + } + else //if ( vz.z >= max(vy.y, vx.x) ) + { + const T t = (T(one) + vz.z) - (vx.x + vy.y); + const T s = rsqrt(t)*T(0.5f); + r = (vx.y - vy.x)*s; + i = (vz.x + vx.z)*s; + j = (vy.z + vz.y)*s; + k = t*s; + } + } + + template QuaternionT::QuaternionT( const T& yaw, const T& pitch, const T& roll ) + { + const T cya = cos(yaw *T(0.5f)); + const T cpi = cos(pitch*T(0.5f)); + const T cro = cos(roll *T(0.5f)); + const T sya = sin(yaw *T(0.5f)); + const T spi = sin(pitch*T(0.5f)); + const T sro = sin(roll *T(0.5f)); + r = cro*cya*cpi + sro*sya*spi; + i = cro*cya*spi + sro*sya*cpi; + j = cro*sya*cpi - sro*cya*spi; + k = sro*cya*cpi - cro*sya*spi; + } + + ////////////////////////////////////////////////////////////////////////////// + /// Output Operators + ////////////////////////////////////////////////////////////////////////////// + + template static embree_ostream operator<<(embree_ostream cout, const QuaternionT& q) { + return cout << "{ r = " << q.r << ", i = " << q.i << ", j = " << q.j << ", k = " << q.k << " }"; + } + + /*! default template instantiations */ + typedef QuaternionT Quaternion3f; + typedef QuaternionT Quaternion3d; + + template using Quaternion3vf = QuaternionT>; + typedef QuaternionT> Quaternion3vf4; + typedef QuaternionT> Quaternion3vf8; + typedef QuaternionT> Quaternion3vf16; + + ////////////////////////////////////////////////////////////////////////////// + /// Interpolation + ////////////////////////////////////////////////////////////////////////////// + template + __forceinline QuaternionTlerp(const QuaternionT& q0, + const QuaternionT& q1, + const T& factor) + { + QuaternionT q; + q.r = lerp(q0.r, q1.r, factor); + q.i = lerp(q0.i, q1.i, factor); + q.j = lerp(q0.j, q1.j, factor); + q.k = lerp(q0.k, q1.k, factor); + return q; + } + + template + __forceinline QuaternionT slerp(const QuaternionT& q0, + const QuaternionT& q1_, + const T& t) + { + T cosTheta = dot(q0, q1_); + QuaternionT q1 = select(cosTheta < 0.f, -q1_, q1_); + cosTheta = select(cosTheta < 0.f, -cosTheta, cosTheta); + if (unlikely(all(cosTheta > 0.9995f))) { + return normalize(lerp(q0, q1, t)); + } + const T phi = t * fastapprox::acos(cosTheta); + T sinPhi, cosPhi; + fastapprox::sincos(phi, sinPhi, cosPhi); + QuaternionT qperp = sinPhi * normalize(msub(cosTheta, q0, q1)); + return msub(cosPhi, q0, qperp); + } +} diff --git a/thirdparty/embree/common/math/range.h b/thirdparty/embree/common/math/range.h new file mode 100644 index 000000000000..762d9cd9eab5 --- /dev/null +++ b/thirdparty/embree/common/math/range.h @@ -0,0 +1,137 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/platform.h" +#include "../math/math.h" + +namespace embree +{ + template + struct range + { + __forceinline range() {} + + __forceinline range(const Ty& begin) + : _begin(begin), _end(begin+1) {} + + __forceinline range(const Ty& begin, const Ty& end) + : _begin(begin), _end(end) {} + + __forceinline range(const range& other) + : _begin(other._begin), _end(other._end) {} + + template + __forceinline range(const range& other) + : _begin(Ty(other._begin)), _end(Ty(other._end)) {} + + template + __forceinline range& operator =(const range& other) { + _begin = other._begin; + _end = other._end; + return *this; + } + + __forceinline Ty begin() const { + return _begin; + } + + __forceinline Ty end() const { + return _end; + } + + __forceinline range intersect(const range& r) const { + return range (max(_begin,r._begin),min(_end,r._end)); + } + + __forceinline Ty size() const { + return _end - _begin; + } + + __forceinline bool empty() const { + return _end <= _begin; + } + + __forceinline Ty center() const { + return (_begin + _end)/2; + } + + __forceinline std::pair split() const + { + const Ty _center = center(); + return std::make_pair(range(_begin,_center),range(_center,_end)); + } + + __forceinline void split(range& left_o, range& right_o) const + { + const Ty _center = center(); + left_o = range(_begin,_center); + right_o = range(_center,_end); + } + + __forceinline friend bool operator< (const range& r0, const range& r1) { + return r0.size() < r1.size(); + } + + friend embree_ostream operator<<(embree_ostream cout, const range& r) { + return cout << "range [" << r.begin() << ", " << r.end() << "]"; + } + + Ty _begin, _end; + }; + + template + range make_range(const Ty& begin, const Ty& end) { + return range(begin,end); + } + + template + struct extended_range : public range + { + __forceinline extended_range () {} + + __forceinline extended_range (const Ty& begin) + : range(begin), _ext_end(begin+1) {} + + __forceinline extended_range (const Ty& begin, const Ty& end) + : range(begin,end), _ext_end(end) {} + + __forceinline extended_range (const Ty& begin, const Ty& end, const Ty& ext_end) + : range(begin,end), _ext_end(ext_end) {} + + __forceinline Ty ext_end() const { + return _ext_end; + } + + __forceinline Ty ext_size() const { + return _ext_end - range::_begin; + } + + __forceinline Ty ext_range_size() const { + return _ext_end - range::_end; + } + + __forceinline bool has_ext_range() const { + assert(_ext_end >= range::_end); + return (_ext_end - range::_end) > 0; + } + + __forceinline void set_ext_range(const size_t ext_end){ + assert(ext_end >= range::_end); + _ext_end = ext_end; + } + + __forceinline void move_right(const size_t plus){ + range::_begin += plus; + range::_end += plus; + _ext_end += plus; + } + + friend embree_ostream operator<<(embree_ostream cout, const extended_range& r) { + return cout << "extended_range [" << r.begin() << ", " << r.end() << " (" << r.ext_end() << ")]"; + } + + Ty _ext_end; + }; +} diff --git a/thirdparty/embree/common/math/transcendental.h b/thirdparty/embree/common/math/transcendental.h new file mode 100644 index 000000000000..6855d82b5341 --- /dev/null +++ b/thirdparty/embree/common/math/transcendental.h @@ -0,0 +1,525 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +// Transcendental functions from "ispc": https://github.com/ispc/ispc/ +// Most of the transcendental implementations in ispc code come from +// Solomon Boulos's "syrah": https://github.com/boulos/syrah/ + +#include "../simd/simd.h" + +namespace embree +{ + +namespace fastapprox +{ + +template +__forceinline T sin(const T &v) +{ + static const float piOverTwoVec = 1.57079637050628662109375; + static const float twoOverPiVec = 0.636619746685028076171875; + auto scaled = v * twoOverPiVec; + auto kReal = floor(scaled); + auto k = toInt(kReal); + + // Reduced range version of x + auto x = v - kReal * piOverTwoVec; + auto kMod4 = k & 3; + auto sinUseCos = (kMod4 == 1 | kMod4 == 3); + auto flipSign = (kMod4 > 1); + + // These coefficients are from sollya with fpminimax(sin(x)/x, [|0, 2, + // 4, 6, 8, 10|], [|single...|], [0;Pi/2]); + static const float sinC2 = -0.16666667163372039794921875; + static const float sinC4 = +8.333347737789154052734375e-3; + static const float sinC6 = -1.9842604524455964565277099609375e-4; + static const float sinC8 = +2.760012648650445044040679931640625e-6; + static const float sinC10 = -2.50293279435709337121807038784027099609375e-8; + + static const float cosC2 = -0.5; + static const float cosC4 = +4.166664183139801025390625e-2; + static const float cosC6 = -1.388833043165504932403564453125e-3; + static const float cosC8 = +2.47562347794882953166961669921875e-5; + static const float cosC10 = -2.59630184018533327616751194000244140625e-7; + + auto outside = select(sinUseCos, 1., x); + auto c2 = select(sinUseCos, T(cosC2), T(sinC2)); + auto c4 = select(sinUseCos, T(cosC4), T(sinC4)); + auto c6 = select(sinUseCos, T(cosC6), T(sinC6)); + auto c8 = select(sinUseCos, T(cosC8), T(sinC8)); + auto c10 = select(sinUseCos, T(cosC10), T(sinC10)); + + auto x2 = x * x; + auto formula = x2 * c10 + c8; + formula = x2 * formula + c6; + formula = x2 * formula + c4; + formula = x2 * formula + c2; + formula = x2 * formula + 1.; + formula *= outside; + + formula = select(flipSign, -formula, formula); + return formula; +} + +template +__forceinline T cos(const T &v) +{ + static const float piOverTwoVec = 1.57079637050628662109375; + static const float twoOverPiVec = 0.636619746685028076171875; + auto scaled = v * twoOverPiVec; + auto kReal = floor(scaled); + auto k = toInt(kReal); + + // Reduced range version of x + auto x = v - kReal * piOverTwoVec; + + auto kMod4 = k & 3; + auto cosUseCos = (kMod4 == 0 | kMod4 == 2); + auto flipSign = (kMod4 == 1 | kMod4 == 2); + + const float sinC2 = -0.16666667163372039794921875; + const float sinC4 = +8.333347737789154052734375e-3; + const float sinC6 = -1.9842604524455964565277099609375e-4; + const float sinC8 = +2.760012648650445044040679931640625e-6; + const float sinC10 = -2.50293279435709337121807038784027099609375e-8; + + const float cosC2 = -0.5; + const float cosC4 = +4.166664183139801025390625e-2; + const float cosC6 = -1.388833043165504932403564453125e-3; + const float cosC8 = +2.47562347794882953166961669921875e-5; + const float cosC10 = -2.59630184018533327616751194000244140625e-7; + + auto outside = select(cosUseCos, 1., x); + auto c2 = select(cosUseCos, T(cosC2), T(sinC2)); + auto c4 = select(cosUseCos, T(cosC4), T(sinC4)); + auto c6 = select(cosUseCos, T(cosC6), T(sinC6)); + auto c8 = select(cosUseCos, T(cosC8), T(sinC8)); + auto c10 = select(cosUseCos, T(cosC10), T(sinC10)); + + auto x2 = x * x; + auto formula = x2 * c10 + c8; + formula = x2 * formula + c6; + formula = x2 * formula + c4; + formula = x2 * formula + c2; + formula = x2 * formula + 1.; + formula *= outside; + + formula = select(flipSign, -formula, formula); + return formula; +} + +template +__forceinline void sincos(const T &v, T &sinResult, T &cosResult) +{ + const float piOverTwoVec = 1.57079637050628662109375; + const float twoOverPiVec = 0.636619746685028076171875; + auto scaled = v * twoOverPiVec; + auto kReal = floor(scaled); + auto k = toInt(kReal); + + // Reduced range version of x + auto x = v - kReal * piOverTwoVec; + auto kMod4 = k & 3; + auto cosUseCos = ((kMod4 == 0) | (kMod4 == 2)); + auto sinUseCos = ((kMod4 == 1) | (kMod4 == 3)); + auto sinFlipSign = (kMod4 > 1); + auto cosFlipSign = ((kMod4 == 1) | (kMod4 == 2)); + + const float oneVec = +1.; + const float sinC2 = -0.16666667163372039794921875; + const float sinC4 = +8.333347737789154052734375e-3; + const float sinC6 = -1.9842604524455964565277099609375e-4; + const float sinC8 = +2.760012648650445044040679931640625e-6; + const float sinC10 = -2.50293279435709337121807038784027099609375e-8; + + const float cosC2 = -0.5; + const float cosC4 = +4.166664183139801025390625e-2; + const float cosC6 = -1.388833043165504932403564453125e-3; + const float cosC8 = +2.47562347794882953166961669921875e-5; + const float cosC10 = -2.59630184018533327616751194000244140625e-7; + + auto x2 = x * x; + + auto sinFormula = x2 * sinC10 + sinC8; + auto cosFormula = x2 * cosC10 + cosC8; + sinFormula = x2 * sinFormula + sinC6; + cosFormula = x2 * cosFormula + cosC6; + + sinFormula = x2 * sinFormula + sinC4; + cosFormula = x2 * cosFormula + cosC4; + + sinFormula = x2 * sinFormula + sinC2; + cosFormula = x2 * cosFormula + cosC2; + + sinFormula = x2 * sinFormula + oneVec; + cosFormula = x2 * cosFormula + oneVec; + + sinFormula *= x; + + sinResult = select(sinUseCos, cosFormula, sinFormula); + cosResult = select(cosUseCos, cosFormula, sinFormula); + + sinResult = select(sinFlipSign, -sinResult, sinResult); + cosResult = select(cosFlipSign, -cosResult, cosResult); +} + +template +__forceinline T tan(const T &v) +{ + const float piOverFourVec = 0.785398185253143310546875; + const float fourOverPiVec = 1.27323949337005615234375; + + auto xLt0 = v < 0.; + auto y = select(xLt0, -v, v); + auto scaled = y * fourOverPiVec; + + auto kReal = floor(scaled); + auto k = toInt(kReal); + + auto x = y - kReal * piOverFourVec; + + // If k & 1, x -= Pi/4 + auto needOffset = (k & 1) != 0; + x = select(needOffset, x - piOverFourVec, x); + + // If k & 3 == (0 or 3) let z = tan_In...(y) otherwise z = -cot_In0To... + auto kMod4 = k & 3; + auto useCotan = (kMod4 == 1) | (kMod4 == 2); + + const float oneVec = 1.0; + + const float tanC2 = +0.33333075046539306640625; + const float tanC4 = +0.13339905440807342529296875; + const float tanC6 = +5.3348250687122344970703125e-2; + const float tanC8 = +2.46033705770969390869140625e-2; + const float tanC10 = +2.892402000725269317626953125e-3; + const float tanC12 = +9.500005282461643218994140625e-3; + + const float cotC2 = -0.3333333432674407958984375; + const float cotC4 = -2.222204394638538360595703125e-2; + const float cotC6 = -2.11752182804048061370849609375e-3; + const float cotC8 = -2.0846328698098659515380859375e-4; + const float cotC10 = -2.548247357481159269809722900390625e-5; + const float cotC12 = -3.5257363606433500535786151885986328125e-7; + + auto x2 = x * x; + T z; + if (any(useCotan)) + { + auto cotVal = x2 * cotC12 + cotC10; + cotVal = x2 * cotVal + cotC8; + cotVal = x2 * cotVal + cotC6; + cotVal = x2 * cotVal + cotC4; + cotVal = x2 * cotVal + cotC2; + cotVal = x2 * cotVal + oneVec; + // The equation is for x * cot(x) but we need -x * cot(x) for the tan part. + cotVal /= -x; + z = cotVal; + } + auto useTan = !useCotan; + if (any(useTan)) + { + auto tanVal = x2 * tanC12 + tanC10; + tanVal = x2 * tanVal + tanC8; + tanVal = x2 * tanVal + tanC6; + tanVal = x2 * tanVal + tanC4; + tanVal = x2 * tanVal + tanC2; + tanVal = x2 * tanVal + oneVec; + // Equation was for tan(x)/x + tanVal *= x; + z = select(useTan, tanVal, z); + } + return select(xLt0, -z, z); +} + +template +__forceinline T asin(const T &x0) +{ + auto isneg = (x0 < 0.f); + auto x = abs(x0); + auto isnan = (x > 1.f); + + // sollya + // fpminimax(((asin(x)-pi/2)/-sqrt(1-x)), [|0,1,2,3,4,5|],[|single...|], + // [1e-20;.9999999999999999]); + // avg error: 1.1105439e-06, max error 1.3187528e-06 + auto v = 1.57079517841339111328125f + + x * (-0.21450997889041900634765625f + + x * (8.78556668758392333984375e-2f + + x * (-4.489909112453460693359375e-2f + + x * (1.928029954433441162109375e-2f + + x * (-4.3095736764371395111083984375e-3f))))); + + v *= -sqrt(1.f - x); + v = v + 1.57079637050628662109375f; + + v = select(v < 0.f, T(0.f), v); + v = select(isneg, -v, v); + v = select(isnan, T(cast_i2f(0x7fc00000)), v); + + return v; +} + +template +__forceinline T acos(const T &v) +{ + return 1.57079637050628662109375f - asin(v); +} + +template +__forceinline T atan(const T &v) +{ + const float piOverTwoVec = 1.57079637050628662109375; + // atan(-x) = -atan(x) (so flip from negative to positive first) + // If x > 1 -> atan(x) = Pi/2 - atan(1/x) + auto xNeg = v < 0.f; + auto xFlipped = select(xNeg, -v, v); + + auto xGt1 = xFlipped > 1.; + auto x = select(xGt1, rcpSafe(xFlipped), xFlipped); + + // These coefficients approximate atan(x)/x + const float atanC0 = +0.99999988079071044921875; + const float atanC2 = -0.3333191573619842529296875; + const float atanC4 = +0.199689209461212158203125; + const float atanC6 = -0.14015688002109527587890625; + const float atanC8 = +9.905083477497100830078125e-2; + const float atanC10 = -5.93664981424808502197265625e-2; + const float atanC12 = +2.417283318936824798583984375e-2; + const float atanC14 = -4.6721356920897960662841796875e-3; + + auto x2 = x * x; + auto result = x2 * atanC14 + atanC12; + result = x2 * result + atanC10; + result = x2 * result + atanC8; + result = x2 * result + atanC6; + result = x2 * result + atanC4; + result = x2 * result + atanC2; + result = x2 * result + atanC0; + result *= x; + + result = select(xGt1, piOverTwoVec - result, result); + result = select(xNeg, -result, result); + return result; +} + +template +__forceinline T atan2(const T &y, const T &x) +{ + const float piVec = 3.1415926536; + // atan2(y, x) = + // + // atan2(y > 0, x = +-0) -> Pi/2 + // atan2(y < 0, x = +-0) -> -Pi/2 + // atan2(y = +-0, x < +0) -> +-Pi + // atan2(y = +-0, x >= +0) -> +-0 + // + // atan2(y >= 0, x < 0) -> Pi + atan(y/x) + // atan2(y < 0, x < 0) -> -Pi + atan(y/x) + // atan2(y, x > 0) -> atan(y/x) + // + // and then a bunch of code for dealing with infinities. + auto yOverX = y * rcpSafe(x); + auto atanArg = atan(yOverX); + auto xLt0 = x < 0.f; + auto yLt0 = y < 0.f; + auto offset = select(xLt0, + select(yLt0, T(-piVec), T(piVec)), 0.f); + return offset + atanArg; +} + +template +__forceinline T exp(const T &v) +{ + const float ln2Part1 = 0.6931457519; + const float ln2Part2 = 1.4286067653e-6; + const float oneOverLn2 = 1.44269502162933349609375; + + auto scaled = v * oneOverLn2; + auto kReal = floor(scaled); + auto k = toInt(kReal); + + // Reduced range version of x + auto x = v - kReal * ln2Part1; + x -= kReal * ln2Part2; + + // These coefficients are for e^x in [0, ln(2)] + const float one = 1.; + const float c2 = 0.4999999105930328369140625; + const float c3 = 0.166668415069580078125; + const float c4 = 4.16539050638675689697265625e-2; + const float c5 = 8.378830738365650177001953125e-3; + const float c6 = 1.304379315115511417388916015625e-3; + const float c7 = 2.7555381529964506626129150390625e-4; + + auto result = x * c7 + c6; + result = x * result + c5; + result = x * result + c4; + result = x * result + c3; + result = x * result + c2; + result = x * result + one; + result = x * result + one; + + // Compute 2^k (should differ for float and double, but I'll avoid + // it for now and just do floats) + const int fpbias = 127; + auto biasedN = k + fpbias; + auto overflow = kReal > fpbias; + // Minimum exponent is -126, so if k is <= -127 (k + 127 <= 0) + // we've got underflow. -127 * ln(2) -> -88.02. So the most + // negative float input that doesn't result in zero is like -88. + auto underflow = kReal <= -fpbias; + const int infBits = 0x7f800000; + biasedN <<= 23; + // Reinterpret this thing as float + auto twoToTheN = asFloat(biasedN); + // Handle both doubles and floats (hopefully eliding the copy for float) + auto elemtype2n = twoToTheN; + result *= elemtype2n; + result = select(overflow, cast_i2f(infBits), result); + result = select(underflow, 0., result); + return result; +} + +// Range reduction for logarithms takes log(x) -> log(2^n * y) -> n +// * log(2) + log(y) where y is the reduced range (usually in [1/2, 1)). +template +__forceinline void __rangeReduceLog(const T &input, + T &reduced, + R &exponent) +{ + auto intVersion = asInt(input); + // single precision = SEEE EEEE EMMM MMMM MMMM MMMM MMMM MMMM + // exponent mask = 0111 1111 1000 0000 0000 0000 0000 0000 + // 0x7 0xF 0x8 0x0 0x0 0x0 0x0 0x0 + // non-exponent = 1000 0000 0111 1111 1111 1111 1111 1111 + // = 0x8 0x0 0x7 0xF 0xF 0xF 0xF 0xF + + //const int exponentMask(0x7F800000) + static const int nonexponentMask = 0x807FFFFF; + + // We want the reduced version to have an exponent of -1 which is + // -1 + 127 after biasing or 126 + static const int exponentNeg1 = (126l << 23); + // NOTE(boulos): We don't need to mask anything out since we know + // the sign bit has to be 0. If it's 1, we need to return infinity/nan + // anyway (log(x), x = +-0 -> infinity, x < 0 -> NaN). + auto biasedExponent = intVersion >> 23; // This number is [0, 255] but it means [-127, 128] + + auto offsetExponent = biasedExponent + 1; // Treat the number as if it were 2^{e+1} * (1.m)/2 + exponent = offsetExponent - 127; // get the real value + + // Blend the offset_exponent with the original input (do this in + // int for now, until I decide if float can have & and ¬) + auto blended = (intVersion & nonexponentMask) | (exponentNeg1); + reduced = asFloat(blended); +} + +template struct ExponentType { }; +template struct ExponentType> { typedef vint Ty; }; +template <> struct ExponentType { typedef int Ty; }; + +template +__forceinline T log(const T &v) +{ + T reduced; + typename ExponentType::Ty exponent; + + const int nanBits = 0x7fc00000; + const int negInfBits = 0xFF800000; + const float nan = cast_i2f(nanBits); + const float negInf = cast_i2f(negInfBits); + auto useNan = v < 0.; + auto useInf = v == 0.; + auto exceptional = useNan | useInf; + const float one = 1.0; + + auto patched = select(exceptional, one, v); + __rangeReduceLog(patched, reduced, exponent); + + const float ln2 = 0.693147182464599609375; + + auto x1 = one - reduced; + const float c1 = +0.50000095367431640625; + const float c2 = +0.33326041698455810546875; + const float c3 = +0.2519190013408660888671875; + const float c4 = +0.17541764676570892333984375; + const float c5 = +0.3424419462680816650390625; + const float c6 = -0.599632322788238525390625; + const float c7 = +1.98442304134368896484375; + const float c8 = -2.4899270534515380859375; + const float c9 = +1.7491014003753662109375; + + auto result = x1 * c9 + c8; + result = x1 * result + c7; + result = x1 * result + c6; + result = x1 * result + c5; + result = x1 * result + c4; + result = x1 * result + c3; + result = x1 * result + c2; + result = x1 * result + c1; + result = x1 * result + one; + + // Equation was for -(ln(red)/(1-red)) + result *= -x1; + result += toFloat(exponent) * ln2; + + return select(exceptional, + select(useNan, T(nan), T(negInf)), + result); +} + +template +__forceinline T pow(const T &x, const T &y) +{ + auto x1 = abs(x); + auto z = exp(y * log(x1)); + + // Handle special cases + const float twoOver23 = 8388608.0f; + auto yInt = y == round(y); + auto yOddInt = select(yInt, asInt(abs(y) + twoOver23) << 31, 0); // set sign bit + + // x == 0 + z = select(x == 0.0f, + select(y < 0.0f, T(inf) | signmsk(x), + select(y == 0.0f, T(1.0f), asFloat(yOddInt) & x)), z); + + // x < 0 + auto xNegative = x < 0.0f; + if (any(xNegative)) + { + auto z1 = z | asFloat(yOddInt); + z1 = select(yInt, z1, std::numeric_limits::quiet_NaN()); + z = select(xNegative, z1, z); + } + + auto xFinite = isfinite(x); + auto yFinite = isfinite(y); + if (all(xFinite & yFinite)) + return z; + + // x finite and y infinite + z = select(andn(xFinite, yFinite), + select(x1 == 1.0f, 1.0f, + select((x1 > 1.0f) ^ (y < 0.0f), inf, T(0.0f))), z); + + // x infinite + z = select(xFinite, z, + select(y == 0.0f, 1.0f, + select(y < 0.0f, T(0.0f), inf) | (asFloat(yOddInt) & x))); + + return z; +} + +template +__forceinline T pow(const T &x, float y) +{ + return pow(x, T(y)); +} + +} // namespace fastapprox + +} // namespace embree diff --git a/thirdparty/embree/common/math/vec2.h b/thirdparty/embree/common/math/vec2.h new file mode 100644 index 000000000000..0ecf8c6384d5 --- /dev/null +++ b/thirdparty/embree/common/math/vec2.h @@ -0,0 +1,235 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "math.h" + +namespace embree +{ + struct Vec2fa; + + //////////////////////////////////////////////////////////////////////////////// + /// Generic 2D vector Class + //////////////////////////////////////////////////////////////////////////////// + + template struct Vec2 + { + enum { N = 2 }; + union { + struct { T x, y; }; +#if !(defined(__WIN32__) && _MSC_VER == 1800) // workaround for older VS 2013 compiler + T components[N]; +#endif + }; + + typedef T Scalar; + + //////////////////////////////////////////////////////////////////////////////// + /// Construction + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec2( ) {} + __forceinline explicit Vec2( const T& a ) : x(a), y(a) {} + __forceinline Vec2( const T& x, const T& y ) : x(x), y(y) {} + + __forceinline Vec2( const Vec2& other ) { x = other.x; y = other.y; } + __forceinline Vec2( const Vec2fa& other ); + + template __forceinline Vec2( const Vec2& a ) : x(T(a.x)), y(T(a.y)) {} + template __forceinline Vec2& operator =( const Vec2& other ) { x = other.x; y = other.y; return *this; } + + __forceinline Vec2& operator =( const Vec2& other ) { x = other.x; y = other.y; return *this; } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec2( ZeroTy ) : x(zero), y(zero) {} + __forceinline Vec2( OneTy ) : x(one), y(one) {} + __forceinline Vec2( PosInfTy ) : x(pos_inf), y(pos_inf) {} + __forceinline Vec2( NegInfTy ) : x(neg_inf), y(neg_inf) {} + +#if defined(__WIN32__) && _MSC_VER == 1800 // workaround for older VS 2013 compiler + __forceinline const T& operator [](const size_t axis) const { assert(axis < 2); return (&x)[axis]; } + __forceinline T& operator [](const size_t axis) { assert(axis < 2); return (&x)[axis]; } +#else + __forceinline const T& operator [](const size_t axis) const { assert(axis < 2); return components[axis]; } + __forceinline T& operator [](const size_t axis ) { assert(axis < 2); return components[axis]; } +#endif + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec2 operator +( const Vec2& a ) { return Vec2(+a.x, +a.y); } + template __forceinline Vec2 operator -( const Vec2& a ) { return Vec2(-a.x, -a.y); } + template __forceinline Vec2 abs ( const Vec2& a ) { return Vec2(abs (a.x), abs (a.y)); } + template __forceinline Vec2 rcp ( const Vec2& a ) { return Vec2(rcp (a.x), rcp (a.y)); } + template __forceinline Vec2 rsqrt ( const Vec2& a ) { return Vec2(rsqrt(a.x), rsqrt(a.y)); } + template __forceinline Vec2 sqrt ( const Vec2& a ) { return Vec2(sqrt (a.x), sqrt (a.y)); } + template __forceinline Vec2 frac ( const Vec2& a ) { return Vec2(frac (a.x), frac (a.y)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec2 operator +( const Vec2& a, const Vec2& b ) { return Vec2(a.x + b.x, a.y + b.y); } + template __forceinline Vec2 operator +( const Vec2& a, const T& b ) { return Vec2(a.x + b , a.y + b ); } + template __forceinline Vec2 operator +( const T& a, const Vec2& b ) { return Vec2(a + b.x, a + b.y); } + template __forceinline Vec2 operator -( const Vec2& a, const Vec2& b ) { return Vec2(a.x - b.x, a.y - b.y); } + template __forceinline Vec2 operator -( const Vec2& a, const T& b ) { return Vec2(a.x - b , a.y - b ); } + template __forceinline Vec2 operator -( const T& a, const Vec2& b ) { return Vec2(a - b.x, a - b.y); } + template __forceinline Vec2 operator *( const Vec2& a, const Vec2& b ) { return Vec2(a.x * b.x, a.y * b.y); } + template __forceinline Vec2 operator *( const T& a, const Vec2& b ) { return Vec2(a * b.x, a * b.y); } + template __forceinline Vec2 operator *( const Vec2& a, const T& b ) { return Vec2(a.x * b , a.y * b ); } + template __forceinline Vec2 operator /( const Vec2& a, const Vec2& b ) { return Vec2(a.x / b.x, a.y / b.y); } + template __forceinline Vec2 operator /( const Vec2& a, const T& b ) { return Vec2(a.x / b , a.y / b ); } + template __forceinline Vec2 operator /( const T& a, const Vec2& b ) { return Vec2(a / b.x, a / b.y); } + + template __forceinline Vec2 min(const Vec2& a, const Vec2& b) { return Vec2(min(a.x, b.x), min(a.y, b.y)); } + template __forceinline Vec2 max(const Vec2& a, const Vec2& b) { return Vec2(max(a.x, b.x), max(a.y, b.y)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Ternary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec2 madd ( const Vec2& a, const Vec2& b, const Vec2& c) { return Vec2( madd(a.x,b.x,c.x), madd(a.y,b.y,c.y) ); } + template __forceinline Vec2 msub ( const Vec2& a, const Vec2& b, const Vec2& c) { return Vec2( msub(a.x,b.x,c.x), msub(a.y,b.y,c.y) ); } + template __forceinline Vec2 nmadd ( const Vec2& a, const Vec2& b, const Vec2& c) { return Vec2(nmadd(a.x,b.x,c.x),nmadd(a.y,b.y,c.y) ); } + template __forceinline Vec2 nmsub ( const Vec2& a, const Vec2& b, const Vec2& c) { return Vec2(nmsub(a.x,b.x,c.x),nmsub(a.y,b.y,c.y) ); } + + template __forceinline Vec2 madd ( const T& a, const Vec2& b, const Vec2& c) { return Vec2( madd(a,b.x,c.x), madd(a,b.y,c.y) ); } + template __forceinline Vec2 msub ( const T& a, const Vec2& b, const Vec2& c) { return Vec2( msub(a,b.x,c.x), msub(a,b.y,c.y) ); } + template __forceinline Vec2 nmadd ( const T& a, const Vec2& b, const Vec2& c) { return Vec2(nmadd(a,b.x,c.x),nmadd(a,b.y,c.y) ); } + template __forceinline Vec2 nmsub ( const T& a, const Vec2& b, const Vec2& c) { return Vec2(nmsub(a,b.x,c.x),nmsub(a,b.y,c.y) ); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec2& operator +=( Vec2& a, const Vec2& b ) { a.x += b.x; a.y += b.y; return a; } + template __forceinline Vec2& operator -=( Vec2& a, const Vec2& b ) { a.x -= b.x; a.y -= b.y; return a; } + template __forceinline Vec2& operator *=( Vec2& a, const T& b ) { a.x *= b ; a.y *= b ; return a; } + template __forceinline Vec2& operator /=( Vec2& a, const T& b ) { a.x /= b ; a.y /= b ; return a; } + + //////////////////////////////////////////////////////////////////////////////// + /// Reduction Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline T reduce_add( const Vec2& a ) { return a.x + a.y; } + template __forceinline T reduce_mul( const Vec2& a ) { return a.x * a.y; } + template __forceinline T reduce_min( const Vec2& a ) { return min(a.x, a.y); } + template __forceinline T reduce_max( const Vec2& a ) { return max(a.x, a.y); } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline bool operator ==( const Vec2& a, const Vec2& b ) { return a.x == b.x && a.y == b.y; } + template __forceinline bool operator !=( const Vec2& a, const Vec2& b ) { return a.x != b.x || a.y != b.y; } + template __forceinline bool operator < ( const Vec2& a, const Vec2& b ) { + if (a.x != b.x) return a.x < b.x; + if (a.y != b.y) return a.y < b.y; + return false; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Shift Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec2 shift_right_1( const Vec2& a ) { + return Vec2(shift_right_1(a.x),shift_right_1(a.y)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Euclidian Space Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline T dot ( const Vec2& a, const Vec2& b ) { return madd(a.x,b.x,a.y*b.y); } + template __forceinline Vec2 cross ( const Vec2& a ) { return Vec2(-a.y,a.x); } + template __forceinline T length ( const Vec2& a ) { return sqrt(dot(a,a)); } + template __forceinline Vec2 normalize( const Vec2& a ) { return a*rsqrt(dot(a,a)); } + template __forceinline T distance ( const Vec2& a, const Vec2& b ) { return length(a-b); } + template __forceinline T det ( const Vec2& a, const Vec2& b ) { return a.x*b.y - a.y*b.x; } + + template __forceinline Vec2 normalize_safe( const Vec2& a ) { + const T d = dot(a,a); return select(d == T( zero ),a, a*rsqrt(d) ); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Select + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec2 select ( bool s, const Vec2& t, const Vec2& f ) { + return Vec2(select(s,t.x,f.x),select(s,t.y,f.y)); + } + + template __forceinline Vec2 select ( const Vec2& s, const Vec2& t, const Vec2& f ) { + return Vec2(select(s.x,t.x,f.x),select(s.y,t.y,f.y)); + } + + template __forceinline Vec2 select ( const typename T::Bool& s, const Vec2& t, const Vec2& f ) { + return Vec2(select(s,t.x,f.x),select(s,t.y,f.y)); + } + + template + __forceinline Vec2 lerp(const Vec2& v0, const Vec2& v1, const T& t) { + return madd(Vec2(T(1.0f)-t),v0,t*v1); + } + + template __forceinline int maxDim ( const Vec2& a ) + { + const Vec2 b = abs(a); + if (b.x > b.y) return 0; + else return 1; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline embree_ostream operator<<(embree_ostream cout, const Vec2& a) { + return cout << "(" << a.x << ", " << a.y << ")"; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Default template instantiations + //////////////////////////////////////////////////////////////////////////////// + + typedef Vec2 Vec2b; + typedef Vec2 Vec2i; + typedef Vec2 Vec2f; +} + +#include "vec2fa.h" + +#if defined __SSE__ +#include "../simd/sse.h" +#endif + +#if defined __AVX__ +#include "../simd/avx.h" +#endif + +#if defined(__AVX512F__) +#include "../simd/avx512.h" +#endif + +namespace embree +{ + template<> __forceinline Vec2::Vec2(const Vec2fa& a) : x(a.x), y(a.y) {} + +#if defined(__SSE__) + template<> __forceinline Vec2::Vec2(const Vec2fa& a) : x(a.x), y(a.y) {} +#endif + +#if defined(__AVX__) + template<> __forceinline Vec2::Vec2(const Vec2fa& a) : x(a.x), y(a.y) {} +#endif + +#if defined(__AVX512F__) + template<> __forceinline Vec2::Vec2(const Vec2fa& a) : x(a.x), y(a.y) {} +#endif +} diff --git a/thirdparty/embree/common/math/vec2fa.h b/thirdparty/embree/common/math/vec2fa.h new file mode 100644 index 000000000000..6b1b6f33f237 --- /dev/null +++ b/thirdparty/embree/common/math/vec2fa.h @@ -0,0 +1,297 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/alloc.h" +#include "math.h" +#include "../simd/sse.h" + +namespace embree +{ + //////////////////////////////////////////////////////////////////////////////// + /// SSE Vec2fa Type + //////////////////////////////////////////////////////////////////////////////// + + struct __aligned(16) Vec2fa + { + ALIGNED_STRUCT_(16); + + typedef float Scalar; + enum { N = 2 }; + union { + __m128 m128; + struct { float x,y,az,aw; }; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec2fa( ) {} + __forceinline Vec2fa( const __m128 a ) : m128(a) {} + + __forceinline Vec2fa ( const Vec2& other ) { x = other.x; y = other.y; } + __forceinline Vec2fa& operator =( const Vec2& other ) { x = other.x; y = other.y; return *this; } + + __forceinline Vec2fa ( const Vec2fa& other ) { m128 = other.m128; } + __forceinline Vec2fa& operator =( const Vec2fa& other ) { m128 = other.m128; return *this; } + + __forceinline explicit Vec2fa( const float a ) : m128(_mm_set1_ps(a)) {} + __forceinline Vec2fa( const float x, const float y) : m128(_mm_set_ps(y, y, y, x)) {} + + __forceinline explicit Vec2fa( const __m128i a ) : m128(_mm_cvtepi32_ps(a)) {} + + __forceinline operator const __m128&() const { return m128; } + __forceinline operator __m128&() { return m128; } + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline Vec2fa load( const void* const a ) { + return Vec2fa(_mm_and_ps(_mm_load_ps((float*)a),_mm_castsi128_ps(_mm_set_epi32(0, 0, -1, -1)))); + } + + static __forceinline Vec2fa loadu( const void* const a ) { + return Vec2fa(_mm_and_ps(_mm_loadu_ps((float*)a),_mm_castsi128_ps(_mm_set_epi32(0, 0, -1, -1)))); + } + + static __forceinline void storeu ( void* ptr, const Vec2fa& v ) { + _mm_storeu_ps((float*)ptr,v); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec2fa( ZeroTy ) : m128(_mm_setzero_ps()) {} + __forceinline Vec2fa( OneTy ) : m128(_mm_set1_ps(1.0f)) {} + __forceinline Vec2fa( PosInfTy ) : m128(_mm_set1_ps(pos_inf)) {} + __forceinline Vec2fa( NegInfTy ) : m128(_mm_set1_ps(neg_inf)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const float& operator []( const size_t index ) const { assert(index < 2); return (&x)[index]; } + __forceinline float& operator []( const size_t index ) { assert(index < 2); return (&x)[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec2fa operator +( const Vec2fa& a ) { return a; } + __forceinline Vec2fa operator -( const Vec2fa& a ) { + const __m128 mask = _mm_castsi128_ps(_mm_set1_epi32(0x80000000)); + return _mm_xor_ps(a.m128, mask); + } + __forceinline Vec2fa abs ( const Vec2fa& a ) { + const __m128 mask = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffff)); + return _mm_and_ps(a.m128, mask); + } + __forceinline Vec2fa sign ( const Vec2fa& a ) { + return blendv_ps(Vec2fa(one), -Vec2fa(one), _mm_cmplt_ps (a,Vec2fa(zero))); + } + + __forceinline Vec2fa rcp ( const Vec2fa& a ) + { +#if defined(__AVX512VL__) + const Vec2fa r = _mm_rcp14_ps(a.m128); +#else + const Vec2fa r = _mm_rcp_ps(a.m128); +#endif + +#if defined(__AVX2__) + const Vec2fa res = _mm_mul_ps(r,_mm_fnmadd_ps(r, a, vfloat4(2.0f))); +#else + const Vec2fa res = _mm_mul_ps(r,_mm_sub_ps(vfloat4(2.0f), _mm_mul_ps(r, a))); + //return _mm_sub_ps(_mm_add_ps(r, r), _mm_mul_ps(_mm_mul_ps(r, r), a)); +#endif + + return res; + } + + __forceinline Vec2fa sqrt ( const Vec2fa& a ) { return _mm_sqrt_ps(a.m128); } + __forceinline Vec2fa sqr ( const Vec2fa& a ) { return _mm_mul_ps(a,a); } + + __forceinline Vec2fa rsqrt( const Vec2fa& a ) + { +#if defined(__AVX512VL__) + __m128 r = _mm_rsqrt14_ps(a.m128); +#else + __m128 r = _mm_rsqrt_ps(a.m128); +#endif + return _mm_add_ps(_mm_mul_ps(_mm_set1_ps(1.5f),r), _mm_mul_ps(_mm_mul_ps(_mm_mul_ps(a, _mm_set1_ps(-0.5f)), r), _mm_mul_ps(r, r))); + } + + __forceinline Vec2fa zero_fix(const Vec2fa& a) { + return blendv_ps(a, _mm_set1_ps(min_rcp_input), _mm_cmplt_ps (abs(a).m128, _mm_set1_ps(min_rcp_input))); + } + __forceinline Vec2fa rcp_safe(const Vec2fa& a) { + return rcp(zero_fix(a)); + } + __forceinline Vec2fa log ( const Vec2fa& a ) { + return Vec2fa(logf(a.x),logf(a.y)); + } + + __forceinline Vec2fa exp ( const Vec2fa& a ) { + return Vec2fa(expf(a.x),expf(a.y)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec2fa operator +( const Vec2fa& a, const Vec2fa& b ) { return _mm_add_ps(a.m128, b.m128); } + __forceinline Vec2fa operator -( const Vec2fa& a, const Vec2fa& b ) { return _mm_sub_ps(a.m128, b.m128); } + __forceinline Vec2fa operator *( const Vec2fa& a, const Vec2fa& b ) { return _mm_mul_ps(a.m128, b.m128); } + __forceinline Vec2fa operator *( const Vec2fa& a, const float b ) { return a * Vec2fa(b); } + __forceinline Vec2fa operator *( const float a, const Vec2fa& b ) { return Vec2fa(a) * b; } + __forceinline Vec2fa operator /( const Vec2fa& a, const Vec2fa& b ) { return _mm_div_ps(a.m128,b.m128); } + __forceinline Vec2fa operator /( const Vec2fa& a, const float b ) { return _mm_div_ps(a.m128,_mm_set1_ps(b)); } + __forceinline Vec2fa operator /( const float a, const Vec2fa& b ) { return _mm_div_ps(_mm_set1_ps(a),b.m128); } + + __forceinline Vec2fa min( const Vec2fa& a, const Vec2fa& b ) { return _mm_min_ps(a.m128,b.m128); } + __forceinline Vec2fa max( const Vec2fa& a, const Vec2fa& b ) { return _mm_max_ps(a.m128,b.m128); } + +#if defined(__SSE4_1__) + __forceinline Vec2fa mini(const Vec2fa& a, const Vec2fa& b) { + const vint4 ai = _mm_castps_si128(a); + const vint4 bi = _mm_castps_si128(b); + const vint4 ci = _mm_min_epi32(ai,bi); + return _mm_castsi128_ps(ci); + } +#endif + +#if defined(__SSE4_1__) + __forceinline Vec2fa maxi(const Vec2fa& a, const Vec2fa& b) { + const vint4 ai = _mm_castps_si128(a); + const vint4 bi = _mm_castps_si128(b); + const vint4 ci = _mm_max_epi32(ai,bi); + return _mm_castsi128_ps(ci); + } +#endif + + __forceinline Vec2fa pow ( const Vec2fa& a, const float& b ) { + return Vec2fa(powf(a.x,b),powf(a.y,b)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Ternary Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX2__) + __forceinline Vec2fa madd ( const Vec2fa& a, const Vec2fa& b, const Vec2fa& c) { return _mm_fmadd_ps(a,b,c); } + __forceinline Vec2fa msub ( const Vec2fa& a, const Vec2fa& b, const Vec2fa& c) { return _mm_fmsub_ps(a,b,c); } + __forceinline Vec2fa nmadd ( const Vec2fa& a, const Vec2fa& b, const Vec2fa& c) { return _mm_fnmadd_ps(a,b,c); } + __forceinline Vec2fa nmsub ( const Vec2fa& a, const Vec2fa& b, const Vec2fa& c) { return _mm_fnmsub_ps(a,b,c); } +#else + __forceinline Vec2fa madd ( const Vec2fa& a, const Vec2fa& b, const Vec2fa& c) { return a*b+c; } + __forceinline Vec2fa msub ( const Vec2fa& a, const Vec2fa& b, const Vec2fa& c) { return a*b-c; } + __forceinline Vec2fa nmadd ( const Vec2fa& a, const Vec2fa& b, const Vec2fa& c) { return -a*b+c;} + __forceinline Vec2fa nmsub ( const Vec2fa& a, const Vec2fa& b, const Vec2fa& c) { return -a*b-c; } +#endif + + __forceinline Vec2fa madd ( const float a, const Vec2fa& b, const Vec2fa& c) { return madd(Vec2fa(a),b,c); } + __forceinline Vec2fa msub ( const float a, const Vec2fa& b, const Vec2fa& c) { return msub(Vec2fa(a),b,c); } + __forceinline Vec2fa nmadd ( const float a, const Vec2fa& b, const Vec2fa& c) { return nmadd(Vec2fa(a),b,c); } + __forceinline Vec2fa nmsub ( const float a, const Vec2fa& b, const Vec2fa& c) { return nmsub(Vec2fa(a),b,c); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec2fa& operator +=( Vec2fa& a, const Vec2fa& b ) { return a = a + b; } + __forceinline Vec2fa& operator -=( Vec2fa& a, const Vec2fa& b ) { return a = a - b; } + __forceinline Vec2fa& operator *=( Vec2fa& a, const Vec2fa& b ) { return a = a * b; } + __forceinline Vec2fa& operator *=( Vec2fa& a, const float b ) { return a = a * b; } + __forceinline Vec2fa& operator /=( Vec2fa& a, const Vec2fa& b ) { return a = a / b; } + __forceinline Vec2fa& operator /=( Vec2fa& a, const float b ) { return a = a / b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline float reduce_add(const Vec2fa& v) { return v.x+v.y; } + __forceinline float reduce_mul(const Vec2fa& v) { return v.x*v.y; } + __forceinline float reduce_min(const Vec2fa& v) { return min(v.x,v.y); } + __forceinline float reduce_max(const Vec2fa& v) { return max(v.x,v.y); } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator ==( const Vec2fa& a, const Vec2fa& b ) { return (_mm_movemask_ps(_mm_cmpeq_ps (a.m128, b.m128)) & 3) == 3; } + __forceinline bool operator !=( const Vec2fa& a, const Vec2fa& b ) { return (_mm_movemask_ps(_mm_cmpneq_ps(a.m128, b.m128)) & 3) != 0; } + + //////////////////////////////////////////////////////////////////////////////// + /// Euclidian Space Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__SSE4_1__) + __forceinline float dot ( const Vec2fa& a, const Vec2fa& b ) { + return _mm_cvtss_f32(_mm_dp_ps(a,b,0x3F)); + } +#else + __forceinline float dot ( const Vec2fa& a, const Vec2fa& b ) { + return reduce_add(a*b); + } +#endif + + __forceinline Vec2fa cross ( const Vec2fa& a ) { + return Vec2fa(-a.y,a.x); + } + + __forceinline float sqr_length ( const Vec2fa& a ) { return dot(a,a); } + __forceinline float rcp_length ( const Vec2fa& a ) { return rsqrt(dot(a,a)); } + __forceinline float rcp_length2( const Vec2fa& a ) { return rcp(dot(a,a)); } + __forceinline float length ( const Vec2fa& a ) { return sqrt(dot(a,a)); } + __forceinline Vec2fa normalize( const Vec2fa& a ) { return a*rsqrt(dot(a,a)); } + __forceinline float distance ( const Vec2fa& a, const Vec2fa& b ) { return length(a-b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec2fa select( bool s, const Vec2fa& t, const Vec2fa& f ) { + __m128 mask = s ? _mm_castsi128_ps(_mm_cmpeq_epi32(_mm_setzero_si128(), _mm_setzero_si128())) : _mm_setzero_ps(); + return blendv_ps(f, t, mask); + } + + __forceinline Vec2fa lerp(const Vec2fa& v0, const Vec2fa& v1, const float t) { + return madd(1.0f-t,v0,t*v1); + } + + __forceinline int maxDim ( const Vec2fa& a ) + { + const Vec2fa b = abs(a); + if (b.x > b.y) return 0; + else return 1; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Rounding Functions + //////////////////////////////////////////////////////////////////////////////// + +#if defined (__SSE4_1__) + //__forceinline Vec2fa trunc( const Vec2fa& a ) { return _mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT); } + __forceinline Vec2fa floor( const Vec2fa& a ) { return _mm_round_ps(a, _MM_FROUND_TO_NEG_INF ); } + __forceinline Vec2fa ceil ( const Vec2fa& a ) { return _mm_round_ps(a, _MM_FROUND_TO_POS_INF ); } +#else + //__forceinline Vec2fa trunc( const Vec2fa& a ) { return Vec2fa(truncf(a.x),truncf(a.y),truncf(a.z)); } + __forceinline Vec2fa floor( const Vec2fa& a ) { return Vec2fa(floorf(a.x),floorf(a.y)); } + __forceinline Vec2fa ceil ( const Vec2fa& a ) { return Vec2fa(ceilf (a.x),ceilf (a.y)); } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator<<(embree_ostream cout, const Vec2fa& a) { + return cout << "(" << a.x << ", " << a.y << ")"; + } + + typedef Vec2fa Vec2fa_t; +} diff --git a/thirdparty/embree/common/math/vec3.h b/thirdparty/embree/common/math/vec3.h new file mode 100644 index 000000000000..ab4753545b80 --- /dev/null +++ b/thirdparty/embree/common/math/vec3.h @@ -0,0 +1,350 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "math.h" + +namespace embree +{ + struct Vec3fa; + + //////////////////////////////////////////////////////////////////////////////// + /// Generic 3D vector Class + //////////////////////////////////////////////////////////////////////////////// + + template struct Vec3 + { + enum { N = 3 }; + + union { + struct { + T x, y, z; + }; +#if !(defined(__WIN32__) && _MSC_VER == 1800) // workaround for older VS 2013 compiler + T components[N]; +#endif + }; + + typedef T Scalar; + + //////////////////////////////////////////////////////////////////////////////// + /// Construction + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3( ) {} + __forceinline explicit Vec3( const T& a ) : x(a), y(a), z(a) {} + __forceinline Vec3( const T& x, const T& y, const T& z ) : x(x), y(y), z(z) {} + + __forceinline Vec3( const Vec3& other ) { x = other.x; y = other.y; z = other.z; } + __forceinline Vec3( const Vec3fa& other ); + + template __forceinline Vec3( const Vec3& a ) : x(T(a.x)), y(T(a.y)), z(T(a.z)) {} + template __forceinline Vec3& operator =(const Vec3& other) { x = other.x; y = other.y; z = other.z; return *this; } + + __forceinline Vec3& operator =(const Vec3& other) { x = other.x; y = other.y; z = other.z; return *this; } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3( ZeroTy ) : x(zero), y(zero), z(zero) {} + __forceinline Vec3( OneTy ) : x(one), y(one), z(one) {} + __forceinline Vec3( PosInfTy ) : x(pos_inf), y(pos_inf), z(pos_inf) {} + __forceinline Vec3( NegInfTy ) : x(neg_inf), y(neg_inf), z(neg_inf) {} + +#if defined(__WIN32__) && (_MSC_VER == 1800) // workaround for older VS 2013 compiler + __forceinline const T& operator []( const size_t axis ) const { assert(axis < 3); return (&x)[axis]; } + __forceinline T& operator []( const size_t axis ) { assert(axis < 3); return (&x)[axis]; } +#else + __forceinline const T& operator [](const size_t axis) const { assert(axis < 3); return components[axis]; } + __forceinline T& operator [](const size_t axis) { assert(axis < 3); return components[axis]; } +#endif + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec3 operator +( const Vec3& a ) { return Vec3(+a.x, +a.y, +a.z); } + template __forceinline Vec3 operator -( const Vec3& a ) { return Vec3(-a.x, -a.y, -a.z); } + template __forceinline Vec3 abs ( const Vec3& a ) { return Vec3(abs (a.x), abs (a.y), abs (a.z)); } + template __forceinline Vec3 rcp ( const Vec3& a ) { return Vec3(rcp (a.x), rcp (a.y), rcp (a.z)); } + template __forceinline Vec3 rsqrt ( const Vec3& a ) { return Vec3(rsqrt(a.x), rsqrt(a.y), rsqrt(a.z)); } + template __forceinline Vec3 sqrt ( const Vec3& a ) { return Vec3(sqrt (a.x), sqrt (a.y), sqrt (a.z)); } + + template __forceinline Vec3 zero_fix( const Vec3& a ) + { + return Vec3(select(abs(a.x) __forceinline Vec3 rcp_safe(const Vec3& a) { return rcp(zero_fix(a)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec3 operator +( const Vec3& a, const Vec3& b ) { return Vec3(a.x + b.x, a.y + b.y, a.z + b.z); } + template __forceinline Vec3 operator -( const Vec3& a, const Vec3& b ) { return Vec3(a.x - b.x, a.y - b.y, a.z - b.z); } + template __forceinline Vec3 operator *( const Vec3& a, const Vec3& b ) { return Vec3(a.x * b.x, a.y * b.y, a.z * b.z); } + template __forceinline Vec3 operator *( const T& a, const Vec3& b ) { return Vec3(a * b.x, a * b.y, a * b.z); } + template __forceinline Vec3 operator *( const Vec3& a, const T& b ) { return Vec3(a.x * b , a.y * b , a.z * b ); } + template __forceinline Vec3 operator /( const Vec3& a, const T& b ) { return Vec3(a.x / b , a.y / b , a.z / b ); } + template __forceinline Vec3 operator /( const T& a, const Vec3& b ) { return Vec3(a / b.x, a / b.y, a / b.z); } + template __forceinline Vec3 operator /( const Vec3& a, const Vec3& b ) { return Vec3(a.x / b.x, a.y / b.y, a.z / b.z); } + + template __forceinline Vec3 min(const Vec3& a, const Vec3& b) { return Vec3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z)); } + template __forceinline Vec3 max(const Vec3& a, const Vec3& b) { return Vec3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z)); } + + template __forceinline Vec3 operator >>( const Vec3& a, const int b ) { return Vec3(a.x >> b, a.y >> b, a.z >> b); } + template __forceinline Vec3 operator <<( const Vec3& a, const int b ) { return Vec3(a.x << b, a.y << b, a.z << b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Ternary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec3 madd ( const Vec3& a, const Vec3& b, const Vec3& c) { return Vec3( madd(a.x,b.x,c.x), madd(a.y,b.y,c.y), madd(a.z,b.z,c.z)); } + template __forceinline Vec3 msub ( const Vec3& a, const Vec3& b, const Vec3& c) { return Vec3( msub(a.x,b.x,c.x), msub(a.y,b.y,c.y), msub(a.z,b.z,c.z)); } + template __forceinline Vec3 nmadd ( const Vec3& a, const Vec3& b, const Vec3& c) { return Vec3(nmadd(a.x,b.x,c.x),nmadd(a.y,b.y,c.y),nmadd(a.z,b.z,c.z));} + template __forceinline Vec3 nmsub ( const Vec3& a, const Vec3& b, const Vec3& c) { return Vec3(nmsub(a.x,b.x,c.x),nmsub(a.y,b.y,c.y),nmsub(a.z,b.z,c.z)); } + + template __forceinline Vec3 madd ( const T& a, const Vec3& b, const Vec3& c) { return Vec3( madd(a,b.x,c.x), madd(a,b.y,c.y), madd(a,b.z,c.z)); } + template __forceinline Vec3 msub ( const T& a, const Vec3& b, const Vec3& c) { return Vec3( msub(a,b.x,c.x), msub(a,b.y,c.y), msub(a,b.z,c.z)); } + template __forceinline Vec3 nmadd ( const T& a, const Vec3& b, const Vec3& c) { return Vec3(nmadd(a,b.x,c.x),nmadd(a,b.y,c.y),nmadd(a,b.z,c.z));} + template __forceinline Vec3 nmsub ( const T& a, const Vec3& b, const Vec3& c) { return Vec3(nmsub(a,b.x,c.x),nmsub(a,b.y,c.y),nmsub(a,b.z,c.z)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec3& operator +=( Vec3& a, const T b ) { a.x += b; a.y += b; a.z += b; return a; } + template __forceinline Vec3& operator +=( Vec3& a, const Vec3& b ) { a.x += b.x; a.y += b.y; a.z += b.z; return a; } + template __forceinline Vec3& operator -=( Vec3& a, const Vec3& b ) { a.x -= b.x; a.y -= b.y; a.z -= b.z; return a; } + template __forceinline Vec3& operator *=( Vec3& a, const T& b ) { a.x *= b ; a.y *= b ; a.z *= b ; return a; } + template __forceinline Vec3& operator /=( Vec3& a, const T& b ) { a.x /= b ; a.y /= b ; a.z /= b ; return a; } + + //////////////////////////////////////////////////////////////////////////////// + /// Reduction Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline T reduce_add( const Vec3& a ) { return a.x + a.y + a.z; } + template __forceinline T reduce_mul( const Vec3& a ) { return a.x * a.y * a.z; } + template __forceinline T reduce_min( const Vec3& a ) { return min(a.x, a.y, a.z); } + template __forceinline T reduce_max( const Vec3& a ) { return max(a.x, a.y, a.z); } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline bool operator ==( const Vec3& a, const Vec3& b ) { return a.x == b.x && a.y == b.y && a.z == b.z; } + template __forceinline bool operator !=( const Vec3& a, const Vec3& b ) { return a.x != b.x || a.y != b.y || a.z != b.z; } + template __forceinline bool operator < ( const Vec3& a, const Vec3& b ) { + if (a.x != b.x) return a.x < b.x; + if (a.y != b.y) return a.y < b.y; + if (a.z != b.z) return a.z < b.z; + return false; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Shift Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec3 shift_right_1( const Vec3& a ) { + return Vec3(shift_right_1(a.x),shift_right_1(a.y),shift_right_1(a.z)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Select + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec3 select ( bool s, const Vec3& t, const Vec3& f ) { + return Vec3(select(s,t.x,f.x),select(s,t.y,f.y),select(s,t.z,f.z)); + } + + template __forceinline Vec3 select ( const Vec3& s, const Vec3& t, const Vec3& f ) { + return Vec3(select(s.x,t.x,f.x),select(s.y,t.y,f.y),select(s.z,t.z,f.z)); + } + + template __forceinline Vec3 select ( const typename T::Bool& s, const Vec3& t, const Vec3& f ) { + return Vec3(select(s,t.x,f.x),select(s,t.y,f.y),select(s,t.z,f.z)); + } + + template + __forceinline Vec3 lerp(const Vec3& v0, const Vec3& v1, const T& t) { + return madd(Vec3(T(1.0f)-t),v0,t*v1); + } + + template __forceinline int maxDim ( const Vec3& a ) + { + const Vec3 b = abs(a); + if (b.x > b.y) { + if (b.x > b.z) return 0; else return 2; + } else { + if (b.y > b.z) return 1; else return 2; + } + } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec3 eq_mask( const Vec3& a, const Vec3& b ) { return Vec3(a.x==b.x,a.y==b.y,a.z==b.z); } + template __forceinline Vec3 neq_mask(const Vec3& a, const Vec3& b ) { return Vec3(a.x!=b.x,a.y!=b.y,a.z!=b.z); } + template __forceinline Vec3 lt_mask( const Vec3& a, const Vec3& b ) { return Vec3(a.x< b.x,a.y< b.y,a.z< b.z); } + template __forceinline Vec3 le_mask( const Vec3& a, const Vec3& b ) { return Vec3(a.x<=b.x,a.y<=b.y,a.z<=b.z); } + template __forceinline Vec3 gt_mask( const Vec3& a, const Vec3& b ) { return Vec3(a.x> b.x,a.y> b.y,a.z> b.z); } + template __forceinline Vec3 ge_mask( const Vec3& a, const Vec3& b ) { return Vec3(a.x>=b.x,a.y>=b.y,a.z>=b.z); } + + //////////////////////////////////////////////////////////////////////////////// + /// Euclidian Space Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline T sqr ( const Vec3& a ) { return dot(a,a); } + template __forceinline T dot ( const Vec3& a, const Vec3& b ) { return madd(a.x,b.x,madd(a.y,b.y,a.z*b.z)); } + template __forceinline T length ( const Vec3& a ) { return sqrt(sqr(a)); } + template __forceinline T rcp_length( const Vec3& a ) { return rsqrt(sqr(a)); } + template __forceinline Vec3 normalize( const Vec3& a ) { return a*rsqrt(sqr(a)); } + template __forceinline T distance ( const Vec3& a, const Vec3& b ) { return length(a-b); } + template __forceinline Vec3 cross ( const Vec3& a, const Vec3& b ) { return Vec3(msub(a.y,b.z,a.z*b.y), msub(a.z,b.x,a.x*b.z), msub(a.x,b.y,a.y*b.x)); } + + template __forceinline Vec3 stable_triangle_normal( const Vec3& a, const Vec3& b, const Vec3& c ) + { + const T ab_x = a.z*b.y, ab_y = a.x*b.z, ab_z = a.y*b.x; + const T bc_x = b.z*c.y, bc_y = b.x*c.z, bc_z = b.y*c.x; + const Vec3 cross_ab(msub(a.y,b.z,ab_x), msub(a.z,b.x,ab_y), msub(a.x,b.y,ab_z)); + const Vec3 cross_bc(msub(b.y,c.z,bc_x), msub(b.z,c.x,bc_y), msub(b.x,c.y,bc_z)); + const auto sx = abs(ab_x) < abs(bc_x); + const auto sy = abs(ab_y) < abs(bc_y); + const auto sz = abs(ab_z) < abs(bc_z); + return Vec3(select(sx,cross_ab.x,cross_bc.x), + select(sy,cross_ab.y,cross_bc.y), + select(sz,cross_ab.z,cross_bc.z)); + } + + template __forceinline T sum ( const Vec3& a ) { return a.x+a.y+a.z; } + + template __forceinline T halfArea ( const Vec3& d ) { return madd(d.x,(d.y+d.z),d.y*d.z); } + template __forceinline T area ( const Vec3& d ) { return 2.0f*halfArea(d); } + + template __forceinline Vec3 normalize_safe( const Vec3& a ) { + const T d = dot(a,a); return select(d == T( zero ), a , a*rsqrt(d) ); + } + + template __forceinline T sqr_point_to_line_distance(const Vec3& P, const Vec3& Q0, const Vec3& Q1) + { + const Vec3 N = cross(P-Q0,Q1-Q0); + const Vec3 D = Q1-Q0; + return dot(N,N)*rcp(dot(D,D)); + } + + template __forceinline T sqr_point_to_line_distance(const Vec3& PmQ0, const Vec3& Q1mQ0) + { + const Vec3 N = cross(PmQ0,Q1mQ0); + const Vec3 D = Q1mQ0; + return dot(N,N)*rcp(dot(D,D)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline embree_ostream operator<<(embree_ostream cout, const Vec3& a) { + return cout << "(" << a.x << ", " << a.y << ", " << a.z << ")"; + } + + typedef Vec3 Vec3b; + typedef Vec3 Vec3i; + typedef Vec3 Vec3f; +} + +#include "vec3ba.h" +#include "vec3ia.h" +#include "vec3fa.h" + +//////////////////////////////////////////////////////////////////////////////// +/// SSE / AVX / MIC specializations +//////////////////////////////////////////////////////////////////////////////// + +#if defined __SSE__ +#include "../simd/sse.h" +#endif + +#if defined __AVX__ +#include "../simd/avx.h" +#endif + +#if defined(__AVX512F__) +#include "../simd/avx512.h" +#endif + +namespace embree +{ + template + __forceinline Vec3 broadcast(const Vec3& a, const size_t k) { + return Vec3(Out(a.x[k]), Out(a.y[k]), Out(a.z[k])); + } + + template<> __forceinline Vec3::Vec3(const Vec3fa& a) { x = a.x; y = a.y; z = a.z; } + +#if defined(__AVX__) + template<> __forceinline Vec3::Vec3(const Vec3fa& a) { + x = a.x; y = a.y; z = a.z; + } +#elif defined(__SSE__) + template<> + __forceinline Vec3::Vec3(const Vec3fa& a) { + const vfloat4 v = vfloat4(a.m128); x = shuffle<0,0,0,0>(v); y = shuffle<1,1,1,1>(v); z = shuffle<2,2,2,2>(v); + } +#endif + +#if defined(__SSE__) + __forceinline Vec3 broadcast4f(const Vec3& a, const size_t k) { + return Vec3(vfloat4::broadcast(&a.x[k]), vfloat4::broadcast(&a.y[k]), vfloat4::broadcast(&a.z[k])); + } + + template<> + __forceinline Vec3 broadcast(const Vec3& a, const size_t k) { + return Vec3(vfloat4::broadcast(&a.x[k]), vfloat4::broadcast(&a.y[k]), vfloat4::broadcast(&a.z[k])); + } + + template + __forceinline Vec3 shuffle(const Vec3& b) { + return Vec3(shuffle(b.x), shuffle(b.y), shuffle(b.z)); + } +#endif + +#if defined(__AVX__) + template<> + __forceinline Vec3::Vec3(const Vec3fa& a) { + x = a.x; y = a.y; z = a.z; + } + __forceinline Vec3 broadcast4f(const Vec3& a, const size_t k) { + return Vec3(vfloat4::broadcast(&a.x[k]), vfloat4::broadcast(&a.y[k]), vfloat4::broadcast(&a.z[k])); + } + __forceinline Vec3 broadcast8f(const Vec3& a, const size_t k) { + return Vec3(vfloat8::broadcast(&a.x[k]), vfloat8::broadcast(&a.y[k]), vfloat8::broadcast(&a.z[k])); + } + __forceinline Vec3 broadcast8f(const Vec3& a, const size_t k) { + return Vec3(vfloat8::broadcast(&a.x[k]), vfloat8::broadcast(&a.y[k]), vfloat8::broadcast(&a.z[k])); + } + + template<> + __forceinline Vec3 broadcast(const Vec3& a, const size_t k) { + return Vec3(vfloat8::broadcast(&a.x[k]), vfloat8::broadcast(&a.y[k]), vfloat8::broadcast(&a.z[k])); + } + template<> + __forceinline Vec3 broadcast(const Vec3& a, const size_t k) { + return Vec3(vfloat8::broadcast(&a.x[k]), vfloat8::broadcast(&a.y[k]), vfloat8::broadcast(&a.z[k])); + } + + template + __forceinline Vec3 shuffle(const Vec3& b) { + return Vec3(shuffle(b.x), shuffle(b.y), shuffle(b.z)); + } +#endif + +#if defined(__AVX512F__) + template<> __forceinline Vec3::Vec3(const Vec3fa& a) : x(a.x), y(a.y), z(a.z) {} +#endif +} diff --git a/thirdparty/embree/common/math/vec3ba.h b/thirdparty/embree/common/math/vec3ba.h new file mode 100644 index 000000000000..90f31739c2d6 --- /dev/null +++ b/thirdparty/embree/common/math/vec3ba.h @@ -0,0 +1,120 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/alloc.h" +#include "math.h" +#include "../simd/sse.h" + +namespace embree +{ + //////////////////////////////////////////////////////////////////////////////// + /// SSE Vec3ba Type + //////////////////////////////////////////////////////////////////////////////// + + struct __aligned(16) Vec3ba + { + ALIGNED_STRUCT_(16); + + union { + __m128 m128; + struct { int x,y,z; }; + }; + + typedef int Scalar; + enum { N = 3 }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3ba( ) {} + __forceinline Vec3ba( const __m128 input ) : m128(input) {} + __forceinline Vec3ba( const Vec3ba& other ) : m128(other.m128) {} + __forceinline Vec3ba& operator =(const Vec3ba& other) { m128 = other.m128; return *this; } + + __forceinline explicit Vec3ba( bool a ) + : m128(mm_lookupmask_ps[(size_t(a) << 3) | (size_t(a) << 2) | (size_t(a) << 1) | size_t(a)]) {} + __forceinline Vec3ba( bool a, bool b, bool c) + : m128(mm_lookupmask_ps[(size_t(c) << 2) | (size_t(b) << 1) | size_t(a)]) {} + + __forceinline operator const __m128&() const { return m128; } + __forceinline operator __m128&() { return m128; } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3ba( FalseTy ) : m128(_mm_setzero_ps()) {} + __forceinline Vec3ba( TrueTy ) : m128(_mm_castsi128_ps(_mm_cmpeq_epi32(_mm_setzero_si128(), _mm_setzero_si128()))) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const int& operator []( const size_t index ) const { assert(index < 3); return (&x)[index]; } + __forceinline int& operator []( const size_t index ) { assert(index < 3); return (&x)[index]; } + }; + + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3ba operator !( const Vec3ba& a ) { return _mm_xor_ps(a.m128, Vec3ba(embree::True)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3ba operator &( const Vec3ba& a, const Vec3ba& b ) { return _mm_and_ps(a.m128, b.m128); } + __forceinline Vec3ba operator |( const Vec3ba& a, const Vec3ba& b ) { return _mm_or_ps (a.m128, b.m128); } + __forceinline Vec3ba operator ^( const Vec3ba& a, const Vec3ba& b ) { return _mm_xor_ps(a.m128, b.m128); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3ba& operator &=( Vec3ba& a, const Vec3ba& b ) { return a = a & b; } + __forceinline Vec3ba& operator |=( Vec3ba& a, const Vec3ba& b ) { return a = a | b; } + __forceinline Vec3ba& operator ^=( Vec3ba& a, const Vec3ba& b ) { return a = a ^ b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator ==( const Vec3ba& a, const Vec3ba& b ) { + return (_mm_movemask_ps(_mm_castsi128_ps(_mm_cmpeq_epi32(_mm_castps_si128(a.m128), _mm_castps_si128(b.m128)))) & 7) == 7; + } + __forceinline bool operator !=( const Vec3ba& a, const Vec3ba& b ) { + return (_mm_movemask_ps(_mm_castsi128_ps(_mm_cmpeq_epi32(_mm_castps_si128(a.m128), _mm_castps_si128(b.m128)))) & 7) != 7; + } + __forceinline bool operator < ( const Vec3ba& a, const Vec3ba& b ) { + if (a.x != b.x) return a.x < b.x; + if (a.y != b.y) return a.y < b.y; + if (a.z != b.z) return a.z < b.z; + return false; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reduction Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool reduce_and( const Vec3ba& a ) { return (_mm_movemask_ps(a) & 0x7) == 0x7; } + __forceinline bool reduce_or ( const Vec3ba& a ) { return (_mm_movemask_ps(a) & 0x7) != 0x0; } + + __forceinline bool all ( const Vec3ba& b ) { return (_mm_movemask_ps(b) & 0x7) == 0x7; } + __forceinline bool any ( const Vec3ba& b ) { return (_mm_movemask_ps(b) & 0x7) != 0x0; } + __forceinline bool none ( const Vec3ba& b ) { return (_mm_movemask_ps(b) & 0x7) == 0x0; } + + __forceinline size_t movemask(const Vec3ba& a) { return _mm_movemask_ps(a) & 0x7; } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator<<(embree_ostream cout, const Vec3ba& a) { + return cout << "(" << (a.x ? "1" : "0") << ", " << (a.y ? "1" : "0") << ", " << (a.z ? "1" : "0") << ")"; + } +} diff --git a/thirdparty/embree/common/math/vec3fa.h b/thirdparty/embree/common/math/vec3fa.h new file mode 100644 index 000000000000..6576a15b4f24 --- /dev/null +++ b/thirdparty/embree/common/math/vec3fa.h @@ -0,0 +1,723 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/alloc.h" +#include "math.h" +#include "../simd/sse.h" + +namespace embree +{ + //////////////////////////////////////////////////////////////////////////////// + /// SSE Vec3fa Type + //////////////////////////////////////////////////////////////////////////////// + + struct __aligned(16) Vec3fa + { + ALIGNED_STRUCT_(16); + + typedef float Scalar; + enum { N = 3 }; + union { + __m128 m128; + struct { float x,y,z; }; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3fa( ) {} + __forceinline Vec3fa( const __m128 a ) : m128(a) {} + + __forceinline Vec3fa ( const Vec3& other ) { m128 = _mm_set_ps(0, other.z, other.y, other.x); } + //__forceinline Vec3fa& operator =( const Vec3& other ) { m128 = _mm_set_ps(0, other.z, other.y, other.x); return *this; } + + __forceinline Vec3fa ( const Vec3fa& other ) { m128 = other.m128; } + __forceinline Vec3fa& operator =( const Vec3fa& other ) { m128 = other.m128; return *this; } + + __forceinline explicit Vec3fa( const float a ) : m128(_mm_set1_ps(a)) {} + __forceinline Vec3fa( const float x, const float y, const float z) : m128(_mm_set_ps(0, z, y, x)) {} + + __forceinline explicit Vec3fa( const __m128i a ) : m128(_mm_cvtepi32_ps(a)) {} + + __forceinline explicit operator const vfloat4() const { return vfloat4(m128); } + __forceinline explicit operator const vint4() const { return vint4(_mm_cvtps_epi32(m128)); } + __forceinline explicit operator const Vec2fa() const { return Vec2fa(m128); } + __forceinline explicit operator const Vec3ia() const { return Vec3ia(_mm_cvtps_epi32(m128)); } + + //__forceinline operator const __m128&() const { return m128; } + //__forceinline operator __m128&() { return m128; } + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline Vec3fa load( const void* const a ) { + return Vec3fa(_mm_and_ps(_mm_load_ps((float*)a),_mm_castsi128_ps(_mm_set_epi32(0, -1, -1, -1)))); + } + + static __forceinline Vec3fa loadu( const void* const a ) { + return Vec3fa(_mm_loadu_ps((float*)a)); + } + + static __forceinline void storeu ( void* ptr, const Vec3fa& v ) { + _mm_storeu_ps((float*)ptr,v.m128); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3fa( ZeroTy ) : m128(_mm_setzero_ps()) {} + __forceinline Vec3fa( OneTy ) : m128(_mm_set1_ps(1.0f)) {} + __forceinline Vec3fa( PosInfTy ) : m128(_mm_set1_ps(pos_inf)) {} + __forceinline Vec3fa( NegInfTy ) : m128(_mm_set1_ps(neg_inf)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const float& operator []( const size_t index ) const { assert(index < 3); return (&x)[index]; } + __forceinline float& operator []( const size_t index ) { assert(index < 3); return (&x)[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3fa operator +( const Vec3fa& a ) { return a; } + __forceinline Vec3fa operator -( const Vec3fa& a ) { + const __m128 mask = _mm_castsi128_ps(_mm_set1_epi32(0x80000000)); + return _mm_xor_ps(a.m128, mask); + } + __forceinline Vec3fa abs ( const Vec3fa& a ) { + const __m128 mask = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffff)); + return _mm_and_ps(a.m128, mask); + } + __forceinline Vec3fa sign ( const Vec3fa& a ) { + return blendv_ps(Vec3fa(one).m128, (-Vec3fa(one)).m128, _mm_cmplt_ps (a.m128,Vec3fa(zero).m128)); + } + + __forceinline Vec3fa rcp ( const Vec3fa& a ) + { +#if defined(__AVX512VL__) + const Vec3fa r = _mm_rcp14_ps(a.m128); +#else + const Vec3fa r = _mm_rcp_ps(a.m128); +#endif + +#if defined(__AVX2__) + const Vec3fa res = _mm_mul_ps(r.m128,_mm_fnmadd_ps(r.m128, a.m128, vfloat4(2.0f))); +#else + const Vec3fa res = _mm_mul_ps(r.m128,_mm_sub_ps(vfloat4(2.0f), _mm_mul_ps(r.m128, a.m128))); + //return _mm_sub_ps(_mm_add_ps(r, r), _mm_mul_ps(_mm_mul_ps(r, r), a)); +#endif + + return res; + } + + __forceinline Vec3fa sqrt ( const Vec3fa& a ) { return _mm_sqrt_ps(a.m128); } + __forceinline Vec3fa sqr ( const Vec3fa& a ) { return _mm_mul_ps(a.m128,a.m128); } + + __forceinline Vec3fa rsqrt( const Vec3fa& a ) + { +#if defined(__AVX512VL__) + __m128 r = _mm_rsqrt14_ps(a.m128); +#else + __m128 r = _mm_rsqrt_ps(a.m128); +#endif + return _mm_add_ps(_mm_mul_ps(_mm_set1_ps(1.5f),r), _mm_mul_ps(_mm_mul_ps(_mm_mul_ps(a.m128, _mm_set1_ps(-0.5f)), r), _mm_mul_ps(r, r))); + } + + __forceinline Vec3fa zero_fix(const Vec3fa& a) { + return blendv_ps(a.m128, _mm_set1_ps(min_rcp_input), _mm_cmplt_ps (abs(a).m128, _mm_set1_ps(min_rcp_input))); + } + __forceinline Vec3fa rcp_safe(const Vec3fa& a) { + return rcp(zero_fix(a)); + } + __forceinline Vec3fa log ( const Vec3fa& a ) { + return Vec3fa(logf(a.x),logf(a.y),logf(a.z)); + } + + __forceinline Vec3fa exp ( const Vec3fa& a ) { + return Vec3fa(expf(a.x),expf(a.y),expf(a.z)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3fa operator +( const Vec3fa& a, const Vec3fa& b ) { return _mm_add_ps(a.m128, b.m128); } + __forceinline Vec3fa operator -( const Vec3fa& a, const Vec3fa& b ) { return _mm_sub_ps(a.m128, b.m128); } + __forceinline Vec3fa operator *( const Vec3fa& a, const Vec3fa& b ) { return _mm_mul_ps(a.m128, b.m128); } + __forceinline Vec3fa operator *( const Vec3fa& a, const float b ) { return a * Vec3fa(b); } + __forceinline Vec3fa operator *( const float a, const Vec3fa& b ) { return Vec3fa(a) * b; } + __forceinline Vec3fa operator /( const Vec3fa& a, const Vec3fa& b ) { return _mm_div_ps(a.m128,b.m128); } + __forceinline Vec3fa operator /( const Vec3fa& a, const float b ) { return _mm_div_ps(a.m128,_mm_set1_ps(b)); } + __forceinline Vec3fa operator /( const float a, const Vec3fa& b ) { return _mm_div_ps(_mm_set1_ps(a),b.m128); } + + __forceinline Vec3fa min( const Vec3fa& a, const Vec3fa& b ) { return _mm_min_ps(a.m128,b.m128); } + __forceinline Vec3fa max( const Vec3fa& a, const Vec3fa& b ) { return _mm_max_ps(a.m128,b.m128); } + +#if defined(__SSE4_1__) + __forceinline Vec3fa mini(const Vec3fa& a, const Vec3fa& b) { + const vint4 ai = _mm_castps_si128(a.m128); + const vint4 bi = _mm_castps_si128(b.m128); + const vint4 ci = _mm_min_epi32(ai,bi); + return _mm_castsi128_ps(ci); + } +#endif + +#if defined(__SSE4_1__) + __forceinline Vec3fa maxi(const Vec3fa& a, const Vec3fa& b) { + const vint4 ai = _mm_castps_si128(a.m128); + const vint4 bi = _mm_castps_si128(b.m128); + const vint4 ci = _mm_max_epi32(ai,bi); + return _mm_castsi128_ps(ci); + } +#endif + + __forceinline Vec3fa pow ( const Vec3fa& a, const float& b ) { + return Vec3fa(powf(a.x,b),powf(a.y,b),powf(a.z,b)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Ternary Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX2__) + __forceinline Vec3fa madd ( const Vec3fa& a, const Vec3fa& b, const Vec3fa& c) { return _mm_fmadd_ps(a.m128,b.m128,c.m128); } + __forceinline Vec3fa msub ( const Vec3fa& a, const Vec3fa& b, const Vec3fa& c) { return _mm_fmsub_ps(a.m128,b.m128,c.m128); } + __forceinline Vec3fa nmadd ( const Vec3fa& a, const Vec3fa& b, const Vec3fa& c) { return _mm_fnmadd_ps(a.m128,b.m128,c.m128); } + __forceinline Vec3fa nmsub ( const Vec3fa& a, const Vec3fa& b, const Vec3fa& c) { return _mm_fnmsub_ps(a.m128,b.m128,c.m128); } +#else + __forceinline Vec3fa madd ( const Vec3fa& a, const Vec3fa& b, const Vec3fa& c) { return a*b+c; } + __forceinline Vec3fa msub ( const Vec3fa& a, const Vec3fa& b, const Vec3fa& c) { return a*b-c; } + __forceinline Vec3fa nmadd ( const Vec3fa& a, const Vec3fa& b, const Vec3fa& c) { return -a*b+c;} + __forceinline Vec3fa nmsub ( const Vec3fa& a, const Vec3fa& b, const Vec3fa& c) { return -a*b-c; } +#endif + + __forceinline Vec3fa madd ( const float a, const Vec3fa& b, const Vec3fa& c) { return madd(Vec3fa(a),b,c); } + __forceinline Vec3fa msub ( const float a, const Vec3fa& b, const Vec3fa& c) { return msub(Vec3fa(a),b,c); } + __forceinline Vec3fa nmadd ( const float a, const Vec3fa& b, const Vec3fa& c) { return nmadd(Vec3fa(a),b,c); } + __forceinline Vec3fa nmsub ( const float a, const Vec3fa& b, const Vec3fa& c) { return nmsub(Vec3fa(a),b,c); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3fa& operator +=( Vec3fa& a, const Vec3fa& b ) { return a = a + b; } + __forceinline Vec3fa& operator -=( Vec3fa& a, const Vec3fa& b ) { return a = a - b; } + __forceinline Vec3fa& operator *=( Vec3fa& a, const Vec3fa& b ) { return a = a * b; } + __forceinline Vec3fa& operator *=( Vec3fa& a, const float b ) { return a = a * b; } + __forceinline Vec3fa& operator /=( Vec3fa& a, const Vec3fa& b ) { return a = a / b; } + __forceinline Vec3fa& operator /=( Vec3fa& a, const float b ) { return a = a / b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline float reduce_add(const Vec3fa& v) { + const vfloat4 a(v.m128); + const vfloat4 b = shuffle<1>(a); + const vfloat4 c = shuffle<2>(a); + return _mm_cvtss_f32(a+b+c); + } + + __forceinline float reduce_mul(const Vec3fa& v) { return v.x*v.y*v.z; } + __forceinline float reduce_min(const Vec3fa& v) { return min(v.x,v.y,v.z); } + __forceinline float reduce_max(const Vec3fa& v) { return max(v.x,v.y,v.z); } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator ==( const Vec3fa& a, const Vec3fa& b ) { return (_mm_movemask_ps(_mm_cmpeq_ps (a.m128, b.m128)) & 7) == 7; } + __forceinline bool operator !=( const Vec3fa& a, const Vec3fa& b ) { return (_mm_movemask_ps(_mm_cmpneq_ps(a.m128, b.m128)) & 7) != 0; } + + __forceinline Vec3ba eq_mask( const Vec3fa& a, const Vec3fa& b ) { return _mm_cmpeq_ps (a.m128, b.m128); } + __forceinline Vec3ba neq_mask(const Vec3fa& a, const Vec3fa& b ) { return _mm_cmpneq_ps(a.m128, b.m128); } + __forceinline Vec3ba lt_mask( const Vec3fa& a, const Vec3fa& b ) { return _mm_cmplt_ps (a.m128, b.m128); } + __forceinline Vec3ba le_mask( const Vec3fa& a, const Vec3fa& b ) { return _mm_cmple_ps (a.m128, b.m128); } + __forceinline Vec3ba gt_mask( const Vec3fa& a, const Vec3fa& b ) { return _mm_cmpnle_ps(a.m128, b.m128); } + __forceinline Vec3ba ge_mask( const Vec3fa& a, const Vec3fa& b ) { return _mm_cmpnlt_ps(a.m128, b.m128); } + + __forceinline bool isvalid ( const Vec3fa& v ) { + return all(gt_mask(v,Vec3fa(-FLT_LARGE)) & lt_mask(v,Vec3fa(+FLT_LARGE))); + } + + __forceinline bool is_finite ( const Vec3fa& a ) { + return all(ge_mask(a,Vec3fa(-FLT_MAX)) & le_mask(a,Vec3fa(+FLT_MAX))); + } + + __forceinline bool isvalid4 ( const Vec3fa& v ) { + return all((vfloat4(v.m128) > vfloat4(-FLT_LARGE)) & (vfloat4(v.m128) < vfloat4(+FLT_LARGE))); + } + + __forceinline bool is_finite4 ( const Vec3fa& a ) { + return all((vfloat4(a.m128) >= vfloat4(-FLT_MAX)) & (vfloat4(a.m128) <= vfloat4(+FLT_MAX))); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Euclidian Space Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__SSE4_1__) + __forceinline float dot ( const Vec3fa& a, const Vec3fa& b ) { + return _mm_cvtss_f32(_mm_dp_ps(a.m128,b.m128,0x7F)); + } +#else + __forceinline float dot ( const Vec3fa& a, const Vec3fa& b ) { + return reduce_add(a*b); + } +#endif + + __forceinline Vec3fa cross ( const Vec3fa& a, const Vec3fa& b ) + { + vfloat4 a0 = vfloat4(a.m128); + vfloat4 b0 = shuffle<1,2,0,3>(vfloat4(b.m128)); + vfloat4 a1 = shuffle<1,2,0,3>(vfloat4(a.m128)); + vfloat4 b1 = vfloat4(b.m128); + return Vec3fa(shuffle<1,2,0,3>(msub(a0,b0,a1*b1))); + } + + __forceinline float sqr_length ( const Vec3fa& a ) { return dot(a,a); } + __forceinline float rcp_length ( const Vec3fa& a ) { return rsqrt(dot(a,a)); } + __forceinline float rcp_length2( const Vec3fa& a ) { return rcp(dot(a,a)); } + __forceinline float length ( const Vec3fa& a ) { return sqrt(dot(a,a)); } + __forceinline Vec3fa normalize( const Vec3fa& a ) { return a*rsqrt(dot(a,a)); } + __forceinline float distance ( const Vec3fa& a, const Vec3fa& b ) { return length(a-b); } + __forceinline float halfArea ( const Vec3fa& d ) { return madd(d.x,(d.y+d.z),d.y*d.z); } + __forceinline float area ( const Vec3fa& d ) { return 2.0f*halfArea(d); } + + __forceinline Vec3fa normalize_safe( const Vec3fa& a ) { + const float d = dot(a,a); if (unlikely(d == 0.0f)) return a; else return a*rsqrt(d); + } + + /*! differentiated normalization */ + __forceinline Vec3fa dnormalize(const Vec3fa& p, const Vec3fa& dp) + { + const float pp = dot(p,p); + const float pdp = dot(p,dp); + return (pp*dp-pdp*p)*rcp(pp)*rsqrt(pp); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3fa select( bool s, const Vec3fa& t, const Vec3fa& f ) { + __m128 mask = s ? _mm_castsi128_ps(_mm_cmpeq_epi32(_mm_setzero_si128(), _mm_setzero_si128())) : _mm_setzero_ps(); + return blendv_ps(f.m128, t.m128, mask); + } + + __forceinline Vec3fa select( const Vec3ba& s, const Vec3fa& t, const Vec3fa& f ) { + return blendv_ps(f.m128, t.m128, s); + } + + __forceinline Vec3fa lerp(const Vec3fa& v0, const Vec3fa& v1, const float t) { + return madd(1.0f-t,v0,t*v1); + } + + __forceinline int maxDim ( const Vec3fa& a ) + { + const Vec3fa b = abs(a); + if (b.x > b.y) { + if (b.x > b.z) return 0; else return 2; + } else { + if (b.y > b.z) return 1; else return 2; + } + } + + //////////////////////////////////////////////////////////////////////////////// + /// Rounding Functions + //////////////////////////////////////////////////////////////////////////////// + +#if defined (__SSE4_1__) + __forceinline Vec3fa trunc( const Vec3fa& a ) { return _mm_round_ps(a.m128, _MM_FROUND_TO_NEAREST_INT); } + __forceinline Vec3fa floor( const Vec3fa& a ) { return _mm_round_ps(a.m128, _MM_FROUND_TO_NEG_INF ); } + __forceinline Vec3fa ceil ( const Vec3fa& a ) { return _mm_round_ps(a.m128, _MM_FROUND_TO_POS_INF ); } +#else + __forceinline Vec3fa trunc( const Vec3fa& a ) { return Vec3fa(truncf(a.x),truncf(a.y),truncf(a.z)); } + __forceinline Vec3fa floor( const Vec3fa& a ) { return Vec3fa(floorf(a.x),floorf(a.y),floorf(a.z)); } + __forceinline Vec3fa ceil ( const Vec3fa& a ) { return Vec3fa(ceilf (a.x),ceilf (a.y),ceilf (a.z)); } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator<<(embree_ostream cout, const Vec3fa& a) { + return cout << "(" << a.x << ", " << a.y << ", " << a.z << ")"; + } + + typedef Vec3fa Vec3fa_t; + + + //////////////////////////////////////////////////////////////////////////////// + /// SSE Vec3fx Type + //////////////////////////////////////////////////////////////////////////////// + + struct __aligned(16) Vec3fx + { + ALIGNED_STRUCT_(16); + + typedef float Scalar; + enum { N = 3 }; + union { + __m128 m128; + struct { float x,y,z; union { int a; unsigned u; float w; }; }; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3fx( ) {} + __forceinline Vec3fx( const __m128 a ) : m128(a) {} + + __forceinline explicit Vec3fx(const Vec3fa& v) : m128(v.m128) {} + __forceinline operator Vec3fa () const { return Vec3fa(m128); } + + __forceinline explicit Vec3fx ( const Vec3& other ) { m128 = _mm_set_ps(0, other.z, other.y, other.x); } + //__forceinline Vec3fx& operator =( const Vec3& other ) { m128 = _mm_set_ps(0, other.z, other.y, other.x); return *this; } + + __forceinline Vec3fx ( const Vec3fx& other ) { m128 = other.m128; } + + __forceinline Vec3fx& operator =( const Vec3fx& other ) { m128 = other.m128; return *this; } + + __forceinline explicit Vec3fx( const float a ) : m128(_mm_set1_ps(a)) {} + __forceinline Vec3fx( const float x, const float y, const float z) : m128(_mm_set_ps(0, z, y, x)) {} + + __forceinline Vec3fx( const Vec3fa& other, const int a1) { m128 = other.m128; a = a1; } + __forceinline Vec3fx( const Vec3fa& other, const unsigned a1) { m128 = other.m128; u = a1; } + __forceinline Vec3fx( const Vec3fa& other, const float w1) { +#if defined (__SSE4_1__) + m128 = _mm_insert_ps(other.m128, _mm_set_ss(w1),3 << 4); +#else + const vint4 mask(-1,-1,-1,0); + m128 = select(vboolf4(_mm_castsi128_ps(mask)),vfloat4(other.m128),vfloat4(w1)); +#endif + } + //__forceinline Vec3fx( const float x, const float y, const float z, const int a) : x(x), y(y), z(z), a(a) {} // not working properly! + //__forceinline Vec3fx( const float x, const float y, const float z, const unsigned a) : x(x), y(y), z(z), u(a) {} // not working properly! + __forceinline Vec3fx( const float x, const float y, const float z, const float w) : m128(_mm_set_ps(w, z, y, x)) {} + + //__forceinline explicit Vec3fx( const __m128i a ) : m128(_mm_cvtepi32_ps(a)) {} + + __forceinline explicit operator const vfloat4() const { return vfloat4(m128); } + __forceinline explicit operator const vint4() const { return vint4(_mm_cvtps_epi32(m128)); } + __forceinline explicit operator const Vec2fa() const { return Vec2fa(m128); } + __forceinline explicit operator const Vec3ia() const { return Vec3ia(_mm_cvtps_epi32(m128)); } + + //__forceinline operator const __m128&() const { return m128; } + //__forceinline operator __m128&() { return m128; } + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline Vec3fx load( const void* const a ) { + return Vec3fx(_mm_and_ps(_mm_load_ps((float*)a),_mm_castsi128_ps(_mm_set_epi32(0, -1, -1, -1)))); + } + + static __forceinline Vec3fx loadu( const void* const a ) { + return Vec3fx(_mm_loadu_ps((float*)a)); + } + + static __forceinline void storeu ( void* ptr, const Vec3fx& v ) { + _mm_storeu_ps((float*)ptr,v.m128); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3fx( ZeroTy ) : m128(_mm_setzero_ps()) {} + __forceinline Vec3fx( OneTy ) : m128(_mm_set1_ps(1.0f)) {} + __forceinline Vec3fx( PosInfTy ) : m128(_mm_set1_ps(pos_inf)) {} + __forceinline Vec3fx( NegInfTy ) : m128(_mm_set1_ps(neg_inf)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const float& operator []( const size_t index ) const { assert(index < 3); return (&x)[index]; } + __forceinline float& operator []( const size_t index ) { assert(index < 3); return (&x)[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3fx operator +( const Vec3fx& a ) { return a; } + __forceinline Vec3fx operator -( const Vec3fx& a ) { + const __m128 mask = _mm_castsi128_ps(_mm_set1_epi32(0x80000000)); + return _mm_xor_ps(a.m128, mask); + } + __forceinline Vec3fx abs ( const Vec3fx& a ) { + const __m128 mask = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffff)); + return _mm_and_ps(a.m128, mask); + } + __forceinline Vec3fx sign ( const Vec3fx& a ) { + return blendv_ps(Vec3fx(one).m128, (-Vec3fx(one)).m128, _mm_cmplt_ps (a.m128,Vec3fx(zero).m128)); + } + + __forceinline Vec3fx rcp ( const Vec3fx& a ) + { +#if defined(__AVX512VL__) + const Vec3fx r = _mm_rcp14_ps(a.m128); +#else + const Vec3fx r = _mm_rcp_ps(a.m128); +#endif + +#if defined(__AVX2__) + const Vec3fx res = _mm_mul_ps(r.m128,_mm_fnmadd_ps(r.m128, a.m128, vfloat4(2.0f))); +#else + const Vec3fx res = _mm_mul_ps(r.m128,_mm_sub_ps(vfloat4(2.0f), _mm_mul_ps(r.m128, a.m128))); + //return _mm_sub_ps(_mm_add_ps(r, r), _mm_mul_ps(_mm_mul_ps(r, r), a)); +#endif + + return res; + } + + __forceinline Vec3fx sqrt ( const Vec3fx& a ) { return _mm_sqrt_ps(a.m128); } + __forceinline Vec3fx sqr ( const Vec3fx& a ) { return _mm_mul_ps(a.m128,a.m128); } + + __forceinline Vec3fx rsqrt( const Vec3fx& a ) + { +#if defined(__AVX512VL__) + __m128 r = _mm_rsqrt14_ps(a.m128); +#else + __m128 r = _mm_rsqrt_ps(a.m128); +#endif + return _mm_add_ps(_mm_mul_ps(_mm_set1_ps(1.5f),r), _mm_mul_ps(_mm_mul_ps(_mm_mul_ps(a.m128, _mm_set1_ps(-0.5f)), r), _mm_mul_ps(r, r))); + } + + __forceinline Vec3fx zero_fix(const Vec3fx& a) { + return blendv_ps(a.m128, _mm_set1_ps(min_rcp_input), _mm_cmplt_ps (abs(a).m128, _mm_set1_ps(min_rcp_input))); + } + __forceinline Vec3fx rcp_safe(const Vec3fx& a) { + return rcp(zero_fix(a)); + } + __forceinline Vec3fx log ( const Vec3fx& a ) { + return Vec3fx(logf(a.x),logf(a.y),logf(a.z)); + } + + __forceinline Vec3fx exp ( const Vec3fx& a ) { + return Vec3fx(expf(a.x),expf(a.y),expf(a.z)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3fx operator +( const Vec3fx& a, const Vec3fx& b ) { return _mm_add_ps(a.m128, b.m128); } + __forceinline Vec3fx operator -( const Vec3fx& a, const Vec3fx& b ) { return _mm_sub_ps(a.m128, b.m128); } + __forceinline Vec3fx operator *( const Vec3fx& a, const Vec3fx& b ) { return _mm_mul_ps(a.m128, b.m128); } + __forceinline Vec3fx operator *( const Vec3fx& a, const float b ) { return a * Vec3fx(b); } + __forceinline Vec3fx operator *( const float a, const Vec3fx& b ) { return Vec3fx(a) * b; } + __forceinline Vec3fx operator /( const Vec3fx& a, const Vec3fx& b ) { return _mm_div_ps(a.m128,b.m128); } + __forceinline Vec3fx operator /( const Vec3fx& a, const float b ) { return _mm_div_ps(a.m128,_mm_set1_ps(b)); } + __forceinline Vec3fx operator /( const float a, const Vec3fx& b ) { return _mm_div_ps(_mm_set1_ps(a),b.m128); } + + __forceinline Vec3fx min( const Vec3fx& a, const Vec3fx& b ) { return _mm_min_ps(a.m128,b.m128); } + __forceinline Vec3fx max( const Vec3fx& a, const Vec3fx& b ) { return _mm_max_ps(a.m128,b.m128); } + +#if defined(__SSE4_1__) + __forceinline Vec3fx mini(const Vec3fx& a, const Vec3fx& b) { + const vint4 ai = _mm_castps_si128(a.m128); + const vint4 bi = _mm_castps_si128(b.m128); + const vint4 ci = _mm_min_epi32(ai,bi); + return _mm_castsi128_ps(ci); + } +#endif + +#if defined(__SSE4_1__) + __forceinline Vec3fx maxi(const Vec3fx& a, const Vec3fx& b) { + const vint4 ai = _mm_castps_si128(a.m128); + const vint4 bi = _mm_castps_si128(b.m128); + const vint4 ci = _mm_max_epi32(ai,bi); + return _mm_castsi128_ps(ci); + } +#endif + + __forceinline Vec3fx pow ( const Vec3fx& a, const float& b ) { + return Vec3fx(powf(a.x,b),powf(a.y,b),powf(a.z,b)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Ternary Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX2__) + __forceinline Vec3fx madd ( const Vec3fx& a, const Vec3fx& b, const Vec3fx& c) { return _mm_fmadd_ps(a.m128,b.m128,c.m128); } + __forceinline Vec3fx msub ( const Vec3fx& a, const Vec3fx& b, const Vec3fx& c) { return _mm_fmsub_ps(a.m128,b.m128,c.m128); } + __forceinline Vec3fx nmadd ( const Vec3fx& a, const Vec3fx& b, const Vec3fx& c) { return _mm_fnmadd_ps(a.m128,b.m128,c.m128); } + __forceinline Vec3fx nmsub ( const Vec3fx& a, const Vec3fx& b, const Vec3fx& c) { return _mm_fnmsub_ps(a.m128,b.m128,c.m128); } +#else + __forceinline Vec3fx madd ( const Vec3fx& a, const Vec3fx& b, const Vec3fx& c) { return a*b+c; } + __forceinline Vec3fx msub ( const Vec3fx& a, const Vec3fx& b, const Vec3fx& c) { return a*b-c; } + __forceinline Vec3fx nmadd ( const Vec3fx& a, const Vec3fx& b, const Vec3fx& c) { return -a*b+c;} + __forceinline Vec3fx nmsub ( const Vec3fx& a, const Vec3fx& b, const Vec3fx& c) { return -a*b-c; } +#endif + + __forceinline Vec3fx madd ( const float a, const Vec3fx& b, const Vec3fx& c) { return madd(Vec3fx(a),b,c); } + __forceinline Vec3fx msub ( const float a, const Vec3fx& b, const Vec3fx& c) { return msub(Vec3fx(a),b,c); } + __forceinline Vec3fx nmadd ( const float a, const Vec3fx& b, const Vec3fx& c) { return nmadd(Vec3fx(a),b,c); } + __forceinline Vec3fx nmsub ( const float a, const Vec3fx& b, const Vec3fx& c) { return nmsub(Vec3fx(a),b,c); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3fx& operator +=( Vec3fx& a, const Vec3fx& b ) { return a = a + b; } + __forceinline Vec3fx& operator -=( Vec3fx& a, const Vec3fx& b ) { return a = a - b; } + __forceinline Vec3fx& operator *=( Vec3fx& a, const Vec3fx& b ) { return a = a * b; } + __forceinline Vec3fx& operator *=( Vec3fx& a, const float b ) { return a = a * b; } + __forceinline Vec3fx& operator /=( Vec3fx& a, const Vec3fx& b ) { return a = a / b; } + __forceinline Vec3fx& operator /=( Vec3fx& a, const float b ) { return a = a / b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline float reduce_add(const Vec3fx& v) { + const vfloat4 a(v.m128); + const vfloat4 b = shuffle<1>(a); + const vfloat4 c = shuffle<2>(a); + return _mm_cvtss_f32(a+b+c); + } + + __forceinline float reduce_mul(const Vec3fx& v) { return v.x*v.y*v.z; } + __forceinline float reduce_min(const Vec3fx& v) { return min(v.x,v.y,v.z); } + __forceinline float reduce_max(const Vec3fx& v) { return max(v.x,v.y,v.z); } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator ==( const Vec3fx& a, const Vec3fx& b ) { return (_mm_movemask_ps(_mm_cmpeq_ps (a.m128, b.m128)) & 7) == 7; } + __forceinline bool operator !=( const Vec3fx& a, const Vec3fx& b ) { return (_mm_movemask_ps(_mm_cmpneq_ps(a.m128, b.m128)) & 7) != 0; } + + __forceinline Vec3ba eq_mask( const Vec3fx& a, const Vec3fx& b ) { return _mm_cmpeq_ps (a.m128, b.m128); } + __forceinline Vec3ba neq_mask(const Vec3fx& a, const Vec3fx& b ) { return _mm_cmpneq_ps(a.m128, b.m128); } + __forceinline Vec3ba lt_mask( const Vec3fx& a, const Vec3fx& b ) { return _mm_cmplt_ps (a.m128, b.m128); } + __forceinline Vec3ba le_mask( const Vec3fx& a, const Vec3fx& b ) { return _mm_cmple_ps (a.m128, b.m128); } + __forceinline Vec3ba gt_mask( const Vec3fx& a, const Vec3fx& b ) { return _mm_cmpnle_ps(a.m128, b.m128); } + __forceinline Vec3ba ge_mask( const Vec3fx& a, const Vec3fx& b ) { return _mm_cmpnlt_ps(a.m128, b.m128); } + + __forceinline bool isvalid ( const Vec3fx& v ) { + return all(gt_mask(v,Vec3fx(-FLT_LARGE)) & lt_mask(v,Vec3fx(+FLT_LARGE))); + } + + __forceinline bool is_finite ( const Vec3fx& a ) { + return all(ge_mask(a,Vec3fx(-FLT_MAX)) & le_mask(a,Vec3fx(+FLT_MAX))); + } + + __forceinline bool isvalid4 ( const Vec3fx& v ) { + return all((vfloat4(v.m128) > vfloat4(-FLT_LARGE)) & (vfloat4(v.m128) < vfloat4(+FLT_LARGE))); + } + + __forceinline bool is_finite4 ( const Vec3fx& a ) { + return all((vfloat4(a.m128) >= vfloat4(-FLT_MAX)) & (vfloat4(a.m128) <= vfloat4(+FLT_MAX))); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Euclidian Space Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__SSE4_1__) + __forceinline float dot ( const Vec3fx& a, const Vec3fx& b ) { + return _mm_cvtss_f32(_mm_dp_ps(a.m128,b.m128,0x7F)); + } +#else + __forceinline float dot ( const Vec3fx& a, const Vec3fx& b ) { + return reduce_add(a*b); + } +#endif + + __forceinline Vec3fx cross ( const Vec3fx& a, const Vec3fx& b ) + { + vfloat4 a0 = vfloat4(a.m128); + vfloat4 b0 = shuffle<1,2,0,3>(vfloat4(b.m128)); + vfloat4 a1 = shuffle<1,2,0,3>(vfloat4(a.m128)); + vfloat4 b1 = vfloat4(b.m128); + return Vec3fx(shuffle<1,2,0,3>(msub(a0,b0,a1*b1))); + } + + __forceinline float sqr_length ( const Vec3fx& a ) { return dot(a,a); } + __forceinline float rcp_length ( const Vec3fx& a ) { return rsqrt(dot(a,a)); } + __forceinline float rcp_length2( const Vec3fx& a ) { return rcp(dot(a,a)); } + __forceinline float length ( const Vec3fx& a ) { return sqrt(dot(a,a)); } + __forceinline Vec3fx normalize( const Vec3fx& a ) { return a*rsqrt(dot(a,a)); } + __forceinline float distance ( const Vec3fx& a, const Vec3fx& b ) { return length(a-b); } + __forceinline float halfArea ( const Vec3fx& d ) { return madd(d.x,(d.y+d.z),d.y*d.z); } + __forceinline float area ( const Vec3fx& d ) { return 2.0f*halfArea(d); } + + __forceinline Vec3fx normalize_safe( const Vec3fx& a ) { + const float d = dot(a,a); if (unlikely(d == 0.0f)) return a; else return a*rsqrt(d); + } + + /*! differentiated normalization */ + __forceinline Vec3fx dnormalize(const Vec3fx& p, const Vec3fx& dp) + { + const float pp = dot(p,p); + const float pdp = dot(p,dp); + return (pp*dp-pdp*p)*rcp(pp)*rsqrt(pp); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3fx select( bool s, const Vec3fx& t, const Vec3fx& f ) { + __m128 mask = s ? _mm_castsi128_ps(_mm_cmpeq_epi32(_mm_setzero_si128(), _mm_setzero_si128())) : _mm_setzero_ps(); + return blendv_ps(f.m128, t.m128, mask); + } + + __forceinline Vec3fx select( const Vec3ba& s, const Vec3fx& t, const Vec3fx& f ) { + return blendv_ps(f.m128, t.m128, s); + } + + __forceinline Vec3fx lerp(const Vec3fx& v0, const Vec3fx& v1, const float t) { + return madd(1.0f-t,v0,t*v1); + } + + __forceinline int maxDim ( const Vec3fx& a ) + { + const Vec3fx b = abs(a); + if (b.x > b.y) { + if (b.x > b.z) return 0; else return 2; + } else { + if (b.y > b.z) return 1; else return 2; + } + } + + //////////////////////////////////////////////////////////////////////////////// + /// Rounding Functions + //////////////////////////////////////////////////////////////////////////////// + +#if defined (__SSE4_1__) + __forceinline Vec3fx trunc( const Vec3fx& a ) { return _mm_round_ps(a.m128, _MM_FROUND_TO_NEAREST_INT); } + __forceinline Vec3fx floor( const Vec3fx& a ) { return _mm_round_ps(a.m128, _MM_FROUND_TO_NEG_INF ); } + __forceinline Vec3fx ceil ( const Vec3fx& a ) { return _mm_round_ps(a.m128, _MM_FROUND_TO_POS_INF ); } +#else + __forceinline Vec3fx trunc( const Vec3fx& a ) { return Vec3fx(truncf(a.x),truncf(a.y),truncf(a.z)); } + __forceinline Vec3fx floor( const Vec3fx& a ) { return Vec3fx(floorf(a.x),floorf(a.y),floorf(a.z)); } + __forceinline Vec3fx ceil ( const Vec3fx& a ) { return Vec3fx(ceilf (a.x),ceilf (a.y),ceilf (a.z)); } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator<<(embree_ostream cout, const Vec3fx& a) { + return cout << "(" << a.x << ", " << a.y << ", " << a.z << ")"; + } + + + typedef Vec3fx Vec3ff; +} diff --git a/thirdparty/embree/common/math/vec3ia.h b/thirdparty/embree/common/math/vec3ia.h new file mode 100644 index 000000000000..e1c997299400 --- /dev/null +++ b/thirdparty/embree/common/math/vec3ia.h @@ -0,0 +1,186 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/alloc.h" +#include "math.h" +#include "../simd/sse.h" + +namespace embree +{ + //////////////////////////////////////////////////////////////////////////////// + /// SSE Vec3ia Type + //////////////////////////////////////////////////////////////////////////////// + + struct __aligned(16) Vec3ia + { + ALIGNED_STRUCT_(16); + + union { + __m128i m128; + struct { int x,y,z; }; + }; + + typedef int Scalar; + enum { N = 3 }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3ia( ) {} + __forceinline Vec3ia( const __m128i a ) : m128(a) {} + __forceinline Vec3ia( const Vec3ia& other ) : m128(other.m128) {} + __forceinline Vec3ia& operator =(const Vec3ia& other) { m128 = other.m128; return *this; } + + __forceinline explicit Vec3ia( const int a ) : m128(_mm_set1_epi32(a)) {} + __forceinline Vec3ia( const int x, const int y, const int z) : m128(_mm_set_epi32(z, z, y, x)) {} + __forceinline explicit Vec3ia( const __m128 a ) : m128(_mm_cvtps_epi32(a)) {} + + __forceinline operator const __m128i&() const { return m128; } + __forceinline operator __m128i&() { return m128; } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3ia( ZeroTy ) : m128(_mm_setzero_si128()) {} + __forceinline Vec3ia( OneTy ) : m128(_mm_set1_epi32(1)) {} + __forceinline Vec3ia( PosInfTy ) : m128(_mm_set1_epi32(pos_inf)) {} + __forceinline Vec3ia( NegInfTy ) : m128(_mm_set1_epi32(neg_inf)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const int& operator []( const size_t index ) const { assert(index < 3); return (&x)[index]; } + __forceinline int& operator []( const size_t index ) { assert(index < 3); return (&x)[index]; } + }; + + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3ia operator +( const Vec3ia& a ) { return a; } + __forceinline Vec3ia operator -( const Vec3ia& a ) { return _mm_sub_epi32(_mm_setzero_si128(), a.m128); } +#if defined(__SSSE3__) + __forceinline Vec3ia abs ( const Vec3ia& a ) { return _mm_abs_epi32(a.m128); } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3ia operator +( const Vec3ia& a, const Vec3ia& b ) { return _mm_add_epi32(a.m128, b.m128); } + __forceinline Vec3ia operator +( const Vec3ia& a, const int b ) { return a+Vec3ia(b); } + __forceinline Vec3ia operator +( const int a, const Vec3ia& b ) { return Vec3ia(a)+b; } + + __forceinline Vec3ia operator -( const Vec3ia& a, const Vec3ia& b ) { return _mm_sub_epi32(a.m128, b.m128); } + __forceinline Vec3ia operator -( const Vec3ia& a, const int b ) { return a-Vec3ia(b); } + __forceinline Vec3ia operator -( const int a, const Vec3ia& b ) { return Vec3ia(a)-b; } + +#if defined(__SSE4_1__) + __forceinline Vec3ia operator *( const Vec3ia& a, const Vec3ia& b ) { return _mm_mullo_epi32(a.m128, b.m128); } + __forceinline Vec3ia operator *( const Vec3ia& a, const int b ) { return a * Vec3ia(b); } + __forceinline Vec3ia operator *( const int a, const Vec3ia& b ) { return Vec3ia(a) * b; } +#endif + + __forceinline Vec3ia operator &( const Vec3ia& a, const Vec3ia& b ) { return _mm_and_si128(a.m128, b.m128); } + __forceinline Vec3ia operator &( const Vec3ia& a, const int b ) { return a & Vec3ia(b); } + __forceinline Vec3ia operator &( const int a, const Vec3ia& b ) { return Vec3ia(a) & b; } + + __forceinline Vec3ia operator |( const Vec3ia& a, const Vec3ia& b ) { return _mm_or_si128(a.m128, b.m128); } + __forceinline Vec3ia operator |( const Vec3ia& a, const int b ) { return a | Vec3ia(b); } + __forceinline Vec3ia operator |( const int a, const Vec3ia& b ) { return Vec3ia(a) | b; } + + __forceinline Vec3ia operator ^( const Vec3ia& a, const Vec3ia& b ) { return _mm_xor_si128(a.m128, b.m128); } + __forceinline Vec3ia operator ^( const Vec3ia& a, const int b ) { return a ^ Vec3ia(b); } + __forceinline Vec3ia operator ^( const int a, const Vec3ia& b ) { return Vec3ia(a) ^ b; } + + __forceinline Vec3ia operator <<( const Vec3ia& a, const int n ) { return _mm_slli_epi32(a.m128, n); } + __forceinline Vec3ia operator >>( const Vec3ia& a, const int n ) { return _mm_srai_epi32(a.m128, n); } + + __forceinline Vec3ia sll ( const Vec3ia& a, const int b ) { return _mm_slli_epi32(a.m128, b); } + __forceinline Vec3ia sra ( const Vec3ia& a, const int b ) { return _mm_srai_epi32(a.m128, b); } + __forceinline Vec3ia srl ( const Vec3ia& a, const int b ) { return _mm_srli_epi32(a.m128, b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3ia& operator +=( Vec3ia& a, const Vec3ia& b ) { return a = a + b; } + __forceinline Vec3ia& operator +=( Vec3ia& a, const int& b ) { return a = a + b; } + + __forceinline Vec3ia& operator -=( Vec3ia& a, const Vec3ia& b ) { return a = a - b; } + __forceinline Vec3ia& operator -=( Vec3ia& a, const int& b ) { return a = a - b; } + +#if defined(__SSE4_1__) + __forceinline Vec3ia& operator *=( Vec3ia& a, const Vec3ia& b ) { return a = a * b; } + __forceinline Vec3ia& operator *=( Vec3ia& a, const int& b ) { return a = a * b; } +#endif + + __forceinline Vec3ia& operator &=( Vec3ia& a, const Vec3ia& b ) { return a = a & b; } + __forceinline Vec3ia& operator &=( Vec3ia& a, const int& b ) { return a = a & b; } + + __forceinline Vec3ia& operator |=( Vec3ia& a, const Vec3ia& b ) { return a = a | b; } + __forceinline Vec3ia& operator |=( Vec3ia& a, const int& b ) { return a = a | b; } + + __forceinline Vec3ia& operator <<=( Vec3ia& a, const int& b ) { return a = a << b; } + __forceinline Vec3ia& operator >>=( Vec3ia& a, const int& b ) { return a = a >> b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline int reduce_add(const Vec3ia& v) { return v.x+v.y+v.z; } + __forceinline int reduce_mul(const Vec3ia& v) { return v.x*v.y*v.z; } + __forceinline int reduce_min(const Vec3ia& v) { return min(v.x,v.y,v.z); } + __forceinline int reduce_max(const Vec3ia& v) { return max(v.x,v.y,v.z); } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator ==( const Vec3ia& a, const Vec3ia& b ) { return (_mm_movemask_ps(_mm_castsi128_ps(_mm_cmpeq_epi32(a.m128, b.m128))) & 7) == 7; } + __forceinline bool operator !=( const Vec3ia& a, const Vec3ia& b ) { return (_mm_movemask_ps(_mm_castsi128_ps(_mm_cmpeq_epi32(a.m128, b.m128))) & 7) != 7; } + __forceinline bool operator < ( const Vec3ia& a, const Vec3ia& b ) { + if (a.x != b.x) return a.x < b.x; + if (a.y != b.y) return a.y < b.y; + if (a.z != b.z) return a.z < b.z; + return false; + } + + __forceinline Vec3ba eq_mask( const Vec3ia& a, const Vec3ia& b ) { return _mm_castsi128_ps(_mm_cmpeq_epi32 (a.m128, b.m128)); } + __forceinline Vec3ba lt_mask( const Vec3ia& a, const Vec3ia& b ) { return _mm_castsi128_ps(_mm_cmplt_epi32 (a.m128, b.m128)); } + __forceinline Vec3ba gt_mask( const Vec3ia& a, const Vec3ia& b ) { return _mm_castsi128_ps(_mm_cmpgt_epi32 (a.m128, b.m128)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3ia select( const Vec3ba& m, const Vec3ia& t, const Vec3ia& f ) { +#if defined(__SSE4_1__) + return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(f), _mm_castsi128_ps(t), m)); +#else + return _mm_or_si128(_mm_and_si128(_mm_castps_si128(m), t), _mm_andnot_si128(_mm_castps_si128(m), f)); +#endif + } + +#if defined(__SSE4_1__) + __forceinline Vec3ia min( const Vec3ia& a, const Vec3ia& b ) { return _mm_min_epi32(a.m128,b.m128); } + __forceinline Vec3ia max( const Vec3ia& a, const Vec3ia& b ) { return _mm_max_epi32(a.m128,b.m128); } +#else + __forceinline Vec3ia min( const Vec3ia& a, const Vec3ia& b ) { return select(lt_mask(a,b),a,b); } + __forceinline Vec3ia max( const Vec3ia& a, const Vec3ia& b ) { return select(gt_mask(a,b),a,b); } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator<<(embree_ostream cout, const Vec3ia& a) { + return cout << "(" << a.x << ", " << a.y << ", " << a.z << ")"; + } +} diff --git a/thirdparty/embree/common/math/vec4.h b/thirdparty/embree/common/math/vec4.h new file mode 100644 index 000000000000..3354b443178a --- /dev/null +++ b/thirdparty/embree/common/math/vec4.h @@ -0,0 +1,258 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "math.h" +#include "vec3.h" + +namespace embree +{ + //////////////////////////////////////////////////////////////////////////////// + /// Generic 4D vector Class + //////////////////////////////////////////////////////////////////////////////// + + template struct Vec4 + { + enum { N = 4 }; + union { + struct { T x, y, z, w; }; +#if !(defined(__WIN32__) && _MSC_VER == 1800) // workaround for older VS 2013 compiler + T components[N]; +#endif + }; + + typedef T Scalar; + + //////////////////////////////////////////////////////////////////////////////// + /// Construction + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec4( ) {} + __forceinline explicit Vec4( const T& a ) : x(a), y(a), z(a), w(a) {} + __forceinline Vec4( const T& x, const T& y, const T& z, const T& w ) : x(x), y(y), z(z), w(w) {} + __forceinline Vec4( const Vec3& xyz, const T& w ) : x(xyz.x), y(xyz.y), z(xyz.z), w(w) {} + + __forceinline Vec4( const Vec4& other ) { x = other.x; y = other.y; z = other.z; w = other.w; } + __forceinline Vec4( const Vec3fx& other ); + + template __forceinline Vec4( const Vec4& a ) : x(T(a.x)), y(T(a.y)), z(T(a.z)), w(T(a.w)) {} + template __forceinline Vec4& operator =(const Vec4& other) { x = other.x; y = other.y; z = other.z; w = other.w; return *this; } + + __forceinline Vec4& operator =(const Vec4& other) { x = other.x; y = other.y; z = other.z; w = other.w; return *this; } + + __forceinline operator Vec3 () const { return Vec3(x,y,z); } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec4( ZeroTy ) : x(zero), y(zero), z(zero), w(zero) {} + __forceinline Vec4( OneTy ) : x(one), y(one), z(one), w(one) {} + __forceinline Vec4( PosInfTy ) : x(pos_inf), y(pos_inf), z(pos_inf), w(pos_inf) {} + __forceinline Vec4( NegInfTy ) : x(neg_inf), y(neg_inf), z(neg_inf), w(neg_inf) {} + +#if defined(__WIN32__) && (_MSC_VER == 1800) // workaround for older VS 2013 compiler + __forceinline const T& operator [](const size_t axis) const { assert(axis < 4); return (&x)[axis]; } + __forceinline T& operator [](const size_t axis) { assert(axis < 4); return (&x)[axis]; } +#else + __forceinline const T& operator [](const size_t axis ) const { assert(axis < 4); return components[axis]; } + __forceinline T& operator [](const size_t axis) { assert(axis < 4); return components[axis]; } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Swizzles + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Vec3 xyz() const { return Vec3(x, y, z); } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec4 operator +( const Vec4& a ) { return Vec4(+a.x, +a.y, +a.z, +a.w); } + template __forceinline Vec4 operator -( const Vec4& a ) { return Vec4(-a.x, -a.y, -a.z, -a.w); } + template __forceinline Vec4 abs ( const Vec4& a ) { return Vec4(abs (a.x), abs (a.y), abs (a.z), abs (a.w)); } + template __forceinline Vec4 rcp ( const Vec4& a ) { return Vec4(rcp (a.x), rcp (a.y), rcp (a.z), rcp (a.w)); } + template __forceinline Vec4 rsqrt ( const Vec4& a ) { return Vec4(rsqrt(a.x), rsqrt(a.y), rsqrt(a.z), rsqrt(a.w)); } + template __forceinline Vec4 sqrt ( const Vec4& a ) { return Vec4(sqrt (a.x), sqrt (a.y), sqrt (a.z), sqrt (a.w)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec4 operator +( const Vec4& a, const Vec4& b ) { return Vec4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } + template __forceinline Vec4 operator -( const Vec4& a, const Vec4& b ) { return Vec4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } + template __forceinline Vec4 operator *( const Vec4& a, const Vec4& b ) { return Vec4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } + template __forceinline Vec4 operator *( const T& a, const Vec4& b ) { return Vec4(a * b.x, a * b.y, a * b.z, a * b.w); } + template __forceinline Vec4 operator *( const Vec4& a, const T& b ) { return Vec4(a.x * b , a.y * b , a.z * b , a.w * b ); } + template __forceinline Vec4 operator /( const Vec4& a, const Vec4& b ) { return Vec4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); } + template __forceinline Vec4 operator /( const Vec4& a, const T& b ) { return Vec4(a.x / b , a.y / b , a.z / b , a.w / b ); } + template __forceinline Vec4 operator /( const T& a, const Vec4& b ) { return Vec4(a / b.x, a / b.y, a / b.z, a / b.w); } + + template __forceinline Vec4 min(const Vec4& a, const Vec4& b) { return Vec4(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z), min(a.w, b.w)); } + template __forceinline Vec4 max(const Vec4& a, const Vec4& b) { return Vec4(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z), max(a.w, b.w)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Ternary Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec4 madd ( const Vec4& a, const Vec4& b, const Vec4& c) { return Vec4( madd(a.x,b.x,c.x), madd(a.y,b.y,c.y), madd(a.z,b.z,c.z), madd(a.w,b.w,c.w)); } + template __forceinline Vec4 msub ( const Vec4& a, const Vec4& b, const Vec4& c) { return Vec4( msub(a.x,b.x,c.x), msub(a.y,b.y,c.y), msub(a.z,b.z,c.z), msub(a.w,b.w,c.w)); } + template __forceinline Vec4 nmadd ( const Vec4& a, const Vec4& b, const Vec4& c) { return Vec4(nmadd(a.x,b.x,c.x),nmadd(a.y,b.y,c.y),nmadd(a.z,b.z,c.z),nmadd(a.w,b.w,c.w)); } + template __forceinline Vec4 nmsub ( const Vec4& a, const Vec4& b, const Vec4& c) { return Vec4(nmsub(a.x,b.x,c.x),nmsub(a.y,b.y,c.y),nmsub(a.z,b.z,c.z),nmsub(a.w,b.w,c.w)); } + + template __forceinline Vec4 madd ( const T& a, const Vec4& b, const Vec4& c) { return Vec4( madd(a,b.x,c.x), madd(a,b.y,c.y), madd(a,b.z,c.z), madd(a,b.w,c.w)); } + template __forceinline Vec4 msub ( const T& a, const Vec4& b, const Vec4& c) { return Vec4( msub(a,b.x,c.x), msub(a,b.y,c.y), msub(a,b.z,c.z), msub(a,b.w,c.w)); } + template __forceinline Vec4 nmadd ( const T& a, const Vec4& b, const Vec4& c) { return Vec4(nmadd(a,b.x,c.x),nmadd(a,b.y,c.y),nmadd(a,b.z,c.z),nmadd(a,b.w,c.w)); } + template __forceinline Vec4 nmsub ( const T& a, const Vec4& b, const Vec4& c) { return Vec4(nmsub(a,b.x,c.x),nmsub(a,b.y,c.y),nmsub(a,b.z,c.z),nmsub(a,b.w,c.w)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec4& operator +=( Vec4& a, const Vec4& b ) { a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w; return a; } + template __forceinline Vec4& operator -=( Vec4& a, const Vec4& b ) { a.x -= b.x; a.y -= b.y; a.z -= b.z; a.w -= b.w; return a; } + template __forceinline Vec4& operator *=( Vec4& a, const T& b ) { a.x *= b ; a.y *= b ; a.z *= b ; a.w *= b ; return a; } + template __forceinline Vec4& operator /=( Vec4& a, const T& b ) { a.x /= b ; a.y /= b ; a.z /= b ; a.w /= b ; return a; } + + //////////////////////////////////////////////////////////////////////////////// + /// Reduction Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline T reduce_add( const Vec4& a ) { return a.x + a.y + a.z + a.w; } + template __forceinline T reduce_mul( const Vec4& a ) { return a.x * a.y * a.z * a.w; } + template __forceinline T reduce_min( const Vec4& a ) { return min(a.x, a.y, a.z, a.w); } + template __forceinline T reduce_max( const Vec4& a ) { return max(a.x, a.y, a.z, a.w); } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline bool operator ==( const Vec4& a, const Vec4& b ) { return a.x == b.x && a.y == b.y && a.z == b.z && a.w == b.w; } + template __forceinline bool operator !=( const Vec4& a, const Vec4& b ) { return a.x != b.x || a.y != b.y || a.z != b.z || a.w != b.w; } + template __forceinline bool operator < ( const Vec4& a, const Vec4& b ) { + if (a.x != b.x) return a.x < b.x; + if (a.y != b.y) return a.y < b.y; + if (a.z != b.z) return a.z < b.z; + if (a.w != b.w) return a.w < b.w; + return false; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Shift Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec4 shift_right_1( const Vec4& a ) { + return Vec4(shift_right_1(a.x),shift_right_1(a.y),shift_right_1(a.z),shift_right_1(a.w)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Euclidian Space Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline T dot ( const Vec4& a, const Vec4& b ) { return madd(a.x,b.x,madd(a.y,b.y,madd(a.z,b.z,a.w*b.w))); } + + template __forceinline T length ( const Vec4& a ) { return sqrt(dot(a,a)); } + template __forceinline Vec4 normalize( const Vec4& a ) { return a*rsqrt(dot(a,a)); } + template __forceinline T distance ( const Vec4& a, const Vec4& b ) { return length(a-b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Select + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline Vec4 select ( bool s, const Vec4& t, const Vec4& f ) { + return Vec4(select(s,t.x,f.x),select(s,t.y,f.y),select(s,t.z,f.z),select(s,t.w,f.w)); + } + + template __forceinline Vec4 select ( const Vec4& s, const Vec4& t, const Vec4& f ) { + return Vec4(select(s.x,t.x,f.x),select(s.y,t.y,f.y),select(s.z,t.z,f.z),select(s.w,t.w,f.w)); + } + + template __forceinline Vec4 select ( const typename T::Bool& s, const Vec4& t, const Vec4& f ) { + return Vec4(select(s,t.x,f.x),select(s,t.y,f.y),select(s,t.z,f.z),select(s,t.w,f.w)); + } + + template + __forceinline Vec4 lerp(const Vec4& v0, const Vec4& v1, const T& t) { + return madd(Vec4(T(1.0f)-t),v0,t*v1); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + template __forceinline embree_ostream operator<<(embree_ostream cout, const Vec4& a) { + return cout << "(" << a.x << ", " << a.y << ", " << a.z << ", " << a.w << ")"; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Default template instantiations + //////////////////////////////////////////////////////////////////////////////// + + typedef Vec4 Vec4b; + typedef Vec4 Vec4uc; + typedef Vec4 Vec4i; + typedef Vec4 Vec4f; +} + +#include "vec3ba.h" +#include "vec3ia.h" +#include "vec3fa.h" + +//////////////////////////////////////////////////////////////////////////////// +/// SSE / AVX / MIC specializations +//////////////////////////////////////////////////////////////////////////////// + +#if defined __SSE__ +#include "../simd/sse.h" +#endif + +#if defined __AVX__ +#include "../simd/avx.h" +#endif + +#if defined __AVX512F__ +#include "../simd/avx512.h" +#endif + +namespace embree +{ + template<> __forceinline Vec4::Vec4( const Vec3fx& a ) { x = a.x; y = a.y; z = a.z; w = a.w; } + +#if defined(__AVX__) + template<> __forceinline Vec4::Vec4( const Vec3fx& a ) { + x = a.x; y = a.y; z = a.z; w = a.w; + } +#elif defined(__SSE__) + template<> __forceinline Vec4::Vec4( const Vec3fx& a ) { + const vfloat4 v = vfloat4(a.m128); x = shuffle<0,0,0,0>(v); y = shuffle<1,1,1,1>(v); z = shuffle<2,2,2,2>(v); w = shuffle<3,3,3,3>(v); + } +#endif + +#if defined(__SSE__) + __forceinline Vec4 broadcast4f( const Vec4& a, const size_t k ) { + return Vec4(vfloat4::broadcast(&a.x[k]), vfloat4::broadcast(&a.y[k]), vfloat4::broadcast(&a.z[k]), vfloat4::broadcast(&a.w[k])); + } +#endif + +#if defined(__AVX__) + template<> __forceinline Vec4::Vec4( const Vec3fx& a ) { + x = a.x; y = a.y; z = a.z; w = a.w; + } + __forceinline Vec4 broadcast4f( const Vec4& a, const size_t k ) { + return Vec4(vfloat4::broadcast(&a.x[k]), vfloat4::broadcast(&a.y[k]), vfloat4::broadcast(&a.z[k]), vfloat4::broadcast(&a.w[k])); + } + __forceinline Vec4 broadcast8f( const Vec4& a, const size_t k ) { + return Vec4(vfloat8::broadcast(&a.x[k]), vfloat8::broadcast(&a.y[k]), vfloat8::broadcast(&a.z[k]), vfloat8::broadcast(&a.w[k])); + } + __forceinline Vec4 broadcast8f( const Vec4& a, const size_t k ) { + return Vec4(vfloat8::broadcast(&a.x[k]), vfloat8::broadcast(&a.y[k]), vfloat8::broadcast(&a.z[k]), vfloat8::broadcast(&a.w[k])); + } +#endif + +#if defined(__AVX512F__) + template<> __forceinline Vec4::Vec4( const Vec3fx& a ) : x(a.x), y(a.y), z(a.z), w(a.w) {} +#endif +} diff --git a/thirdparty/embree/common/simd/avx.h b/thirdparty/embree/common/simd/avx.h new file mode 100644 index 000000000000..c840e418058a --- /dev/null +++ b/thirdparty/embree/common/simd/avx.h @@ -0,0 +1,34 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "sse.h" + +#if defined(__AVX512VL__) +#include "vboolf8_avx512.h" +#include "vboold4_avx512.h" +#else +#include "vboolf8_avx.h" +#include "vboold4_avx.h" +#endif + +#if defined(__AVX2__) +#include "vint8_avx2.h" +#include "vuint8_avx2.h" +#if defined(__X86_64__) +#include "vllong4_avx2.h" +#endif +#else +#include "vint8_avx.h" +#include "vuint8_avx.h" +#endif +#include "vfloat8_avx.h" +#if defined(__X86_64__) +#include "vdouble4_avx.h" +#endif + +#if defined(__AVX512F__) +#include "avx512.h" +#endif + diff --git a/thirdparty/embree/common/simd/avx512.h b/thirdparty/embree/common/simd/avx512.h new file mode 100644 index 000000000000..25414ab5b1db --- /dev/null +++ b/thirdparty/embree/common/simd/avx512.h @@ -0,0 +1,41 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/platform.h" +#include "../sys/intrinsics.h" +#include "../math/constants.h" +#include "../sys/alloc.h" +#include "varying.h" + +#include "vboolf16_avx512.h" +#include "vint16_avx512.h" +#include "vuint16_avx512.h" +#include "vfloat16_avx512.h" + +#include "vboold8_avx512.h" +#include "vllong8_avx512.h" +#include "vdouble8_avx512.h" + +namespace embree +{ + //////////////////////////////////////////////////////////////////////////////// + /// Prefetching + //////////////////////////////////////////////////////////////////////////////// + +#define PFHINT_L1 0 +#define PFHINT_L2 1 +#define PFHINT_NT 2 + + template + __forceinline void prefetch(const void * __restrict__ const m) + { + if (mode == PFHINT_L1) + _mm_prefetch((const char*)m,_MM_HINT_T0); + else if (mode == PFHINT_L2) + _mm_prefetch((const char*)m,_MM_HINT_T1); + else if (mode == PFHINT_NT) + _mm_prefetch((const char*)m,_MM_HINT_NTA); + } +} diff --git a/thirdparty/embree/common/simd/simd.h b/thirdparty/embree/common/simd/simd.h new file mode 100644 index 000000000000..c1351c2c8884 --- /dev/null +++ b/thirdparty/embree/common/simd/simd.h @@ -0,0 +1,110 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../math/math.h" + +/* include SSE wrapper classes */ +#if defined(__SSE__) +# include "sse.h" +#endif + +/* include AVX wrapper classes */ +#if defined(__AVX__) +# include "avx.h" +#endif + +/* include AVX512 wrapper classes */ +#if defined (__AVX512F__) +# include "avx512.h" +#endif + +namespace embree +{ + template + __forceinline vbool isfinite(const vfloat& v) + { + return (v >= vfloat(-std::numeric_limits::max())) + & (v <= vfloat( std::numeric_limits::max())); + } + + /* foreach unique */ + template + __forceinline void foreach_unique(const vbool& valid0, const vint& vi, const Closure& closure) + { + vbool valid1 = valid0; + while (any(valid1)) { + const int j = int(bsf(movemask(valid1))); + const int i = vi[j]; + const vbool valid2 = valid1 & (i == vi); + valid1 = andn(valid1, valid2); + closure(valid2, i); + } + } + + /* returns the next unique value i in vi and the corresponding valid_i mask */ + template + __forceinline int next_unique(vbool& valid, const vint& vi, /*out*/ vbool& valid_i) + { + assert(any(valid)); + const int j = int(bsf(movemask(valid))); + const int i = vi[j]; + valid_i = valid & (i == vi); + valid = andn(valid, valid_i); + return i; + } + + /* foreach unique index */ + template + __forceinline void foreach_unique_index(const vbool& valid0, const vint& vi, const Closure& closure) + { + vbool valid1 = valid0; + while (any(valid1)) { + const int j = int(bsf(movemask(valid1))); + const int i = vi[j]; + const vbool valid2 = valid1 & (i == vi); + valid1 = andn(valid1, valid2); + closure(valid2, i, j); + } + } + + /* returns the index of the next unique value i in vi and the corresponding valid_i mask */ + template + __forceinline int next_unique_index(vbool& valid, const vint& vi, /*out*/ vbool& valid_i) + { + assert(any(valid)); + const int j = int(bsf(movemask(valid))); + const int i = vi[j]; + valid_i = valid & (i == vi); + valid = andn(valid, valid_i); + return j; + } + + template + __forceinline void foreach2(int x0, int x1, int y0, int y1, const Closure& closure) + { + __aligned(64) int U[2*VSIZEX]; + __aligned(64) int V[2*VSIZEX]; + int index = 0; + for (int y=y0; y=y1; + const vintx vy = y; + for (int x=x0; x= x1; + vintx vx = x+vintx(step); + vintx::storeu(&U[index], vx); + vintx::storeu(&V[index], vy); + const int dx = min(x1-x,VSIZEX); + index += dx; + x += dx; + if (index >= VSIZEX || (lastx && lasty)) { + const vboolx valid = vintx(step) < vintx(index); + closure(valid, vintx::load(U), vintx::load(V)); + x-= max(0, index-VSIZEX); + index = 0; + } + } + } + } +} diff --git a/thirdparty/embree/common/simd/sse.cpp b/thirdparty/embree/common/simd/sse.cpp new file mode 100644 index 000000000000..1732cfa42155 --- /dev/null +++ b/thirdparty/embree/common/simd/sse.cpp @@ -0,0 +1,34 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "sse.h" + +namespace embree +{ + const __m128 mm_lookupmask_ps[16] = { + _mm_castsi128_ps(_mm_set_epi32( 0, 0, 0, 0)), + _mm_castsi128_ps(_mm_set_epi32( 0, 0, 0,-1)), + _mm_castsi128_ps(_mm_set_epi32( 0, 0,-1, 0)), + _mm_castsi128_ps(_mm_set_epi32( 0, 0,-1,-1)), + _mm_castsi128_ps(_mm_set_epi32( 0,-1, 0, 0)), + _mm_castsi128_ps(_mm_set_epi32( 0,-1, 0,-1)), + _mm_castsi128_ps(_mm_set_epi32( 0,-1,-1, 0)), + _mm_castsi128_ps(_mm_set_epi32( 0,-1,-1,-1)), + _mm_castsi128_ps(_mm_set_epi32(-1, 0, 0, 0)), + _mm_castsi128_ps(_mm_set_epi32(-1, 0, 0,-1)), + _mm_castsi128_ps(_mm_set_epi32(-1, 0,-1, 0)), + _mm_castsi128_ps(_mm_set_epi32(-1, 0,-1,-1)), + _mm_castsi128_ps(_mm_set_epi32(-1,-1, 0, 0)), + _mm_castsi128_ps(_mm_set_epi32(-1,-1, 0,-1)), + _mm_castsi128_ps(_mm_set_epi32(-1,-1,-1, 0)), + _mm_castsi128_ps(_mm_set_epi32(-1,-1,-1,-1)) + }; + + const __m128d mm_lookupmask_pd[4] = { + _mm_castsi128_pd(_mm_set_epi32( 0, 0, 0, 0)), + _mm_castsi128_pd(_mm_set_epi32( 0, 0,-1,-1)), + _mm_castsi128_pd(_mm_set_epi32(-1,-1, 0, 0)), + _mm_castsi128_pd(_mm_set_epi32(-1,-1,-1,-1)) + }; + +} diff --git a/thirdparty/embree/common/simd/sse.h b/thirdparty/embree/common/simd/sse.h new file mode 100644 index 000000000000..67df3ec00980 --- /dev/null +++ b/thirdparty/embree/common/simd/sse.h @@ -0,0 +1,35 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/platform.h" +#include "../sys/intrinsics.h" +#include "../sys/alloc.h" +#include "../math/constants.h" +#include "varying.h" + +namespace embree +{ +#if defined(__SSE4_1__) + __forceinline __m128 blendv_ps(__m128 f, __m128 t, __m128 mask) { + return _mm_blendv_ps(f,t,mask); + } +#else + __forceinline __m128 blendv_ps(__m128 f, __m128 t, __m128 mask) { + return _mm_or_ps(_mm_and_ps(mask, t), _mm_andnot_ps(mask, f)); + } +#endif + + extern const __m128 mm_lookupmask_ps[16]; + extern const __m128d mm_lookupmask_pd[4]; +} + +#if defined(__AVX512VL__) +#include "vboolf4_avx512.h" +#else +#include "vboolf4_sse2.h" +#endif +#include "vint4_sse2.h" +#include "vuint4_sse2.h" +#include "vfloat4_sse2.h" diff --git a/thirdparty/embree/common/simd/varying.h b/thirdparty/embree/common/simd/varying.h new file mode 100644 index 000000000000..9a46817da913 --- /dev/null +++ b/thirdparty/embree/common/simd/varying.h @@ -0,0 +1,132 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/platform.h" + +namespace embree +{ + /* Varying numeric types */ + template + struct vfloat + { + union { float f[N]; int i[N]; }; + __forceinline const float& operator [](size_t index) const { assert(index < N); return f[index]; } + __forceinline float& operator [](size_t index) { assert(index < N); return f[index]; } + }; + + template + struct vdouble + { + union { double f[N]; long long i[N]; }; + __forceinline const double& operator [](size_t index) const { assert(index < N); return f[index]; } + __forceinline double& operator [](size_t index) { assert(index < N); return f[index]; } + }; + + template + struct vint + { + int i[N]; + __forceinline const int& operator [](size_t index) const { assert(index < N); return i[index]; } + __forceinline int& operator [](size_t index) { assert(index < N); return i[index]; } + }; + + template + struct vuint + { + unsigned int i[N]; + __forceinline const unsigned int& operator [](size_t index) const { assert(index < N); return i[index]; } + __forceinline unsigned int& operator [](size_t index) { assert(index < N); return i[index]; } + }; + + template + struct vllong + { + long long i[N]; + __forceinline const long long& operator [](size_t index) const { assert(index < N); return i[index]; } + __forceinline long long& operator [](size_t index) { assert(index < N); return i[index]; } + }; + + /* Varying bool types */ + template struct vboolf { int i[N]; }; // for float/int + template struct vboold { long long i[N]; }; // for double/long long + + /* Aliases to default types */ + template using vreal = vfloat; + template using vbool = vboolf; + + /* Varying size constants */ +#if defined(__AVX512VL__) // SKX + const int VSIZEX = 8; // default size + const int VSIZEL = 16; // large size +#elif defined(__AVX512F__) // KNL + const int VSIZEX = 16; + const int VSIZEL = 16; +#elif defined(__AVX__) + const int VSIZEX = 8; + const int VSIZEL = 8; +#else + const int VSIZEX = 4; + const int VSIZEL = 4; +#endif + + /* Extends varying size N to optimal or up to max(N, N2) */ + template + struct vextend + { +#if defined(__AVX512F__) && !defined(__AVX512VL__) // KNL + /* use 16-wide SIMD calculations on KNL even for 4 and 8 wide SIMD */ + static const int size = (N2 == VSIZEX) ? VSIZEX : N; + #define SIMD_MODE(N) N, 16 +#else + /* calculate with same SIMD width otherwise */ + static const int size = N; + #define SIMD_MODE(N) N, N +#endif + }; + + /* 4-wide shortcuts */ + typedef vfloat<4> vfloat4; + typedef vdouble<4> vdouble4; + typedef vreal<4> vreal4; + typedef vint<4> vint4; + typedef vuint<4> vuint4; + typedef vllong<4> vllong4; + typedef vbool<4> vbool4; + typedef vboolf<4> vboolf4; + typedef vboold<4> vboold4; + + /* 8-wide shortcuts */ + typedef vfloat<8> vfloat8; + typedef vdouble<8> vdouble8; + typedef vreal<8> vreal8; + typedef vint<8> vint8; + typedef vuint<8> vuint8; + typedef vllong<8> vllong8; + typedef vbool<8> vbool8; + typedef vboolf<8> vboolf8; + typedef vboold<8> vboold8; + + /* 16-wide shortcuts */ + typedef vfloat<16> vfloat16; + typedef vdouble<16> vdouble16; + typedef vreal<16> vreal16; + typedef vint<16> vint16; + typedef vuint<16> vuint16; + typedef vllong<16> vllong16; + typedef vbool<16> vbool16; + typedef vboolf<16> vboolf16; + typedef vboold<16> vboold16; + + /* Default shortcuts */ + typedef vfloat vfloatx; + typedef vdouble vdoublex; + typedef vreal vrealx; + typedef vint vintx; + typedef vuint vuintx; + typedef vllong vllongx; + typedef vbool vboolx; + typedef vboolf vboolfx; + typedef vboold vbooldx; +} diff --git a/thirdparty/embree/common/simd/vboold4_avx.h b/thirdparty/embree/common/simd/vboold4_avx.h new file mode 100644 index 000000000000..44e423b001be --- /dev/null +++ b/thirdparty/embree/common/simd/vboold4_avx.h @@ -0,0 +1,155 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 4-wide AVX bool type for 64bit data types*/ + template<> + struct vboold<4> + { + ALIGNED_STRUCT_(32); + + typedef vboold4 Bool; + + enum { size = 4 }; // number of SIMD elements + union { // data + __m256d v; + struct { __m128d vl,vh; }; + long long i[4]; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold() {} + __forceinline vboold(const vboold4& a) { v = a.v; } + __forceinline vboold4& operator =(const vboold4& a) { v = a.v; return *this; } + + __forceinline vboold(__m256d a) : v(a) {} + __forceinline vboold(__m256i a) : v(_mm256_castsi256_pd(a)) {} + + __forceinline operator const __m256() const { return _mm256_castpd_ps(v); } + __forceinline operator const __m256i() const { return _mm256_castpd_si256(v); } + __forceinline operator const __m256d() const { return v; } + + __forceinline vboold(int a) + { + assert(a >= 0 && a <= 255); +#if defined (__AVX2__) + const __m256i mask = _mm256_set_epi64x(0x8, 0x4, 0x2, 0x1); + const __m256i b = _mm256_set1_epi64x(a); + const __m256i c = _mm256_and_si256(b,mask); + v = _mm256_castsi256_pd(_mm256_cmpeq_epi64(c,mask)); +#else + vl = mm_lookupmask_pd[a & 0x3]; + vh = mm_lookupmask_pd[a >> 2]; +#endif + } + + __forceinline vboold(__m128d a, __m128d b) : vl(a), vh(b) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold(FalseTy) : v(_mm256_setzero_pd()) {} + __forceinline vboold(TrueTy) : v(_mm256_cmp_pd(_mm256_setzero_pd(), _mm256_setzero_pd(), _CMP_EQ_OQ)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator [](size_t index) const { assert(index < 4); return (_mm256_movemask_pd(v) >> index) & 1; } + __forceinline long long& operator [](size_t index) { assert(index < 4); return i[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold4 operator !(const vboold4& a) { return _mm256_xor_pd(a, vboold4(embree::True)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold4 operator &(const vboold4& a, const vboold4& b) { return _mm256_and_pd(a, b); } + __forceinline vboold4 operator |(const vboold4& a, const vboold4& b) { return _mm256_or_pd (a, b); } + __forceinline vboold4 operator ^(const vboold4& a, const vboold4& b) { return _mm256_xor_pd(a, b); } + + __forceinline vboold4 andn(const vboold4& a, const vboold4& b) { return _mm256_andnot_pd(b, a); } + + __forceinline vboold4& operator &=(vboold4& a, const vboold4& b) { return a = a & b; } + __forceinline vboold4& operator |=(vboold4& a, const vboold4& b) { return a = a | b; } + __forceinline vboold4& operator ^=(vboold4& a, const vboold4& b) { return a = a ^ b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold4 operator !=(const vboold4& a, const vboold4& b) { return _mm256_xor_pd(a, b); } + __forceinline vboold4 operator ==(const vboold4& a, const vboold4& b) { return _mm256_xor_pd(_mm256_xor_pd(a,b),vboold4(embree::True)); } + + __forceinline vboold4 select(const vboold4& mask, const vboold4& t, const vboold4& f) { + return _mm256_blendv_pd(f, t, mask); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold4 unpacklo(const vboold4& a, const vboold4& b) { return _mm256_unpacklo_pd(a, b); } + __forceinline vboold4 unpackhi(const vboold4& a, const vboold4& b) { return _mm256_unpackhi_pd(a, b); } + + +#if defined(__AVX2__) + template + __forceinline vboold4 shuffle(const vboold4& v) { + return _mm256_permute4x64_pd(v, _MM_SHUFFLE(i3, i2, i1, i0)); + } + + template + __forceinline vboold4 shuffle(const vboold4& v) { + return _mm256_permute4x64_pd(v, _MM_SHUFFLE(i, i, i, i)); + } +#endif + + + //////////////////////////////////////////////////////////////////////////////// + /// Reduction Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool reduce_and(const vboold4& a) { return _mm256_movemask_pd(a) == (unsigned int)0xf; } + __forceinline bool reduce_or (const vboold4& a) { return !_mm256_testz_pd(a,a); } + + __forceinline bool all (const vboold4& a) { return _mm256_movemask_pd(a) == (unsigned int)0xf; } + __forceinline bool any (const vboold4& a) { return !_mm256_testz_pd(a,a); } + __forceinline bool none(const vboold4& a) { return _mm256_testz_pd(a,a) != 0; } + + __forceinline bool all (const vboold4& valid, const vboold4& b) { return all((!valid) | b); } + __forceinline bool any (const vboold4& valid, const vboold4& b) { return any(valid & b); } + __forceinline bool none(const vboold4& valid, const vboold4& b) { return none(valid & b); } + + __forceinline unsigned int movemask(const vboold4& a) { return _mm256_movemask_pd(a); } + __forceinline size_t popcnt (const vboold4& a) { return popcnt((size_t)_mm256_movemask_pd(a)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Get/Set Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool get(const vboold4& a, size_t index) { return a[index]; } + __forceinline void set (vboold4& a, size_t index) { a[index] = -1; } + __forceinline void clear(vboold4& a, size_t index) { a[index] = 0; } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vboold4& a) { + return cout << "<" << a[0] << ", " << a[1] << ", " << a[2] << ", " << a[3] << ", " + << a[4] << ", " << a[5] << ", " << a[6] << ", " << a[7] << ">"; + } +} diff --git a/thirdparty/embree/common/simd/vboold4_avx512.h b/thirdparty/embree/common/simd/vboold4_avx512.h new file mode 100644 index 000000000000..4fe730d713d6 --- /dev/null +++ b/thirdparty/embree/common/simd/vboold4_avx512.h @@ -0,0 +1,140 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 4-wide AVX-512 bool type */ + template<> + struct vboold<4> + { + typedef vboold4 Bool; + typedef vint4 Int; + + enum { size = 4 }; // number of SIMD elements + __mmask8 v; // data + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold() {} + __forceinline vboold(const vboold4& t) { v = t.v; } + __forceinline vboold4& operator =(const vboold4& f) { v = f.v; return *this; } + + __forceinline vboold(const __mmask8 &t) { v = t; } + __forceinline operator __mmask8() const { return v; } + + __forceinline vboold(bool b) { v = b ? 0xf : 0x0; } + __forceinline vboold(int t) { v = (__mmask8)t; } + __forceinline vboold(unsigned int t) { v = (__mmask8)t; } + + /* return int8 mask */ + __forceinline __m128i mask8() const { + return _mm_movm_epi8(v); + } + + /* return int32 mask */ + __forceinline __m128i mask32() const { + return _mm_movm_epi32(v); + } + + /* return int64 mask */ + __forceinline __m256i mask64() const { + return _mm256_movm_epi64(v); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold(FalseTy) : v(0x0) {} + __forceinline vboold(TrueTy) : v(0xf) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator [](size_t index) const { + assert(index < 4); return (mm512_mask2int(v) >> index) & 1; + } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold4 operator !(const vboold4& a) { return _mm512_kandn(a, 0xf); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold4 operator &(const vboold4& a, const vboold4& b) { return _mm512_kand(a, b); } + __forceinline vboold4 operator |(const vboold4& a, const vboold4& b) { return _mm512_kor(a, b); } + __forceinline vboold4 operator ^(const vboold4& a, const vboold4& b) { return _mm512_kxor(a, b); } + + __forceinline vboold4 andn(const vboold4& a, const vboold4& b) { return _mm512_kandn(b, a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold4& operator &=(vboold4& a, const vboold4& b) { return a = a & b; } + __forceinline vboold4& operator |=(vboold4& a, const vboold4& b) { return a = a | b; } + __forceinline vboold4& operator ^=(vboold4& a, const vboold4& b) { return a = a ^ b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold4 operator !=(const vboold4& a, const vboold4& b) { return _mm512_kxor(a, b); } + __forceinline vboold4 operator ==(const vboold4& a, const vboold4& b) { return _mm512_kand(_mm512_kxnor(a, b), 0xf); } + + __forceinline vboold4 select(const vboold4& s, const vboold4& a, const vboold4& b) { + return _mm512_kor(_mm512_kand(s, a), _mm512_kandn(s, b)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reduction Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline int all (const vboold4& a) { return a.v == 0xf; } + __forceinline int any (const vboold4& a) { return _mm512_kortestz(a, a) == 0; } + __forceinline int none(const vboold4& a) { return _mm512_kortestz(a, a) != 0; } + + __forceinline int all (const vboold4& valid, const vboold4& b) { return all((!valid) | b); } + __forceinline int any (const vboold4& valid, const vboold4& b) { return any(valid & b); } + __forceinline int none(const vboold4& valid, const vboold4& b) { return none(valid & b); } + + __forceinline size_t movemask(const vboold4& a) { return _mm512_kmov(a); } + __forceinline size_t popcnt (const vboold4& a) { return popcnt(a.v); } + + //////////////////////////////////////////////////////////////////////////////// + /// Conversion Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline unsigned int toInt(const vboold4& a) { return mm512_mask2int(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Get/Set Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool get(const vboold4& a, size_t index) { assert(index < 4); return (toInt(a) >> index) & 1; } + __forceinline void set(vboold4& a, size_t index) { assert(index < 4); a |= 1 << index; } + __forceinline void clear(vboold4& a, size_t index) { assert(index < 4); a = andn(a, 1 << index); } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vboold4& a) + { + cout << "<"; + for (size_t i=0; i<4; i++) { + if ((a.v >> i) & 1) cout << "1"; else cout << "0"; + } + return cout << ">"; + } +} diff --git a/thirdparty/embree/common/simd/vboold8_avx512.h b/thirdparty/embree/common/simd/vboold8_avx512.h new file mode 100644 index 000000000000..fdf3f00de544 --- /dev/null +++ b/thirdparty/embree/common/simd/vboold8_avx512.h @@ -0,0 +1,148 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 8-wide AVX-512 bool type */ + template<> + struct vboold<8> + { + typedef vboold8 Bool; + typedef vint8 Int; + + enum { size = 8 }; // number of SIMD elements + __mmask8 v; // data + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold() {} + __forceinline vboold(const vboold8& t) { v = t.v; } + __forceinline vboold8& operator =(const vboold8& f) { v = f.v; return *this; } + + __forceinline vboold(const __mmask8& t) { v = t; } + __forceinline operator __mmask8() const { return v; } + + __forceinline vboold(bool b) { v = b ? 0xff : 0x00; } + __forceinline vboold(int t) { v = (__mmask8)t; } + __forceinline vboold(unsigned int t) { v = (__mmask8)t; } + + /* return int8 mask */ + __forceinline __m128i mask8() const { +#if defined(__AVX512BW__) + return _mm_movm_epi8(v); +#else + const __m512i f = _mm512_set1_epi64(0); + const __m512i t = _mm512_set1_epi64(-1); + const __m512i m = _mm512_mask_or_epi64(f,v,t,t); + return _mm512_cvtepi64_epi8(m); +#endif + } + + /* return int64 mask */ + __forceinline __m512i mask64() const { +#if defined(__AVX512DQ__) + return _mm512_movm_epi64(v); +#else + const __m512i f = _mm512_set1_epi64(0); + const __m512i t = _mm512_set1_epi64(-1); + return _mm512_mask_or_epi64(f,v,t,t); +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold(FalseTy) : v(0x00) {} + __forceinline vboold(TrueTy) : v(0xff) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator [](size_t index) const { + assert(index < 8); return (mm512_mask2int(v) >> index) & 1; + } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold8 operator !(const vboold8& a) { return _mm512_knot(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold8 operator &(const vboold8& a, const vboold8& b) { return _mm512_kand(a, b); } + __forceinline vboold8 operator |(const vboold8& a, const vboold8& b) { return _mm512_kor(a, b); } + __forceinline vboold8 operator ^(const vboold8& a, const vboold8& b) { return _mm512_kxor(a, b); } + + __forceinline vboold8 andn(const vboold8& a, const vboold8& b) { return _mm512_kandn(b, a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold8& operator &=(vboold8& a, const vboold8& b) { return a = a & b; } + __forceinline vboold8& operator |=(vboold8& a, const vboold8& b) { return a = a | b; } + __forceinline vboold8& operator ^=(vboold8& a, const vboold8& b) { return a = a ^ b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold8 operator !=(const vboold8& a, const vboold8& b) { return _mm512_kxor(a, b); } + __forceinline vboold8 operator ==(const vboold8& a, const vboold8& b) { return _mm512_kxnor(a, b); } + + __forceinline vboold8 select(const vboold8& s, const vboold8& a, const vboold8& b) { + return _mm512_kor(_mm512_kand(s, a), _mm512_kandn(s, b)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reduction Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline int all (const vboold8& a) { return a.v == 0xff; } + __forceinline int any (const vboold8& a) { return _mm512_kortestz(a, a) == 0; } + __forceinline int none(const vboold8& a) { return _mm512_kortestz(a, a) != 0; } + + __forceinline int all (const vboold8& valid, const vboold8& b) { return all((!valid) | b); } + __forceinline int any (const vboold8& valid, const vboold8& b) { return any(valid & b); } + __forceinline int none(const vboold8& valid, const vboold8& b) { return none(valid & b); } + + __forceinline size_t movemask(const vboold8& a) { return _mm512_kmov(a); } + __forceinline size_t popcnt (const vboold8& a) { return popcnt(a.v); } + + //////////////////////////////////////////////////////////////////////////////// + /// Conversion Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline unsigned int toInt(const vboold8& a) { return mm512_mask2int(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Get/Set Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool get(const vboold8& a, size_t index) { assert(index < 8); return (toInt(a) >> index) & 1; } + __forceinline void set(vboold8& a, size_t index) { assert(index < 8); a |= 1 << index; } + __forceinline void clear(vboold8& a, size_t index) { assert(index < 8); a = andn(a, 1 << index); } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vboold8& a) + { + cout << "<"; + for (size_t i=0; i<8; i++) { + if ((a.v >> i) & 1) cout << "1"; else cout << "0"; + } + return cout << ">"; + } +} diff --git a/thirdparty/embree/common/simd/vboolf16_avx512.h b/thirdparty/embree/common/simd/vboolf16_avx512.h new file mode 100644 index 000000000000..238cdc8eb920 --- /dev/null +++ b/thirdparty/embree/common/simd/vboolf16_avx512.h @@ -0,0 +1,150 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 16-wide AVX-512 bool type */ + template<> + struct vboolf<16> + { + typedef vboolf16 Bool; + typedef vint16 Int; + typedef vfloat16 Float; + + enum { size = 16 }; // number of SIMD elements + __mmask16 v; // data + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf() {} + __forceinline vboolf(const vboolf16& t) { v = t.v; } + __forceinline vboolf16& operator =(const vboolf16& f) { v = f.v; return *this; } + + __forceinline vboolf(const __mmask16& t) { v = t; } + __forceinline operator __mmask16() const { return v; } + + __forceinline vboolf(bool b) { v = b ? 0xFFFF : 0x0000; } + __forceinline vboolf(int t) { v = (__mmask16)t; } + __forceinline vboolf(unsigned int t) { v = (__mmask16)t; } + + /* return int8 mask */ + __forceinline __m128i mask8() const { +#if defined(__AVX512BW__) + return _mm_movm_epi8(v); +#else + const __m512i f = _mm512_set1_epi32(0); + const __m512i t = _mm512_set1_epi32(-1); + const __m512i m = _mm512_mask_or_epi32(f,v,t,t); + return _mm512_cvtepi32_epi8(m); +#endif + } + + /* return int32 mask */ + __forceinline __m512i mask32() const { +#if defined(__AVX512DQ__) + return _mm512_movm_epi32(v); +#else + const __m512i f = _mm512_set1_epi32(0); + const __m512i t = _mm512_set1_epi32(-1); + return _mm512_mask_or_epi32(f,v,t,t); +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf(FalseTy) : v(0x0000) {} + __forceinline vboolf(TrueTy) : v(0xffff) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator [](size_t index) const { + assert(index < 16); return (mm512_mask2int(v) >> index) & 1; + } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf16 operator !(const vboolf16& a) { return _mm512_knot(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf16 operator &(const vboolf16& a, const vboolf16& b) { return _mm512_kand(a,b); } + __forceinline vboolf16 operator |(const vboolf16& a, const vboolf16& b) { return _mm512_kor(a,b); } + __forceinline vboolf16 operator ^(const vboolf16& a, const vboolf16& b) { return _mm512_kxor(a,b); } + + __forceinline vboolf16 andn(const vboolf16& a, const vboolf16& b) { return _mm512_kandn(b,a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf16& operator &=(vboolf16& a, const vboolf16& b) { return a = a & b; } + __forceinline vboolf16& operator |=(vboolf16& a, const vboolf16& b) { return a = a | b; } + __forceinline vboolf16& operator ^=(vboolf16& a, const vboolf16& b) { return a = a ^ b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf16 operator !=(const vboolf16& a, const vboolf16& b) { return _mm512_kxor(a, b); } + __forceinline vboolf16 operator ==(const vboolf16& a, const vboolf16& b) { return _mm512_kxnor(a, b); } + + __forceinline vboolf16 select(const vboolf16& s, const vboolf16& a, const vboolf16& b) { + return _mm512_kor(_mm512_kand(s,a),_mm512_kandn(s,b)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reduction Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline int all (const vboolf16& a) { return _mm512_kortestc(a,a) != 0; } + __forceinline int any (const vboolf16& a) { return _mm512_kortestz(a,a) == 0; } + __forceinline int none(const vboolf16& a) { return _mm512_kortestz(a,a) != 0; } + + __forceinline int all (const vboolf16& valid, const vboolf16& b) { return all((!valid) | b); } + __forceinline int any (const vboolf16& valid, const vboolf16& b) { return any(valid & b); } + __forceinline int none(const vboolf16& valid, const vboolf16& b) { return none(valid & b); } + + __forceinline size_t movemask(const vboolf16& a) { return _mm512_kmov(a); } + __forceinline size_t popcnt (const vboolf16& a) { return popcnt(a.v); } + + //////////////////////////////////////////////////////////////////////////////// + /// Convertion Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline unsigned int toInt (const vboolf16& a) { return mm512_mask2int(a); } + __forceinline vboolf16 toMask(const int& a) { return mm512_int2mask(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Get/Set Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool get(const vboolf16& a, size_t index) { assert(index < 16); return (toInt(a) >> index) & 1; } + __forceinline void set(vboolf16& a, size_t index) { assert(index < 16); a |= 1 << index; } + __forceinline void clear(vboolf16& a, size_t index) { assert(index < 16); a = andn(a, 1 << index); } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vboolf16& a) + { + cout << "<"; + for (size_t i=0; i<16; i++) { + if ((a.v >> i) & 1) cout << "1"; else cout << "0"; + } + return cout << ">"; + } +} diff --git a/thirdparty/embree/common/simd/vboolf4_avx512.h b/thirdparty/embree/common/simd/vboolf4_avx512.h new file mode 100644 index 000000000000..2ae4c4470e2d --- /dev/null +++ b/thirdparty/embree/common/simd/vboolf4_avx512.h @@ -0,0 +1,143 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 4-wide AVX-512 bool type */ + template<> + struct vboolf<4> + { + typedef vboolf4 Bool; + typedef vint4 Int; + + enum { size = 4 }; // number of SIMD elements + __mmask8 v; // data + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf() {} + __forceinline vboolf(const vboolf4& t) { v = t.v; } + __forceinline vboolf4& operator =(const vboolf4& f) { v = f.v; return *this; } + + __forceinline vboolf(const __mmask8 &t) { v = t; } + __forceinline operator __mmask8() const { return v; } + + __forceinline vboolf(bool b) { v = b ? 0xf : 0x0; } + __forceinline vboolf(int t) { v = (__mmask8)t; } + __forceinline vboolf(unsigned int t) { v = (__mmask8)t; } + + __forceinline vboolf(bool a, bool b, bool c, bool d) + : v((__mmask8)((int(d) << 3) | (int(c) << 2) | (int(b) << 1) | int(a))) {} + + /* return int8 mask */ + __forceinline __m128i mask8() const { + return _mm_movm_epi8(v); + } + + /* return int32 mask */ + __forceinline __m128i mask32() const { + return _mm_movm_epi32(v); + } + + /* return int64 mask */ + __forceinline __m256i mask64() const { + return _mm256_movm_epi64(v); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf(FalseTy) : v(0x0) {} + __forceinline vboolf(TrueTy) : v(0xf) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator [](size_t index) const { + assert(index < 4); return (mm512_mask2int(v) >> index) & 1; + } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf4 operator !(const vboolf4& a) { return _mm512_kandn(a, 0xf); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf4 operator &(const vboolf4& a, const vboolf4& b) { return _mm512_kand(a, b); } + __forceinline vboolf4 operator |(const vboolf4& a, const vboolf4& b) { return _mm512_kor(a, b); } + __forceinline vboolf4 operator ^(const vboolf4& a, const vboolf4& b) { return _mm512_kxor(a, b); } + + __forceinline vboolf4 andn(const vboolf4& a, const vboolf4& b) { return _mm512_kandn(b, a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf4& operator &=(vboolf4& a, const vboolf4& b) { return a = a & b; } + __forceinline vboolf4& operator |=(vboolf4& a, const vboolf4& b) { return a = a | b; } + __forceinline vboolf4& operator ^=(vboolf4& a, const vboolf4& b) { return a = a ^ b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf4 operator !=(const vboolf4& a, const vboolf4& b) { return _mm512_kxor(a, b); } + __forceinline vboolf4 operator ==(const vboolf4& a, const vboolf4& b) { return _mm512_kand(_mm512_kxnor(a, b), 0xf); } + + __forceinline vboolf4 select(const vboolf4& s, const vboolf4& a, const vboolf4& b) { + return _mm512_kor(_mm512_kand(s, a), _mm512_kandn(s, b)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reduction Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline int all (const vboolf4& a) { return a.v == 0xf; } + __forceinline int any (const vboolf4& a) { return _mm512_kortestz(a, a) == 0; } + __forceinline int none(const vboolf4& a) { return _mm512_kortestz(a, a) != 0; } + + __forceinline int all (const vboolf4& valid, const vboolf4& b) { return all((!valid) | b); } + __forceinline int any (const vboolf4& valid, const vboolf4& b) { return any(valid & b); } + __forceinline int none(const vboolf4& valid, const vboolf4& b) { return none(valid & b); } + + __forceinline size_t movemask(const vboolf4& a) { return _mm512_kmov(a); } + __forceinline size_t popcnt (const vboolf4& a) { return popcnt(a.v); } + + //////////////////////////////////////////////////////////////////////////////// + /// Conversion Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline unsigned int toInt(const vboolf4& a) { return mm512_mask2int(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Get/Set Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool get(const vboolf4& a, size_t index) { assert(index < 4); return (toInt(a) >> index) & 1; } + __forceinline void set(vboolf4& a, size_t index) { assert(index < 4); a |= 1 << index; } + __forceinline void clear(vboolf4& a, size_t index) { assert(index < 4); a = andn(a, 1 << index); } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vboolf4& a) + { + cout << "<"; + for (size_t i=0; i<4; i++) { + if ((a.v >> i) & 1) cout << "1"; else cout << "0"; + } + return cout << ">"; + } +} diff --git a/thirdparty/embree/common/simd/vboolf4_sse2.h b/thirdparty/embree/common/simd/vboolf4_sse2.h new file mode 100644 index 000000000000..afec10fd499e --- /dev/null +++ b/thirdparty/embree/common/simd/vboolf4_sse2.h @@ -0,0 +1,173 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 4-wide SSE bool type */ + template<> + struct vboolf<4> + { + ALIGNED_STRUCT_(16); + + typedef vboolf4 Bool; + typedef vint4 Int; + typedef vfloat4 Float; + + enum { size = 4 }; // number of SIMD elements + union { __m128 v; int i[4]; }; // data + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf() {} + __forceinline vboolf(const vboolf4& other) { v = other.v; } + __forceinline vboolf4& operator =(const vboolf4& other) { v = other.v; return *this; } + + __forceinline vboolf(__m128 input) : v(input) {} + __forceinline operator const __m128&() const { return v; } + __forceinline operator const __m128i() const { return _mm_castps_si128(v); } + __forceinline operator const __m128d() const { return _mm_castps_pd(v); } + + __forceinline vboolf(bool a) + : v(mm_lookupmask_ps[(size_t(a) << 3) | (size_t(a) << 2) | (size_t(a) << 1) | size_t(a)]) {} + __forceinline vboolf(bool a, bool b) + : v(mm_lookupmask_ps[(size_t(b) << 3) | (size_t(a) << 2) | (size_t(b) << 1) | size_t(a)]) {} + __forceinline vboolf(bool a, bool b, bool c, bool d) + : v(mm_lookupmask_ps[(size_t(d) << 3) | (size_t(c) << 2) | (size_t(b) << 1) | size_t(a)]) {} + __forceinline vboolf(int mask) { assert(mask >= 0 && mask < 16); v = mm_lookupmask_ps[mask]; } + __forceinline vboolf(unsigned int mask) { assert(mask < 16); v = mm_lookupmask_ps[mask]; } + + /* return int32 mask */ + __forceinline __m128i mask32() const { + return _mm_castps_si128(v); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf(FalseTy) : v(_mm_setzero_ps()) {} + __forceinline vboolf(TrueTy) : v(_mm_castsi128_ps(_mm_cmpeq_epi32(_mm_setzero_si128(), _mm_setzero_si128()))) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator [](size_t index) const { assert(index < 4); return (_mm_movemask_ps(v) >> index) & 1; } + __forceinline int& operator [](size_t index) { assert(index < 4); return i[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf4 operator !(const vboolf4& a) { return _mm_xor_ps(a, vboolf4(embree::True)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf4 operator &(const vboolf4& a, const vboolf4& b) { return _mm_and_ps(a, b); } + __forceinline vboolf4 operator |(const vboolf4& a, const vboolf4& b) { return _mm_or_ps (a, b); } + __forceinline vboolf4 operator ^(const vboolf4& a, const vboolf4& b) { return _mm_xor_ps(a, b); } + + __forceinline vboolf4 andn(const vboolf4& a, const vboolf4& b) { return _mm_andnot_ps(b, a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf4& operator &=(vboolf4& a, const vboolf4& b) { return a = a & b; } + __forceinline vboolf4& operator |=(vboolf4& a, const vboolf4& b) { return a = a | b; } + __forceinline vboolf4& operator ^=(vboolf4& a, const vboolf4& b) { return a = a ^ b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf4 operator !=(const vboolf4& a, const vboolf4& b) { return _mm_xor_ps(a, b); } + __forceinline vboolf4 operator ==(const vboolf4& a, const vboolf4& b) { return _mm_castsi128_ps(_mm_cmpeq_epi32(a, b)); } + + __forceinline vboolf4 select(const vboolf4& m, const vboolf4& t, const vboolf4& f) { +#if defined(__SSE4_1__) + return _mm_blendv_ps(f, t, m); +#else + return _mm_or_ps(_mm_and_ps(m, t), _mm_andnot_ps(m, f)); +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + /// Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf4 unpacklo(const vboolf4& a, const vboolf4& b) { return _mm_unpacklo_ps(a, b); } + __forceinline vboolf4 unpackhi(const vboolf4& a, const vboolf4& b) { return _mm_unpackhi_ps(a, b); } + + template + __forceinline vboolf4 shuffle(const vboolf4& v) { + return _mm_castsi128_ps(_mm_shuffle_epi32(v, _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template + __forceinline vboolf4 shuffle(const vboolf4& a, const vboolf4& b) { + return _mm_shuffle_ps(a, b, _MM_SHUFFLE(i3, i2, i1, i0)); + } + + template + __forceinline vboolf4 shuffle(const vboolf4& v) { + return shuffle(v); + } + +#if defined(__SSE3__) + template<> __forceinline vboolf4 shuffle<0, 0, 2, 2>(const vboolf4& v) { return _mm_moveldup_ps(v); } + template<> __forceinline vboolf4 shuffle<1, 1, 3, 3>(const vboolf4& v) { return _mm_movehdup_ps(v); } + template<> __forceinline vboolf4 shuffle<0, 1, 0, 1>(const vboolf4& v) { return _mm_castpd_ps(_mm_movedup_pd(v)); } +#endif + +#if defined(__SSE4_1__) + template __forceinline vboolf4 insert(const vboolf4& a, const vboolf4& b) { return _mm_insert_ps(a, b, (dst << 4) | (src << 6) | clr); } + template __forceinline vboolf4 insert(const vboolf4& a, const vboolf4& b) { return insert(a, b); } + template __forceinline vboolf4 insert(const vboolf4& a, const bool b) { return insert(a, vboolf4(b)); } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Reduction Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool reduce_and(const vboolf4& a) { return _mm_movemask_ps(a) == 0xf; } + __forceinline bool reduce_or (const vboolf4& a) { return _mm_movemask_ps(a) != 0x0; } + + __forceinline bool all (const vboolf4& b) { return _mm_movemask_ps(b) == 0xf; } + __forceinline bool any (const vboolf4& b) { return _mm_movemask_ps(b) != 0x0; } + __forceinline bool none(const vboolf4& b) { return _mm_movemask_ps(b) == 0x0; } + + __forceinline bool all (const vboolf4& valid, const vboolf4& b) { return all((!valid) | b); } + __forceinline bool any (const vboolf4& valid, const vboolf4& b) { return any(valid & b); } + __forceinline bool none(const vboolf4& valid, const vboolf4& b) { return none(valid & b); } + + __forceinline size_t movemask(const vboolf4& a) { return _mm_movemask_ps(a); } +#if defined(__SSE4_2__) + __forceinline size_t popcnt(const vboolf4& a) { return popcnt((size_t)_mm_movemask_ps(a)); } +#else + __forceinline size_t popcnt(const vboolf4& a) { return bool(a[0])+bool(a[1])+bool(a[2])+bool(a[3]); } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Get/Set Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool get(const vboolf4& a, size_t index) { return a[index]; } + __forceinline void set(vboolf4& a, size_t index) { a[index] = -1; } + __forceinline void clear(vboolf4& a, size_t index) { a[index] = 0; } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vboolf4& a) { + return cout << "<" << a[0] << ", " << a[1] << ", " << a[2] << ", " << a[3] << ">"; + } +} diff --git a/thirdparty/embree/common/simd/vboolf8_avx.h b/thirdparty/embree/common/simd/vboolf8_avx.h new file mode 100644 index 000000000000..5d7c0d68c1b0 --- /dev/null +++ b/thirdparty/embree/common/simd/vboolf8_avx.h @@ -0,0 +1,186 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 8-wide AVX bool type */ + template<> + struct vboolf<8> + { + ALIGNED_STRUCT_(32); + + typedef vboolf8 Bool; + typedef vint8 Int; + typedef vfloat8 Float; + + enum { size = 8 }; // number of SIMD elements + union { // data + __m256 v; + struct { __m128 vl,vh; }; + int i[8]; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf() {} + __forceinline vboolf(const vboolf8& a) { v = a.v; } + __forceinline vboolf8& operator =(const vboolf8& a) { v = a.v; return *this; } + + __forceinline vboolf(__m256 a) : v(a) {} + __forceinline operator const __m256&() const { return v; } + __forceinline operator const __m256i() const { return _mm256_castps_si256(v); } + __forceinline operator const __m256d() const { return _mm256_castps_pd(v); } + + __forceinline vboolf(int a) + { + assert(a >= 0 && a <= 255); +#if defined (__AVX2__) + const __m256i mask = _mm256_set_epi32(0x80, 0x40, 0x20, 0x10, 0x8, 0x4, 0x2, 0x1); + const __m256i b = _mm256_set1_epi32(a); + const __m256i c = _mm256_and_si256(b,mask); + v = _mm256_castsi256_ps(_mm256_cmpeq_epi32(c,mask)); +#else + vl = mm_lookupmask_ps[a & 0xF]; + vh = mm_lookupmask_ps[a >> 4]; +#endif + } + + __forceinline vboolf(const vboolf4& a) : v(_mm256_insertf128_ps(_mm256_castps128_ps256(a),a,1)) {} + __forceinline vboolf(const vboolf4& a, const vboolf4& b) : v(_mm256_insertf128_ps(_mm256_castps128_ps256(a),b,1)) {} + __forceinline vboolf(__m128 a, __m128 b) : vl(a), vh(b) {} + + __forceinline vboolf(bool a) : v(vboolf8(vboolf4(a), vboolf4(a))) {} + __forceinline vboolf(bool a, bool b) : v(vboolf8(vboolf4(a), vboolf4(b))) {} + __forceinline vboolf(bool a, bool b, bool c, bool d) : v(vboolf8(vboolf4(a,b), vboolf4(c,d))) {} + __forceinline vboolf(bool a, bool b, bool c, bool d, bool e, bool f, bool g, bool h) : v(vboolf8(vboolf4(a,b,c,d), vboolf4(e,f,g,h))) {} + + /* return int32 mask */ + __forceinline __m256i mask32() const { + return _mm256_castps_si256(v); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf(FalseTy) : v(_mm256_setzero_ps()) {} + __forceinline vboolf(TrueTy) : v(_mm256_cmp_ps(_mm256_setzero_ps(), _mm256_setzero_ps(), _CMP_EQ_OQ)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator [](size_t index) const { assert(index < 8); return (_mm256_movemask_ps(v) >> index) & 1; } + __forceinline int& operator [](size_t index) { assert(index < 8); return i[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf8 operator !(const vboolf8& a) { return _mm256_xor_ps(a, vboolf8(embree::True)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf8 operator &(const vboolf8& a, const vboolf8& b) { return _mm256_and_ps(a, b); } + __forceinline vboolf8 operator |(const vboolf8& a, const vboolf8& b) { return _mm256_or_ps (a, b); } + __forceinline vboolf8 operator ^(const vboolf8& a, const vboolf8& b) { return _mm256_xor_ps(a, b); } + + __forceinline vboolf8 andn(const vboolf8& a, const vboolf8& b) { return _mm256_andnot_ps(b, a); } + + __forceinline vboolf8& operator &=(vboolf8& a, const vboolf8& b) { return a = a & b; } + __forceinline vboolf8& operator |=(vboolf8& a, const vboolf8& b) { return a = a | b; } + __forceinline vboolf8& operator ^=(vboolf8& a, const vboolf8& b) { return a = a ^ b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf8 operator !=(const vboolf8& a, const vboolf8& b) { return _mm256_xor_ps(a, b); } + __forceinline vboolf8 operator ==(const vboolf8& a, const vboolf8& b) { return _mm256_xor_ps(_mm256_xor_ps(a,b),vboolf8(embree::True)); } + + __forceinline vboolf8 select(const vboolf8& mask, const vboolf8& t, const vboolf8& f) { + return _mm256_blendv_ps(f, t, mask); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf8 unpacklo(const vboolf8& a, const vboolf8& b) { return _mm256_unpacklo_ps(a, b); } + __forceinline vboolf8 unpackhi(const vboolf8& a, const vboolf8& b) { return _mm256_unpackhi_ps(a, b); } + + template + __forceinline vboolf8 shuffle(const vboolf8& v) { + return _mm256_permute_ps(v, _MM_SHUFFLE(i, i, i, i)); + } + + template + __forceinline vboolf8 shuffle4(const vboolf8& v) { + return _mm256_permute2f128_ps(v, v, (i1 << 4) | (i0 << 0)); + } + + template + __forceinline vboolf8 shuffle4(const vboolf8& a, const vboolf8& b) { + return _mm256_permute2f128_ps(a, b, (i1 << 4) | (i0 << 0)); + } + + template + __forceinline vboolf8 shuffle(const vboolf8& v) { + return _mm256_permute_ps(v, _MM_SHUFFLE(i3, i2, i1, i0)); + } + + template + __forceinline vboolf8 shuffle(const vboolf8& a, const vboolf8& b) { + return _mm256_shuffle_ps(a, b, _MM_SHUFFLE(i3, i2, i1, i0)); + } + + template<> __forceinline vboolf8 shuffle<0, 0, 2, 2>(const vboolf8& v) { return _mm256_moveldup_ps(v); } + template<> __forceinline vboolf8 shuffle<1, 1, 3, 3>(const vboolf8& v) { return _mm256_movehdup_ps(v); } + template<> __forceinline vboolf8 shuffle<0, 1, 0, 1>(const vboolf8& v) { return _mm256_castpd_ps(_mm256_movedup_pd(_mm256_castps_pd(v))); } + + template __forceinline vboolf8 insert4(const vboolf8& a, const vboolf4& b) { return _mm256_insertf128_ps(a, b, i); } + template __forceinline vboolf4 extract4 (const vboolf8& a) { return _mm256_extractf128_ps(a, i); } + template<> __forceinline vboolf4 extract4<0>(const vboolf8& a) { return _mm256_castps256_ps128(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Reduction Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool reduce_and(const vboolf8& a) { return _mm256_movemask_ps(a) == (unsigned int)0xff; } + __forceinline bool reduce_or (const vboolf8& a) { return !_mm256_testz_ps(a,a); } + + __forceinline bool all (const vboolf8& a) { return _mm256_movemask_ps(a) == (unsigned int)0xff; } + __forceinline bool any (const vboolf8& a) { return !_mm256_testz_ps(a,a); } + __forceinline bool none(const vboolf8& a) { return _mm256_testz_ps(a,a) != 0; } + + __forceinline bool all (const vboolf8& valid, const vboolf8& b) { return all((!valid) | b); } + __forceinline bool any (const vboolf8& valid, const vboolf8& b) { return any(valid & b); } + __forceinline bool none(const vboolf8& valid, const vboolf8& b) { return none(valid & b); } + + __forceinline unsigned int movemask(const vboolf8& a) { return _mm256_movemask_ps(a); } + __forceinline size_t popcnt (const vboolf8& a) { return popcnt((size_t)_mm256_movemask_ps(a)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Get/Set Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool get(const vboolf8& a, size_t index) { return a[index]; } + __forceinline void set(vboolf8& a, size_t index) { a[index] = -1; } + __forceinline void clear(vboolf8& a, size_t index) { a[index] = 0; } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vboolf8& a) { + return cout << "<" << a[0] << ", " << a[1] << ", " << a[2] << ", " << a[3] << ", " + << a[4] << ", " << a[5] << ", " << a[6] << ", " << a[7] << ">"; + } +} diff --git a/thirdparty/embree/common/simd/vboolf8_avx512.h b/thirdparty/embree/common/simd/vboolf8_avx512.h new file mode 100644 index 000000000000..2a52b554c79c --- /dev/null +++ b/thirdparty/embree/common/simd/vboolf8_avx512.h @@ -0,0 +1,143 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 8-wide AVX-512 bool type */ + template<> + struct vboolf<8> + { + typedef vboolf8 Bool; + typedef vint8 Int; + + enum { size = 8 }; // number of SIMD elements + __mmask8 v; // data + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf() {} + __forceinline vboolf(const vboolf8& t) { v = t.v; } + __forceinline vboolf8& operator =(const vboolf8& f) { v = f.v; return *this; } + + __forceinline vboolf(const __mmask8 &t) { v = t; } + __forceinline operator __mmask8() const { return v; } + + __forceinline vboolf(bool b) { v = b ? 0xff : 0x00; } + __forceinline vboolf(int t) { v = (__mmask8)t; } + __forceinline vboolf(unsigned int t) { v = (__mmask8)t; } + + __forceinline vboolf(bool a, bool b, bool c, bool d, bool e, bool f, bool g, bool h) + : v((__mmask8)((int(h) << 7) | (int(g) << 6) | (int(f) << 5) | (int(e) << 4) | (int(d) << 3) | (int(c) << 2) | (int(b) << 1) | int(a))) {} + + /* return int8 mask */ + __forceinline __m128i mask8() const { + return _mm_movm_epi8(v); + } + + /* return int32 mask */ + __forceinline __m256i mask32() const { + return _mm256_movm_epi32(v); + } + + /* return int64 mask */ + __forceinline __m512i mask64() const { + return _mm512_movm_epi64(v); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf(FalseTy) : v(0x00) {} + __forceinline vboolf(TrueTy) : v(0xff) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool operator [](size_t index) const { + assert(index < 8); return (mm512_mask2int(v) >> index) & 1; + } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf8 operator !(const vboolf8& a) { return _mm512_knot(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf8 operator &(const vboolf8& a, const vboolf8& b) { return _mm512_kand(a, b); } + __forceinline vboolf8 operator |(const vboolf8& a, const vboolf8& b) { return _mm512_kor(a, b); } + __forceinline vboolf8 operator ^(const vboolf8& a, const vboolf8& b) { return _mm512_kxor(a, b); } + + __forceinline vboolf8 andn(const vboolf8& a, const vboolf8& b) { return _mm512_kandn(b, a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf8& operator &=(vboolf8& a, const vboolf8& b) { return a = a & b; } + __forceinline vboolf8& operator |=(vboolf8& a, const vboolf8& b) { return a = a | b; } + __forceinline vboolf8& operator ^=(vboolf8& a, const vboolf8& b) { return a = a ^ b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf8 operator !=(const vboolf8& a, const vboolf8& b) { return _mm512_kxor(a, b); } + __forceinline vboolf8 operator ==(const vboolf8& a, const vboolf8& b) { return _mm512_kxnor(a, b); } + + __forceinline vboolf8 select(const vboolf8& s, const vboolf8& a, const vboolf8& b) { + return _mm512_kor(_mm512_kand(s, a), _mm512_kandn(s, b)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reduction Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline int all (const vboolf8& a) { return a.v == 0xff; } + __forceinline int any (const vboolf8& a) { return _mm512_kortestz(a, a) == 0; } + __forceinline int none(const vboolf8& a) { return _mm512_kortestz(a, a) != 0; } + + __forceinline int all (const vboolf8& valid, const vboolf8& b) { return all((!valid) | b); } + __forceinline int any (const vboolf8& valid, const vboolf8& b) { return any(valid & b); } + __forceinline int none(const vboolf8& valid, const vboolf8& b) { return none(valid & b); } + + __forceinline size_t movemask(const vboolf8& a) { return _mm512_kmov(a); } + __forceinline size_t popcnt (const vboolf8& a) { return popcnt(a.v); } + + //////////////////////////////////////////////////////////////////////////////// + /// Conversion Operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline unsigned int toInt(const vboolf8& a) { return mm512_mask2int(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Get/Set Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline bool get(const vboolf8& a, size_t index) { assert(index < 8); return (toInt(a) >> index) & 1; } + __forceinline void set(vboolf8& a, size_t index) { assert(index < 8); a |= 1 << index; } + __forceinline void clear(vboolf8& a, size_t index) { assert(index < 8); a = andn(a, 1 << index); } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vboolf8& a) + { + cout << "<"; + for (size_t i=0; i<8; i++) { + if ((a.v >> i) & 1) cout << "1"; else cout << "0"; + } + return cout << ">"; + } +} diff --git a/thirdparty/embree/common/simd/vdouble4_avx.h b/thirdparty/embree/common/simd/vdouble4_avx.h new file mode 100644 index 000000000000..eedb04aafbe6 --- /dev/null +++ b/thirdparty/embree/common/simd/vdouble4_avx.h @@ -0,0 +1,317 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 4-wide AVX 64-bit double type */ + template<> + struct vdouble<4> + { + ALIGNED_STRUCT_(32); + + typedef vboold4 Bool; + + enum { size = 4 }; // number of SIMD elements + union { // data + __m256d v; + double i[4]; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vdouble() {} + __forceinline vdouble(const vdouble4& t) { v = t.v; } + __forceinline vdouble4& operator =(const vdouble4& f) { v = f.v; return *this; } + + __forceinline vdouble(const __m256d& t) { v = t; } + __forceinline operator __m256d() const { return v; } + + __forceinline vdouble(double i) { + v = _mm256_set1_pd(i); + } + + __forceinline vdouble(double a, double b, double c, double d) { + v = _mm256_set_pd(d,c,b,a); + } + + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vdouble(ZeroTy) : v(_mm256_setzero_pd()) {} + __forceinline vdouble(OneTy) : v(_mm256_set1_pd(1)) {} + __forceinline vdouble(StepTy) : v(_mm256_set_pd(3.0,2.0,1.0,0.0)) {} + __forceinline vdouble(ReverseStepTy) : v(_mm256_setr_pd(3.0,2.0,1.0,0.0)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline void store_nt(double *__restrict__ ptr, const vdouble4& a) { + _mm256_stream_pd(ptr, a); + } + + static __forceinline vdouble4 loadu(const double* addr) { + return _mm256_loadu_pd(addr); + } + + static __forceinline vdouble4 load(const vdouble4* addr) { + return _mm256_load_pd((double*)addr); + } + + static __forceinline vdouble4 load(const double* addr) { + return _mm256_load_pd(addr); + } + + static __forceinline void store(double* ptr, const vdouble4& v) { + _mm256_store_pd(ptr, v); + } + + static __forceinline void storeu(double* ptr, const vdouble4& v) { + _mm256_storeu_pd(ptr, v); + } + + static __forceinline vdouble4 broadcast(const void* a) { return _mm256_set1_pd(*(double*)a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline double& operator [](size_t index) { assert(index < 4); return i[index]; } + __forceinline const double& operator [](size_t index) const { assert(index < 4); return i[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX2__) + __forceinline vdouble4 asDouble(const vllong4& a) { return _mm256_castsi256_pd(a); } + __forceinline vllong4 asLLong (const vdouble4& a) { return _mm256_castpd_si256(a); } +#endif + + __forceinline vdouble4 operator +(const vdouble4& a) { return a; } + __forceinline vdouble4 operator -(const vdouble4& a) { return _mm256_sub_pd(_mm256_setzero_pd(), a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vdouble4 operator +(const vdouble4& a, const vdouble4& b) { return _mm256_add_pd(a, b); } + __forceinline vdouble4 operator +(const vdouble4& a, double b) { return a + vdouble4(b); } + __forceinline vdouble4 operator +(double a, const vdouble4& b) { return vdouble4(a) + b; } + + __forceinline vdouble4 operator -(const vdouble4& a, const vdouble4& b) { return _mm256_sub_pd(a, b); } + __forceinline vdouble4 operator -(const vdouble4& a, double b) { return a - vdouble4(b); } + __forceinline vdouble4 operator -(double a, const vdouble4& b) { return vdouble4(a) - b; } + + __forceinline vdouble4 operator *(const vdouble4& a, const vdouble4& b) { return _mm256_mul_pd(a, b); } + __forceinline vdouble4 operator *(const vdouble4& a, double b) { return a * vdouble4(b); } + __forceinline vdouble4 operator *(double a, const vdouble4& b) { return vdouble4(a) * b; } + + __forceinline vdouble4 operator &(const vdouble4& a, const vdouble4& b) { return _mm256_and_pd(a, b); } + __forceinline vdouble4 operator &(const vdouble4& a, double b) { return a & vdouble4(b); } + __forceinline vdouble4 operator &(double a, const vdouble4& b) { return vdouble4(a) & b; } + + __forceinline vdouble4 operator |(const vdouble4& a, const vdouble4& b) { return _mm256_or_pd(a, b); } + __forceinline vdouble4 operator |(const vdouble4& a, double b) { return a | vdouble4(b); } + __forceinline vdouble4 operator |(double a, const vdouble4& b) { return vdouble4(a) | b; } + + __forceinline vdouble4 operator ^(const vdouble4& a, const vdouble4& b) { return _mm256_xor_pd(a, b); } + __forceinline vdouble4 operator ^(const vdouble4& a, double b) { return a ^ vdouble4(b); } + __forceinline vdouble4 operator ^(double a, const vdouble4& b) { return vdouble4(a) ^ b; } + + __forceinline vdouble4 min(const vdouble4& a, const vdouble4& b) { return _mm256_min_pd(a, b); } + __forceinline vdouble4 min(const vdouble4& a, double b) { return min(a,vdouble4(b)); } + __forceinline vdouble4 min(double a, const vdouble4& b) { return min(vdouble4(a),b); } + + __forceinline vdouble4 max(const vdouble4& a, const vdouble4& b) { return _mm256_max_pd(a, b); } + __forceinline vdouble4 max(const vdouble4& a, double b) { return max(a,vdouble4(b)); } + __forceinline vdouble4 max(double a, const vdouble4& b) { return max(vdouble4(a),b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Ternary Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__FMA__) + __forceinline vdouble4 madd (const vdouble4& a, const vdouble4& b, const vdouble4& c) { return _mm256_fmadd_pd(a,b,c); } + __forceinline vdouble4 msub (const vdouble4& a, const vdouble4& b, const vdouble4& c) { return _mm256_fmsub_pd(a,b,c); } + __forceinline vdouble4 nmadd(const vdouble4& a, const vdouble4& b, const vdouble4& c) { return _mm256_fnmadd_pd(a,b,c); } + __forceinline vdouble4 nmsub(const vdouble4& a, const vdouble4& b, const vdouble4& c) { return _mm256_fnmsub_pd(a,b,c); } +#else + __forceinline vdouble4 madd (const vdouble4& a, const vdouble4& b, const vdouble4& c) { return a*b+c; } + __forceinline vdouble4 msub (const vdouble4& a, const vdouble4& b, const vdouble4& c) { return a*b-c; } + __forceinline vdouble4 nmadd(const vdouble4& a, const vdouble4& b, const vdouble4& c) { return -a*b+c;} + __forceinline vdouble4 nmsub(const vdouble4& a, const vdouble4& b, const vdouble4& c) { return -a*b-c; } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vdouble4& operator +=(vdouble4& a, const vdouble4& b) { return a = a + b; } + __forceinline vdouble4& operator +=(vdouble4& a, double b) { return a = a + b; } + + __forceinline vdouble4& operator -=(vdouble4& a, const vdouble4& b) { return a = a - b; } + __forceinline vdouble4& operator -=(vdouble4& a, double b) { return a = a - b; } + + __forceinline vdouble4& operator *=(vdouble4& a, const vdouble4& b) { return a = a * b; } + __forceinline vdouble4& operator *=(vdouble4& a, double b) { return a = a * b; } + + __forceinline vdouble4& operator &=(vdouble4& a, const vdouble4& b) { return a = a & b; } + __forceinline vdouble4& operator &=(vdouble4& a, double b) { return a = a & b; } + + __forceinline vdouble4& operator |=(vdouble4& a, const vdouble4& b) { return a = a | b; } + __forceinline vdouble4& operator |=(vdouble4& a, double b) { return a = a | b; } + + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX512VL__) + __forceinline vboold4 operator ==(const vdouble4& a, const vdouble4& b) { return _mm256_cmp_pd_mask(a, b, _MM_CMPINT_EQ); } + __forceinline vboold4 operator !=(const vdouble4& a, const vdouble4& b) { return _mm256_cmp_pd_mask(a, b, _MM_CMPINT_NE); } + __forceinline vboold4 operator < (const vdouble4& a, const vdouble4& b) { return _mm256_cmp_pd_mask(a, b, _MM_CMPINT_LT); } + __forceinline vboold4 operator >=(const vdouble4& a, const vdouble4& b) { return _mm256_cmp_pd_mask(a, b, _MM_CMPINT_GE); } + __forceinline vboold4 operator > (const vdouble4& a, const vdouble4& b) { return _mm256_cmp_pd_mask(a, b, _MM_CMPINT_GT); } + __forceinline vboold4 operator <=(const vdouble4& a, const vdouble4& b) { return _mm256_cmp_pd_mask(a, b, _MM_CMPINT_LE); } +#else + __forceinline vboold4 operator ==(const vdouble4& a, const vdouble4& b) { return _mm256_cmp_pd(a, b, _CMP_EQ_OQ); } + __forceinline vboold4 operator !=(const vdouble4& a, const vdouble4& b) { return _mm256_cmp_pd(a, b, _CMP_NEQ_UQ); } + __forceinline vboold4 operator < (const vdouble4& a, const vdouble4& b) { return _mm256_cmp_pd(a, b, _CMP_LT_OS); } + __forceinline vboold4 operator >=(const vdouble4& a, const vdouble4& b) { return _mm256_cmp_pd(a, b, _CMP_NLT_US); } + __forceinline vboold4 operator > (const vdouble4& a, const vdouble4& b) { return _mm256_cmp_pd(a, b, _CMP_NLE_US); } + __forceinline vboold4 operator <=(const vdouble4& a, const vdouble4& b) { return _mm256_cmp_pd(a, b, _CMP_LE_OS); } +#endif + + __forceinline vboold4 operator ==(const vdouble4& a, double b) { return a == vdouble4(b); } + __forceinline vboold4 operator ==(double a, const vdouble4& b) { return vdouble4(a) == b; } + + __forceinline vboold4 operator !=(const vdouble4& a, double b) { return a != vdouble4(b); } + __forceinline vboold4 operator !=(double a, const vdouble4& b) { return vdouble4(a) != b; } + + __forceinline vboold4 operator < (const vdouble4& a, double b) { return a < vdouble4(b); } + __forceinline vboold4 operator < (double a, const vdouble4& b) { return vdouble4(a) < b; } + + __forceinline vboold4 operator >=(const vdouble4& a, double b) { return a >= vdouble4(b); } + __forceinline vboold4 operator >=(double a, const vdouble4& b) { return vdouble4(a) >= b; } + + __forceinline vboold4 operator > (const vdouble4& a, double b) { return a > vdouble4(b); } + __forceinline vboold4 operator > (double a, const vdouble4& b) { return vdouble4(a) > b; } + + __forceinline vboold4 operator <=(const vdouble4& a, double b) { return a <= vdouble4(b); } + __forceinline vboold4 operator <=(double a, const vdouble4& b) { return vdouble4(a) <= b; } + + __forceinline vboold4 eq(const vdouble4& a, const vdouble4& b) { return a == b; } + __forceinline vboold4 ne(const vdouble4& a, const vdouble4& b) { return a != b; } + __forceinline vboold4 lt(const vdouble4& a, const vdouble4& b) { return a < b; } + __forceinline vboold4 ge(const vdouble4& a, const vdouble4& b) { return a >= b; } + __forceinline vboold4 gt(const vdouble4& a, const vdouble4& b) { return a > b; } + __forceinline vboold4 le(const vdouble4& a, const vdouble4& b) { return a <= b; } + +#if defined(__AVX512VL__) + __forceinline vboold4 eq(const vboold4& mask, const vdouble4& a, const vdouble4& b) { return _mm256_mask_cmp_pd_mask(mask, a, b, _MM_CMPINT_EQ); } + __forceinline vboold4 ne(const vboold4& mask, const vdouble4& a, const vdouble4& b) { return _mm256_mask_cmp_pd_mask(mask, a, b, _MM_CMPINT_NE); } + __forceinline vboold4 lt(const vboold4& mask, const vdouble4& a, const vdouble4& b) { return _mm256_mask_cmp_pd_mask(mask, a, b, _MM_CMPINT_LT); } + __forceinline vboold4 ge(const vboold4& mask, const vdouble4& a, const vdouble4& b) { return _mm256_mask_cmp_pd_mask(mask, a, b, _MM_CMPINT_GE); } + __forceinline vboold4 gt(const vboold4& mask, const vdouble4& a, const vdouble4& b) { return _mm256_mask_cmp_pd_mask(mask, a, b, _MM_CMPINT_GT); } + __forceinline vboold4 le(const vboold4& mask, const vdouble4& a, const vdouble4& b) { return _mm256_mask_cmp_pd_mask(mask, a, b, _MM_CMPINT_LE); } +#else + __forceinline vboold4 eq(const vboold4& mask, const vdouble4& a, const vdouble4& b) { return mask & (a == b); } + __forceinline vboold4 ne(const vboold4& mask, const vdouble4& a, const vdouble4& b) { return mask & (a != b); } + __forceinline vboold4 lt(const vboold4& mask, const vdouble4& a, const vdouble4& b) { return mask & (a < b); } + __forceinline vboold4 ge(const vboold4& mask, const vdouble4& a, const vdouble4& b) { return mask & (a >= b); } + __forceinline vboold4 gt(const vboold4& mask, const vdouble4& a, const vdouble4& b) { return mask & (a > b); } + __forceinline vboold4 le(const vboold4& mask, const vdouble4& a, const vdouble4& b) { return mask & (a <= b); } +#endif + + __forceinline vdouble4 select(const vboold4& m, const vdouble4& t, const vdouble4& f) { +#if defined(__AVX512VL__) + return _mm256_mask_blend_pd(m, f, t); +#else + return _mm256_blendv_pd(f, t, m); +#endif + } + + __forceinline void xchg(const vboold4& m, vdouble4& a, vdouble4& b) { + const vdouble4 c = a; a = select(m,b,a); b = select(m,c,b); + } + + __forceinline vboold4 test(const vdouble4& a, const vdouble4& b) { +#if defined(__AVX512VL__) + return _mm256_test_epi64_mask(_mm256_castpd_si256(a),_mm256_castpd_si256(b)); +#else + return _mm256_testz_si256(_mm256_castpd_si256(a),_mm256_castpd_si256(b)); +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + // Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vdouble4 shuffle(const vdouble4& v) { + return _mm256_permute_pd(v, (i1 << 3) | (i0 << 2) | (i1 << 1) | i0); + } + + template + __forceinline vdouble4 shuffle(const vdouble4& v) { + return shuffle(v); + } + + template + __forceinline vdouble4 shuffle2(const vdouble4& v) { + return _mm256_permute2f128_pd(v, v, (i1 << 4) | i0); + } + + __forceinline double toScalar(const vdouble4& v) { + return _mm_cvtsd_f64(_mm256_castpd256_pd128(v)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vdouble4 vreduce_min2(const vdouble4& x) { return min(x, shuffle<1,0>(x)); } + __forceinline vdouble4 vreduce_min (const vdouble4& y) { const vdouble4 x = vreduce_min2(y); return min(x, shuffle2<1,0>(x)); } + + __forceinline vdouble4 vreduce_max2(const vdouble4& x) { return max(x,shuffle<1,0>(x)); } + __forceinline vdouble4 vreduce_max (const vdouble4& y) { const vdouble4 x = vreduce_max2(y); return max(x, shuffle2<1,0>(x)); } + + __forceinline vdouble4 vreduce_and2(const vdouble4& x) { return x & shuffle<1,0>(x); } + __forceinline vdouble4 vreduce_and (const vdouble4& y) { const vdouble4 x = vreduce_and2(y); return x & shuffle2<1,0>(x); } + + __forceinline vdouble4 vreduce_or2(const vdouble4& x) { return x | shuffle<1,0>(x); } + __forceinline vdouble4 vreduce_or (const vdouble4& y) { const vdouble4 x = vreduce_or2(y); return x | shuffle2<1,0>(x); } + + __forceinline vdouble4 vreduce_add2(const vdouble4& x) { return x + shuffle<1,0>(x); } + __forceinline vdouble4 vreduce_add (const vdouble4& y) { const vdouble4 x = vreduce_add2(y); return x + shuffle2<1,0>(x); } + + __forceinline double reduce_add(const vdouble4& a) { return toScalar(vreduce_add(a)); } + __forceinline double reduce_min(const vdouble4& a) { return toScalar(vreduce_min(a)); } + __forceinline double reduce_max(const vdouble4& a) { return toScalar(vreduce_max(a)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Memory load and store operations + //////////////////////////////////////////////////////////////////////////////// + + + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vdouble4& v) + { + cout << "<" << v[0]; + for (size_t i=1; i<4; i++) cout << ", " << v[i]; + cout << ">"; + return cout; + } +} diff --git a/thirdparty/embree/common/simd/vdouble8_avx512.h b/thirdparty/embree/common/simd/vdouble8_avx512.h new file mode 100644 index 000000000000..4eec7d2f6a0f --- /dev/null +++ b/thirdparty/embree/common/simd/vdouble8_avx512.h @@ -0,0 +1,356 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 8-wide AVX-512 64-bit double type */ + template<> + struct vdouble<8> + { + ALIGNED_STRUCT_(64); + + typedef vboold8 Bool; + + enum { size = 8 }; // number of SIMD elements + union { // data + __m512d v; + double i[8]; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vdouble() {} + __forceinline vdouble(const vdouble8& t) { v = t.v; } + __forceinline vdouble8& operator =(const vdouble8& f) { v = f.v; return *this; } + + __forceinline vdouble(const __m512d& t) { v = t; } + __forceinline operator __m512d() const { return v; } + __forceinline operator __m256d() const { return _mm512_castpd512_pd256(v); } + + __forceinline vdouble(double i) { + v = _mm512_set1_pd(i); + } + + __forceinline vdouble(double a, double b, double c, double d) { + v = _mm512_set4_pd(d,c,b,a); + } + + __forceinline vdouble(double a0, double a1, double a2, double a3, + double a4, double a5, double a6, double a7) + { + v = _mm512_set_pd(a7,a6,a5,a4,a3,a2,a1,a0); + } + + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vdouble(ZeroTy) : v(_mm512_setzero_pd()) {} + __forceinline vdouble(OneTy) : v(_mm512_set1_pd(1)) {} + __forceinline vdouble(StepTy) : v(_mm512_set_pd(7.0,6.0,5.0,4.0,3.0,2.0,1.0,0.0)) {} + __forceinline vdouble(ReverseStepTy) : v(_mm512_setr_pd(7.0,6.0,5.0,4.0,3.0,2.0,1.0,0.0)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline void store_nt(void *__restrict__ ptr, const vdouble8& a) { + _mm512_stream_pd((double*)ptr, a); + } + + static __forceinline vdouble8 loadu(const void* addr) { + return _mm512_loadu_pd((double*)addr); + } + + static __forceinline vdouble8 load(const vdouble8* addr) { + return _mm512_load_pd((double*)addr); + } + + static __forceinline vdouble8 load(const double* addr) { + return _mm512_load_pd(addr); + } + + static __forceinline void store(void* ptr, const vdouble8& v) { + _mm512_store_pd(ptr, v); + } + + static __forceinline void storeu(void* ptr, const vdouble8& v) { + _mm512_storeu_pd(ptr, v); + } + + static __forceinline void storeu(const vboold8& mask, double* ptr, const vdouble8& f) { + _mm512_mask_storeu_pd(ptr, mask, f); + } + + static __forceinline void store(const vboold8& mask, void* addr, const vdouble8& v2) { + _mm512_mask_store_pd(addr, mask, v2); + } + + /* pass by value to avoid compiler generating inefficient code */ + static __forceinline void storeu_compact(const vboold8 mask,void * addr, const vdouble8& reg) { + _mm512_mask_compressstoreu_pd(addr, mask, reg); + } + + static __forceinline vdouble8 compact64bit(const vboold8& mask, vdouble8& v) { + return _mm512_mask_compress_pd(v, mask, v); + } + + static __forceinline vdouble8 compact(const vboold8& mask, vdouble8& v) { + return _mm512_mask_compress_pd(v, mask, v); + } + + static __forceinline vdouble8 compact(const vboold8& mask, const vdouble8& a, vdouble8& b) { + return _mm512_mask_compress_pd(a, mask, b); + } + + static __forceinline vdouble8 broadcast(const void* a) { return _mm512_set1_pd(*(double*)a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline double& operator [](size_t index) { assert(index < 8); return i[index]; } + __forceinline const double& operator [](size_t index) const { assert(index < 8); return i[index]; } + + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vdouble8 asDouble(const vllong8& a) { return _mm512_castsi512_pd(a); } + __forceinline vllong8 asLLong (const vdouble8& a) { return _mm512_castpd_si512(a); } + + __forceinline vdouble8 operator +(const vdouble8& a) { return a; } + __forceinline vdouble8 operator -(const vdouble8& a) { return _mm512_sub_pd(_mm512_setzero_pd(), a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vdouble8 operator +(const vdouble8& a, const vdouble8& b) { return _mm512_add_pd(a, b); } + __forceinline vdouble8 operator +(const vdouble8& a, double b) { return a + vdouble8(b); } + __forceinline vdouble8 operator +(double a, const vdouble8& b) { return vdouble8(a) + b; } + + __forceinline vdouble8 operator -(const vdouble8& a, const vdouble8& b) { return _mm512_sub_pd(a, b); } + __forceinline vdouble8 operator -(const vdouble8& a, double b) { return a - vdouble8(b); } + __forceinline vdouble8 operator -(double a, const vdouble8& b) { return vdouble8(a) - b; } + + __forceinline vdouble8 operator *(const vdouble8& a, const vdouble8& b) { return _mm512_mul_pd(a, b); } + __forceinline vdouble8 operator *(const vdouble8& a, double b) { return a * vdouble8(b); } + __forceinline vdouble8 operator *(double a, const vdouble8& b) { return vdouble8(a) * b; } + + __forceinline vdouble8 operator &(const vdouble8& a, const vdouble8& b) { return _mm512_and_pd(a, b); } + __forceinline vdouble8 operator &(const vdouble8& a, double b) { return a & vdouble8(b); } + __forceinline vdouble8 operator &(double a, const vdouble8& b) { return vdouble8(a) & b; } + + __forceinline vdouble8 operator |(const vdouble8& a, const vdouble8& b) { return _mm512_or_pd(a, b); } + __forceinline vdouble8 operator |(const vdouble8& a, double b) { return a | vdouble8(b); } + __forceinline vdouble8 operator |(double a, const vdouble8& b) { return vdouble8(a) | b; } + + __forceinline vdouble8 operator ^(const vdouble8& a, const vdouble8& b) { return _mm512_xor_pd(a, b); } + __forceinline vdouble8 operator ^(const vdouble8& a, double b) { return a ^ vdouble8(b); } + __forceinline vdouble8 operator ^(double a, const vdouble8& b) { return vdouble8(a) ^ b; } + + __forceinline vdouble8 operator <<(const vdouble8& a, const unsigned int n) { return _mm512_castsi512_pd(_mm512_slli_epi64(_mm512_castpd_si512(a), n)); } + __forceinline vdouble8 operator >>(const vdouble8& a, const unsigned int n) { return _mm512_castsi512_pd(_mm512_srai_epi64(_mm512_castpd_si512(a), n)); } + + __forceinline vdouble8 operator <<(const vdouble8& a, const vllong8& n) { return _mm512_castsi512_pd(_mm512_sllv_epi64(_mm512_castpd_si512(a), n)); } + __forceinline vdouble8 operator >>(const vdouble8& a, const vllong8& n) { return _mm512_castsi512_pd(_mm512_srav_epi64(_mm512_castpd_si512(a), n)); } + + __forceinline vdouble8 sll (const vdouble8& a, const unsigned int b) { return _mm512_castsi512_pd(_mm512_slli_epi64(_mm512_castpd_si512(a), b)); } + __forceinline vdouble8 sra (const vdouble8& a, const unsigned int b) { return _mm512_castsi512_pd(_mm512_srai_epi64(_mm512_castpd_si512(a), b)); } + __forceinline vdouble8 srl (const vdouble8& a, const unsigned int b) { return _mm512_castsi512_pd(_mm512_srli_epi64(_mm512_castpd_si512(a), b)); } + + __forceinline vdouble8 min(const vdouble8& a, const vdouble8& b) { return _mm512_min_pd(a, b); } + __forceinline vdouble8 min(const vdouble8& a, double b) { return min(a,vdouble8(b)); } + __forceinline vdouble8 min(double a, const vdouble8& b) { return min(vdouble8(a),b); } + + __forceinline vdouble8 max(const vdouble8& a, const vdouble8& b) { return _mm512_max_pd(a, b); } + __forceinline vdouble8 max(const vdouble8& a, double b) { return max(a,vdouble8(b)); } + __forceinline vdouble8 max(double a, const vdouble8& b) { return max(vdouble8(a),b); } + + __forceinline vdouble8 mask_add(const vboold8& mask, vdouble8& c, const vdouble8& a, const vdouble8& b) { return _mm512_mask_add_pd(c,mask,a,b); } + __forceinline vdouble8 mask_sub(const vboold8& mask, vdouble8& c, const vdouble8& a, const vdouble8& b) { return _mm512_mask_sub_pd(c,mask,a,b); } + + __forceinline vdouble8 mask_and(const vboold8& m,vdouble8& c, const vdouble8& a, const vdouble8& b) { return _mm512_mask_and_pd(c,m,a,b); } + __forceinline vdouble8 mask_or (const vboold8& m,vdouble8& c, const vdouble8& a, const vdouble8& b) { return _mm512_mask_or_pd(c,m,a,b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Ternary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vdouble8 madd (const vdouble8& a, const vdouble8& b, const vdouble8& c) { return _mm512_fmadd_pd(a,b,c); } + __forceinline vdouble8 msub (const vdouble8& a, const vdouble8& b, const vdouble8& c) { return _mm512_fmsub_pd(a,b,c); } + __forceinline vdouble8 nmadd(const vdouble8& a, const vdouble8& b, const vdouble8& c) { return _mm512_fnmadd_pd(a,b,c); } + __forceinline vdouble8 nmsub(const vdouble8& a, const vdouble8& b, const vdouble8& c) { return _mm512_fnmsub_pd(a,b,c); } + + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vdouble8& operator +=(vdouble8& a, const vdouble8& b) { return a = a + b; } + __forceinline vdouble8& operator +=(vdouble8& a, double b) { return a = a + b; } + + __forceinline vdouble8& operator -=(vdouble8& a, const vdouble8& b) { return a = a - b; } + __forceinline vdouble8& operator -=(vdouble8& a, double b) { return a = a - b; } + + __forceinline vdouble8& operator *=(vdouble8& a, const vdouble8& b) { return a = a * b; } + __forceinline vdouble8& operator *=(vdouble8& a, double b) { return a = a * b; } + + __forceinline vdouble8& operator &=(vdouble8& a, const vdouble8& b) { return a = a & b; } + __forceinline vdouble8& operator &=(vdouble8& a, double b) { return a = a & b; } + + __forceinline vdouble8& operator |=(vdouble8& a, const vdouble8& b) { return a = a | b; } + __forceinline vdouble8& operator |=(vdouble8& a, double b) { return a = a | b; } + + __forceinline vdouble8& operator <<=(vdouble8& a, const double b) { return a = a << b; } + __forceinline vdouble8& operator >>=(vdouble8& a, const double b) { return a = a >> b; } + + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold8 operator ==(const vdouble8& a, const vdouble8& b) { return _mm512_cmp_pd_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboold8 operator ==(const vdouble8& a, double b) { return a == vdouble8(b); } + __forceinline vboold8 operator ==(double a, const vdouble8& b) { return vdouble8(a) == b; } + + __forceinline vboold8 operator !=(const vdouble8& a, const vdouble8& b) { return _mm512_cmp_pd_mask(a,b,_MM_CMPINT_NE); } + __forceinline vboold8 operator !=(const vdouble8& a, double b) { return a != vdouble8(b); } + __forceinline vboold8 operator !=(double a, const vdouble8& b) { return vdouble8(a) != b; } + + __forceinline vboold8 operator < (const vdouble8& a, const vdouble8& b) { return _mm512_cmp_pd_mask(a,b,_MM_CMPINT_LT); } + __forceinline vboold8 operator < (const vdouble8& a, double b) { return a < vdouble8(b); } + __forceinline vboold8 operator < (double a, const vdouble8& b) { return vdouble8(a) < b; } + + __forceinline vboold8 operator >=(const vdouble8& a, const vdouble8& b) { return _mm512_cmp_pd_mask(a,b,_MM_CMPINT_GE); } + __forceinline vboold8 operator >=(const vdouble8& a, double b) { return a >= vdouble8(b); } + __forceinline vboold8 operator >=(double a, const vdouble8& b) { return vdouble8(a) >= b; } + + __forceinline vboold8 operator > (const vdouble8& a, const vdouble8& b) { return _mm512_cmp_pd_mask(a,b,_MM_CMPINT_GT); } + __forceinline vboold8 operator > (const vdouble8& a, double b) { return a > vdouble8(b); } + __forceinline vboold8 operator > (double a, const vdouble8& b) { return vdouble8(a) > b; } + + __forceinline vboold8 operator <=(const vdouble8& a, const vdouble8& b) { return _mm512_cmp_pd_mask(a,b,_MM_CMPINT_LE); } + __forceinline vboold8 operator <=(const vdouble8& a, double b) { return a <= vdouble8(b); } + __forceinline vboold8 operator <=(double a, const vdouble8& b) { return vdouble8(a) <= b; } + + __forceinline vboold8 eq(const vdouble8& a, const vdouble8& b) { return _mm512_cmp_pd_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboold8 ne(const vdouble8& a, const vdouble8& b) { return _mm512_cmp_pd_mask(a,b,_MM_CMPINT_NE); } + __forceinline vboold8 lt(const vdouble8& a, const vdouble8& b) { return _mm512_cmp_pd_mask(a,b,_MM_CMPINT_LT); } + __forceinline vboold8 ge(const vdouble8& a, const vdouble8& b) { return _mm512_cmp_pd_mask(a,b,_MM_CMPINT_GE); } + __forceinline vboold8 gt(const vdouble8& a, const vdouble8& b) { return _mm512_cmp_pd_mask(a,b,_MM_CMPINT_GT); } + __forceinline vboold8 le(const vdouble8& a, const vdouble8& b) { return _mm512_cmp_pd_mask(a,b,_MM_CMPINT_LE); } + + __forceinline vboold8 eq(const vboold8 mask, const vdouble8& a, const vdouble8& b) { return _mm512_mask_cmp_pd_mask(mask,a,b,_MM_CMPINT_EQ); } + __forceinline vboold8 ne(const vboold8 mask, const vdouble8& a, const vdouble8& b) { return _mm512_mask_cmp_pd_mask(mask,a,b,_MM_CMPINT_NE); } + __forceinline vboold8 lt(const vboold8 mask, const vdouble8& a, const vdouble8& b) { return _mm512_mask_cmp_pd_mask(mask,a,b,_MM_CMPINT_LT); } + __forceinline vboold8 ge(const vboold8 mask, const vdouble8& a, const vdouble8& b) { return _mm512_mask_cmp_pd_mask(mask,a,b,_MM_CMPINT_GE); } + __forceinline vboold8 gt(const vboold8 mask, const vdouble8& a, const vdouble8& b) { return _mm512_mask_cmp_pd_mask(mask,a,b,_MM_CMPINT_GT); } + __forceinline vboold8 le(const vboold8 mask, const vdouble8& a, const vdouble8& b) { return _mm512_mask_cmp_pd_mask(mask,a,b,_MM_CMPINT_LE); } + + __forceinline vdouble8 select(const vboold8& m, const vdouble8& t, const vdouble8& f) { + return _mm512_mask_or_pd(f,m,t,t); + } + + __forceinline void xchg(const vboold8& m, vdouble8& a, vdouble8& b) { + const vdouble8 c = a; a = select(m,b,a); b = select(m,c,b); + } + + __forceinline vboold8 test(const vboold8& m, const vdouble8& a, const vdouble8& b) { + return _mm512_mask_test_epi64_mask(m,_mm512_castpd_si512(a),_mm512_castpd_si512(b)); + } + + __forceinline vboold8 test(const vdouble8& a, const vdouble8& b) { + return _mm512_test_epi64_mask(_mm512_castpd_si512(a),_mm512_castpd_si512(b)); + } + + //////////////////////////////////////////////////////////////////////////////// + // Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vdouble8 shuffle(const vdouble8& v) { + return _mm512_permute_pd(v, (i1 << 7) | (i0 << 6) | (i1 << 5) | (i0 << 4) | (i1 << 3) | (i0 << 2) | (i1 << 1) | i0); + } + + template + __forceinline vdouble8 shuffle(const vdouble8& v) { + return shuffle(v); + } + + template + __forceinline vdouble8 shuffle(const vdouble8& v) { + return _mm512_permutex_pd(v, _MM_SHUFFLE(i3, i2, i1, i0)); + } + + template + __forceinline vdouble8 shuffle4(const vdouble8& v) { + return _mm512_shuffle_f64x2(v, v, _MM_SHUFFLE(i1*2+1, i1*2, i0*2+1, i0*2)); + } + + template + __forceinline vdouble8 shuffle4(const vdouble8& v) { + return shuffle4(v); + } + + template + __forceinline vdouble8 align_shift_right(const vdouble8& a, const vdouble8& b) { + return _mm512_castsi512_pd(_mm512_alignr_epi64(_mm512_castpd_si512(a), _mm512_castpd_si512(b), i)); + } + + __forceinline double toScalar(const vdouble8& v) { + return _mm_cvtsd_f64(_mm512_castpd512_pd128(v)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vdouble8 vreduce_add2(vdouble8 x) { return x + shuffle<1,0,3,2>(x); } + __forceinline vdouble8 vreduce_add4(vdouble8 x) { x = vreduce_add2(x); return x + shuffle<2,3,0,1>(x); } + __forceinline vdouble8 vreduce_add (vdouble8 x) { x = vreduce_add4(x); return x + shuffle4<1,0>(x); } + + __forceinline vdouble8 vreduce_min2(vdouble8 x) { return min(x, shuffle<1,0,3,2>(x)); } + __forceinline vdouble8 vreduce_min4(vdouble8 x) { x = vreduce_min2(x); return min(x, shuffle<2,3,0,1>(x)); } + __forceinline vdouble8 vreduce_min (vdouble8 x) { x = vreduce_min4(x); return min(x, shuffle4<1,0>(x)); } + + __forceinline vdouble8 vreduce_max2(vdouble8 x) { return max(x, shuffle<1,0,3,2>(x)); } + __forceinline vdouble8 vreduce_max4(vdouble8 x) { x = vreduce_max2(x); return max(x, shuffle<2,3,0,1>(x)); } + __forceinline vdouble8 vreduce_max (vdouble8 x) { x = vreduce_max4(x); return max(x, shuffle4<1,0>(x)); } + + __forceinline double reduce_add(const vdouble8& v) { return toScalar(vreduce_add(v)); } + __forceinline double reduce_min(const vdouble8& v) { return toScalar(vreduce_min(v)); } + __forceinline double reduce_max(const vdouble8& v) { return toScalar(vreduce_max(v)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Memory load and store operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vdouble8 permute(const vdouble8& v, const vllong8& index) { + return _mm512_permutexvar_pd(index, v); + } + + __forceinline vdouble8 reverse(const vdouble8& a) { + return permute(a, vllong8(reverse_step)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vdouble8& v) + { + cout << "<" << v[0]; + for (size_t i=1; i<8; i++) cout << ", " << v[i]; + cout << ">"; + return cout; + } +} diff --git a/thirdparty/embree/common/simd/vfloat16_avx512.h b/thirdparty/embree/common/simd/vfloat16_avx512.h new file mode 100644 index 000000000000..aed2419b779b --- /dev/null +++ b/thirdparty/embree/common/simd/vfloat16_avx512.h @@ -0,0 +1,771 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 16-wide AVX-512 float type */ + template<> + struct vfloat<16> + { + ALIGNED_STRUCT_(64); + + typedef vboolf16 Bool; + typedef vint16 Int; + typedef vfloat16 Float; + + enum { size = 16 }; // number of SIMD elements + union { // data + __m512 v; + float f[16]; + int i[16]; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat() {} + __forceinline vfloat(const vfloat16& t) { v = t; } + __forceinline vfloat16& operator =(const vfloat16& f) { v = f.v; return *this; } + + __forceinline vfloat(const __m512& t) { v = t; } + __forceinline operator __m512() const { return v; } + __forceinline operator __m256() const { return _mm512_castps512_ps256(v); } + __forceinline operator __m128() const { return _mm512_castps512_ps128(v); } + + __forceinline vfloat(float f) { + v = _mm512_set1_ps(f); + } + + __forceinline vfloat(float a, float b, float c, float d) { + v = _mm512_set4_ps(a, b, c, d); + } + + __forceinline vfloat(const vfloat4& i) { + v = _mm512_broadcast_f32x4(i); + } + + __forceinline vfloat(const vfloat4& a, const vfloat4& b, const vfloat4& c, const vfloat4& d) { + v = _mm512_castps128_ps512(a); + v = _mm512_insertf32x4(v, b, 1); + v = _mm512_insertf32x4(v, c, 2); + v = _mm512_insertf32x4(v, d, 3); + } + + __forceinline vfloat(const vboolf16& mask, const vfloat4& a, const vfloat4& b) { + v = _mm512_broadcast_f32x4(a); + v = _mm512_mask_broadcast_f32x4(v,mask,b); + } + + __forceinline vfloat(const vfloat8& i) { + v = _mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_castps_pd(i))); + } + + __forceinline vfloat(const vfloat8& a, const vfloat8& b) { + v = _mm512_castps256_ps512(a); +#if defined(__AVX512DQ__) + v = _mm512_insertf32x8(v, b, 1); +#else + v = _mm512_castpd_ps(_mm512_insertf64x4(_mm512_castps_pd(v), _mm256_castps_pd(b), 1)); +#endif + } + + /* WARNING: due to f64x4 the mask is considered as an 8bit mask */ + __forceinline vfloat(const vboolf16& mask, const vfloat8& a, const vfloat8& b) { + __m512d aa = _mm512_broadcast_f64x4(_mm256_castps_pd(a)); + aa = _mm512_mask_broadcast_f64x4(aa,mask,_mm256_castps_pd(b)); + v = _mm512_castpd_ps(aa); + } + + __forceinline explicit vfloat(const vint16& a) { + v = _mm512_cvtepi32_ps(a); + } + + __forceinline explicit vfloat(const vuint16& a) { + v = _mm512_cvtepu32_ps(a); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat(ZeroTy) : v(_mm512_setzero_ps()) {} + __forceinline vfloat(OneTy) : v(_mm512_set1_ps(1.0f)) {} + __forceinline vfloat(PosInfTy) : v(_mm512_set1_ps(pos_inf)) {} + __forceinline vfloat(NegInfTy) : v(_mm512_set1_ps(neg_inf)) {} + __forceinline vfloat(StepTy) : v(_mm512_set_ps(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)) {} + __forceinline vfloat(NaNTy) : v(_mm512_set1_ps(nan)) {} + __forceinline vfloat(UndefinedTy) : v(_mm512_undefined_ps()) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline vfloat16 load (const void* ptr) { return _mm512_load_ps((float*)ptr); } + static __forceinline vfloat16 loadu(const void* ptr) { return _mm512_loadu_ps((float*)ptr); } + + static __forceinline vfloat16 load (const vboolf16& mask, const void* ptr) { return _mm512_mask_load_ps (_mm512_setzero_ps(),mask,(float*)ptr); } + static __forceinline vfloat16 loadu(const vboolf16& mask, const void* ptr) { return _mm512_mask_loadu_ps(_mm512_setzero_ps(),mask,(float*)ptr); } + + static __forceinline void store (void* ptr, const vfloat16& v) { _mm512_store_ps ((float*)ptr,v); } + static __forceinline void storeu(void* ptr, const vfloat16& v) { _mm512_storeu_ps((float*)ptr,v); } + + static __forceinline void store (const vboolf16& mask, void* ptr, const vfloat16& v) { _mm512_mask_store_ps ((float*)ptr,mask,v); } + static __forceinline void storeu(const vboolf16& mask, void* ptr, const vfloat16& v) { _mm512_mask_storeu_ps((float*)ptr,mask,v); } + + static __forceinline void store_nt(void* __restrict__ ptr, const vfloat16& a) { + _mm512_stream_ps((float*)ptr,a); + } + + static __forceinline vfloat16 broadcast(const float* f) { + return _mm512_set1_ps(*f); + } + + static __forceinline vfloat16 compact(const vboolf16& mask, vfloat16 &v) { + return _mm512_mask_compress_ps(v, mask, v); + } + static __forceinline vfloat16 compact(const vboolf16& mask, vfloat16 &a, const vfloat16& b) { + return _mm512_mask_compress_ps(a, mask, b); + } + + static __forceinline vfloat16 expand(const vboolf16& mask, const vfloat16& a, vfloat16& b) { + return _mm512_mask_expand_ps(b, mask, a); + } + + static __forceinline vfloat16 loadu_compact(const vboolf16& mask, const void* ptr) { + return _mm512_mask_expandloadu_ps(_mm512_setzero_ps(), mask, (float*)ptr); + } + + static __forceinline void storeu_compact(const vboolf16& mask, float *addr, const vfloat16 reg) { + _mm512_mask_compressstoreu_ps(addr, mask, reg); + } + + static __forceinline void storeu_compact_single(const vboolf16& mask, float * addr, const vfloat16& reg) { + //_mm512_mask_compressstoreu_ps(addr,mask,reg); + *addr = mm512_cvtss_f32(_mm512_mask_compress_ps(reg, mask, reg)); + } + + template + static __forceinline vfloat16 gather(const float* ptr, const vint16& index) { + return _mm512_i32gather_ps(index, ptr, scale); + } + + template + static __forceinline vfloat16 gather(const vboolf16& mask, const float* ptr, const vint16& index) { + vfloat16 r = zero; + return _mm512_mask_i32gather_ps(r, mask, index, ptr, scale); + } + + template + static __forceinline void scatter(float* ptr, const vint16& index, const vfloat16& v) { + _mm512_i32scatter_ps(ptr, index, v, scale); + } + + template + static __forceinline void scatter(const vboolf16& mask, float* ptr, const vint16& index, const vfloat16& v) { + _mm512_mask_i32scatter_ps(ptr, mask, index, v, scale); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline float& operator [](size_t index) { assert(index < 16); return f[index]; } + __forceinline const float& operator [](size_t index) const { assert(index < 16); return f[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat16 asFloat(const vint16& a) { return _mm512_castsi512_ps(a); } + __forceinline vint16 asInt (const vfloat16& a) { return _mm512_castps_si512(a); } + __forceinline vuint16 asUInt (const vfloat16& a) { return _mm512_castps_si512(a); } + + __forceinline vint16 toInt (const vfloat16& a) { return vint16(a); } + __forceinline vfloat16 toFloat(const vint16& a) { return vfloat16(a); } + + __forceinline vfloat16 operator +(const vfloat16& a) { return a; } + __forceinline vfloat16 operator -(const vfloat16& a) { return _mm512_mul_ps(a,vfloat16(-1)); } + + __forceinline vfloat16 abs (const vfloat16& a) { return _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(a),_mm512_set1_epi32(0x7FFFFFFF))); } + __forceinline vfloat16 signmsk(const vfloat16& a) { return _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(a),_mm512_set1_epi32(0x80000000))); } + + __forceinline vfloat16 rcp(const vfloat16& a) { +#if defined(__AVX512ER__) + return _mm512_rcp28_ps(a); +#else + const vfloat16 r = _mm512_rcp14_ps(a); + return _mm512_mul_ps(r, _mm512_fnmadd_ps(r, a, vfloat16(2.0f))); +#endif + } + + __forceinline vfloat16 sqr (const vfloat16& a) { return _mm512_mul_ps(a,a); } + __forceinline vfloat16 sqrt(const vfloat16& a) { return _mm512_sqrt_ps(a); } + + __forceinline vfloat16 rsqrt(const vfloat16& a) + { +#if defined(__AVX512VL__) + const vfloat16 r = _mm512_rsqrt14_ps(a); + return _mm512_fmadd_ps(_mm512_set1_ps(1.5f), r, + _mm512_mul_ps(_mm512_mul_ps(_mm512_mul_ps(a, _mm512_set1_ps(-0.5f)), r), _mm512_mul_ps(r, r))); +#else + return _mm512_rsqrt28_ps(a); +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat16 operator +(const vfloat16& a, const vfloat16& b) { return _mm512_add_ps(a, b); } + __forceinline vfloat16 operator +(const vfloat16& a, float b) { return a + vfloat16(b); } + __forceinline vfloat16 operator +(float a, const vfloat16& b) { return vfloat16(a) + b; } + + __forceinline vfloat16 operator -(const vfloat16& a, const vfloat16& b) { return _mm512_sub_ps(a, b); } + __forceinline vfloat16 operator -(const vfloat16& a, float b) { return a - vfloat16(b); } + __forceinline vfloat16 operator -(float a, const vfloat16& b) { return vfloat16(a) - b; } + + __forceinline vfloat16 operator *(const vfloat16& a, const vfloat16& b) { return _mm512_mul_ps(a, b); } + __forceinline vfloat16 operator *(const vfloat16& a, float b) { return a * vfloat16(b); } + __forceinline vfloat16 operator *(float a, const vfloat16& b) { return vfloat16(a) * b; } + + __forceinline vfloat16 operator /(const vfloat16& a, const vfloat16& b) { return _mm512_div_ps(a,b); } + __forceinline vfloat16 operator /(const vfloat16& a, float b) { return a/vfloat16(b); } + __forceinline vfloat16 operator /(float a, const vfloat16& b) { return vfloat16(a)/b; } + + __forceinline vfloat16 operator &(const vfloat16& a, const vfloat16& b) { return _mm512_and_ps(a,b); } + __forceinline vfloat16 operator |(const vfloat16& a, const vfloat16& b) { return _mm512_or_ps(a,b); } + __forceinline vfloat16 operator ^(const vfloat16& a, const vfloat16& b) { + return _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(a),_mm512_castps_si512(b))); + } + + __forceinline vfloat16 min(const vfloat16& a, const vfloat16& b) { + return _mm512_min_ps(a,b); + } + __forceinline vfloat16 min(const vfloat16& a, float b) { + return _mm512_min_ps(a,vfloat16(b)); + } + __forceinline vfloat16 min(const float& a, const vfloat16& b) { + return _mm512_min_ps(vfloat16(a),b); + } + + __forceinline vfloat16 max(const vfloat16& a, const vfloat16& b) { + return _mm512_max_ps(a,b); + } + __forceinline vfloat16 max(const vfloat16& a, float b) { + return _mm512_max_ps(a,vfloat16(b)); + } + __forceinline vfloat16 max(const float& a, const vfloat16& b) { + return _mm512_max_ps(vfloat16(a),b); + } + + __forceinline vfloat16 mask_add(const vboolf16& mask, const vfloat16& c, const vfloat16& a, const vfloat16& b) { return _mm512_mask_add_ps (c,mask,a,b); } + __forceinline vfloat16 mask_min(const vboolf16& mask, const vfloat16& c, const vfloat16& a, const vfloat16& b) { + return _mm512_mask_min_ps(c,mask,a,b); + }; + __forceinline vfloat16 mask_max(const vboolf16& mask, const vfloat16& c, const vfloat16& a, const vfloat16& b) { + return _mm512_mask_max_ps(c,mask,a,b); + }; + + __forceinline vfloat16 mini(const vfloat16& a, const vfloat16& b) { +#if !defined(__AVX512ER__) // SKX + const vint16 ai = _mm512_castps_si512(a); + const vint16 bi = _mm512_castps_si512(b); + const vint16 ci = _mm512_min_epi32(ai,bi); + return _mm512_castsi512_ps(ci); +#else // KNL + return min(a,b); +#endif + } + + __forceinline vfloat16 maxi(const vfloat16& a, const vfloat16& b) { +#if !defined(__AVX512ER__) // SKX + const vint16 ai = _mm512_castps_si512(a); + const vint16 bi = _mm512_castps_si512(b); + const vint16 ci = _mm512_max_epi32(ai,bi); + return _mm512_castsi512_ps(ci); +#else // KNL + return max(a,b); +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + /// Ternary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat16 madd (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmadd_ps(a,b,c); } + __forceinline vfloat16 msub (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmsub_ps(a,b,c); } + __forceinline vfloat16 nmadd(const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fnmadd_ps(a,b,c); } + __forceinline vfloat16 nmsub(const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fnmsub_ps(a,b,c); } + + __forceinline vfloat16 mask_msub(const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_fmsub_ps(a,mask,b,c); } + + __forceinline vfloat16 madd231 (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmadd_ps(c,b,a); } + __forceinline vfloat16 msub213 (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmsub_ps(a,b,c); } + __forceinline vfloat16 msub231 (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmsub_ps(c,b,a); } + __forceinline vfloat16 msubr231(const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fnmadd_ps(c,b,a); } + + + //////////////////////////////////////////////////////////////////////////////// + /// Operators with rounding + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat16 madd_round_down(const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmadd_round_ps(a,b,c,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); } + __forceinline vfloat16 madd_round_up (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmadd_round_ps(a,b,c,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); } + + __forceinline vfloat16 mul_round_down(const vfloat16& a, const vfloat16& b) { return _mm512_mul_round_ps(a,b,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); } + __forceinline vfloat16 mul_round_up (const vfloat16& a, const vfloat16& b) { return _mm512_mul_round_ps(a,b,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); } + + __forceinline vfloat16 add_round_down(const vfloat16& a, const vfloat16& b) { return _mm512_add_round_ps(a,b,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); } + __forceinline vfloat16 add_round_up (const vfloat16& a, const vfloat16& b) { return _mm512_add_round_ps(a,b,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); } + + __forceinline vfloat16 sub_round_down(const vfloat16& a, const vfloat16& b) { return _mm512_sub_round_ps(a,b,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); } + __forceinline vfloat16 sub_round_up (const vfloat16& a, const vfloat16& b) { return _mm512_sub_round_ps(a,b,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); } + + __forceinline vfloat16 div_round_down(const vfloat16& a, const vfloat16& b) { return _mm512_div_round_ps(a,b,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); } + __forceinline vfloat16 div_round_up (const vfloat16& a, const vfloat16& b) { return _mm512_div_round_ps(a,b,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); } + + __forceinline vfloat16 mask_msub_round_down(const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_fmsub_round_ps(a,mask,b,c,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); } + __forceinline vfloat16 mask_msub_round_up (const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_fmsub_round_ps(a,mask,b,c,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); } + + __forceinline vfloat16 mask_mul_round_down(const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_mul_round_ps(a,mask,b,c,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); } + __forceinline vfloat16 mask_mul_round_up (const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_mul_round_ps(a,mask,b,c,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); } + + __forceinline vfloat16 mask_sub_round_down(const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_sub_round_ps(a,mask,b,c,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); } + __forceinline vfloat16 mask_sub_round_up (const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_sub_round_ps(a,mask,b,c,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); } + + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat16& operator +=(vfloat16& a, const vfloat16& b) { return a = a + b; } + __forceinline vfloat16& operator +=(vfloat16& a, float b) { return a = a + b; } + + __forceinline vfloat16& operator -=(vfloat16& a, const vfloat16& b) { return a = a - b; } + __forceinline vfloat16& operator -=(vfloat16& a, float b) { return a = a - b; } + + __forceinline vfloat16& operator *=(vfloat16& a, const vfloat16& b) { return a = a * b; } + __forceinline vfloat16& operator *=(vfloat16& a, float b) { return a = a * b; } + + __forceinline vfloat16& operator /=(vfloat16& a, const vfloat16& b) { return a = a / b; } + __forceinline vfloat16& operator /=(vfloat16& a, float b) { return a = a / b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf16 operator ==(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboolf16 operator ==(const vfloat16& a, float b) { return a == vfloat16(b); } + __forceinline vboolf16 operator ==(float a, const vfloat16& b) { return vfloat16(a) == b; } + + __forceinline vboolf16 operator !=(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_NE); } + __forceinline vboolf16 operator !=(const vfloat16& a, float b) { return a != vfloat16(b); } + __forceinline vboolf16 operator !=(float a, const vfloat16& b) { return vfloat16(a) != b; } + + __forceinline vboolf16 operator < (const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_LT); } + __forceinline vboolf16 operator < (const vfloat16& a, float b) { return a < vfloat16(b); } + __forceinline vboolf16 operator < (float a, const vfloat16& b) { return vfloat16(a) < b; } + + __forceinline vboolf16 operator >=(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_GE); } + __forceinline vboolf16 operator >=(const vfloat16& a, float b) { return a >= vfloat16(b); } + __forceinline vboolf16 operator >=(float a, const vfloat16& b) { return vfloat16(a) >= b; } + + __forceinline vboolf16 operator > (const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_GT); } + __forceinline vboolf16 operator > (const vfloat16& a, float b) { return a > vfloat16(b); } + __forceinline vboolf16 operator > (float a, const vfloat16& b) { return vfloat16(a) > b; } + + __forceinline vboolf16 operator <=(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_LE); } + __forceinline vboolf16 operator <=(const vfloat16& a, float b) { return a <= vfloat16(b); } + __forceinline vboolf16 operator <=(float a, const vfloat16& b) { return vfloat16(a) <= b; } + + __forceinline vboolf16 eq(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboolf16 ne(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_NE); } + __forceinline vboolf16 lt(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_LT); } + __forceinline vboolf16 ge(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_GE); } + __forceinline vboolf16 gt(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_GT); } + __forceinline vboolf16 le(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_LE); } + + __forceinline vboolf16 eq(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_EQ); } + __forceinline vboolf16 ne(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_NE); } + __forceinline vboolf16 lt(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_LT); } + __forceinline vboolf16 ge(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_GE); } + __forceinline vboolf16 gt(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_GT); } + __forceinline vboolf16 le(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_LE); } + + __forceinline vfloat16 select(const vboolf16& s, const vfloat16& t, const vfloat16& f) { + return _mm512_mask_blend_ps(s, f, t); + } + + __forceinline vfloat16 lerp(const vfloat16& a, const vfloat16& b, const vfloat16& t) { + return madd(t,b-a,a); + } + + __forceinline void xchg(vboolf16 m, vfloat16& a, vfloat16& b) + { + vfloat16 c = a; + a = select(m,b,a); + b = select(m,c,b); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Rounding Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat16 floor(const vfloat16& a) { + return _mm512_floor_ps(a); + } + __forceinline vfloat16 ceil (const vfloat16& a) { + return _mm512_ceil_ps(a); + } + __forceinline vfloat16 round (const vfloat16& a) { + return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + } + __forceinline vint16 floori (const vfloat16& a) { + return _mm512_cvt_roundps_epi32(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat16 unpacklo(const vfloat16& a, const vfloat16& b) { return _mm512_unpacklo_ps(a, b); } + __forceinline vfloat16 unpackhi(const vfloat16& a, const vfloat16& b) { return _mm512_unpackhi_ps(a, b); } + + template + __forceinline vfloat16 shuffle(const vfloat16& v) { + return _mm512_permute_ps(v, _MM_SHUFFLE(i, i, i, i)); + } + + template + __forceinline vfloat16 shuffle(const vfloat16& v) { + return _mm512_permute_ps(v, _MM_SHUFFLE(i3, i2, i1, i0)); + } + + template + __forceinline vfloat16 shuffle4(const vfloat16& v) { + return _mm512_shuffle_f32x4(v, v ,_MM_SHUFFLE(i, i, i, i)); + } + + template + __forceinline vfloat16 shuffle4(const vfloat16& v) { + return _mm512_shuffle_f32x4(v, v, _MM_SHUFFLE(i3, i2, i1, i0)); + } + + __forceinline vfloat16 interleave_even(const vfloat16& a, const vfloat16& b) { + return _mm512_castsi512_ps(_mm512_mask_shuffle_epi32(_mm512_castps_si512(a), mm512_int2mask(0xaaaa), _mm512_castps_si512(b), (_MM_PERM_ENUM)0xb1)); + } + + __forceinline vfloat16 interleave_odd(const vfloat16& a, const vfloat16& b) { + return _mm512_castsi512_ps(_mm512_mask_shuffle_epi32(_mm512_castps_si512(b), mm512_int2mask(0x5555), _mm512_castps_si512(a), (_MM_PERM_ENUM)0xb1)); + } + + __forceinline vfloat16 interleave2_even(const vfloat16& a, const vfloat16& b) { + /* mask should be 8-bit but is 16-bit to reuse for interleave_even */ + return _mm512_castsi512_ps(_mm512_mask_permutex_epi64(_mm512_castps_si512(a), mm512_int2mask(0xaaaa), _mm512_castps_si512(b), (_MM_PERM_ENUM)0xb1)); + } + + __forceinline vfloat16 interleave2_odd(const vfloat16& a, const vfloat16& b) { + /* mask should be 8-bit but is 16-bit to reuse for interleave_odd */ + return _mm512_castsi512_ps(_mm512_mask_permutex_epi64(_mm512_castps_si512(b), mm512_int2mask(0x5555), _mm512_castps_si512(a), (_MM_PERM_ENUM)0xb1)); + } + + __forceinline vfloat16 interleave4_even(const vfloat16& a, const vfloat16& b) { + return _mm512_castsi512_ps(_mm512_mask_permutex_epi64(_mm512_castps_si512(a), mm512_int2mask(0xcc), _mm512_castps_si512(b), (_MM_PERM_ENUM)0x4e)); + } + + __forceinline vfloat16 interleave4_odd(const vfloat16& a, const vfloat16& b) { + return _mm512_castsi512_ps(_mm512_mask_permutex_epi64(_mm512_castps_si512(b), mm512_int2mask(0x33), _mm512_castps_si512(a), (_MM_PERM_ENUM)0x4e)); + } + + __forceinline vfloat16 permute(vfloat16 v, __m512i index) { + return _mm512_castsi512_ps(_mm512_permutexvar_epi32(index, _mm512_castps_si512(v))); + } + + __forceinline vfloat16 reverse(const vfloat16& v) { + return permute(v,_mm512_setr_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)); + } + + template + __forceinline vfloat16 align_shift_right(const vfloat16& a, const vfloat16& b) { + return _mm512_castsi512_ps(_mm512_alignr_epi32(_mm512_castps_si512(a),_mm512_castps_si512(b),i)); + }; + + template + __forceinline vfloat16 mask_align_shift_right(const vboolf16& mask, vfloat16& c, const vfloat16& a, const vfloat16& b) { + return _mm512_castsi512_ps(_mm512_mask_alignr_epi32(_mm512_castps_si512(c),mask,_mm512_castps_si512(a),_mm512_castps_si512(b),i)); + }; + + __forceinline vfloat16 shift_left_1(const vfloat16& a) { + vfloat16 z = zero; + return mask_align_shift_right<15>(0xfffe,z,a,a); + } + + __forceinline vfloat16 shift_right_1(const vfloat16& x) { + return align_shift_right<1>(zero,x); + } + + __forceinline float toScalar(const vfloat16& v) { return mm512_cvtss_f32(v); } + + + template __forceinline vfloat16 insert4(const vfloat16& a, const vfloat4& b) { return _mm512_insertf32x4(a, b, i); } + + template + vfloat extractN(const vfloat16& v); + + template<> __forceinline vfloat4 extractN<4,0>(const vfloat16& v) { return _mm512_castps512_ps128(v); } + template<> __forceinline vfloat4 extractN<4,1>(const vfloat16& v) { return _mm512_extractf32x4_ps(v, 1); } + template<> __forceinline vfloat4 extractN<4,2>(const vfloat16& v) { return _mm512_extractf32x4_ps(v, 2); } + template<> __forceinline vfloat4 extractN<4,3>(const vfloat16& v) { return _mm512_extractf32x4_ps(v, 3); } + + template<> __forceinline vfloat8 extractN<8,0>(const vfloat16& v) { return _mm512_castps512_ps256(v); } + template<> __forceinline vfloat8 extractN<8,1>(const vfloat16& v) { return _mm512_extractf32x8_ps(v, 1); } + + template __forceinline vfloat4 extract4 (const vfloat16& v) { return _mm512_extractf32x4_ps(v, i); } + template<> __forceinline vfloat4 extract4<0>(const vfloat16& v) { return _mm512_castps512_ps128(v); } + + template __forceinline vfloat8 extract8 (const vfloat16& v) { return _mm512_extractf32x8_ps(v, i); } + template<> __forceinline vfloat8 extract8<0>(const vfloat16& v) { return _mm512_castps512_ps256(v); } + + //////////////////////////////////////////////////////////////////////////////// + /// Transpose + //////////////////////////////////////////////////////////////////////////////// + + __forceinline void transpose(const vfloat16& r0, const vfloat16& r1, const vfloat16& r2, const vfloat16& r3, + vfloat16& c0, vfloat16& c1, vfloat16& c2, vfloat16& c3) + { +#if defined(__AVX512F__) && !defined(__AVX512VL__) // KNL + vfloat16 a0a1_c0c1 = interleave_even(r0, r1); + vfloat16 a2a3_c2c3 = interleave_even(r2, r3); + vfloat16 b0b1_d0d1 = interleave_odd (r0, r1); + vfloat16 b2b3_d2d3 = interleave_odd (r2, r3); + + c0 = interleave2_even(a0a1_c0c1, a2a3_c2c3); + c1 = interleave2_even(b0b1_d0d1, b2b3_d2d3); + c2 = interleave2_odd (a0a1_c0c1, a2a3_c2c3); + c3 = interleave2_odd (b0b1_d0d1, b2b3_d2d3); +#else + vfloat16 a0a2_b0b2 = unpacklo(r0, r2); + vfloat16 c0c2_d0d2 = unpackhi(r0, r2); + vfloat16 a1a3_b1b3 = unpacklo(r1, r3); + vfloat16 c1c3_d1d3 = unpackhi(r1, r3); + + c0 = unpacklo(a0a2_b0b2, a1a3_b1b3); + c1 = unpackhi(a0a2_b0b2, a1a3_b1b3); + c2 = unpacklo(c0c2_d0d2, c1c3_d1d3); + c3 = unpackhi(c0c2_d0d2, c1c3_d1d3); +#endif + } + + __forceinline void transpose(const vfloat4& r0, const vfloat4& r1, const vfloat4& r2, const vfloat4& r3, + const vfloat4& r4, const vfloat4& r5, const vfloat4& r6, const vfloat4& r7, + const vfloat4& r8, const vfloat4& r9, const vfloat4& r10, const vfloat4& r11, + const vfloat4& r12, const vfloat4& r13, const vfloat4& r14, const vfloat4& r15, + vfloat16& c0, vfloat16& c1, vfloat16& c2, vfloat16& c3) + { + return transpose(vfloat16(r0, r4, r8, r12), vfloat16(r1, r5, r9, r13), vfloat16(r2, r6, r10, r14), vfloat16(r3, r7, r11, r15), + c0, c1, c2, c3); + } + + __forceinline void transpose(const vfloat16& r0, const vfloat16& r1, const vfloat16& r2, const vfloat16& r3, + const vfloat16& r4, const vfloat16& r5, const vfloat16& r6, const vfloat16& r7, + vfloat16& c0, vfloat16& c1, vfloat16& c2, vfloat16& c3, + vfloat16& c4, vfloat16& c5, vfloat16& c6, vfloat16& c7) + { + vfloat16 a0a1a2a3_e0e1e2e3, b0b1b2b3_f0f1f2f3, c0c1c2c3_g0g1g2g3, d0d1d2d3_h0h1h2h3; + transpose(r0, r1, r2, r3, a0a1a2a3_e0e1e2e3, b0b1b2b3_f0f1f2f3, c0c1c2c3_g0g1g2g3, d0d1d2d3_h0h1h2h3); + + vfloat16 a4a5a6a7_e4e5e6e7, b4b5b6b7_f4f5f6f7, c4c5c6c7_g4g5g6g7, d4d5d6d7_h4h5h6h7; + transpose(r4, r5, r6, r7, a4a5a6a7_e4e5e6e7, b4b5b6b7_f4f5f6f7, c4c5c6c7_g4g5g6g7, d4d5d6d7_h4h5h6h7); + + c0 = interleave4_even(a0a1a2a3_e0e1e2e3, a4a5a6a7_e4e5e6e7); + c1 = interleave4_even(b0b1b2b3_f0f1f2f3, b4b5b6b7_f4f5f6f7); + c2 = interleave4_even(c0c1c2c3_g0g1g2g3, c4c5c6c7_g4g5g6g7); + c3 = interleave4_even(d0d1d2d3_h0h1h2h3, d4d5d6d7_h4h5h6h7); + c4 = interleave4_odd (a0a1a2a3_e0e1e2e3, a4a5a6a7_e4e5e6e7); + c5 = interleave4_odd (b0b1b2b3_f0f1f2f3, b4b5b6b7_f4f5f6f7); + c6 = interleave4_odd (c0c1c2c3_g0g1g2g3, c4c5c6c7_g4g5g6g7); + c7 = interleave4_odd (d0d1d2d3_h0h1h2h3, d4d5d6d7_h4h5h6h7); + } + + __forceinline void transpose(const vfloat8& r0, const vfloat8& r1, const vfloat8& r2, const vfloat8& r3, + const vfloat8& r4, const vfloat8& r5, const vfloat8& r6, const vfloat8& r7, + const vfloat8& r8, const vfloat8& r9, const vfloat8& r10, const vfloat8& r11, + const vfloat8& r12, const vfloat8& r13, const vfloat8& r14, const vfloat8& r15, + vfloat16& c0, vfloat16& c1, vfloat16& c2, vfloat16& c3, + vfloat16& c4, vfloat16& c5, vfloat16& c6, vfloat16& c7) + { + return transpose(vfloat16(r0, r8), vfloat16(r1, r9), vfloat16(r2, r10), vfloat16(r3, r11), + vfloat16(r4, r12), vfloat16(r5, r13), vfloat16(r6, r14), vfloat16(r7, r15), + c0, c1, c2, c3, c4, c5, c6, c7); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat16 vreduce_add2(vfloat16 x) { return x + shuffle<1,0,3,2>(x); } + __forceinline vfloat16 vreduce_add4(vfloat16 x) { x = vreduce_add2(x); return x + shuffle<2,3,0,1>(x); } + __forceinline vfloat16 vreduce_add8(vfloat16 x) { x = vreduce_add4(x); return x + shuffle4<1,0,3,2>(x); } + __forceinline vfloat16 vreduce_add (vfloat16 x) { x = vreduce_add8(x); return x + shuffle4<2,3,0,1>(x); } + + __forceinline vfloat16 vreduce_min2(vfloat16 x) { return min(x, shuffle<1,0,3,2>(x)); } + __forceinline vfloat16 vreduce_min4(vfloat16 x) { x = vreduce_min2(x); return min(x, shuffle<2,3,0,1>(x)); } + __forceinline vfloat16 vreduce_min8(vfloat16 x) { x = vreduce_min4(x); return min(x, shuffle4<1,0,3,2>(x)); } + __forceinline vfloat16 vreduce_min (vfloat16 x) { x = vreduce_min8(x); return min(x, shuffle4<2,3,0,1>(x)); } + + __forceinline vfloat16 vreduce_max2(vfloat16 x) { return max(x, shuffle<1,0,3,2>(x)); } + __forceinline vfloat16 vreduce_max4(vfloat16 x) { x = vreduce_max2(x); return max(x, shuffle<2,3,0,1>(x)); } + __forceinline vfloat16 vreduce_max8(vfloat16 x) { x = vreduce_max4(x); return max(x, shuffle4<1,0,3,2>(x)); } + __forceinline vfloat16 vreduce_max (vfloat16 x) { x = vreduce_max8(x); return max(x, shuffle4<2,3,0,1>(x)); } + + __forceinline float reduce_add(const vfloat16& v) { return toScalar(vreduce_add(v)); } + __forceinline float reduce_min(const vfloat16& v) { return toScalar(vreduce_min(v)); } + __forceinline float reduce_max(const vfloat16& v) { return toScalar(vreduce_max(v)); } + + __forceinline size_t select_min(const vfloat16& v) { + return bsf(_mm512_kmov(_mm512_cmp_epi32_mask(_mm512_castps_si512(v),_mm512_castps_si512(vreduce_min(v)),_MM_CMPINT_EQ))); + } + + __forceinline size_t select_max(const vfloat16& v) { + return bsf(_mm512_kmov(_mm512_cmp_epi32_mask(_mm512_castps_si512(v),_mm512_castps_si512(vreduce_max(v)),_MM_CMPINT_EQ))); + } + + __forceinline size_t select_min(const vboolf16& valid, const vfloat16& v) + { + const vfloat16 a = select(valid,v,vfloat16(pos_inf)); + const vbool16 valid_min = valid & (a == vreduce_min(a)); + return bsf(movemask(any(valid_min) ? valid_min : valid)); + } + + __forceinline size_t select_max(const vboolf16& valid, const vfloat16& v) + { + const vfloat16 a = select(valid,v,vfloat16(neg_inf)); + const vbool16 valid_max = valid & (a == vreduce_max(a)); + return bsf(movemask(any(valid_max) ? valid_max : valid)); + } + + __forceinline vfloat16 prefix_sum(const vfloat16& a) + { + const vfloat16 z(zero); + vfloat16 v = a; + v = v + align_shift_right<16-1>(v,z); + v = v + align_shift_right<16-2>(v,z); + v = v + align_shift_right<16-4>(v,z); + v = v + align_shift_right<16-8>(v,z); + return v; + } + + __forceinline vfloat16 reverse_prefix_sum(const vfloat16& a) + { + const vfloat16 z(zero); + vfloat16 v = a; + v = v + align_shift_right<1>(z,v); + v = v + align_shift_right<2>(z,v); + v = v + align_shift_right<4>(z,v); + v = v + align_shift_right<8>(z,v); + return v; + } + + __forceinline vfloat16 prefix_min(const vfloat16& a) + { + const vfloat16 z(pos_inf); + vfloat16 v = a; + v = min(v,align_shift_right<16-1>(v,z)); + v = min(v,align_shift_right<16-2>(v,z)); + v = min(v,align_shift_right<16-4>(v,z)); + v = min(v,align_shift_right<16-8>(v,z)); + return v; + } + + __forceinline vfloat16 prefix_max(const vfloat16& a) + { + const vfloat16 z(neg_inf); + vfloat16 v = a; + v = max(v,align_shift_right<16-1>(v,z)); + v = max(v,align_shift_right<16-2>(v,z)); + v = max(v,align_shift_right<16-4>(v,z)); + v = max(v,align_shift_right<16-8>(v,z)); + return v; + } + + + __forceinline vfloat16 reverse_prefix_min(const vfloat16& a) + { + const vfloat16 z(pos_inf); + vfloat16 v = a; + v = min(v,align_shift_right<1>(z,v)); + v = min(v,align_shift_right<2>(z,v)); + v = min(v,align_shift_right<4>(z,v)); + v = min(v,align_shift_right<8>(z,v)); + return v; + } + + __forceinline vfloat16 reverse_prefix_max(const vfloat16& a) + { + const vfloat16 z(neg_inf); + vfloat16 v = a; + v = max(v,align_shift_right<1>(z,v)); + v = max(v,align_shift_right<2>(z,v)); + v = max(v,align_shift_right<4>(z,v)); + v = max(v,align_shift_right<8>(z,v)); + return v; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Memory load and store operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat16 loadAOS4to16f(const float& x, const float& y, const float& z) + { + vfloat16 f = zero; + f = select(0x1111,vfloat16::broadcast(&x),f); + f = select(0x2222,vfloat16::broadcast(&y),f); + f = select(0x4444,vfloat16::broadcast(&z),f); + return f; + } + + __forceinline vfloat16 loadAOS4to16f(unsigned int index, + const vfloat16& x, + const vfloat16& y, + const vfloat16& z) + { + vfloat16 f = zero; + f = select(0x1111,vfloat16::broadcast((float*)&x + index),f); + f = select(0x2222,vfloat16::broadcast((float*)&y + index),f); + f = select(0x4444,vfloat16::broadcast((float*)&z + index),f); + return f; + } + + __forceinline vfloat16 loadAOS4to16f(unsigned int index, + const vfloat16& x, + const vfloat16& y, + const vfloat16& z, + const vfloat16& fill) + { + vfloat16 f = fill; + f = select(0x1111,vfloat16::broadcast((float*)&x + index),f); + f = select(0x2222,vfloat16::broadcast((float*)&y + index),f); + f = select(0x4444,vfloat16::broadcast((float*)&z + index),f); + return f; + } + + __forceinline vfloat16 rcp_safe(const vfloat16& a) { + return rcp(select(a != vfloat16(zero), a, vfloat16(min_rcp_input))); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vfloat16& v) + { + cout << "<" << v[0]; + for (int i=1; i<16; i++) cout << ", " << v[i]; + cout << ">"; + return cout; + } +} diff --git a/thirdparty/embree/common/simd/vfloat4_sse2.h b/thirdparty/embree/common/simd/vfloat4_sse2.h new file mode 100644 index 000000000000..96f984cebd5e --- /dev/null +++ b/thirdparty/embree/common/simd/vfloat4_sse2.h @@ -0,0 +1,708 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 4-wide SSE float type */ + template<> + struct vfloat<4> + { + ALIGNED_STRUCT_(16); + + typedef vboolf4 Bool; + typedef vint4 Int; + typedef vfloat4 Float; + + enum { size = 4 }; // number of SIMD elements + union { __m128 v; float f[4]; int i[4]; }; // data + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat() {} + __forceinline vfloat(const vfloat4& other) { v = other.v; } + __forceinline vfloat4& operator =(const vfloat4& other) { v = other.v; return *this; } + + __forceinline vfloat(__m128 a) : v(a) {} + __forceinline operator const __m128&() const { return v; } + __forceinline operator __m128&() { return v; } + + __forceinline vfloat(float a) : v(_mm_set1_ps(a)) {} + __forceinline vfloat(float a, float b, float c, float d) : v(_mm_set_ps(d, c, b, a)) {} + + __forceinline explicit vfloat(const vint4& a) : v(_mm_cvtepi32_ps(a)) {} + __forceinline explicit vfloat(const vuint4& x) { + const __m128i a = _mm_and_si128(x,_mm_set1_epi32(0x7FFFFFFF)); + const __m128i b = _mm_and_si128(_mm_srai_epi32(x,31),_mm_set1_epi32(0x4F000000)); //0x4F000000 = 2^31 + const __m128 af = _mm_cvtepi32_ps(a); + const __m128 bf = _mm_castsi128_ps(b); + v = _mm_add_ps(af,bf); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat(ZeroTy) : v(_mm_setzero_ps()) {} + __forceinline vfloat(OneTy) : v(_mm_set1_ps(1.0f)) {} + __forceinline vfloat(PosInfTy) : v(_mm_set1_ps(pos_inf)) {} + __forceinline vfloat(NegInfTy) : v(_mm_set1_ps(neg_inf)) {} + __forceinline vfloat(StepTy) : v(_mm_set_ps(3.0f, 2.0f, 1.0f, 0.0f)) {} + __forceinline vfloat(NaNTy) : v(_mm_set1_ps(nan)) {} + __forceinline vfloat(UndefinedTy) : v(_mm_undefined_ps()) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline vfloat4 load (const void* a) { return _mm_load_ps((float*)a); } + static __forceinline vfloat4 loadu(const void* a) { return _mm_loadu_ps((float*)a); } + + static __forceinline void store (void* ptr, const vfloat4& v) { _mm_store_ps((float*)ptr,v); } + static __forceinline void storeu(void* ptr, const vfloat4& v) { _mm_storeu_ps((float*)ptr,v); } + +#if defined(__AVX512VL__) + + static __forceinline vfloat4 compact(const vboolf4& mask, vfloat4 &v) { + return _mm_mask_compress_ps(v, mask, v); + } + static __forceinline vfloat4 compact(const vboolf4& mask, vfloat4 &a, const vfloat4& b) { + return _mm_mask_compress_ps(a, mask, b); + } + + static __forceinline vfloat4 load (const vboolf4& mask, const void* ptr) { return _mm_mask_load_ps (_mm_setzero_ps(),mask,(float*)ptr); } + static __forceinline vfloat4 loadu(const vboolf4& mask, const void* ptr) { return _mm_mask_loadu_ps(_mm_setzero_ps(),mask,(float*)ptr); } + + static __forceinline void store (const vboolf4& mask, void* ptr, const vfloat4& v) { _mm_mask_store_ps ((float*)ptr,mask,v); } + static __forceinline void storeu(const vboolf4& mask, void* ptr, const vfloat4& v) { _mm_mask_storeu_ps((float*)ptr,mask,v); } +#elif defined(__AVX__) + static __forceinline vfloat4 load (const vboolf4& mask, const void* ptr) { return _mm_maskload_ps((float*)ptr,mask); } + static __forceinline vfloat4 loadu(const vboolf4& mask, const void* ptr) { return _mm_maskload_ps((float*)ptr,mask); } + + static __forceinline void store (const vboolf4& mask, void* ptr, const vfloat4& v) { _mm_maskstore_ps((float*)ptr,(__m128i)mask,v); } + static __forceinline void storeu(const vboolf4& mask, void* ptr, const vfloat4& v) { _mm_maskstore_ps((float*)ptr,(__m128i)mask,v); } +#else + static __forceinline vfloat4 load (const vboolf4& mask, const void* ptr) { return _mm_and_ps(_mm_load_ps ((float*)ptr),mask); } + static __forceinline vfloat4 loadu(const vboolf4& mask, const void* ptr) { return _mm_and_ps(_mm_loadu_ps((float*)ptr),mask); } + + static __forceinline void store (const vboolf4& mask, void* ptr, const vfloat4& v) { store (ptr,select(mask,v,load (ptr))); } + static __forceinline void storeu(const vboolf4& mask, void* ptr, const vfloat4& v) { storeu(ptr,select(mask,v,loadu(ptr))); } +#endif + +#if defined(__AVX__) + static __forceinline vfloat4 broadcast(const void* a) { return _mm_broadcast_ss((float*)a); } +#else + static __forceinline vfloat4 broadcast(const void* a) { return _mm_set1_ps(*(float*)a); } +#endif + + static __forceinline vfloat4 load_nt (const float* ptr) { +#if defined (__SSE4_1__) + return _mm_castsi128_ps(_mm_stream_load_si128((__m128i*)ptr)); +#else + return _mm_load_ps(ptr); +#endif + } + +#if defined(__SSE4_1__) + static __forceinline vfloat4 load(const char* ptr) { + return _mm_cvtepi32_ps(_mm_cvtepi8_epi32(_mm_loadu_si128((__m128i*)ptr))); + } +#else + static __forceinline vfloat4 load(const char* ptr) { + return vfloat4(ptr[0],ptr[1],ptr[2],ptr[3]); + } +#endif + +#if defined(__SSE4_1__) + static __forceinline vfloat4 load(const unsigned char* ptr) { + return _mm_cvtepi32_ps(_mm_cvtepu8_epi32(_mm_loadu_si128((__m128i*)ptr))); + } +#else + static __forceinline vfloat4 load(const unsigned char* ptr) { + //return _mm_cvtpu8_ps(*(__m64*)ptr); // don't enable, will use MMX instructions + return vfloat4(ptr[0],ptr[1],ptr[2],ptr[3]); + } +#endif + +#if defined(__SSE4_1__) + static __forceinline vfloat4 load(const short* ptr) { + return _mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_loadu_si128((__m128i*)ptr))); + } +#else + static __forceinline vfloat4 load(const short* ptr) { + return vfloat4(ptr[0],ptr[1],ptr[2],ptr[3]); + } +#endif + + static __forceinline vfloat4 load(const unsigned short* ptr) { + return _mm_mul_ps(vfloat4(vint4::load(ptr)),vfloat4(1.0f/65535.0f)); + } + + static __forceinline void store_nt(void* ptr, const vfloat4& v) + { +#if defined (__SSE4_1__) + _mm_stream_ps((float*)ptr,v); +#else + _mm_store_ps((float*)ptr,v); +#endif + } + + template + static __forceinline vfloat4 gather(const float* ptr, const vint4& index) { +#if defined(__AVX2__) + return _mm_i32gather_ps(ptr, index, scale); +#else + return vfloat4( + *(float*)(((char*)ptr)+scale*index[0]), + *(float*)(((char*)ptr)+scale*index[1]), + *(float*)(((char*)ptr)+scale*index[2]), + *(float*)(((char*)ptr)+scale*index[3])); +#endif + } + + template + static __forceinline vfloat4 gather(const vboolf4& mask, const float* ptr, const vint4& index) { + vfloat4 r = zero; +#if defined(__AVX512VL__) + return _mm_mmask_i32gather_ps(r, mask, index, ptr, scale); +#elif defined(__AVX2__) + return _mm_mask_i32gather_ps(r, ptr, index, mask, scale); +#else + if (likely(mask[0])) r[0] = *(float*)(((char*)ptr)+scale*index[0]); + if (likely(mask[1])) r[1] = *(float*)(((char*)ptr)+scale*index[1]); + if (likely(mask[2])) r[2] = *(float*)(((char*)ptr)+scale*index[2]); + if (likely(mask[3])) r[3] = *(float*)(((char*)ptr)+scale*index[3]); + return r; +#endif + } + + template + static __forceinline void scatter(void* ptr, const vint4& index, const vfloat4& v) + { +#if defined(__AVX512VL__) + _mm_i32scatter_ps((float*)ptr, index, v, scale); +#else + *(float*)(((char*)ptr)+scale*index[0]) = v[0]; + *(float*)(((char*)ptr)+scale*index[1]) = v[1]; + *(float*)(((char*)ptr)+scale*index[2]) = v[2]; + *(float*)(((char*)ptr)+scale*index[3]) = v[3]; +#endif + } + + template + static __forceinline void scatter(const vboolf4& mask, void* ptr, const vint4& index, const vfloat4& v) + { +#if defined(__AVX512VL__) + _mm_mask_i32scatter_ps((float*)ptr ,mask, index, v, scale); +#else + if (likely(mask[0])) *(float*)(((char*)ptr)+scale*index[0]) = v[0]; + if (likely(mask[1])) *(float*)(((char*)ptr)+scale*index[1]) = v[1]; + if (likely(mask[2])) *(float*)(((char*)ptr)+scale*index[2]) = v[2]; + if (likely(mask[3])) *(float*)(((char*)ptr)+scale*index[3]) = v[3]; +#endif + } + + static __forceinline void store(const vboolf4& mask, char* ptr, const vint4& ofs, const vfloat4& v) { + scatter<1>(mask,ptr,ofs,v); + } + static __forceinline void store(const vboolf4& mask, float* ptr, const vint4& ofs, const vfloat4& v) { + scatter<4>(mask,ptr,ofs,v); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const float& operator [](size_t index) const { assert(index < 4); return f[index]; } + __forceinline float& operator [](size_t index) { assert(index < 4); return f[index]; } + + friend __forceinline vfloat4 select(const vboolf4& m, const vfloat4& t, const vfloat4& f) { +#if defined(__AVX512VL__) + return _mm_mask_blend_ps(m, f, t); +#elif defined(__SSE4_1__) + return _mm_blendv_ps(f, t, m); +#else + return _mm_or_ps(_mm_and_ps(m, t), _mm_andnot_ps(m, f)); +#endif + } + }; + + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat4 asFloat(const vint4& a) { return _mm_castsi128_ps(a); } + __forceinline vint4 asInt (const vfloat4& a) { return _mm_castps_si128(a); } + __forceinline vuint4 asUInt (const vfloat4& a) { return _mm_castps_si128(a); } + + __forceinline vint4 toInt (const vfloat4& a) { return vint4(a); } + __forceinline vfloat4 toFloat(const vint4& a) { return vfloat4(a); } + + __forceinline vfloat4 operator +(const vfloat4& a) { return a; } + __forceinline vfloat4 operator -(const vfloat4& a) { return _mm_xor_ps(a, _mm_castsi128_ps(_mm_set1_epi32(0x80000000))); } + + __forceinline vfloat4 abs(const vfloat4& a) { return _mm_and_ps(a, _mm_castsi128_ps(_mm_set1_epi32(0x7fffffff))); } +#if defined(__AVX512VL__) + __forceinline vfloat4 sign(const vfloat4& a) { return _mm_mask_blend_ps(_mm_cmp_ps_mask(a, vfloat4(zero), _CMP_LT_OQ), vfloat4(one), -vfloat4(one)); } +#else + __forceinline vfloat4 sign(const vfloat4& a) { return blendv_ps(vfloat4(one), -vfloat4(one), _mm_cmplt_ps(a, vfloat4(zero))); } +#endif + __forceinline vfloat4 signmsk(const vfloat4& a) { return _mm_and_ps(a,_mm_castsi128_ps(_mm_set1_epi32(0x80000000))); } + + __forceinline vfloat4 rcp(const vfloat4& a) + { +#if defined(__AVX512VL__) + const vfloat4 r = _mm_rcp14_ps(a); +#else + const vfloat4 r = _mm_rcp_ps(a); +#endif + +#if defined(__AVX2__) + return _mm_mul_ps(r,_mm_fnmadd_ps(r, a, vfloat4(2.0f))); +#else + return _mm_mul_ps(r,_mm_sub_ps(vfloat4(2.0f), _mm_mul_ps(r, a))); +#endif + } + __forceinline vfloat4 sqr (const vfloat4& a) { return _mm_mul_ps(a,a); } + __forceinline vfloat4 sqrt(const vfloat4& a) { return _mm_sqrt_ps(a); } + + __forceinline vfloat4 rsqrt(const vfloat4& a) + { +#if defined(__AVX512VL__) + const vfloat4 r = _mm_rsqrt14_ps(a); +#else + const vfloat4 r = _mm_rsqrt_ps(a); +#endif + +#if defined(__AVX2__) + return _mm_fmadd_ps(_mm_set1_ps(1.5f), r, + _mm_mul_ps(_mm_mul_ps(_mm_mul_ps(a, _mm_set1_ps(-0.5f)), r), _mm_mul_ps(r, r))); +#else + return _mm_add_ps(_mm_mul_ps(_mm_set1_ps(1.5f), r), + _mm_mul_ps(_mm_mul_ps(_mm_mul_ps(a, _mm_set1_ps(-0.5f)), r), _mm_mul_ps(r, r))); +#endif + } + + __forceinline vboolf4 isnan(const vfloat4& a) { + const vfloat4 b = _mm_and_ps(a, _mm_castsi128_ps(_mm_set1_epi32(0x7fffffff))); +#if defined(__AVX512VL__) + return _mm_cmp_epi32_mask(_mm_castps_si128(b), _mm_set1_epi32(0x7f800000), _MM_CMPINT_GT); +#else + return _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_castps_si128(b), _mm_set1_epi32(0x7f800000))); +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat4 operator +(const vfloat4& a, const vfloat4& b) { return _mm_add_ps(a, b); } + __forceinline vfloat4 operator +(const vfloat4& a, float b) { return a + vfloat4(b); } + __forceinline vfloat4 operator +(float a, const vfloat4& b) { return vfloat4(a) + b; } + + __forceinline vfloat4 operator -(const vfloat4& a, const vfloat4& b) { return _mm_sub_ps(a, b); } + __forceinline vfloat4 operator -(const vfloat4& a, float b) { return a - vfloat4(b); } + __forceinline vfloat4 operator -(float a, const vfloat4& b) { return vfloat4(a) - b; } + + __forceinline vfloat4 operator *(const vfloat4& a, const vfloat4& b) { return _mm_mul_ps(a, b); } + __forceinline vfloat4 operator *(const vfloat4& a, float b) { return a * vfloat4(b); } + __forceinline vfloat4 operator *(float a, const vfloat4& b) { return vfloat4(a) * b; } + + __forceinline vfloat4 operator /(const vfloat4& a, const vfloat4& b) { return _mm_div_ps(a,b); } + __forceinline vfloat4 operator /(const vfloat4& a, float b) { return a/vfloat4(b); } + __forceinline vfloat4 operator /(float a, const vfloat4& b) { return vfloat4(a)/b; } + + __forceinline vfloat4 operator &(const vfloat4& a, const vfloat4& b) { return _mm_and_ps(a,b); } + __forceinline vfloat4 operator |(const vfloat4& a, const vfloat4& b) { return _mm_or_ps(a,b); } + __forceinline vfloat4 operator ^(const vfloat4& a, const vfloat4& b) { return _mm_xor_ps(a,b); } + __forceinline vfloat4 operator ^(const vfloat4& a, const vint4& b) { return _mm_xor_ps(a,_mm_castsi128_ps(b)); } + + __forceinline vfloat4 min(const vfloat4& a, const vfloat4& b) { return _mm_min_ps(a,b); } + __forceinline vfloat4 min(const vfloat4& a, float b) { return _mm_min_ps(a,vfloat4(b)); } + __forceinline vfloat4 min(float a, const vfloat4& b) { return _mm_min_ps(vfloat4(a),b); } + + __forceinline vfloat4 max(const vfloat4& a, const vfloat4& b) { return _mm_max_ps(a,b); } + __forceinline vfloat4 max(const vfloat4& a, float b) { return _mm_max_ps(a,vfloat4(b)); } + __forceinline vfloat4 max(float a, const vfloat4& b) { return _mm_max_ps(vfloat4(a),b); } + +#if defined(__SSE4_1__) + __forceinline vfloat4 mini(const vfloat4& a, const vfloat4& b) { + const vint4 ai = _mm_castps_si128(a); + const vint4 bi = _mm_castps_si128(b); + const vint4 ci = _mm_min_epi32(ai,bi); + return _mm_castsi128_ps(ci); + } + + __forceinline vfloat4 maxi(const vfloat4& a, const vfloat4& b) { + const vint4 ai = _mm_castps_si128(a); + const vint4 bi = _mm_castps_si128(b); + const vint4 ci = _mm_max_epi32(ai,bi); + return _mm_castsi128_ps(ci); + } + + __forceinline vfloat4 minui(const vfloat4& a, const vfloat4& b) { + const vint4 ai = _mm_castps_si128(a); + const vint4 bi = _mm_castps_si128(b); + const vint4 ci = _mm_min_epu32(ai,bi); + return _mm_castsi128_ps(ci); + } + + __forceinline vfloat4 maxui(const vfloat4& a, const vfloat4& b) { + const vint4 ai = _mm_castps_si128(a); + const vint4 bi = _mm_castps_si128(b); + const vint4 ci = _mm_max_epu32(ai,bi); + return _mm_castsi128_ps(ci); + } +#else + __forceinline vfloat4 mini(const vfloat4& a, const vfloat4& b) { + return min(a,b); + } + + __forceinline vfloat4 maxi(const vfloat4& a, const vfloat4& b) { + return max(a,b); + } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Ternary Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX2__) + __forceinline vfloat4 madd (const vfloat4& a, const vfloat4& b, const vfloat4& c) { return _mm_fmadd_ps(a,b,c); } + __forceinline vfloat4 msub (const vfloat4& a, const vfloat4& b, const vfloat4& c) { return _mm_fmsub_ps(a,b,c); } + __forceinline vfloat4 nmadd(const vfloat4& a, const vfloat4& b, const vfloat4& c) { return _mm_fnmadd_ps(a,b,c); } + __forceinline vfloat4 nmsub(const vfloat4& a, const vfloat4& b, const vfloat4& c) { return _mm_fnmsub_ps(a,b,c); } +#else + __forceinline vfloat4 madd (const vfloat4& a, const vfloat4& b, const vfloat4& c) { return a*b+c; } + __forceinline vfloat4 msub (const vfloat4& a, const vfloat4& b, const vfloat4& c) { return a*b-c; } + __forceinline vfloat4 nmadd(const vfloat4& a, const vfloat4& b, const vfloat4& c) { return -a*b+c;} + __forceinline vfloat4 nmsub(const vfloat4& a, const vfloat4& b, const vfloat4& c) { return -a*b-c; } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat4& operator +=(vfloat4& a, const vfloat4& b) { return a = a + b; } + __forceinline vfloat4& operator +=(vfloat4& a, float b) { return a = a + b; } + + __forceinline vfloat4& operator -=(vfloat4& a, const vfloat4& b) { return a = a - b; } + __forceinline vfloat4& operator -=(vfloat4& a, float b) { return a = a - b; } + + __forceinline vfloat4& operator *=(vfloat4& a, const vfloat4& b) { return a = a * b; } + __forceinline vfloat4& operator *=(vfloat4& a, float b) { return a = a * b; } + + __forceinline vfloat4& operator /=(vfloat4& a, const vfloat4& b) { return a = a / b; } + __forceinline vfloat4& operator /=(vfloat4& a, float b) { return a = a / b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX512VL__) + __forceinline vboolf4 operator ==(const vfloat4& a, const vfloat4& b) { return _mm_cmp_ps_mask(a, b, _MM_CMPINT_EQ); } + __forceinline vboolf4 operator !=(const vfloat4& a, const vfloat4& b) { return _mm_cmp_ps_mask(a, b, _MM_CMPINT_NE); } + __forceinline vboolf4 operator < (const vfloat4& a, const vfloat4& b) { return _mm_cmp_ps_mask(a, b, _MM_CMPINT_LT); } + __forceinline vboolf4 operator >=(const vfloat4& a, const vfloat4& b) { return _mm_cmp_ps_mask(a, b, _MM_CMPINT_GE); } + __forceinline vboolf4 operator > (const vfloat4& a, const vfloat4& b) { return _mm_cmp_ps_mask(a, b, _MM_CMPINT_GT); } + __forceinline vboolf4 operator <=(const vfloat4& a, const vfloat4& b) { return _mm_cmp_ps_mask(a, b, _MM_CMPINT_LE); } +#else + __forceinline vboolf4 operator ==(const vfloat4& a, const vfloat4& b) { return _mm_cmpeq_ps (a, b); } + __forceinline vboolf4 operator !=(const vfloat4& a, const vfloat4& b) { return _mm_cmpneq_ps(a, b); } + __forceinline vboolf4 operator < (const vfloat4& a, const vfloat4& b) { return _mm_cmplt_ps (a, b); } + __forceinline vboolf4 operator >=(const vfloat4& a, const vfloat4& b) { return _mm_cmpnlt_ps(a, b); } + __forceinline vboolf4 operator > (const vfloat4& a, const vfloat4& b) { return _mm_cmpnle_ps(a, b); } + __forceinline vboolf4 operator <=(const vfloat4& a, const vfloat4& b) { return _mm_cmple_ps (a, b); } +#endif + + __forceinline vboolf4 operator ==(const vfloat4& a, float b) { return a == vfloat4(b); } + __forceinline vboolf4 operator ==(float a, const vfloat4& b) { return vfloat4(a) == b; } + + __forceinline vboolf4 operator !=(const vfloat4& a, float b) { return a != vfloat4(b); } + __forceinline vboolf4 operator !=(float a, const vfloat4& b) { return vfloat4(a) != b; } + + __forceinline vboolf4 operator < (const vfloat4& a, float b) { return a < vfloat4(b); } + __forceinline vboolf4 operator < (float a, const vfloat4& b) { return vfloat4(a) < b; } + + __forceinline vboolf4 operator >=(const vfloat4& a, float b) { return a >= vfloat4(b); } + __forceinline vboolf4 operator >=(float a, const vfloat4& b) { return vfloat4(a) >= b; } + + __forceinline vboolf4 operator > (const vfloat4& a, float b) { return a > vfloat4(b); } + __forceinline vboolf4 operator > (float a, const vfloat4& b) { return vfloat4(a) > b; } + + __forceinline vboolf4 operator <=(const vfloat4& a, float b) { return a <= vfloat4(b); } + __forceinline vboolf4 operator <=(float a, const vfloat4& b) { return vfloat4(a) <= b; } + + __forceinline vboolf4 eq(const vfloat4& a, const vfloat4& b) { return a == b; } + __forceinline vboolf4 ne(const vfloat4& a, const vfloat4& b) { return a != b; } + __forceinline vboolf4 lt(const vfloat4& a, const vfloat4& b) { return a < b; } + __forceinline vboolf4 ge(const vfloat4& a, const vfloat4& b) { return a >= b; } + __forceinline vboolf4 gt(const vfloat4& a, const vfloat4& b) { return a > b; } + __forceinline vboolf4 le(const vfloat4& a, const vfloat4& b) { return a <= b; } + +#if defined(__AVX512VL__) + __forceinline vboolf4 eq(const vboolf4& mask, const vfloat4& a, const vfloat4& b) { return _mm_mask_cmp_ps_mask(mask, a, b, _MM_CMPINT_EQ); } + __forceinline vboolf4 ne(const vboolf4& mask, const vfloat4& a, const vfloat4& b) { return _mm_mask_cmp_ps_mask(mask, a, b, _MM_CMPINT_NE); } + __forceinline vboolf4 lt(const vboolf4& mask, const vfloat4& a, const vfloat4& b) { return _mm_mask_cmp_ps_mask(mask, a, b, _MM_CMPINT_LT); } + __forceinline vboolf4 ge(const vboolf4& mask, const vfloat4& a, const vfloat4& b) { return _mm_mask_cmp_ps_mask(mask, a, b, _MM_CMPINT_GE); } + __forceinline vboolf4 gt(const vboolf4& mask, const vfloat4& a, const vfloat4& b) { return _mm_mask_cmp_ps_mask(mask, a, b, _MM_CMPINT_GT); } + __forceinline vboolf4 le(const vboolf4& mask, const vfloat4& a, const vfloat4& b) { return _mm_mask_cmp_ps_mask(mask, a, b, _MM_CMPINT_LE); } +#else + __forceinline vboolf4 eq(const vboolf4& mask, const vfloat4& a, const vfloat4& b) { return mask & (a == b); } + __forceinline vboolf4 ne(const vboolf4& mask, const vfloat4& a, const vfloat4& b) { return mask & (a != b); } + __forceinline vboolf4 lt(const vboolf4& mask, const vfloat4& a, const vfloat4& b) { return mask & (a < b); } + __forceinline vboolf4 ge(const vboolf4& mask, const vfloat4& a, const vfloat4& b) { return mask & (a >= b); } + __forceinline vboolf4 gt(const vboolf4& mask, const vfloat4& a, const vfloat4& b) { return mask & (a > b); } + __forceinline vboolf4 le(const vboolf4& mask, const vfloat4& a, const vfloat4& b) { return mask & (a <= b); } +#endif + + template + __forceinline vfloat4 select(const vfloat4& t, const vfloat4& f) + { +#if defined(__SSE4_1__) + return _mm_blend_ps(f, t, mask); +#else + return select(vboolf4(mask), t, f); +#endif + } + + __forceinline vfloat4 lerp(const vfloat4& a, const vfloat4& b, const vfloat4& t) { + return madd(t,b-a,a); + } + + __forceinline bool isvalid(const vfloat4& v) { + return all((v > vfloat4(-FLT_LARGE)) & (v < vfloat4(+FLT_LARGE))); + } + + __forceinline bool is_finite(const vfloat4& a) { + return all((a >= vfloat4(-FLT_MAX)) & (a <= vfloat4(+FLT_MAX))); + } + + __forceinline bool is_finite(const vboolf4& valid, const vfloat4& a) { + return all(valid, (a >= vfloat4(-FLT_MAX)) & (a <= vfloat4(+FLT_MAX))); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Rounding Functions + //////////////////////////////////////////////////////////////////////////////// + +#if defined (__SSE4_1__) + __forceinline vfloat4 floor(const vfloat4& a) { return _mm_round_ps(a, _MM_FROUND_TO_NEG_INF ); } + __forceinline vfloat4 ceil (const vfloat4& a) { return _mm_round_ps(a, _MM_FROUND_TO_POS_INF ); } + __forceinline vfloat4 trunc(const vfloat4& a) { return _mm_round_ps(a, _MM_FROUND_TO_ZERO ); } + __forceinline vfloat4 round(const vfloat4& a) { return _mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT); } +#else + __forceinline vfloat4 floor(const vfloat4& a) { return vfloat4(floorf(a[0]),floorf(a[1]),floorf(a[2]),floorf(a[3])); } + __forceinline vfloat4 ceil (const vfloat4& a) { return vfloat4(ceilf (a[0]),ceilf (a[1]),ceilf (a[2]),ceilf (a[3])); } + __forceinline vfloat4 trunc(const vfloat4& a) { return vfloat4(truncf(a[0]),truncf(a[1]),truncf(a[2]),truncf(a[3])); } + __forceinline vfloat4 round(const vfloat4& a) { return vfloat4(roundf(a[0]),roundf(a[1]),roundf(a[2]),roundf(a[3])); } +#endif + __forceinline vfloat4 frac(const vfloat4& a) { return a-floor(a); } + + __forceinline vint4 floori(const vfloat4& a) { +#if defined(__SSE4_1__) + return vint4(floor(a)); +#else + return vint4(a-vfloat4(0.5f)); +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + /// Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat4 unpacklo(const vfloat4& a, const vfloat4& b) { return _mm_unpacklo_ps(a, b); } + __forceinline vfloat4 unpackhi(const vfloat4& a, const vfloat4& b) { return _mm_unpackhi_ps(a, b); } + + template + __forceinline vfloat4 shuffle(const vfloat4& v) { + return _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(v), _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template + __forceinline vfloat4 shuffle(const vfloat4& a, const vfloat4& b) { + return _mm_shuffle_ps(a, b, _MM_SHUFFLE(i3, i2, i1, i0)); + } + +#if defined (__SSSE3__) + __forceinline vfloat4 shuffle8(const vfloat4& a, const vint4& shuf) { + return _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(a), shuf)); + } +#endif + +#if defined(__SSE3__) + template<> __forceinline vfloat4 shuffle<0, 0, 2, 2>(const vfloat4& v) { return _mm_moveldup_ps(v); } + template<> __forceinline vfloat4 shuffle<1, 1, 3, 3>(const vfloat4& v) { return _mm_movehdup_ps(v); } + template<> __forceinline vfloat4 shuffle<0, 1, 0, 1>(const vfloat4& v) { return _mm_castpd_ps(_mm_movedup_pd(_mm_castps_pd(v))); } +#endif + + template + __forceinline vfloat4 shuffle(const vfloat4& v) { + return shuffle(v); + } + +#if defined (__SSE4_1__) && !defined(__GNUC__) + template __forceinline float extract(const vfloat4& a) { return _mm_cvtss_f32(_mm_extract_ps(a,i)); } +#else + template __forceinline float extract(const vfloat4& a) { return _mm_cvtss_f32(shuffle(a)); } +#endif + template<> __forceinline float extract<0>(const vfloat4& a) { return _mm_cvtss_f32(a); } + +#if defined (__SSE4_1__) + template __forceinline vfloat4 insert(const vfloat4& a, const vfloat4& b) { return _mm_insert_ps(a, b, (dst << 4) | (src << 6) | clr); } + template __forceinline vfloat4 insert(const vfloat4& a, const vfloat4& b) { return insert(a, b); } + template __forceinline vfloat4 insert(const vfloat4& a, const float b) { return insert(a, _mm_set_ss(b)); } +#else + template __forceinline vfloat4 insert(const vfloat4& a, const vfloat4& b) { vfloat4 c = a; c[dst&3] = b[src&3]; return c; } + template __forceinline vfloat4 insert(const vfloat4& a, float b) { vfloat4 c = a; c[dst&3] = b; return c; } +#endif + + __forceinline float toScalar(const vfloat4& v) { return _mm_cvtss_f32(v); } + + __forceinline vfloat4 broadcast4f(const vfloat4& a, size_t k) { + return vfloat4::broadcast(&a[k]); + } + + __forceinline vfloat4 shift_right_1(const vfloat4& x) { + return _mm_castsi128_ps(_mm_srli_si128(_mm_castps_si128(x), 4)); + } + +#if defined (__AVX2__) + __forceinline vfloat4 permute(const vfloat4 &a, const __m128i &index) { + return _mm_permutevar_ps(a,index); + } + + __forceinline vfloat4 broadcast1f(const void* a) { return _mm_broadcast_ss((float*)a); } + +#endif + +#if defined(__AVX512VL__) + template + __forceinline vfloat4 align_shift_right(const vfloat4& a, const vfloat4& b) { + return _mm_castsi128_ps(_mm_alignr_epi32(_mm_castps_si128(a), _mm_castps_si128(b), i)); + } +#endif + + + //////////////////////////////////////////////////////////////////////////////// + /// Sorting Network + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat4 sort_ascending(const vfloat4& v) + { + const vfloat4 a0 = v; + const vfloat4 b0 = shuffle<1,0,3,2>(a0); + const vfloat4 c0 = min(a0,b0); + const vfloat4 d0 = max(a0,b0); + const vfloat4 a1 = select<0x5 /* 0b0101 */>(c0,d0); + const vfloat4 b1 = shuffle<2,3,0,1>(a1); + const vfloat4 c1 = min(a1,b1); + const vfloat4 d1 = max(a1,b1); + const vfloat4 a2 = select<0x3 /* 0b0011 */>(c1,d1); + const vfloat4 b2 = shuffle<0,2,1,3>(a2); + const vfloat4 c2 = min(a2,b2); + const vfloat4 d2 = max(a2,b2); + const vfloat4 a3 = select<0x2 /* 0b0010 */>(c2,d2); + return a3; + } + + __forceinline vfloat4 sort_descending(const vfloat4& v) + { + const vfloat4 a0 = v; + const vfloat4 b0 = shuffle<1,0,3,2>(a0); + const vfloat4 c0 = max(a0,b0); + const vfloat4 d0 = min(a0,b0); + const vfloat4 a1 = select<0x5 /* 0b0101 */>(c0,d0); + const vfloat4 b1 = shuffle<2,3,0,1>(a1); + const vfloat4 c1 = max(a1,b1); + const vfloat4 d1 = min(a1,b1); + const vfloat4 a2 = select<0x3 /* 0b0011 */>(c1,d1); + const vfloat4 b2 = shuffle<0,2,1,3>(a2); + const vfloat4 c2 = max(a2,b2); + const vfloat4 d2 = min(a2,b2); + const vfloat4 a3 = select<0x2 /* 0b0010 */>(c2,d2); + return a3; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Transpose + //////////////////////////////////////////////////////////////////////////////// + + __forceinline void transpose(const vfloat4& r0, const vfloat4& r1, const vfloat4& r2, const vfloat4& r3, vfloat4& c0, vfloat4& c1, vfloat4& c2, vfloat4& c3) + { + vfloat4 l02 = unpacklo(r0,r2); + vfloat4 h02 = unpackhi(r0,r2); + vfloat4 l13 = unpacklo(r1,r3); + vfloat4 h13 = unpackhi(r1,r3); + c0 = unpacklo(l02,l13); + c1 = unpackhi(l02,l13); + c2 = unpacklo(h02,h13); + c3 = unpackhi(h02,h13); + } + + __forceinline void transpose(const vfloat4& r0, const vfloat4& r1, const vfloat4& r2, const vfloat4& r3, vfloat4& c0, vfloat4& c1, vfloat4& c2) + { + vfloat4 l02 = unpacklo(r0,r2); + vfloat4 h02 = unpackhi(r0,r2); + vfloat4 l13 = unpacklo(r1,r3); + vfloat4 h13 = unpackhi(r1,r3); + c0 = unpacklo(l02,l13); + c1 = unpackhi(l02,l13); + c2 = unpacklo(h02,h13); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat4 vreduce_min(const vfloat4& v) { vfloat4 h = min(shuffle<1,0,3,2>(v),v); return min(shuffle<2,3,0,1>(h),h); } + __forceinline vfloat4 vreduce_max(const vfloat4& v) { vfloat4 h = max(shuffle<1,0,3,2>(v),v); return max(shuffle<2,3,0,1>(h),h); } + __forceinline vfloat4 vreduce_add(const vfloat4& v) { vfloat4 h = shuffle<1,0,3,2>(v) + v ; return shuffle<2,3,0,1>(h) + h ; } + + __forceinline float reduce_min(const vfloat4& v) { return _mm_cvtss_f32(vreduce_min(v)); } + __forceinline float reduce_max(const vfloat4& v) { return _mm_cvtss_f32(vreduce_max(v)); } + __forceinline float reduce_add(const vfloat4& v) { return _mm_cvtss_f32(vreduce_add(v)); } + + __forceinline size_t select_min(const vboolf4& valid, const vfloat4& v) + { + const vfloat4 a = select(valid,v,vfloat4(pos_inf)); + const vbool4 valid_min = valid & (a == vreduce_min(a)); + return bsf(movemask(any(valid_min) ? valid_min : valid)); + } + __forceinline size_t select_max(const vboolf4& valid, const vfloat4& v) + { + const vfloat4 a = select(valid,v,vfloat4(neg_inf)); + const vbool4 valid_max = valid & (a == vreduce_max(a)); + return bsf(movemask(any(valid_max) ? valid_max : valid)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Euclidian Space Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline float dot(const vfloat4& a, const vfloat4& b) { + return reduce_add(a*b); + } + + __forceinline vfloat4 cross(const vfloat4& a, const vfloat4& b) + { + const vfloat4 a0 = a; + const vfloat4 b0 = shuffle<1,2,0,3>(b); + const vfloat4 a1 = shuffle<1,2,0,3>(a); + const vfloat4 b1 = b; + return shuffle<1,2,0,3>(msub(a0,b0,a1*b1)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vfloat4& a) { + return cout << "<" << a[0] << ", " << a[1] << ", " << a[2] << ", " << a[3] << ">"; + } + +} diff --git a/thirdparty/embree/common/simd/vfloat8_avx.h b/thirdparty/embree/common/simd/vfloat8_avx.h new file mode 100644 index 000000000000..09d7ccc71e96 --- /dev/null +++ b/thirdparty/embree/common/simd/vfloat8_avx.h @@ -0,0 +1,780 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 8-wide AVX float type */ + template<> + struct vfloat<8> + { + ALIGNED_STRUCT_(32); + + typedef vboolf8 Bool; + typedef vint8 Int; + typedef vfloat8 Float; + + enum { size = 8 }; // number of SIMD elements + union { __m256 v; float f[8]; int i[8]; }; // data + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat() {} + __forceinline vfloat(const vfloat8& other) { v = other.v; } + __forceinline vfloat8& operator =(const vfloat8& other) { v = other.v; return *this; } + + __forceinline vfloat(__m256 a) : v(a) {} + __forceinline operator const __m256&() const { return v; } + __forceinline operator __m256&() { return v; } + + __forceinline explicit vfloat(const vfloat4& a) : v(_mm256_insertf128_ps(_mm256_castps128_ps256(a),a,1)) {} + __forceinline vfloat(const vfloat4& a, const vfloat4& b) : v(_mm256_insertf128_ps(_mm256_castps128_ps256(a),b,1)) {} + + __forceinline explicit vfloat(const char* a) : v(_mm256_loadu_ps((const float*)a)) {} + __forceinline vfloat(float a) : v(_mm256_set1_ps(a)) {} + __forceinline vfloat(float a, float b) : v(_mm256_set_ps(b, a, b, a, b, a, b, a)) {} + __forceinline vfloat(float a, float b, float c, float d) : v(_mm256_set_ps(d, c, b, a, d, c, b, a)) {} + __forceinline vfloat(float a, float b, float c, float d, float e, float f, float g, float h) : v(_mm256_set_ps(h, g, f, e, d, c, b, a)) {} + + __forceinline explicit vfloat(__m256i a) : v(_mm256_cvtepi32_ps(a)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat(ZeroTy) : v(_mm256_setzero_ps()) {} + __forceinline vfloat(OneTy) : v(_mm256_set1_ps(1.0f)) {} + __forceinline vfloat(PosInfTy) : v(_mm256_set1_ps(pos_inf)) {} + __forceinline vfloat(NegInfTy) : v(_mm256_set1_ps(neg_inf)) {} + __forceinline vfloat(StepTy) : v(_mm256_set_ps(7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f)) {} + __forceinline vfloat(NaNTy) : v(_mm256_set1_ps(nan)) {} + __forceinline vfloat(UndefinedTy) : v(_mm256_undefined_ps()) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline vfloat8 broadcast(const void* a) { + return _mm256_broadcast_ss((float*)a); + } + + static __forceinline vfloat8 broadcast2(const float* a, const float* b) { +#if defined(__INTEL_COMPILER) + const vfloat8 v0 = _mm256_broadcast_ss(a); + const vfloat8 v1 = _mm256_broadcast_ss(b); + return _mm256_blend_ps(v1, v0, 0xf); +#else + return _mm256_set_ps(*b,*b,*b,*b,*a,*a,*a,*a); +#endif + } + + static __forceinline vfloat8 broadcast4f(const vfloat4* ptr) { + return _mm256_broadcast_ps((__m128*)ptr); + } + + static __forceinline vfloat8 load(const char* ptr) { +#if defined(__AVX2__) + return _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)ptr))); +#else + return vfloat8(vfloat4::load(ptr),vfloat4::load(ptr+4)); +#endif + } + + static __forceinline vfloat8 load(const unsigned char* ptr) { +#if defined(__AVX2__) + return _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadu_si128((__m128i*)ptr))); +#else + return vfloat8(vfloat4::load(ptr),vfloat4::load(ptr+4)); +#endif + } + + static __forceinline vfloat8 load(const short* ptr) { +#if defined(__AVX2__) + return _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm_loadu_si128((__m128i*)ptr))); +#else + return vfloat8(vfloat4::load(ptr),vfloat4::load(ptr+4)); +#endif + } + + static __forceinline vfloat8 load (const void* ptr) { return _mm256_load_ps((float*)ptr); } + static __forceinline vfloat8 loadu(const void* ptr) { return _mm256_loadu_ps((float*)ptr); } + + static __forceinline void store (void* ptr, const vfloat8& v) { return _mm256_store_ps((float*)ptr,v); } + static __forceinline void storeu(void* ptr, const vfloat8& v) { return _mm256_storeu_ps((float*)ptr,v); } + +#if defined(__AVX512VL__) + + static __forceinline vfloat8 compact(const vboolf8& mask, vfloat8 &v) { + return _mm256_mask_compress_ps(v, mask, v); + } + static __forceinline vfloat8 compact(const vboolf8& mask, vfloat8 &a, const vfloat8& b) { + return _mm256_mask_compress_ps(a, mask, b); + } + + static __forceinline vfloat8 load (const vboolf8& mask, const void* ptr) { return _mm256_mask_load_ps (_mm256_setzero_ps(),mask,(float*)ptr); } + static __forceinline vfloat8 loadu(const vboolf8& mask, const void* ptr) { return _mm256_mask_loadu_ps(_mm256_setzero_ps(),mask,(float*)ptr); } + + static __forceinline void store (const vboolf8& mask, void* ptr, const vfloat8& v) { _mm256_mask_store_ps ((float*)ptr,mask,v); } + static __forceinline void storeu(const vboolf8& mask, void* ptr, const vfloat8& v) { _mm256_mask_storeu_ps((float*)ptr,mask,v); } +#else + static __forceinline vfloat8 load (const vboolf8& mask, const void* ptr) { return _mm256_maskload_ps((float*)ptr,(__m256i)mask); } + static __forceinline vfloat8 loadu(const vboolf8& mask, const void* ptr) { return _mm256_maskload_ps((float*)ptr,(__m256i)mask); } + + static __forceinline void store (const vboolf8& mask, void* ptr, const vfloat8& v) { _mm256_maskstore_ps((float*)ptr,(__m256i)mask,v); } + static __forceinline void storeu(const vboolf8& mask, void* ptr, const vfloat8& v) { _mm256_maskstore_ps((float*)ptr,(__m256i)mask,v); } +#endif + +#if defined(__AVX2__) + static __forceinline vfloat8 load_nt(void* ptr) { + return _mm256_castsi256_ps(_mm256_stream_load_si256((__m256i*)ptr)); + } +#endif + + static __forceinline void store_nt(void* ptr, const vfloat8& v) { + _mm256_stream_ps((float*)ptr,v); + } + + template + static __forceinline vfloat8 gather(const float* ptr, const vint8& index) { +#if defined(__AVX2__) + return _mm256_i32gather_ps(ptr, index ,scale); +#else + return vfloat8( + *(float*)(((char*)ptr)+scale*index[0]), + *(float*)(((char*)ptr)+scale*index[1]), + *(float*)(((char*)ptr)+scale*index[2]), + *(float*)(((char*)ptr)+scale*index[3]), + *(float*)(((char*)ptr)+scale*index[4]), + *(float*)(((char*)ptr)+scale*index[5]), + *(float*)(((char*)ptr)+scale*index[6]), + *(float*)(((char*)ptr)+scale*index[7])); +#endif + } + + template + static __forceinline vfloat8 gather(const vboolf8& mask, const float* ptr, const vint8& index) { + vfloat8 r = zero; +#if defined(__AVX512VL__) + return _mm256_mmask_i32gather_ps(r, mask, index, ptr, scale); +#elif defined(__AVX2__) + return _mm256_mask_i32gather_ps(r, ptr, index, mask, scale); +#else + if (likely(mask[0])) r[0] = *(float*)(((char*)ptr)+scale*index[0]); + if (likely(mask[1])) r[1] = *(float*)(((char*)ptr)+scale*index[1]); + if (likely(mask[2])) r[2] = *(float*)(((char*)ptr)+scale*index[2]); + if (likely(mask[3])) r[3] = *(float*)(((char*)ptr)+scale*index[3]); + if (likely(mask[4])) r[4] = *(float*)(((char*)ptr)+scale*index[4]); + if (likely(mask[5])) r[5] = *(float*)(((char*)ptr)+scale*index[5]); + if (likely(mask[6])) r[6] = *(float*)(((char*)ptr)+scale*index[6]); + if (likely(mask[7])) r[7] = *(float*)(((char*)ptr)+scale*index[7]); + return r; + #endif + } + + template + static __forceinline void scatter(void* ptr, const vint8& ofs, const vfloat8& v) + { +#if defined(__AVX512VL__) + _mm256_i32scatter_ps((float*)ptr, ofs, v, scale); +#else + *(float*)(((char*)ptr)+scale*ofs[0]) = v[0]; + *(float*)(((char*)ptr)+scale*ofs[1]) = v[1]; + *(float*)(((char*)ptr)+scale*ofs[2]) = v[2]; + *(float*)(((char*)ptr)+scale*ofs[3]) = v[3]; + *(float*)(((char*)ptr)+scale*ofs[4]) = v[4]; + *(float*)(((char*)ptr)+scale*ofs[5]) = v[5]; + *(float*)(((char*)ptr)+scale*ofs[6]) = v[6]; + *(float*)(((char*)ptr)+scale*ofs[7]) = v[7]; +#endif + } + + template + static __forceinline void scatter(const vboolf8& mask, void* ptr, const vint8& ofs, const vfloat8& v) + { +#if defined(__AVX512VL__) + _mm256_mask_i32scatter_ps((float*)ptr, mask, ofs, v, scale); +#else + if (likely(mask[0])) *(float*)(((char*)ptr)+scale*ofs[0]) = v[0]; + if (likely(mask[1])) *(float*)(((char*)ptr)+scale*ofs[1]) = v[1]; + if (likely(mask[2])) *(float*)(((char*)ptr)+scale*ofs[2]) = v[2]; + if (likely(mask[3])) *(float*)(((char*)ptr)+scale*ofs[3]) = v[3]; + if (likely(mask[4])) *(float*)(((char*)ptr)+scale*ofs[4]) = v[4]; + if (likely(mask[5])) *(float*)(((char*)ptr)+scale*ofs[5]) = v[5]; + if (likely(mask[6])) *(float*)(((char*)ptr)+scale*ofs[6]) = v[6]; + if (likely(mask[7])) *(float*)(((char*)ptr)+scale*ofs[7]) = v[7]; +#endif + } + + static __forceinline void store(const vboolf8& mask, char* ptr, const vint8& ofs, const vfloat8& v) { + scatter<1>(mask,ptr,ofs,v); + } + static __forceinline void store(const vboolf8& mask, float* ptr, const vint8& ofs, const vfloat8& v) { + scatter<4>(mask,ptr,ofs,v); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const float& operator [](size_t index) const { assert(index < 8); return f[index]; } + __forceinline float& operator [](size_t index) { assert(index < 8); return f[index]; } + }; + + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat8 asFloat(const vint8& a) { return _mm256_castsi256_ps(a); } + __forceinline vint8 asInt (const vfloat8& a) { return _mm256_castps_si256(a); } + + __forceinline vint8 toInt (const vfloat8& a) { return vint8(a); } + __forceinline vfloat8 toFloat(const vint8& a) { return vfloat8(a); } + + __forceinline vfloat8 operator +(const vfloat8& a) { return a; } + __forceinline vfloat8 operator -(const vfloat8& a) { + const __m256 mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x80000000)); + return _mm256_xor_ps(a, mask); + } + __forceinline vfloat8 abs(const vfloat8& a) { + const __m256 mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff)); + return _mm256_and_ps(a, mask); + } + __forceinline vfloat8 sign (const vfloat8& a) { return _mm256_blendv_ps(vfloat8(one), -vfloat8(one), _mm256_cmp_ps(a, vfloat8(zero), _CMP_NGE_UQ)); } + __forceinline vfloat8 signmsk(const vfloat8& a) { return _mm256_and_ps(a,_mm256_castsi256_ps(_mm256_set1_epi32(0x80000000))); } + + + static __forceinline vfloat8 rcp(const vfloat8& a) + { +#if defined(__AVX512VL__) + const vfloat8 r = _mm256_rcp14_ps(a); +#else + const vfloat8 r = _mm256_rcp_ps(a); +#endif + +#if defined(__AVX2__) + return _mm256_mul_ps(r, _mm256_fnmadd_ps(r, a, vfloat8(2.0f))); +#else + return _mm256_mul_ps(r, _mm256_sub_ps(vfloat8(2.0f), _mm256_mul_ps(r, a))); +#endif + } + __forceinline vfloat8 sqr (const vfloat8& a) { return _mm256_mul_ps(a,a); } + __forceinline vfloat8 sqrt(const vfloat8& a) { return _mm256_sqrt_ps(a); } + + static __forceinline vfloat8 rsqrt(const vfloat8& a) + { +#if defined(__AVX512VL__) + const vfloat8 r = _mm256_rsqrt14_ps(a); +#else + const vfloat8 r = _mm256_rsqrt_ps(a); +#endif + +#if defined(__AVX2__) + return _mm256_fmadd_ps(_mm256_set1_ps(1.5f), r, + _mm256_mul_ps(_mm256_mul_ps(_mm256_mul_ps(a, _mm256_set1_ps(-0.5f)), r), _mm256_mul_ps(r, r))); +#else + return _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(1.5f), r), + _mm256_mul_ps(_mm256_mul_ps(_mm256_mul_ps(a, _mm256_set1_ps(-0.5f)), r), _mm256_mul_ps(r, r))); +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat8 operator +(const vfloat8& a, const vfloat8& b) { return _mm256_add_ps(a, b); } + __forceinline vfloat8 operator +(const vfloat8& a, float b) { return a + vfloat8(b); } + __forceinline vfloat8 operator +(float a, const vfloat8& b) { return vfloat8(a) + b; } + + __forceinline vfloat8 operator -(const vfloat8& a, const vfloat8& b) { return _mm256_sub_ps(a, b); } + __forceinline vfloat8 operator -(const vfloat8& a, float b) { return a - vfloat8(b); } + __forceinline vfloat8 operator -(float a, const vfloat8& b) { return vfloat8(a) - b; } + + __forceinline vfloat8 operator *(const vfloat8& a, const vfloat8& b) { return _mm256_mul_ps(a, b); } + __forceinline vfloat8 operator *(const vfloat8& a, float b) { return a * vfloat8(b); } + __forceinline vfloat8 operator *(float a, const vfloat8& b) { return vfloat8(a) * b; } + + __forceinline vfloat8 operator /(const vfloat8& a, const vfloat8& b) { return _mm256_div_ps(a, b); } + __forceinline vfloat8 operator /(const vfloat8& a, float b) { return a / vfloat8(b); } + __forceinline vfloat8 operator /(float a, const vfloat8& b) { return vfloat8(a) / b; } + + __forceinline vfloat8 operator &(const vfloat8& a, const vfloat8& b) { return _mm256_and_ps(a,b); } + __forceinline vfloat8 operator |(const vfloat8& a, const vfloat8& b) { return _mm256_or_ps(a,b); } + __forceinline vfloat8 operator ^(const vfloat8& a, const vfloat8& b) { return _mm256_xor_ps(a,b); } + __forceinline vfloat8 operator ^(const vfloat8& a, const vint8& b) { return _mm256_xor_ps(a,_mm256_castsi256_ps(b)); } + + __forceinline vfloat8 min(const vfloat8& a, const vfloat8& b) { return _mm256_min_ps(a, b); } + __forceinline vfloat8 min(const vfloat8& a, float b) { return _mm256_min_ps(a, vfloat8(b)); } + __forceinline vfloat8 min(float a, const vfloat8& b) { return _mm256_min_ps(vfloat8(a), b); } + + __forceinline vfloat8 max(const vfloat8& a, const vfloat8& b) { return _mm256_max_ps(a, b); } + __forceinline vfloat8 max(const vfloat8& a, float b) { return _mm256_max_ps(a, vfloat8(b)); } + __forceinline vfloat8 max(float a, const vfloat8& b) { return _mm256_max_ps(vfloat8(a), b); } + + /* need "static __forceinline for MSVC, otherwise we'll link the wrong version in debug mode */ +#if defined(__AVX2__) + + static __forceinline vfloat8 mini(const vfloat8& a, const vfloat8& b) { + const vint8 ai = _mm256_castps_si256(a); + const vint8 bi = _mm256_castps_si256(b); + const vint8 ci = _mm256_min_epi32(ai,bi); + return _mm256_castsi256_ps(ci); + } + + static __forceinline vfloat8 maxi(const vfloat8& a, const vfloat8& b) { + const vint8 ai = _mm256_castps_si256(a); + const vint8 bi = _mm256_castps_si256(b); + const vint8 ci = _mm256_max_epi32(ai,bi); + return _mm256_castsi256_ps(ci); + } + + static __forceinline vfloat8 minui(const vfloat8& a, const vfloat8& b) { + const vint8 ai = _mm256_castps_si256(a); + const vint8 bi = _mm256_castps_si256(b); + const vint8 ci = _mm256_min_epu32(ai,bi); + return _mm256_castsi256_ps(ci); + } + + static __forceinline vfloat8 maxui(const vfloat8& a, const vfloat8& b) { + const vint8 ai = _mm256_castps_si256(a); + const vint8 bi = _mm256_castps_si256(b); + const vint8 ci = _mm256_max_epu32(ai,bi); + return _mm256_castsi256_ps(ci); + } + +#else + + static __forceinline vfloat8 mini(const vfloat8& a, const vfloat8& b) { + return asFloat(min(asInt(a),asInt(b))); + } + + static __forceinline vfloat8 maxi(const vfloat8& a, const vfloat8& b) { + return asFloat(max(asInt(a),asInt(b))); + } + +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Ternary Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX2__) + static __forceinline vfloat8 madd (const vfloat8& a, const vfloat8& b, const vfloat8& c) { return _mm256_fmadd_ps(a,b,c); } + static __forceinline vfloat8 msub (const vfloat8& a, const vfloat8& b, const vfloat8& c) { return _mm256_fmsub_ps(a,b,c); } + static __forceinline vfloat8 nmadd (const vfloat8& a, const vfloat8& b, const vfloat8& c) { return _mm256_fnmadd_ps(a,b,c); } + static __forceinline vfloat8 nmsub (const vfloat8& a, const vfloat8& b, const vfloat8& c) { return _mm256_fnmsub_ps(a,b,c); } +#else + static __forceinline vfloat8 madd (const vfloat8& a, const vfloat8& b, const vfloat8& c) { return a*b+c; } + static __forceinline vfloat8 msub (const vfloat8& a, const vfloat8& b, const vfloat8& c) { return a*b-c; } + static __forceinline vfloat8 nmadd (const vfloat8& a, const vfloat8& b, const vfloat8& c) { return -a*b+c;} + static __forceinline vfloat8 nmsub (const vfloat8& a, const vfloat8& b, const vfloat8& c) { return -a*b-c; } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat8& operator +=(vfloat8& a, const vfloat8& b) { return a = a + b; } + __forceinline vfloat8& operator +=(vfloat8& a, float b) { return a = a + b; } + + __forceinline vfloat8& operator -=(vfloat8& a, const vfloat8& b) { return a = a - b; } + __forceinline vfloat8& operator -=(vfloat8& a, float b) { return a = a - b; } + + __forceinline vfloat8& operator *=(vfloat8& a, const vfloat8& b) { return a = a * b; } + __forceinline vfloat8& operator *=(vfloat8& a, float b) { return a = a * b; } + + __forceinline vfloat8& operator /=(vfloat8& a, const vfloat8& b) { return a = a / b; } + __forceinline vfloat8& operator /=(vfloat8& a, float b) { return a = a / b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX512VL__) + static __forceinline vboolf8 operator ==(const vfloat8& a, const vfloat8& b) { return _mm256_cmp_ps_mask(a, b, _MM_CMPINT_EQ); } + static __forceinline vboolf8 operator !=(const vfloat8& a, const vfloat8& b) { return _mm256_cmp_ps_mask(a, b, _MM_CMPINT_NE); } + static __forceinline vboolf8 operator < (const vfloat8& a, const vfloat8& b) { return _mm256_cmp_ps_mask(a, b, _MM_CMPINT_LT); } + static __forceinline vboolf8 operator >=(const vfloat8& a, const vfloat8& b) { return _mm256_cmp_ps_mask(a, b, _MM_CMPINT_GE); } + static __forceinline vboolf8 operator > (const vfloat8& a, const vfloat8& b) { return _mm256_cmp_ps_mask(a, b, _MM_CMPINT_GT); } + static __forceinline vboolf8 operator <=(const vfloat8& a, const vfloat8& b) { return _mm256_cmp_ps_mask(a, b, _MM_CMPINT_LE); } + + static __forceinline vfloat8 select(const vboolf8& m, const vfloat8& t, const vfloat8& f) { + return _mm256_mask_blend_ps(m, f, t); + } +#else + static __forceinline vboolf8 operator ==(const vfloat8& a, const vfloat8& b) { return _mm256_cmp_ps(a, b, _CMP_EQ_OQ); } + static __forceinline vboolf8 operator !=(const vfloat8& a, const vfloat8& b) { return _mm256_cmp_ps(a, b, _CMP_NEQ_UQ); } + static __forceinline vboolf8 operator < (const vfloat8& a, const vfloat8& b) { return _mm256_cmp_ps(a, b, _CMP_LT_OS); } + static __forceinline vboolf8 operator >=(const vfloat8& a, const vfloat8& b) { return _mm256_cmp_ps(a, b, _CMP_NLT_US); } + static __forceinline vboolf8 operator > (const vfloat8& a, const vfloat8& b) { return _mm256_cmp_ps(a, b, _CMP_NLE_US); } + static __forceinline vboolf8 operator <=(const vfloat8& a, const vfloat8& b) { return _mm256_cmp_ps(a, b, _CMP_LE_OS); } + + static __forceinline vfloat8 select(const vboolf8& m, const vfloat8& t, const vfloat8& f) { + return _mm256_blendv_ps(f, t, m); + } +#endif + + template + __forceinline vfloat8 select(const vfloat8& t, const vfloat8& f) { + return _mm256_blend_ps(f, t, mask); + } + + __forceinline vboolf8 operator ==(const vfloat8& a, const float& b) { return a == vfloat8(b); } + __forceinline vboolf8 operator ==(const float& a, const vfloat8& b) { return vfloat8(a) == b; } + + __forceinline vboolf8 operator !=(const vfloat8& a, const float& b) { return a != vfloat8(b); } + __forceinline vboolf8 operator !=(const float& a, const vfloat8& b) { return vfloat8(a) != b; } + + __forceinline vboolf8 operator < (const vfloat8& a, const float& b) { return a < vfloat8(b); } + __forceinline vboolf8 operator < (const float& a, const vfloat8& b) { return vfloat8(a) < b; } + + __forceinline vboolf8 operator >=(const vfloat8& a, const float& b) { return a >= vfloat8(b); } + __forceinline vboolf8 operator >=(const float& a, const vfloat8& b) { return vfloat8(a) >= b; } + + __forceinline vboolf8 operator > (const vfloat8& a, const float& b) { return a > vfloat8(b); } + __forceinline vboolf8 operator > (const float& a, const vfloat8& b) { return vfloat8(a) > b; } + + __forceinline vboolf8 operator <=(const vfloat8& a, const float& b) { return a <= vfloat8(b); } + __forceinline vboolf8 operator <=(const float& a, const vfloat8& b) { return vfloat8(a) <= b; } + + __forceinline vboolf8 eq(const vfloat8& a, const vfloat8& b) { return a == b; } + __forceinline vboolf8 ne(const vfloat8& a, const vfloat8& b) { return a != b; } + __forceinline vboolf8 lt(const vfloat8& a, const vfloat8& b) { return a < b; } + __forceinline vboolf8 ge(const vfloat8& a, const vfloat8& b) { return a >= b; } + __forceinline vboolf8 gt(const vfloat8& a, const vfloat8& b) { return a > b; } + __forceinline vboolf8 le(const vfloat8& a, const vfloat8& b) { return a <= b; } + +#if defined(__AVX512VL__) + static __forceinline vboolf8 eq(const vboolf8& mask, const vfloat8& a, const vfloat8& b) { return _mm256_mask_cmp_ps_mask(mask, a, b, _MM_CMPINT_EQ); } + static __forceinline vboolf8 ne(const vboolf8& mask, const vfloat8& a, const vfloat8& b) { return _mm256_mask_cmp_ps_mask(mask, a, b, _MM_CMPINT_NE); } + static __forceinline vboolf8 lt(const vboolf8& mask, const vfloat8& a, const vfloat8& b) { return _mm256_mask_cmp_ps_mask(mask, a, b, _MM_CMPINT_LT); } + static __forceinline vboolf8 ge(const vboolf8& mask, const vfloat8& a, const vfloat8& b) { return _mm256_mask_cmp_ps_mask(mask, a, b, _MM_CMPINT_GE); } + static __forceinline vboolf8 gt(const vboolf8& mask, const vfloat8& a, const vfloat8& b) { return _mm256_mask_cmp_ps_mask(mask, a, b, _MM_CMPINT_GT); } + static __forceinline vboolf8 le(const vboolf8& mask, const vfloat8& a, const vfloat8& b) { return _mm256_mask_cmp_ps_mask(mask, a, b, _MM_CMPINT_LE); } +#else + static __forceinline vboolf8 eq(const vboolf8& mask, const vfloat8& a, const vfloat8& b) { return mask & (a == b); } + static __forceinline vboolf8 ne(const vboolf8& mask, const vfloat8& a, const vfloat8& b) { return mask & (a != b); } + static __forceinline vboolf8 lt(const vboolf8& mask, const vfloat8& a, const vfloat8& b) { return mask & (a < b); } + static __forceinline vboolf8 ge(const vboolf8& mask, const vfloat8& a, const vfloat8& b) { return mask & (a >= b); } + static __forceinline vboolf8 gt(const vboolf8& mask, const vfloat8& a, const vfloat8& b) { return mask & (a > b); } + static __forceinline vboolf8 le(const vboolf8& mask, const vfloat8& a, const vfloat8& b) { return mask & (a <= b); } +#endif + + __forceinline vfloat8 lerp(const vfloat8& a, const vfloat8& b, const vfloat8& t) { + return madd(t,b-a,a); + } + + __forceinline bool isvalid (const vfloat8& v) { + return all((v > vfloat8(-FLT_LARGE)) & (v < vfloat8(+FLT_LARGE))); + } + + __forceinline bool is_finite (const vfloat8& a) { + return all((a >= vfloat8(-FLT_MAX)) & (a <= vfloat8(+FLT_MAX))); + } + + __forceinline bool is_finite (const vboolf8& valid, const vfloat8& a) { + return all(valid, (a >= vfloat8(-FLT_MAX)) & (a <= vfloat8(+FLT_MAX))); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Rounding Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat8 floor(const vfloat8& a) { return _mm256_round_ps(a, _MM_FROUND_TO_NEG_INF ); } + __forceinline vfloat8 ceil (const vfloat8& a) { return _mm256_round_ps(a, _MM_FROUND_TO_POS_INF ); } + __forceinline vfloat8 trunc(const vfloat8& a) { return _mm256_round_ps(a, _MM_FROUND_TO_ZERO ); } + __forceinline vfloat8 round(const vfloat8& a) { return _mm256_round_ps(a, _MM_FROUND_TO_NEAREST_INT); } + __forceinline vfloat8 frac (const vfloat8& a) { return a-floor(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat8 unpacklo(const vfloat8& a, const vfloat8& b) { return _mm256_unpacklo_ps(a, b); } + __forceinline vfloat8 unpackhi(const vfloat8& a, const vfloat8& b) { return _mm256_unpackhi_ps(a, b); } + + template + __forceinline vfloat8 shuffle(const vfloat8& v) { + return _mm256_permute_ps(v, _MM_SHUFFLE(i, i, i, i)); + } + + template + __forceinline vfloat8 shuffle4(const vfloat8& v) { + return _mm256_permute2f128_ps(v, v, (i1 << 4) | (i0 << 0)); + } + + template + __forceinline vfloat8 shuffle4(const vfloat8& a, const vfloat8& b) { + return _mm256_permute2f128_ps(a, b, (i1 << 4) | (i0 << 0)); + } + + template + __forceinline vfloat8 shuffle(const vfloat8& v) { + return _mm256_permute_ps(v, _MM_SHUFFLE(i3, i2, i1, i0)); + } + + template + __forceinline vfloat8 shuffle(const vfloat8& a, const vfloat8& b) { + return _mm256_shuffle_ps(a, b, _MM_SHUFFLE(i3, i2, i1, i0)); + } + + template<> __forceinline vfloat8 shuffle<0, 0, 2, 2>(const vfloat8& v) { return _mm256_moveldup_ps(v); } + template<> __forceinline vfloat8 shuffle<1, 1, 3, 3>(const vfloat8& v) { return _mm256_movehdup_ps(v); } + template<> __forceinline vfloat8 shuffle<0, 1, 0, 1>(const vfloat8& v) { return _mm256_castpd_ps(_mm256_movedup_pd(_mm256_castps_pd(v))); } + + __forceinline vfloat8 broadcast(const float* ptr) { return _mm256_broadcast_ss(ptr); } + template __forceinline vfloat8 insert4(const vfloat8& a, const vfloat4& b) { return _mm256_insertf128_ps(a, b, i); } + template __forceinline vfloat4 extract4 (const vfloat8& a) { return _mm256_extractf128_ps(a, i); } + template<> __forceinline vfloat4 extract4<0>(const vfloat8& a) { return _mm256_castps256_ps128(a); } + + __forceinline float toScalar(const vfloat8& v) { return _mm_cvtss_f32(_mm256_castps256_ps128(v)); } + + __forceinline vfloat8 assign(const vfloat4& a) { return _mm256_castps128_ps256(a); } + +#if defined (__AVX2__) + static __forceinline vfloat8 permute(const vfloat8& a, const __m256i& index) { + return _mm256_permutevar8x32_ps(a, index); + } +#endif + +#if defined(__AVX512VL__) + template + static __forceinline vfloat8 align_shift_right(const vfloat8& a, const vfloat8& b) { + return _mm256_castsi256_ps(_mm256_alignr_epi32(_mm256_castps_si256(a), _mm256_castps_si256(b), i)); + } +#endif + +#if defined (__AVX_I__) + template + static __forceinline vint4 convert_to_hf16(const vfloat8& a) { + return _mm256_cvtps_ph(a, mode); + } + + static __forceinline vfloat8 convert_from_hf16(const vint4& a) { + return _mm256_cvtph_ps(a); + } +#endif + + __forceinline vfloat4 broadcast4f(const vfloat8& a, const size_t k) { + return vfloat4::broadcast(&a[k]); + } + + __forceinline vfloat8 broadcast8f(const vfloat8& a, const size_t k) { + return vfloat8::broadcast(&a[k]); + } + +#if defined(__AVX512VL__) + static __forceinline vfloat8 shift_right_1(const vfloat8& x) { + return align_shift_right<1>(zero,x); + } +#else + static __forceinline vfloat8 shift_right_1(const vfloat8& x) { + const vfloat8 t0 = shuffle<1,2,3,0>(x); + const vfloat8 t1 = shuffle4<1,0>(t0); + return _mm256_blend_ps(t0,t1,0x88); + } +#endif + + __forceinline vint8 floori(const vfloat8& a) { + return vint8(floor(a)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Transpose + //////////////////////////////////////////////////////////////////////////////// + + __forceinline void transpose(const vfloat8& r0, const vfloat8& r1, const vfloat8& r2, const vfloat8& r3, vfloat8& c0, vfloat8& c1, vfloat8& c2, vfloat8& c3) + { + vfloat8 l02 = unpacklo(r0,r2); + vfloat8 h02 = unpackhi(r0,r2); + vfloat8 l13 = unpacklo(r1,r3); + vfloat8 h13 = unpackhi(r1,r3); + c0 = unpacklo(l02,l13); + c1 = unpackhi(l02,l13); + c2 = unpacklo(h02,h13); + c3 = unpackhi(h02,h13); + } + + __forceinline void transpose(const vfloat8& r0, const vfloat8& r1, const vfloat8& r2, const vfloat8& r3, vfloat8& c0, vfloat8& c1, vfloat8& c2) + { + vfloat8 l02 = unpacklo(r0,r2); + vfloat8 h02 = unpackhi(r0,r2); + vfloat8 l13 = unpacklo(r1,r3); + vfloat8 h13 = unpackhi(r1,r3); + c0 = unpacklo(l02,l13); + c1 = unpackhi(l02,l13); + c2 = unpacklo(h02,h13); + } + + __forceinline void transpose(const vfloat8& r0, const vfloat8& r1, const vfloat8& r2, const vfloat8& r3, const vfloat8& r4, const vfloat8& r5, const vfloat8& r6, const vfloat8& r7, + vfloat8& c0, vfloat8& c1, vfloat8& c2, vfloat8& c3, vfloat8& c4, vfloat8& c5, vfloat8& c6, vfloat8& c7) + { + vfloat8 h0,h1,h2,h3; transpose(r0,r1,r2,r3,h0,h1,h2,h3); + vfloat8 h4,h5,h6,h7; transpose(r4,r5,r6,r7,h4,h5,h6,h7); + c0 = shuffle4<0,2>(h0,h4); + c1 = shuffle4<0,2>(h1,h5); + c2 = shuffle4<0,2>(h2,h6); + c3 = shuffle4<0,2>(h3,h7); + c4 = shuffle4<1,3>(h0,h4); + c5 = shuffle4<1,3>(h1,h5); + c6 = shuffle4<1,3>(h2,h6); + c7 = shuffle4<1,3>(h3,h7); + } + + __forceinline void transpose(const vfloat4& r0, const vfloat4& r1, const vfloat4& r2, const vfloat4& r3, const vfloat4& r4, const vfloat4& r5, const vfloat4& r6, const vfloat4& r7, + vfloat8& c0, vfloat8& c1, vfloat8& c2, vfloat8& c3) + { + transpose(vfloat8(r0,r4), vfloat8(r1,r5), vfloat8(r2,r6), vfloat8(r3,r7), c0, c1, c2, c3); + } + + __forceinline void transpose(const vfloat4& r0, const vfloat4& r1, const vfloat4& r2, const vfloat4& r3, const vfloat4& r4, const vfloat4& r5, const vfloat4& r6, const vfloat4& r7, + vfloat8& c0, vfloat8& c1, vfloat8& c2) + { + transpose(vfloat8(r0,r4), vfloat8(r1,r5), vfloat8(r2,r6), vfloat8(r3,r7), c0, c1, c2); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat8 vreduce_min2(const vfloat8& v) { return min(v,shuffle<1,0,3,2>(v)); } + __forceinline vfloat8 vreduce_min4(const vfloat8& v) { vfloat8 v1 = vreduce_min2(v); return min(v1,shuffle<2,3,0,1>(v1)); } + __forceinline vfloat8 vreduce_min (const vfloat8& v) { vfloat8 v1 = vreduce_min4(v); return min(v1,shuffle4<1,0>(v1)); } + + __forceinline vfloat8 vreduce_max2(const vfloat8& v) { return max(v,shuffle<1,0,3,2>(v)); } + __forceinline vfloat8 vreduce_max4(const vfloat8& v) { vfloat8 v1 = vreduce_max2(v); return max(v1,shuffle<2,3,0,1>(v1)); } + __forceinline vfloat8 vreduce_max (const vfloat8& v) { vfloat8 v1 = vreduce_max4(v); return max(v1,shuffle4<1,0>(v1)); } + + __forceinline vfloat8 vreduce_add2(const vfloat8& v) { return v + shuffle<1,0,3,2>(v); } + __forceinline vfloat8 vreduce_add4(const vfloat8& v) { vfloat8 v1 = vreduce_add2(v); return v1 + shuffle<2,3,0,1>(v1); } + __forceinline vfloat8 vreduce_add (const vfloat8& v) { vfloat8 v1 = vreduce_add4(v); return v1 + shuffle4<1,0>(v1); } + + __forceinline float reduce_min(const vfloat8& v) { return toScalar(vreduce_min(v)); } + __forceinline float reduce_max(const vfloat8& v) { return toScalar(vreduce_max(v)); } + __forceinline float reduce_add(const vfloat8& v) { return toScalar(vreduce_add(v)); } + + __forceinline size_t select_min(const vboolf8& valid, const vfloat8& v) + { + const vfloat8 a = select(valid,v,vfloat8(pos_inf)); + const vbool8 valid_min = valid & (a == vreduce_min(a)); + return bsf(movemask(any(valid_min) ? valid_min : valid)); + } + + __forceinline size_t select_max(const vboolf8& valid, const vfloat8& v) + { + const vfloat8 a = select(valid,v,vfloat8(neg_inf)); + const vbool8 valid_max = valid & (a == vreduce_max(a)); + return bsf(movemask(any(valid_max) ? valid_max : valid)); + } + + + //////////////////////////////////////////////////////////////////////////////// + /// Euclidian Space Operators (pairs of Vec3fa's) + //////////////////////////////////////////////////////////////////////////////// + + //__forceinline vfloat8 dot(const vfloat8& a, const vfloat8& b) { + // return vreduce_add4(a*b); + //} + + __forceinline vfloat8 dot(const vfloat8& a, const vfloat8& b) { + return _mm256_dp_ps(a,b,0x7F); + } + + __forceinline vfloat8 cross(const vfloat8& a, const vfloat8& b) + { + const vfloat8 a0 = a; + const vfloat8 b0 = shuffle<1,2,0,3>(b); + const vfloat8 a1 = shuffle<1,2,0,3>(a); + const vfloat8 b1 = b; + return shuffle<1,2,0,3>(msub(a0,b0,a1*b1)); + } + + //__forceinline float sqr_length (const vfloat<8>& a) { return dot(a,a); } + //__forceinline float rcp_length (const vfloat<8>& a) { return rsqrt(dot(a,a)); } + //__forceinline float rcp_length2(const vfloat<8>& a) { return rcp(dot(a,a)); } + //__forceinline float length (const vfloat<8>& a) { return sqrt(dot(a,a)); } + __forceinline vfloat<8> normalize(const vfloat<8>& a) { return a*rsqrt(dot(a,a)); } + //__forceinline float distance(const vfloat<8>& a, const vfloat<8>& b) { return length(a-b); } + //__forceinline float halfArea(const vfloat<8>& d) { return madd(d.x,(d.y+d.z),d.y*d.z); } + //__forceinline float area (const vfloat<8>& d) { return 2.0f*halfArea(d); } + //__forceinline vfloat<8> reflect(const vfloat<8>& V, const vfloat<8>& N) { return 2.0f*dot(V,N)*N-V; } + + //__forceinline vfloat<8> normalize_safe(const vfloat<8>& a) { + // const float d = dot(a,a); if (unlikely(d == 0.0f)) return a; else return a*rsqrt(d); + //} + + //////////////////////////////////////////////////////////////////////////////// + /// In Register Sorting + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vfloat8 sort_ascending(const vfloat8& v) + { + const vfloat8 a0 = v; + const vfloat8 b0 = shuffle<1,0,3,2>(a0); + const vfloat8 c0 = min(a0,b0); + const vfloat8 d0 = max(a0,b0); + const vfloat8 a1 = select<0x99 /* 0b10011001 */>(c0,d0); + const vfloat8 b1 = shuffle<2,3,0,1>(a1); + const vfloat8 c1 = min(a1,b1); + const vfloat8 d1 = max(a1,b1); + const vfloat8 a2 = select<0xc3 /* 0b11000011 */>(c1,d1); + const vfloat8 b2 = shuffle<1,0,3,2>(a2); + const vfloat8 c2 = min(a2,b2); + const vfloat8 d2 = max(a2,b2); + const vfloat8 a3 = select<0xa5 /* 0b10100101 */>(c2,d2); + const vfloat8 b3 = shuffle4<1,0>(a3); + const vfloat8 c3 = min(a3,b3); + const vfloat8 d3 = max(a3,b3); + const vfloat8 a4 = select<0xf /* 0b00001111 */>(c3,d3); + const vfloat8 b4 = shuffle<2,3,0,1>(a4); + const vfloat8 c4 = min(a4,b4); + const vfloat8 d4 = max(a4,b4); + const vfloat8 a5 = select<0x33 /* 0b00110011 */>(c4,d4); + const vfloat8 b5 = shuffle<1,0,3,2>(a5); + const vfloat8 c5 = min(a5,b5); + const vfloat8 d5 = max(a5,b5); + const vfloat8 a6 = select<0x55 /* 0b01010101 */>(c5,d5); + return a6; + } + + __forceinline vfloat8 sort_descending(const vfloat8& v) + { + const vfloat8 a0 = v; + const vfloat8 b0 = shuffle<1,0,3,2>(a0); + const vfloat8 c0 = max(a0,b0); + const vfloat8 d0 = min(a0,b0); + const vfloat8 a1 = select<0x99 /* 0b10011001 */>(c0,d0); + const vfloat8 b1 = shuffle<2,3,0,1>(a1); + const vfloat8 c1 = max(a1,b1); + const vfloat8 d1 = min(a1,b1); + const vfloat8 a2 = select<0xc3 /* 0b11000011 */>(c1,d1); + const vfloat8 b2 = shuffle<1,0,3,2>(a2); + const vfloat8 c2 = max(a2,b2); + const vfloat8 d2 = min(a2,b2); + const vfloat8 a3 = select<0xa5 /* 0b10100101 */>(c2,d2); + const vfloat8 b3 = shuffle4<1,0>(a3); + const vfloat8 c3 = max(a3,b3); + const vfloat8 d3 = min(a3,b3); + const vfloat8 a4 = select<0xf /* 0b00001111 */>(c3,d3); + const vfloat8 b4 = shuffle<2,3,0,1>(a4); + const vfloat8 c4 = max(a4,b4); + const vfloat8 d4 = min(a4,b4); + const vfloat8 a5 = select<0x33 /* 0b00110011 */>(c4,d4); + const vfloat8 b5 = shuffle<1,0,3,2>(a5); + const vfloat8 c5 = max(a5,b5); + const vfloat8 d5 = min(a5,b5); + const vfloat8 a6 = select<0x55 /* 0b01010101 */>(c5,d5); + return a6; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vfloat8& a) { + return cout << "<" << a[0] << ", " << a[1] << ", " << a[2] << ", " << a[3] << ", " << a[4] << ", " << a[5] << ", " << a[6] << ", " << a[7] << ">"; + } +} diff --git a/thirdparty/embree/common/simd/vint16_avx512.h b/thirdparty/embree/common/simd/vint16_avx512.h new file mode 100644 index 000000000000..34e3d5ca0785 --- /dev/null +++ b/thirdparty/embree/common/simd/vint16_avx512.h @@ -0,0 +1,490 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 16-wide AVX-512 integer type */ + template<> + struct vint<16> + { + ALIGNED_STRUCT_(64); + + typedef vboolf16 Bool; + typedef vint16 Int; + typedef vfloat16 Float; + + enum { size = 16 }; // number of SIMD elements + union { // data + __m512i v; + int i[16]; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint() {} + __forceinline vint(const vint16& t) { v = t.v; } + __forceinline vint16& operator =(const vint16& f) { v = f.v; return *this; } + + __forceinline vint(const __m512i& t) { v = t; } + __forceinline operator __m512i() const { return v; } + __forceinline operator __m256i() const { return _mm512_castsi512_si256(v); } + + __forceinline vint(int i) { + v = _mm512_set1_epi32(i); + } + + __forceinline vint(int a, int b, int c, int d) { + v = _mm512_set4_epi32(d,c,b,a); + } + + __forceinline vint(int a0 , int a1 , int a2 , int a3, + int a4 , int a5 , int a6 , int a7, + int a8 , int a9 , int a10, int a11, + int a12, int a13, int a14, int a15) + { + v = _mm512_set_epi32(a15,a14,a13,a12,a11,a10,a9,a8,a7,a6,a5,a4,a3,a2,a1,a0); + } + + __forceinline vint(const vint4& i) { + v = _mm512_broadcast_i32x4(i); + } + + __forceinline vint(const vint4& a, const vint4& b, const vint4& c, const vint4& d) { + v = _mm512_castsi128_si512(a); + v = _mm512_inserti32x4(v, b, 1); + v = _mm512_inserti32x4(v, c, 2); + v = _mm512_inserti32x4(v, d, 3); + } + + __forceinline vint(const vint8& i) { + v = _mm512_castps_si512(_mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_castsi256_pd(i)))); + } + + __forceinline vint(const vint8& a, const vint8& b) { + v = _mm512_castsi256_si512(a); + v = _mm512_inserti64x4(v, b, 1); + } + + __forceinline explicit vint(const __m512& f) { + v = _mm512_cvtps_epi32(f); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint(ZeroTy) : v(_mm512_setzero_epi32()) {} + __forceinline vint(OneTy) : v(_mm512_set1_epi32(1)) {} + __forceinline vint(PosInfTy) : v(_mm512_set1_epi32(pos_inf)) {} + __forceinline vint(NegInfTy) : v(_mm512_set1_epi32(neg_inf)) {} + __forceinline vint(StepTy) : v(_mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)) {} + __forceinline vint(ReverseStepTy) : v(_mm512_setr_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline vint16 load (const void* addr) { return _mm512_load_si512((int*)addr); } + + static __forceinline vint16 load(const unsigned char* ptr) { return _mm512_cvtepu8_epi32(_mm_load_si128((__m128i*)ptr)); } + static __forceinline vint16 load(const unsigned short* ptr) { return _mm512_cvtepu16_epi32(_mm256_load_si256((__m256i*)ptr)); } + + static __forceinline vint16 loadu(const unsigned char* ptr) { return _mm512_cvtepu8_epi32(_mm_loadu_si128((__m128i*)ptr)); } + static __forceinline vint16 loadu(const unsigned short* ptr) { return _mm512_cvtepu16_epi32(_mm256_loadu_si256((__m256i*)ptr)); } + + static __forceinline vint16 loadu(const void* addr) { return _mm512_loadu_si512(addr); } + + static __forceinline vint16 load (const vboolf16& mask, const void* addr) { return _mm512_mask_load_epi32 (_mm512_setzero_epi32(),mask,addr); } + static __forceinline vint16 loadu(const vboolf16& mask, const void* addr) { return _mm512_mask_loadu_epi32(_mm512_setzero_epi32(),mask,addr); } + + static __forceinline void store (void* ptr, const vint16& v) { _mm512_store_si512 (ptr,v); } + static __forceinline void storeu(void* ptr, const vint16& v) { _mm512_storeu_si512(ptr,v); } + + static __forceinline void store (const vboolf16& mask, void* addr, const vint16& v2) { _mm512_mask_store_epi32(addr,mask,v2); } + static __forceinline void storeu(const vboolf16& mask, void* ptr, const vint16& f) { _mm512_mask_storeu_epi32((int*)ptr,mask,f); } + + static __forceinline void store_nt(void* __restrict__ ptr, const vint16& a) { _mm512_stream_si512((__m512i*)ptr,a); } + + /* pass by value to avoid compiler generating inefficient code */ + static __forceinline void storeu_compact(const vboolf16 mask, void* addr, vint16 reg) { + _mm512_mask_compressstoreu_epi32(addr,mask,reg); + } + + static __forceinline void storeu_compact_single(const vboolf16 mask, void* addr, vint16 reg) { + //_mm512_mask_compressstoreu_epi32(addr,mask,reg); + *(float*)addr = mm512_cvtss_f32(_mm512_mask_compress_ps(_mm512_castsi512_ps(reg),mask,_mm512_castsi512_ps(reg))); + } + + static __forceinline vint16 compact64bit(const vboolf16& mask, vint16 &v) { + return _mm512_mask_compress_epi64(v,mask,v); + } + + static __forceinline vint16 compact(const vboolf16& mask, vint16 &v) { + return _mm512_mask_compress_epi32(v,mask,v); + } + + static __forceinline vint16 compact(const vboolf16& mask, const vint16 &a, vint16 &b) { + return _mm512_mask_compress_epi32(a,mask,b); + } + + static __forceinline vint16 expand(const vboolf16& mask, const vint16& a, vint16& b) { + return _mm512_mask_expand_epi32(b,mask,a); + } + + template + static __forceinline vint16 gather(const int* ptr, const vint16& index) { + return _mm512_i32gather_epi32(index,ptr,scale); + } + + template + static __forceinline vint16 gather(const vboolf16& mask, const int* ptr, const vint16& index) { + return _mm512_mask_i32gather_epi32(_mm512_undefined_epi32(),mask,index,ptr,scale); + } + + template + static __forceinline vint16 gather(const vboolf16& mask, vint16& dest, const int* ptr, const vint16& index) { + return _mm512_mask_i32gather_epi32(dest,mask,index,ptr,scale); + } + + template + static __forceinline void scatter(int* ptr, const vint16& index, const vint16& v) { + _mm512_i32scatter_epi32((int*)ptr,index,v,scale); + } + + template + static __forceinline void scatter(const vboolf16& mask, int* ptr, const vint16& index, const vint16& v) { + _mm512_mask_i32scatter_epi32((int*)ptr,mask,index,v,scale); + } + + static __forceinline vint16 broadcast64bit(size_t v) { + return _mm512_set1_epi64(v); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline int& operator [](size_t index) { assert(index < 16); return i[index]; } + __forceinline const int& operator [](size_t index) const { assert(index < 16); return i[index]; } + + __forceinline unsigned int uint (size_t index) const { assert(index < 16); return ((unsigned int*)i)[index]; } + __forceinline size_t& uint64_t(size_t index) const { assert(index < 8); return ((size_t*)i)[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf16 asBool(const vint16& a) { return _mm512_movepi32_mask(a); } + + __forceinline vint16 operator +(const vint16& a) { return a; } + __forceinline vint16 operator -(const vint16& a) { return _mm512_sub_epi32(_mm512_setzero_epi32(), a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint16 operator +(const vint16& a, const vint16& b) { return _mm512_add_epi32(a, b); } + __forceinline vint16 operator +(const vint16& a, int b) { return a + vint16(b); } + __forceinline vint16 operator +(int a, const vint16& b) { return vint16(a) + b; } + + __forceinline vint16 operator -(const vint16& a, const vint16& b) { return _mm512_sub_epi32(a, b); } + __forceinline vint16 operator -(const vint16& a, int b) { return a - vint16(b); } + __forceinline vint16 operator -(int a, const vint16& b) { return vint16(a) - b; } + + __forceinline vint16 operator *(const vint16& a, const vint16& b) { return _mm512_mullo_epi32(a, b); } + __forceinline vint16 operator *(const vint16& a, int b) { return a * vint16(b); } + __forceinline vint16 operator *(int a, const vint16& b) { return vint16(a) * b; } + + __forceinline vint16 operator &(const vint16& a, const vint16& b) { return _mm512_and_epi32(a, b); } + __forceinline vint16 operator &(const vint16& a, int b) { return a & vint16(b); } + __forceinline vint16 operator &(int a, const vint16& b) { return vint16(a) & b; } + + __forceinline vint16 operator |(const vint16& a, const vint16& b) { return _mm512_or_epi32(a, b); } + __forceinline vint16 operator |(const vint16& a, int b) { return a | vint16(b); } + __forceinline vint16 operator |(int a, const vint16& b) { return vint16(a) | b; } + + __forceinline vint16 operator ^(const vint16& a, const vint16& b) { return _mm512_xor_epi32(a, b); } + __forceinline vint16 operator ^(const vint16& a, int b) { return a ^ vint16(b); } + __forceinline vint16 operator ^(int a, const vint16& b) { return vint16(a) ^ b; } + + __forceinline vint16 operator <<(const vint16& a, int n) { return _mm512_slli_epi32(a, n); } + __forceinline vint16 operator >>(const vint16& a, int n) { return _mm512_srai_epi32(a, n); } + + __forceinline vint16 operator <<(const vint16& a, const vint16& n) { return _mm512_sllv_epi32(a, n); } + __forceinline vint16 operator >>(const vint16& a, const vint16& n) { return _mm512_srav_epi32(a, n); } + + __forceinline vint16 sll (const vint16& a, int b) { return _mm512_slli_epi32(a, b); } + __forceinline vint16 sra (const vint16& a, int b) { return _mm512_srai_epi32(a, b); } + __forceinline vint16 srl (const vint16& a, int b) { return _mm512_srli_epi32(a, b); } + + __forceinline vint16 min(const vint16& a, const vint16& b) { return _mm512_min_epi32(a, b); } + __forceinline vint16 min(const vint16& a, int b) { return min(a,vint16(b)); } + __forceinline vint16 min(int a, const vint16& b) { return min(vint16(a),b); } + + __forceinline vint16 max(const vint16& a, const vint16& b) { return _mm512_max_epi32(a, b); } + __forceinline vint16 max(const vint16& a, int b) { return max(a,vint16(b)); } + __forceinline vint16 max(int a, const vint16& b) { return max(vint16(a),b); } + + __forceinline vint16 umin(const vint16& a, const vint16& b) { return _mm512_min_epu32(a, b); } + __forceinline vint16 umax(const vint16& a, const vint16& b) { return _mm512_max_epu32(a, b); } + + __forceinline vint16 mask_add(const vboolf16& mask, vint16& c, const vint16& a, const vint16& b) { return _mm512_mask_add_epi32(c,mask,a,b); } + __forceinline vint16 mask_sub(const vboolf16& mask, vint16& c, const vint16& a, const vint16& b) { return _mm512_mask_sub_epi32(c,mask,a,b); } + + __forceinline vint16 mask_and(const vboolf16& m, vint16& c, const vint16& a, const vint16& b) { return _mm512_mask_and_epi32(c,m,a,b); } + __forceinline vint16 mask_or (const vboolf16& m, vint16& c, const vint16& a, const vint16& b) { return _mm512_mask_or_epi32(c,m,a,b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint16& operator +=(vint16& a, const vint16& b) { return a = a + b; } + __forceinline vint16& operator +=(vint16& a, int b) { return a = a + b; } + + __forceinline vint16& operator -=(vint16& a, const vint16& b) { return a = a - b; } + __forceinline vint16& operator -=(vint16& a, int b) { return a = a - b; } + + __forceinline vint16& operator *=(vint16& a, const vint16& b) { return a = a * b; } + __forceinline vint16& operator *=(vint16& a, int b) { return a = a * b; } + + __forceinline vint16& operator &=(vint16& a, const vint16& b) { return a = a & b; } + __forceinline vint16& operator &=(vint16& a, int b) { return a = a & b; } + + __forceinline vint16& operator |=(vint16& a, const vint16& b) { return a = a | b; } + __forceinline vint16& operator |=(vint16& a, int b) { return a = a | b; } + + __forceinline vint16& operator <<=(vint16& a, int b) { return a = a << b; } + __forceinline vint16& operator >>=(vint16& a, int b) { return a = a >> b; } + + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf16 operator ==(const vint16& a, const vint16& b) { return _mm512_cmp_epi32_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboolf16 operator ==(const vint16& a, int b) { return a == vint16(b); } + __forceinline vboolf16 operator ==(int a, const vint16& b) { return vint16(a) == b; } + + __forceinline vboolf16 operator !=(const vint16& a, const vint16& b) { return _mm512_cmp_epi32_mask(a,b,_MM_CMPINT_NE); } + __forceinline vboolf16 operator !=(const vint16& a, int b) { return a != vint16(b); } + __forceinline vboolf16 operator !=(int a, const vint16& b) { return vint16(a) != b; } + + __forceinline vboolf16 operator < (const vint16& a, const vint16& b) { return _mm512_cmp_epi32_mask(a,b,_MM_CMPINT_LT); } + __forceinline vboolf16 operator < (const vint16& a, int b) { return a < vint16(b); } + __forceinline vboolf16 operator < (int a, const vint16& b) { return vint16(a) < b; } + + __forceinline vboolf16 operator >=(const vint16& a, const vint16& b) { return _mm512_cmp_epi32_mask(a,b,_MM_CMPINT_GE); } + __forceinline vboolf16 operator >=(const vint16& a, int b) { return a >= vint16(b); } + __forceinline vboolf16 operator >=(int a, const vint16& b) { return vint16(a) >= b; } + + __forceinline vboolf16 operator > (const vint16& a, const vint16& b) { return _mm512_cmp_epi32_mask(a,b,_MM_CMPINT_GT); } + __forceinline vboolf16 operator > (const vint16& a, int b) { return a > vint16(b); } + __forceinline vboolf16 operator > (int a, const vint16& b) { return vint16(a) > b; } + + __forceinline vboolf16 operator <=(const vint16& a, const vint16& b) { return _mm512_cmp_epi32_mask(a,b,_MM_CMPINT_LE); } + __forceinline vboolf16 operator <=(const vint16& a, int b) { return a <= vint16(b); } + __forceinline vboolf16 operator <=(int a, const vint16& b) { return vint16(a) <= b; } + + __forceinline vboolf16 eq(const vint16& a, const vint16& b) { return _mm512_cmp_epi32_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboolf16 ne(const vint16& a, const vint16& b) { return _mm512_cmp_epi32_mask(a,b,_MM_CMPINT_NE); } + __forceinline vboolf16 lt(const vint16& a, const vint16& b) { return _mm512_cmp_epi32_mask(a,b,_MM_CMPINT_LT); } + __forceinline vboolf16 ge(const vint16& a, const vint16& b) { return _mm512_cmp_epi32_mask(a,b,_MM_CMPINT_GE); } + __forceinline vboolf16 gt(const vint16& a, const vint16& b) { return _mm512_cmp_epi32_mask(a,b,_MM_CMPINT_GT); } + __forceinline vboolf16 le(const vint16& a, const vint16& b) { return _mm512_cmp_epi32_mask(a,b,_MM_CMPINT_LE); } + __forceinline vboolf16 uint_le(const vint16& a, const vint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_LE); } + __forceinline vboolf16 uint_gt(const vint16& a, const vint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_GT); } + + __forceinline vboolf16 eq(const vboolf16 mask, const vint16& a, const vint16& b) { return _mm512_mask_cmp_epi32_mask(mask,a,b,_MM_CMPINT_EQ); } + __forceinline vboolf16 ne(const vboolf16 mask, const vint16& a, const vint16& b) { return _mm512_mask_cmp_epi32_mask(mask,a,b,_MM_CMPINT_NE); } + __forceinline vboolf16 lt(const vboolf16 mask, const vint16& a, const vint16& b) { return _mm512_mask_cmp_epi32_mask(mask,a,b,_MM_CMPINT_LT); } + __forceinline vboolf16 ge(const vboolf16 mask, const vint16& a, const vint16& b) { return _mm512_mask_cmp_epi32_mask(mask,a,b,_MM_CMPINT_GE); } + __forceinline vboolf16 gt(const vboolf16 mask, const vint16& a, const vint16& b) { return _mm512_mask_cmp_epi32_mask(mask,a,b,_MM_CMPINT_GT); } + __forceinline vboolf16 le(const vboolf16 mask, const vint16& a, const vint16& b) { return _mm512_mask_cmp_epi32_mask(mask,a,b,_MM_CMPINT_LE); } + __forceinline vboolf16 uint_le(const vboolf16 mask, const vint16& a, const vint16& b) { return _mm512_mask_cmp_epu32_mask(mask,a,b,_MM_CMPINT_LE); } + __forceinline vboolf16 uint_gt(const vboolf16 mask, const vint16& a, const vint16& b) { return _mm512_mask_cmp_epu32_mask(mask,a,b,_MM_CMPINT_GT); } + + + __forceinline vint16 select(const vboolf16& m, const vint16& t, const vint16& f) { + return _mm512_mask_or_epi32(f,m,t,t); + } + + __forceinline void xchg(const vboolf16& m, vint16& a, vint16& b) { + const vint16 c = a; a = select(m,b,a); b = select(m,c,b); + } + + __forceinline vboolf16 test(const vboolf16& m, const vint16& a, const vint16& b) { + return _mm512_mask_test_epi32_mask(m,a,b); + } + + __forceinline vboolf16 test(const vint16& a, const vint16& b) { + return _mm512_test_epi32_mask(a,b); + } + + //////////////////////////////////////////////////////////////////////////////// + // Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint16 unpacklo(const vint16& a, const vint16& b) { return _mm512_unpacklo_epi32(a, b); } + __forceinline vint16 unpackhi(const vint16& a, const vint16& b) { return _mm512_unpackhi_epi32(a, b); } + + template + __forceinline vint16 shuffle(const vint16& v) { + return _mm512_castps_si512(_mm512_permute_ps(_mm512_castsi512_ps(v), _MM_SHUFFLE(i, i, i, i))); + } + + template + __forceinline vint16 shuffle(const vint16& v) { + return _mm512_castps_si512(_mm512_permute_ps(_mm512_castsi512_ps(v), _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template + __forceinline vint16 shuffle4(const vint16& v) { + return _mm512_castps_si512(_mm512_shuffle_f32x4(_mm512_castsi512_ps(v), _mm512_castsi512_ps(v), _MM_SHUFFLE(i, i, i, i))); + } + + template + __forceinline vint16 shuffle4(const vint16& v) { + return _mm512_castps_si512(_mm512_shuffle_f32x4(_mm512_castsi512_ps(v), _mm512_castsi512_ps(v), _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template + __forceinline vint16 align_shift_right(const vint16& a, const vint16& b) { + return _mm512_alignr_epi32(a, b, i); + }; + + __forceinline int toScalar(const vint16& v) { + return _mm_cvtsi128_si32(_mm512_castsi512_si128(v)); + } + + template __forceinline vint16 insert4(const vint16& a, const vint4& b) { return _mm512_inserti32x4(a, b, i); } + + __forceinline size_t extract64bit(const vint16& v) { + return _mm_cvtsi128_si64(_mm512_castsi512_si128(v)); + } + + template + vint extractN(const vint16& v); + + template<> __forceinline vint4 extractN<4,0>(const vint16& v) { return _mm512_castsi512_si128(v); } + template<> __forceinline vint4 extractN<4,1>(const vint16& v) { return _mm512_extracti32x4_epi32(v, 1); } + template<> __forceinline vint4 extractN<4,2>(const vint16& v) { return _mm512_extracti32x4_epi32(v, 2); } + template<> __forceinline vint4 extractN<4,3>(const vint16& v) { return _mm512_extracti32x4_epi32(v, 3); } + + template<> __forceinline vint8 extractN<8,0>(const vint16& v) { return _mm512_castsi512_si256(v); } + template<> __forceinline vint8 extractN<8,1>(const vint16& v) { return _mm512_extracti32x8_epi32(v, 1); } + + template __forceinline vint4 extract4 (const vint16& v) { return _mm512_extracti32x4_epi32(v, i); } + template<> __forceinline vint4 extract4<0>(const vint16& v) { return _mm512_castsi512_si128(v); } + + template __forceinline vint8 extract8 (const vint16& v) { return _mm512_extracti32x8_epi32(v, i); } + template<> __forceinline vint8 extract8<0>(const vint16& v) { return _mm512_castsi512_si256(v); } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint16 vreduce_min2(vint16 x) { return min(x, shuffle<1,0,3,2>(x)); } + __forceinline vint16 vreduce_min4(vint16 x) { x = vreduce_min2(x); return min(x, shuffle<2,3,0,1>(x)); } + __forceinline vint16 vreduce_min8(vint16 x) { x = vreduce_min4(x); return min(x, shuffle4<1,0,3,2>(x)); } + __forceinline vint16 vreduce_min (vint16 x) { x = vreduce_min8(x); return min(x, shuffle4<2,3,0,1>(x)); } + + __forceinline vint16 vreduce_max2(vint16 x) { return max(x, shuffle<1,0,3,2>(x)); } + __forceinline vint16 vreduce_max4(vint16 x) { x = vreduce_max2(x); return max(x, shuffle<2,3,0,1>(x)); } + __forceinline vint16 vreduce_max8(vint16 x) { x = vreduce_max4(x); return max(x, shuffle4<1,0,3,2>(x)); } + __forceinline vint16 vreduce_max (vint16 x) { x = vreduce_max8(x); return max(x, shuffle4<2,3,0,1>(x)); } + + __forceinline vint16 vreduce_and2(vint16 x) { return x & shuffle<1,0,3,2>(x); } + __forceinline vint16 vreduce_and4(vint16 x) { x = vreduce_and2(x); return x & shuffle<2,3,0,1>(x); } + __forceinline vint16 vreduce_and8(vint16 x) { x = vreduce_and4(x); return x & shuffle4<1,0,3,2>(x); } + __forceinline vint16 vreduce_and (vint16 x) { x = vreduce_and8(x); return x & shuffle4<2,3,0,1>(x); } + + __forceinline vint16 vreduce_or2(vint16 x) { return x | shuffle<1,0,3,2>(x); } + __forceinline vint16 vreduce_or4(vint16 x) { x = vreduce_or2(x); return x | shuffle<2,3,0,1>(x); } + __forceinline vint16 vreduce_or8(vint16 x) { x = vreduce_or4(x); return x | shuffle4<1,0,3,2>(x); } + __forceinline vint16 vreduce_or (vint16 x) { x = vreduce_or8(x); return x | shuffle4<2,3,0,1>(x); } + + __forceinline vint16 vreduce_add2(vint16 x) { return x + shuffle<1,0,3,2>(x); } + __forceinline vint16 vreduce_add4(vint16 x) { x = vreduce_add2(x); return x + shuffle<2,3,0,1>(x); } + __forceinline vint16 vreduce_add8(vint16 x) { x = vreduce_add4(x); return x + shuffle4<1,0,3,2>(x); } + __forceinline vint16 vreduce_add (vint16 x) { x = vreduce_add8(x); return x + shuffle4<2,3,0,1>(x); } + + __forceinline int reduce_min(const vint16& v) { return toScalar(vreduce_min(v)); } + __forceinline int reduce_max(const vint16& v) { return toScalar(vreduce_max(v)); } + __forceinline int reduce_and(const vint16& v) { return toScalar(vreduce_and(v)); } + __forceinline int reduce_or (const vint16& v) { return toScalar(vreduce_or (v)); } + __forceinline int reduce_add(const vint16& v) { return toScalar(vreduce_add(v)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Memory load and store operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint16 conflict(const vint16& index) + { + return _mm512_conflict_epi32(index); + } + + __forceinline vint16 conflict(const vboolf16& mask, vint16& dest, const vint16& index) + { + return _mm512_mask_conflict_epi32(dest,mask,index); + } + + __forceinline vint16 convert_uint32_t(const __m512& f) { + return _mm512_cvtps_epu32(f); + } + + __forceinline vint16 permute(vint16 v, vint16 index) { + return _mm512_permutexvar_epi32(index,v); + } + + __forceinline vint16 reverse(const vint16 &a) { + return permute(a,vint16(reverse_step)); + } + + __forceinline vint16 prefix_sum(const vint16& a) + { + const vint16 z(zero); + vint16 v = a; + v = v + align_shift_right<16-1>(v,z); + v = v + align_shift_right<16-2>(v,z); + v = v + align_shift_right<16-4>(v,z); + v = v + align_shift_right<16-8>(v,z); + return v; + } + + __forceinline vint16 reverse_prefix_sum(const vint16& a) + { + const vint16 z(zero); + vint16 v = a; + v = v + align_shift_right<1>(z,v); + v = v + align_shift_right<2>(z,v); + v = v + align_shift_right<4>(z,v); + v = v + align_shift_right<8>(z,v); + return v; + } + + /* this should use a vbool8 and a vint8_64...*/ + template + __forceinline void gather_prefetch64(const void* base_addr, const vbool16& mask, const vint16& offset) + { +#if defined(__AVX512PF__) + _mm512_mask_prefetch_i64gather_pd(offset, mask, base_addr, scale, hint); +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vint16& v) + { + cout << "<" << v[0]; + for (int i=1; i<16; i++) cout << ", " << v[i]; + cout << ">"; + return cout; + } +} diff --git a/thirdparty/embree/common/simd/vint4_sse2.h b/thirdparty/embree/common/simd/vint4_sse2.h new file mode 100644 index 000000000000..458f8cfaa670 --- /dev/null +++ b/thirdparty/embree/common/simd/vint4_sse2.h @@ -0,0 +1,583 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../math/math.h" + +namespace embree +{ + /* 4-wide SSE integer type */ + template<> + struct vint<4> + { + ALIGNED_STRUCT_(16); + + typedef vboolf4 Bool; + typedef vint4 Int; + typedef vfloat4 Float; + + enum { size = 4 }; // number of SIMD elements + union { __m128i v; int i[4]; }; // data + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint() {} + __forceinline vint(const vint4& a) { v = a.v; } + __forceinline vint4& operator =(const vint4& a) { v = a.v; return *this; } + + __forceinline vint(__m128i a) : v(a) {} + __forceinline operator const __m128i&() const { return v; } + __forceinline operator __m128i&() { return v; } + + __forceinline vint(int a) : v(_mm_set1_epi32(a)) {} + __forceinline vint(int a, int b, int c, int d) : v(_mm_set_epi32(d, c, b, a)) {} + + __forceinline explicit vint(__m128 a) : v(_mm_cvtps_epi32(a)) {} +#if defined(__AVX512VL__) + __forceinline explicit vint(const vboolf4& a) : v(_mm_movm_epi32(a)) {} +#else + __forceinline explicit vint(const vboolf4& a) : v(_mm_castps_si128((__m128)a)) {} +#endif + + __forceinline vint(long long a, long long b) : v(_mm_set_epi64x(b,a)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint(ZeroTy) : v(_mm_setzero_si128()) {} + __forceinline vint(OneTy) : v(_mm_set_epi32(1, 1, 1, 1)) {} + __forceinline vint(PosInfTy) : v(_mm_set_epi32(pos_inf, pos_inf, pos_inf, pos_inf)) {} + __forceinline vint(NegInfTy) : v(_mm_set_epi32(neg_inf, neg_inf, neg_inf, neg_inf)) {} + __forceinline vint(StepTy) : v(_mm_set_epi32(3, 2, 1, 0)) {} + __forceinline vint(ReverseStepTy) : v(_mm_set_epi32(0, 1, 2, 3)) {} + + __forceinline vint(TrueTy) { v = _mm_cmpeq_epi32(v,v); } + __forceinline vint(UndefinedTy) : v(_mm_castps_si128(_mm_undefined_ps())) {} + + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline vint4 load (const void* a) { return _mm_load_si128((__m128i*)a); } + static __forceinline vint4 loadu(const void* a) { return _mm_loadu_si128((__m128i*)a); } + + static __forceinline void store (void* ptr, const vint4& v) { _mm_store_si128((__m128i*)ptr,v); } + static __forceinline void storeu(void* ptr, const vint4& v) { _mm_storeu_si128((__m128i*)ptr,v); } + +#if defined(__AVX512VL__) + + static __forceinline vint4 compact(const vboolf4& mask, vint4 &v) { + return _mm_mask_compress_epi32(v, mask, v); + } + static __forceinline vint4 compact(const vboolf4& mask, vint4 &a, const vint4& b) { + return _mm_mask_compress_epi32(a, mask, b); + } + + static __forceinline vint4 load (const vboolf4& mask, const void* ptr) { return _mm_mask_load_epi32 (_mm_setzero_si128(),mask,ptr); } + static __forceinline vint4 loadu(const vboolf4& mask, const void* ptr) { return _mm_mask_loadu_epi32(_mm_setzero_si128(),mask,ptr); } + + static __forceinline void store (const vboolf4& mask, void* ptr, const vint4& v) { _mm_mask_store_epi32 (ptr,mask,v); } + static __forceinline void storeu(const vboolf4& mask, void* ptr, const vint4& v) { _mm_mask_storeu_epi32(ptr,mask,v); } +#elif defined(__AVX__) + static __forceinline vint4 load (const vbool4& mask, const void* a) { return _mm_castps_si128(_mm_maskload_ps((float*)a,mask)); } + static __forceinline vint4 loadu(const vbool4& mask, const void* a) { return _mm_castps_si128(_mm_maskload_ps((float*)a,mask)); } + + static __forceinline void store (const vboolf4& mask, void* ptr, const vint4& i) { _mm_maskstore_ps((float*)ptr,(__m128i)mask,_mm_castsi128_ps(i)); } + static __forceinline void storeu(const vboolf4& mask, void* ptr, const vint4& i) { _mm_maskstore_ps((float*)ptr,(__m128i)mask,_mm_castsi128_ps(i)); } +#else + static __forceinline vint4 load (const vbool4& mask, const void* a) { return _mm_and_si128(_mm_load_si128 ((__m128i*)a),mask); } + static __forceinline vint4 loadu(const vbool4& mask, const void* a) { return _mm_and_si128(_mm_loadu_si128((__m128i*)a),mask); } + + static __forceinline void store (const vboolf4& mask, void* ptr, const vint4& i) { store (ptr,select(mask,i,load (ptr))); } + static __forceinline void storeu(const vboolf4& mask, void* ptr, const vint4& i) { storeu(ptr,select(mask,i,loadu(ptr))); } +#endif + + +#if defined(__SSE4_1__) + static __forceinline vint4 load(const unsigned char* ptr) { + return _mm_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)ptr)); + } + + static __forceinline vint4 loadu(const unsigned char* ptr) { + return _mm_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)ptr)); + } +#else + + static __forceinline vint4 load(const unsigned char* ptr) { + return vint4(ptr[0],ptr[1],ptr[2],ptr[3]); + } + + static __forceinline vint4 loadu(const unsigned char* ptr) { + return vint4(ptr[0],ptr[1],ptr[2],ptr[3]); + } + +#endif + + static __forceinline vint4 load(const unsigned short* ptr) { +#if defined (__SSE4_1__) + return _mm_cvtepu16_epi32(_mm_loadu_si128((__m128i*)ptr)); +#else + return vint4(ptr[0],ptr[1],ptr[2],ptr[3]); +#endif + } + + static __forceinline void store(unsigned char* ptr, const vint4& v) { +#if defined(__SSE4_1__) + __m128i x = v; + x = _mm_packus_epi32(x, x); + x = _mm_packus_epi16(x, x); + *(int*)ptr = _mm_cvtsi128_si32(x); +#else + for (size_t i=0;i<4;i++) + ptr[i] = (unsigned char)v[i]; +#endif + } + + static __forceinline void store(unsigned short* ptr, const vint4& v) { + for (size_t i=0;i<4;i++) + ptr[i] = (unsigned short)v[i]; + } + + static __forceinline vint4 load_nt(void* ptr) { +#if defined(__SSE4_1__) + return _mm_stream_load_si128((__m128i*)ptr); +#else + return _mm_load_si128((__m128i*)ptr); +#endif + } + + static __forceinline void store_nt(void* ptr, const vint4& v) { +#if defined(__SSE4_1__) + _mm_stream_ps((float*)ptr, _mm_castsi128_ps(v)); +#else + _mm_store_si128((__m128i*)ptr,v); +#endif + } + + template + static __forceinline vint4 gather(const int* ptr, const vint4& index) { +#if defined(__AVX2__) + return _mm_i32gather_epi32(ptr, index, scale); +#else + return vint4( + *(int*)(((char*)ptr)+scale*index[0]), + *(int*)(((char*)ptr)+scale*index[1]), + *(int*)(((char*)ptr)+scale*index[2]), + *(int*)(((char*)ptr)+scale*index[3])); +#endif + } + + template + static __forceinline vint4 gather(const vboolf4& mask, const int* ptr, const vint4& index) { + vint4 r = zero; +#if defined(__AVX512VL__) + return _mm_mmask_i32gather_epi32(r, mask, index, ptr, scale); +#elif defined(__AVX2__) + return _mm_mask_i32gather_epi32(r, ptr, index, mask, scale); +#else + if (likely(mask[0])) r[0] = *(int*)(((char*)ptr)+scale*index[0]); + if (likely(mask[1])) r[1] = *(int*)(((char*)ptr)+scale*index[1]); + if (likely(mask[2])) r[2] = *(int*)(((char*)ptr)+scale*index[2]); + if (likely(mask[3])) r[3] = *(int*)(((char*)ptr)+scale*index[3]); + return r; +#endif + } + + template + static __forceinline void scatter(void* ptr, const vint4& index, const vint4& v) + { +#if defined(__AVX512VL__) + _mm_i32scatter_epi32((int*)ptr, index, v, scale); +#else + *(int*)(((char*)ptr)+scale*index[0]) = v[0]; + *(int*)(((char*)ptr)+scale*index[1]) = v[1]; + *(int*)(((char*)ptr)+scale*index[2]) = v[2]; + *(int*)(((char*)ptr)+scale*index[3]) = v[3]; +#endif + } + + template + static __forceinline void scatter(const vboolf4& mask, void* ptr, const vint4& index, const vint4& v) + { +#if defined(__AVX512VL__) + _mm_mask_i32scatter_epi32((int*)ptr, mask, index, v, scale); +#else + if (likely(mask[0])) *(int*)(((char*)ptr)+scale*index[0]) = v[0]; + if (likely(mask[1])) *(int*)(((char*)ptr)+scale*index[1]) = v[1]; + if (likely(mask[2])) *(int*)(((char*)ptr)+scale*index[2]) = v[2]; + if (likely(mask[3])) *(int*)(((char*)ptr)+scale*index[3]) = v[3]; +#endif + } + +#if defined(__x86_64__) + static __forceinline vint4 broadcast64(long long a) { return _mm_set1_epi64x(a); } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const int& operator [](size_t index) const { assert(index < 4); return i[index]; } + __forceinline int& operator [](size_t index) { assert(index < 4); return i[index]; } + + friend __forceinline vint4 select(const vboolf4& m, const vint4& t, const vint4& f) { +#if defined(__AVX512VL__) + return _mm_mask_blend_epi32(m, (__m128i)f, (__m128i)t); +#elif defined(__SSE4_1__) + return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(f), _mm_castsi128_ps(t), m)); +#else + return _mm_or_si128(_mm_and_si128(m, t), _mm_andnot_si128(m, f)); +#endif + } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX512VL__) + __forceinline vboolf4 asBool(const vint4& a) { return _mm_movepi32_mask(a); } +#else + __forceinline vboolf4 asBool(const vint4& a) { return _mm_castsi128_ps(a); } +#endif + + __forceinline vint4 operator +(const vint4& a) { return a; } + __forceinline vint4 operator -(const vint4& a) { return _mm_sub_epi32(_mm_setzero_si128(), a); } +#if defined(__SSSE3__) + __forceinline vint4 abs(const vint4& a) { return _mm_abs_epi32(a); } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint4 operator +(const vint4& a, const vint4& b) { return _mm_add_epi32(a, b); } + __forceinline vint4 operator +(const vint4& a, int b) { return a + vint4(b); } + __forceinline vint4 operator +(int a, const vint4& b) { return vint4(a) + b; } + + __forceinline vint4 operator -(const vint4& a, const vint4& b) { return _mm_sub_epi32(a, b); } + __forceinline vint4 operator -(const vint4& a, int b) { return a - vint4(b); } + __forceinline vint4 operator -(int a, const vint4& b) { return vint4(a) - b; } + +#if defined(__SSE4_1__) + __forceinline vint4 operator *(const vint4& a, const vint4& b) { return _mm_mullo_epi32(a, b); } +#else + __forceinline vint4 operator *(const vint4& a, const vint4& b) { return vint4(a[0]*b[0],a[1]*b[1],a[2]*b[2],a[3]*b[3]); } +#endif + __forceinline vint4 operator *(const vint4& a, int b) { return a * vint4(b); } + __forceinline vint4 operator *(int a, const vint4& b) { return vint4(a) * b; } + + __forceinline vint4 operator &(const vint4& a, const vint4& b) { return _mm_and_si128(a, b); } + __forceinline vint4 operator &(const vint4& a, int b) { return a & vint4(b); } + __forceinline vint4 operator &(int a, const vint4& b) { return vint4(a) & b; } + + __forceinline vint4 operator |(const vint4& a, const vint4& b) { return _mm_or_si128(a, b); } + __forceinline vint4 operator |(const vint4& a, int b) { return a | vint4(b); } + __forceinline vint4 operator |(int a, const vint4& b) { return vint4(a) | b; } + + __forceinline vint4 operator ^(const vint4& a, const vint4& b) { return _mm_xor_si128(a, b); } + __forceinline vint4 operator ^(const vint4& a, int b) { return a ^ vint4(b); } + __forceinline vint4 operator ^(int a, const vint4& b) { return vint4(a) ^ b; } + + __forceinline vint4 operator <<(const vint4& a, int n) { return _mm_slli_epi32(a, n); } + __forceinline vint4 operator >>(const vint4& a, int n) { return _mm_srai_epi32(a, n); } + + __forceinline vint4 sll (const vint4& a, int b) { return _mm_slli_epi32(a, b); } + __forceinline vint4 sra (const vint4& a, int b) { return _mm_srai_epi32(a, b); } + __forceinline vint4 srl (const vint4& a, int b) { return _mm_srli_epi32(a, b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint4& operator +=(vint4& a, const vint4& b) { return a = a + b; } + __forceinline vint4& operator +=(vint4& a, int b) { return a = a + b; } + + __forceinline vint4& operator -=(vint4& a, const vint4& b) { return a = a - b; } + __forceinline vint4& operator -=(vint4& a, int b) { return a = a - b; } + +#if defined(__SSE4_1__) + __forceinline vint4& operator *=(vint4& a, const vint4& b) { return a = a * b; } + __forceinline vint4& operator *=(vint4& a, int b) { return a = a * b; } +#endif + + __forceinline vint4& operator &=(vint4& a, const vint4& b) { return a = a & b; } + __forceinline vint4& operator &=(vint4& a, int b) { return a = a & b; } + + __forceinline vint4& operator |=(vint4& a, const vint4& b) { return a = a | b; } + __forceinline vint4& operator |=(vint4& a, int b) { return a = a | b; } + + __forceinline vint4& operator <<=(vint4& a, int b) { return a = a << b; } + __forceinline vint4& operator >>=(vint4& a, int b) { return a = a >> b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX512VL__) + __forceinline vboolf4 operator ==(const vint4& a, const vint4& b) { return _mm_cmp_epi32_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboolf4 operator !=(const vint4& a, const vint4& b) { return _mm_cmp_epi32_mask(a,b,_MM_CMPINT_NE); } + __forceinline vboolf4 operator < (const vint4& a, const vint4& b) { return _mm_cmp_epi32_mask(a,b,_MM_CMPINT_LT); } + __forceinline vboolf4 operator >=(const vint4& a, const vint4& b) { return _mm_cmp_epi32_mask(a,b,_MM_CMPINT_GE); } + __forceinline vboolf4 operator > (const vint4& a, const vint4& b) { return _mm_cmp_epi32_mask(a,b,_MM_CMPINT_GT); } + __forceinline vboolf4 operator <=(const vint4& a, const vint4& b) { return _mm_cmp_epi32_mask(a,b,_MM_CMPINT_LE); } +#else + __forceinline vboolf4 operator ==(const vint4& a, const vint4& b) { return _mm_castsi128_ps(_mm_cmpeq_epi32(a, b)); } + __forceinline vboolf4 operator !=(const vint4& a, const vint4& b) { return !(a == b); } + __forceinline vboolf4 operator < (const vint4& a, const vint4& b) { return _mm_castsi128_ps(_mm_cmplt_epi32(a, b)); } + __forceinline vboolf4 operator >=(const vint4& a, const vint4& b) { return !(a < b); } + __forceinline vboolf4 operator > (const vint4& a, const vint4& b) { return _mm_castsi128_ps(_mm_cmpgt_epi32(a, b)); } + __forceinline vboolf4 operator <=(const vint4& a, const vint4& b) { return !(a > b); } +#endif + + __forceinline vboolf4 operator ==(const vint4& a, int b) { return a == vint4(b); } + __forceinline vboolf4 operator ==(int a, const vint4& b) { return vint4(a) == b; } + + __forceinline vboolf4 operator !=(const vint4& a, int b) { return a != vint4(b); } + __forceinline vboolf4 operator !=(int a, const vint4& b) { return vint4(a) != b; } + + __forceinline vboolf4 operator < (const vint4& a, int b) { return a < vint4(b); } + __forceinline vboolf4 operator < (int a, const vint4& b) { return vint4(a) < b; } + + __forceinline vboolf4 operator >=(const vint4& a, int b) { return a >= vint4(b); } + __forceinline vboolf4 operator >=(int a, const vint4& b) { return vint4(a) >= b; } + + __forceinline vboolf4 operator > (const vint4& a, int b) { return a > vint4(b); } + __forceinline vboolf4 operator > (int a, const vint4& b) { return vint4(a) > b; } + + __forceinline vboolf4 operator <=(const vint4& a, int b) { return a <= vint4(b); } + __forceinline vboolf4 operator <=(int a, const vint4& b) { return vint4(a) <= b; } + + __forceinline vboolf4 eq(const vint4& a, const vint4& b) { return a == b; } + __forceinline vboolf4 ne(const vint4& a, const vint4& b) { return a != b; } + __forceinline vboolf4 lt(const vint4& a, const vint4& b) { return a < b; } + __forceinline vboolf4 ge(const vint4& a, const vint4& b) { return a >= b; } + __forceinline vboolf4 gt(const vint4& a, const vint4& b) { return a > b; } + __forceinline vboolf4 le(const vint4& a, const vint4& b) { return a <= b; } + +#if defined(__AVX512VL__) + __forceinline vboolf4 eq(const vboolf4& mask, const vint4& a, const vint4& b) { return _mm_mask_cmp_epi32_mask(mask, a, b, _MM_CMPINT_EQ); } + __forceinline vboolf4 ne(const vboolf4& mask, const vint4& a, const vint4& b) { return _mm_mask_cmp_epi32_mask(mask, a, b, _MM_CMPINT_NE); } + __forceinline vboolf4 lt(const vboolf4& mask, const vint4& a, const vint4& b) { return _mm_mask_cmp_epi32_mask(mask, a, b, _MM_CMPINT_LT); } + __forceinline vboolf4 ge(const vboolf4& mask, const vint4& a, const vint4& b) { return _mm_mask_cmp_epi32_mask(mask, a, b, _MM_CMPINT_GE); } + __forceinline vboolf4 gt(const vboolf4& mask, const vint4& a, const vint4& b) { return _mm_mask_cmp_epi32_mask(mask, a, b, _MM_CMPINT_GT); } + __forceinline vboolf4 le(const vboolf4& mask, const vint4& a, const vint4& b) { return _mm_mask_cmp_epi32_mask(mask, a, b, _MM_CMPINT_LE); } +#else + __forceinline vboolf4 eq(const vboolf4& mask, const vint4& a, const vint4& b) { return mask & (a == b); } + __forceinline vboolf4 ne(const vboolf4& mask, const vint4& a, const vint4& b) { return mask & (a != b); } + __forceinline vboolf4 lt(const vboolf4& mask, const vint4& a, const vint4& b) { return mask & (a < b); } + __forceinline vboolf4 ge(const vboolf4& mask, const vint4& a, const vint4& b) { return mask & (a >= b); } + __forceinline vboolf4 gt(const vboolf4& mask, const vint4& a, const vint4& b) { return mask & (a > b); } + __forceinline vboolf4 le(const vboolf4& mask, const vint4& a, const vint4& b) { return mask & (a <= b); } +#endif + + template + __forceinline vint4 select(const vint4& t, const vint4& f) { +#if defined(__SSE4_1__) + return _mm_castps_si128(_mm_blend_ps(_mm_castsi128_ps(f), _mm_castsi128_ps(t), mask)); +#else + return select(vboolf4(mask), t, f); +#endif + } + +#if defined(__SSE4_1__) + __forceinline vint4 min(const vint4& a, const vint4& b) { return _mm_min_epi32(a, b); } + __forceinline vint4 max(const vint4& a, const vint4& b) { return _mm_max_epi32(a, b); } + + __forceinline vint4 umin(const vint4& a, const vint4& b) { return _mm_min_epu32(a, b); } + __forceinline vint4 umax(const vint4& a, const vint4& b) { return _mm_max_epu32(a, b); } + +#else + __forceinline vint4 min(const vint4& a, const vint4& b) { return select(a < b,a,b); } + __forceinline vint4 max(const vint4& a, const vint4& b) { return select(a < b,b,a); } +#endif + + __forceinline vint4 min(const vint4& a, int b) { return min(a,vint4(b)); } + __forceinline vint4 min(int a, const vint4& b) { return min(vint4(a),b); } + __forceinline vint4 max(const vint4& a, int b) { return max(a,vint4(b)); } + __forceinline vint4 max(int a, const vint4& b) { return max(vint4(a),b); } + + //////////////////////////////////////////////////////////////////////////////// + // Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint4 unpacklo(const vint4& a, const vint4& b) { return _mm_castps_si128(_mm_unpacklo_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b))); } + __forceinline vint4 unpackhi(const vint4& a, const vint4& b) { return _mm_castps_si128(_mm_unpackhi_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b))); } + + template + __forceinline vint4 shuffle(const vint4& v) { + return _mm_shuffle_epi32(v, _MM_SHUFFLE(i3, i2, i1, i0)); + } + + template + __forceinline vint4 shuffle(const vint4& a, const vint4& b) { + return _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), _MM_SHUFFLE(i3, i2, i1, i0))); + } + +#if defined(__SSE3__) + template<> __forceinline vint4 shuffle<0, 0, 2, 2>(const vint4& v) { return _mm_castps_si128(_mm_moveldup_ps(_mm_castsi128_ps(v))); } + template<> __forceinline vint4 shuffle<1, 1, 3, 3>(const vint4& v) { return _mm_castps_si128(_mm_movehdup_ps(_mm_castsi128_ps(v))); } + template<> __forceinline vint4 shuffle<0, 1, 0, 1>(const vint4& v) { return _mm_castpd_si128(_mm_movedup_pd (_mm_castsi128_pd(v))); } +#endif + + template + __forceinline vint4 shuffle(const vint4& v) { + return shuffle(v); + } + +#if defined(__SSE4_1__) + template __forceinline int extract(const vint4& b) { return _mm_extract_epi32(b, src); } + template __forceinline vint4 insert(const vint4& a, const int b) { return _mm_insert_epi32(a, b, dst); } +#else + template __forceinline int extract(const vint4& b) { return b[src&3]; } + template __forceinline vint4 insert(const vint4& a, int b) { vint4 c = a; c[dst&3] = b; return c; } +#endif + + + template<> __forceinline int extract<0>(const vint4& b) { return _mm_cvtsi128_si32(b); } + + __forceinline int toScalar(const vint4& v) { return _mm_cvtsi128_si32(v); } + + __forceinline size_t toSizeT(const vint4& v) { +#if defined(__WIN32__) && !defined(__X86_64__) // win32 workaround + return toScalar(v); +#else + return _mm_cvtsi128_si64(v); +#endif + } + +#if defined(__AVX512VL__) + + __forceinline vint4 permute(const vint4 &a, const vint4 &index) { + return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(a),index)); + } + + template + __forceinline vint4 align_shift_right(const vint4& a, const vint4& b) { + return _mm_alignr_epi32(a, b, i); + } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__SSE4_1__) + __forceinline vint4 vreduce_min(const vint4& v) { vint4 h = min(shuffle<1,0,3,2>(v),v); return min(shuffle<2,3,0,1>(h),h); } + __forceinline vint4 vreduce_max(const vint4& v) { vint4 h = max(shuffle<1,0,3,2>(v),v); return max(shuffle<2,3,0,1>(h),h); } + __forceinline vint4 vreduce_add(const vint4& v) { vint4 h = shuffle<1,0,3,2>(v) + v ; return shuffle<2,3,0,1>(h) + h ; } + + __forceinline int reduce_min(const vint4& v) { return toScalar(vreduce_min(v)); } + __forceinline int reduce_max(const vint4& v) { return toScalar(vreduce_max(v)); } + __forceinline int reduce_add(const vint4& v) { return toScalar(vreduce_add(v)); } + + __forceinline size_t select_min(const vint4& v) { return bsf(movemask(v == vreduce_min(v))); } + __forceinline size_t select_max(const vint4& v) { return bsf(movemask(v == vreduce_max(v))); } + + __forceinline size_t select_min(const vboolf4& valid, const vint4& v) { const vint4 a = select(valid,v,vint4(pos_inf)); return bsf(movemask(valid & (a == vreduce_min(a)))); } + __forceinline size_t select_max(const vboolf4& valid, const vint4& v) { const vint4 a = select(valid,v,vint4(neg_inf)); return bsf(movemask(valid & (a == vreduce_max(a)))); } + +#else + + __forceinline int reduce_min(const vint4& v) { return min(v[0],v[1],v[2],v[3]); } + __forceinline int reduce_max(const vint4& v) { return max(v[0],v[1],v[2],v[3]); } + __forceinline int reduce_add(const vint4& v) { return v[0]+v[1]+v[2]+v[3]; } + +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Sorting networks + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__SSE4_1__) + + __forceinline vint4 usort_ascending(const vint4& v) + { + const vint4 a0 = v; + const vint4 b0 = shuffle<1,0,3,2>(a0); + const vint4 c0 = umin(a0,b0); + const vint4 d0 = umax(a0,b0); + const vint4 a1 = select<0x5 /* 0b0101 */>(c0,d0); + const vint4 b1 = shuffle<2,3,0,1>(a1); + const vint4 c1 = umin(a1,b1); + const vint4 d1 = umax(a1,b1); + const vint4 a2 = select<0x3 /* 0b0011 */>(c1,d1); + const vint4 b2 = shuffle<0,2,1,3>(a2); + const vint4 c2 = umin(a2,b2); + const vint4 d2 = umax(a2,b2); + const vint4 a3 = select<0x2 /* 0b0010 */>(c2,d2); + return a3; + } + + __forceinline vint4 usort_descending(const vint4& v) + { + const vint4 a0 = v; + const vint4 b0 = shuffle<1,0,3,2>(a0); + const vint4 c0 = umax(a0,b0); + const vint4 d0 = umin(a0,b0); + const vint4 a1 = select<0x5 /* 0b0101 */>(c0,d0); + const vint4 b1 = shuffle<2,3,0,1>(a1); + const vint4 c1 = umax(a1,b1); + const vint4 d1 = umin(a1,b1); + const vint4 a2 = select<0x3 /* 0b0011 */>(c1,d1); + const vint4 b2 = shuffle<0,2,1,3>(a2); + const vint4 c2 = umax(a2,b2); + const vint4 d2 = umin(a2,b2); + const vint4 a3 = select<0x2 /* 0b0010 */>(c2,d2); + return a3; + } + +#else + + __forceinline vint4 usort_ascending(const vint4& v) + { + const vint4 a0 = v-vint4(0x80000000); + const vint4 b0 = shuffle<1,0,3,2>(a0); + const vint4 c0 = min(a0,b0); + const vint4 d0 = max(a0,b0); + const vint4 a1 = select<0x5 /* 0b0101 */>(c0,d0); + const vint4 b1 = shuffle<2,3,0,1>(a1); + const vint4 c1 = min(a1,b1); + const vint4 d1 = max(a1,b1); + const vint4 a2 = select<0x3 /* 0b0011 */>(c1,d1); + const vint4 b2 = shuffle<0,2,1,3>(a2); + const vint4 c2 = min(a2,b2); + const vint4 d2 = max(a2,b2); + const vint4 a3 = select<0x2 /* 0b0010 */>(c2,d2); + return a3+vint4(0x80000000); + } + + __forceinline vint4 usort_descending(const vint4& v) + { + const vint4 a0 = v-vint4(0x80000000); + const vint4 b0 = shuffle<1,0,3,2>(a0); + const vint4 c0 = max(a0,b0); + const vint4 d0 = min(a0,b0); + const vint4 a1 = select<0x5 /* 0b0101 */>(c0,d0); + const vint4 b1 = shuffle<2,3,0,1>(a1); + const vint4 c1 = max(a1,b1); + const vint4 d1 = min(a1,b1); + const vint4 a2 = select<0x3 /* 0b0011 */>(c1,d1); + const vint4 b2 = shuffle<0,2,1,3>(a2); + const vint4 c2 = max(a2,b2); + const vint4 d2 = min(a2,b2); + const vint4 a3 = select<0x2 /* 0b0010 */>(c2,d2); + return a3+vint4(0x80000000); + } + +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vint4& a) { + return cout << "<" << a[0] << ", " << a[1] << ", " << a[2] << ", " << a[3] << ">"; + } +} + diff --git a/thirdparty/embree/common/simd/vint8_avx.h b/thirdparty/embree/common/simd/vint8_avx.h new file mode 100644 index 000000000000..c373907e9c32 --- /dev/null +++ b/thirdparty/embree/common/simd/vint8_avx.h @@ -0,0 +1,459 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 8-wide AVX integer type */ + template<> + struct vint<8> + { + ALIGNED_STRUCT_(32); + + typedef vboolf8 Bool; + typedef vint8 Int; + typedef vfloat8 Float; + + enum { size = 8 }; // number of SIMD elements + union { // data + __m256i v; + struct { __m128i vl,vh; }; + int i[8]; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint() {} + __forceinline vint(const vint8& a) { v = a.v; } + __forceinline vint8& operator =(const vint8& a) { v = a.v; return *this; } + + __forceinline vint(__m256i a) : v(a) {} + __forceinline operator const __m256i&() const { return v; } + __forceinline operator __m256i&() { return v; } + + __forceinline explicit vint(const vint4& a) : v(_mm256_insertf128_si256(_mm256_castsi128_si256(a),a,1)) {} + __forceinline vint(const vint4& a, const vint4& b) : v(_mm256_insertf128_si256(_mm256_castsi128_si256(a),b,1)) {} + __forceinline vint(const __m128i& a, const __m128i& b) : vl(a), vh(b) {} + + __forceinline explicit vint(const int* a) : v(_mm256_castps_si256(_mm256_loadu_ps((const float*)a))) {} + __forceinline vint(int a) : v(_mm256_set1_epi32(a)) {} + __forceinline vint(int a, int b) : v(_mm256_set_epi32(b, a, b, a, b, a, b, a)) {} + __forceinline vint(int a, int b, int c, int d) : v(_mm256_set_epi32(d, c, b, a, d, c, b, a)) {} + __forceinline vint(int a, int b, int c, int d, int e, int f, int g, int vh) : v(_mm256_set_epi32(vh, g, f, e, d, c, b, a)) {} + + __forceinline explicit vint(__m256 a) : v(_mm256_cvtps_epi32(a)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint(ZeroTy) : v(_mm256_setzero_si256()) {} + __forceinline vint(OneTy) : v(_mm256_set_epi32(1,1,1,1,1,1,1,1)) {} + __forceinline vint(PosInfTy) : v(_mm256_set_epi32(pos_inf,pos_inf,pos_inf,pos_inf,pos_inf,pos_inf,pos_inf,pos_inf)) {} + __forceinline vint(NegInfTy) : v(_mm256_set_epi32(neg_inf,neg_inf,neg_inf,neg_inf,neg_inf,neg_inf,neg_inf,neg_inf)) {} + __forceinline vint(StepTy) : v(_mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0)) {} + __forceinline vint(ReverseStepTy) : v(_mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7)) {} + __forceinline vint(UndefinedTy) : v(_mm256_undefined_si256()) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline vint8 load (const void* a) { return _mm256_castps_si256(_mm256_load_ps((float*)a)); } + static __forceinline vint8 loadu(const void* a) { return _mm256_castps_si256(_mm256_loadu_ps((float*)a)); } + + static __forceinline vint8 load (const vboolf8& mask, const void* a) { return _mm256_castps_si256(_mm256_maskload_ps((float*)a,mask)); } + static __forceinline vint8 loadu(const vboolf8& mask, const void* a) { return _mm256_castps_si256(_mm256_maskload_ps((float*)a,mask)); } + + static __forceinline void store (void* ptr, const vint8& f) { _mm256_store_ps((float*)ptr,_mm256_castsi256_ps(f)); } + static __forceinline void storeu(void* ptr, const vint8& f) { _mm256_storeu_ps((float*)ptr,_mm256_castsi256_ps(f)); } + + static __forceinline void store (const vboolf8& mask, void* ptr, const vint8& f) { _mm256_maskstore_ps((float*)ptr,(__m256i)mask,_mm256_castsi256_ps(f)); } + static __forceinline void storeu(const vboolf8& mask, void* ptr, const vint8& f) { _mm256_maskstore_ps((float*)ptr,(__m256i)mask,_mm256_castsi256_ps(f)); } + + static __forceinline void store_nt(void* ptr, const vint8& v) { + _mm256_stream_ps((float*)ptr,_mm256_castsi256_ps(v)); + } + + static __forceinline vint8 load(const unsigned char* ptr) { + vint4 il = vint4::load(ptr+0); + vint4 ih = vint4::load(ptr+4); + return vint8(il,ih); + } + + static __forceinline vint8 loadu(const unsigned char* ptr) { + vint4 il = vint4::loadu(ptr+0); + vint4 ih = vint4::loadu(ptr+4); + return vint8(il,ih); + } + + static __forceinline vint8 load(const unsigned short* ptr) { + vint4 il = vint4::load(ptr+0); + vint4 ih = vint4::load(ptr+4); + return vint8(il,ih); + } + + static __forceinline vint8 loadu(const unsigned short* ptr) { + vint4 il = vint4::loadu(ptr+0); + vint4 ih = vint4::loadu(ptr+4); + return vint8(il,ih); + } + + static __forceinline void store(unsigned char* ptr, const vint8& i) { + vint4 il(i.vl); + vint4 ih(i.vh); + vint4::store(ptr + 0,il); + vint4::store(ptr + 4,ih); + } + + static __forceinline void store(unsigned short* ptr, const vint8& v) { + for (size_t i=0;i<8;i++) + ptr[i] = (unsigned short)v[i]; + } + + template + static __forceinline vint8 gather(const int* ptr, const vint8& index) { + return vint8( + *(int*)(((char*)ptr)+scale*index[0]), + *(int*)(((char*)ptr)+scale*index[1]), + *(int*)(((char*)ptr)+scale*index[2]), + *(int*)(((char*)ptr)+scale*index[3]), + *(int*)(((char*)ptr)+scale*index[4]), + *(int*)(((char*)ptr)+scale*index[5]), + *(int*)(((char*)ptr)+scale*index[6]), + *(int*)(((char*)ptr)+scale*index[7])); + } + + template + static __forceinline vint8 gather(const vboolf8& mask, const int* ptr, const vint8& index) { + vint8 r = zero; + if (likely(mask[0])) r[0] = *(int*)(((char*)ptr)+scale*index[0]); + if (likely(mask[1])) r[1] = *(int*)(((char*)ptr)+scale*index[1]); + if (likely(mask[2])) r[2] = *(int*)(((char*)ptr)+scale*index[2]); + if (likely(mask[3])) r[3] = *(int*)(((char*)ptr)+scale*index[3]); + if (likely(mask[4])) r[4] = *(int*)(((char*)ptr)+scale*index[4]); + if (likely(mask[5])) r[5] = *(int*)(((char*)ptr)+scale*index[5]); + if (likely(mask[6])) r[6] = *(int*)(((char*)ptr)+scale*index[6]); + if (likely(mask[7])) r[7] = *(int*)(((char*)ptr)+scale*index[7]); + return r; + } + + template + static __forceinline void scatter(void* ptr, const vint8& ofs, const vint8& v) + { + *(int*)(((char*)ptr)+scale*ofs[0]) = v[0]; + *(int*)(((char*)ptr)+scale*ofs[1]) = v[1]; + *(int*)(((char*)ptr)+scale*ofs[2]) = v[2]; + *(int*)(((char*)ptr)+scale*ofs[3]) = v[3]; + *(int*)(((char*)ptr)+scale*ofs[4]) = v[4]; + *(int*)(((char*)ptr)+scale*ofs[5]) = v[5]; + *(int*)(((char*)ptr)+scale*ofs[6]) = v[6]; + *(int*)(((char*)ptr)+scale*ofs[7]) = v[7]; + } + + template + static __forceinline void scatter(const vboolf8& mask, void* ptr, const vint8& ofs, const vint8& v) + { + if (likely(mask[0])) *(int*)(((char*)ptr)+scale*ofs[0]) = v[0]; + if (likely(mask[1])) *(int*)(((char*)ptr)+scale*ofs[1]) = v[1]; + if (likely(mask[2])) *(int*)(((char*)ptr)+scale*ofs[2]) = v[2]; + if (likely(mask[3])) *(int*)(((char*)ptr)+scale*ofs[3]) = v[3]; + if (likely(mask[4])) *(int*)(((char*)ptr)+scale*ofs[4]) = v[4]; + if (likely(mask[5])) *(int*)(((char*)ptr)+scale*ofs[5]) = v[5]; + if (likely(mask[6])) *(int*)(((char*)ptr)+scale*ofs[6]) = v[6]; + if (likely(mask[7])) *(int*)(((char*)ptr)+scale*ofs[7]) = v[7]; + } + + + static __forceinline vint8 broadcast64(const long long& a) { return _mm256_set1_epi64x(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const int& operator [](size_t index) const { assert(index < 8); return i[index]; } + __forceinline int& operator [](size_t index) { assert(index < 8); return i[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf8 asBool(const vint8& a) { return _mm256_castsi256_ps(a); } + + __forceinline vint8 operator +(const vint8& a) { return a; } + __forceinline vint8 operator -(const vint8& a) { return vint8(_mm_sub_epi32(_mm_setzero_si128(), a.vl), _mm_sub_epi32(_mm_setzero_si128(), a.vh)); } + __forceinline vint8 abs (const vint8& a) { return vint8(_mm_abs_epi32(a.vl), _mm_abs_epi32(a.vh)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint8 operator +(const vint8& a, const vint8& b) { return vint8(_mm_add_epi32(a.vl, b.vl), _mm_add_epi32(a.vh, b.vh)); } + __forceinline vint8 operator +(const vint8& a, int b) { return a + vint8(b); } + __forceinline vint8 operator +(int a, const vint8& b) { return vint8(a) + b; } + + __forceinline vint8 operator -(const vint8& a, const vint8& b) { return vint8(_mm_sub_epi32(a.vl, b.vl), _mm_sub_epi32(a.vh, b.vh)); } + __forceinline vint8 operator -(const vint8& a, int b) { return a - vint8(b); } + __forceinline vint8 operator -(int a, const vint8& b) { return vint8(a) - b; } + + __forceinline vint8 operator *(const vint8& a, const vint8& b) { return vint8(_mm_mullo_epi32(a.vl, b.vl), _mm_mullo_epi32(a.vh, b.vh)); } + __forceinline vint8 operator *(const vint8& a, int b) { return a * vint8(b); } + __forceinline vint8 operator *(int a, const vint8& b) { return vint8(a) * b; } + + __forceinline vint8 operator &(const vint8& a, const vint8& b) { return _mm256_castps_si256(_mm256_and_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b))); } + __forceinline vint8 operator &(const vint8& a, int b) { return a & vint8(b); } + __forceinline vint8 operator &(int a, const vint8& b) { return vint8(a) & b; } + + __forceinline vint8 operator |(const vint8& a, const vint8& b) { return _mm256_castps_si256(_mm256_or_ps (_mm256_castsi256_ps(a), _mm256_castsi256_ps(b))); } + __forceinline vint8 operator |(const vint8& a, int b) { return a | vint8(b); } + __forceinline vint8 operator |(int a, const vint8& b) { return vint8(a) | b; } + + __forceinline vint8 operator ^(const vint8& a, const vint8& b) { return _mm256_castps_si256(_mm256_xor_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b))); } + __forceinline vint8 operator ^(const vint8& a, int b) { return a ^ vint8(b); } + __forceinline vint8 operator ^(int a, const vint8& b) { return vint8(a) ^ b; } + + __forceinline vint8 operator <<(const vint8& a, int n) { return vint8(_mm_slli_epi32(a.vl, n), _mm_slli_epi32(a.vh, n)); } + __forceinline vint8 operator >>(const vint8& a, int n) { return vint8(_mm_srai_epi32(a.vl, n), _mm_srai_epi32(a.vh, n)); } + + __forceinline vint8 sll (const vint8& a, int b) { return vint8(_mm_slli_epi32(a.vl, b), _mm_slli_epi32(a.vh, b)); } + __forceinline vint8 sra (const vint8& a, int b) { return vint8(_mm_srai_epi32(a.vl, b), _mm_srai_epi32(a.vh, b)); } + __forceinline vint8 srl (const vint8& a, int b) { return vint8(_mm_srli_epi32(a.vl, b), _mm_srli_epi32(a.vh, b)); } + + __forceinline vint8 min(const vint8& a, const vint8& b) { return vint8(_mm_min_epi32(a.vl, b.vl), _mm_min_epi32(a.vh, b.vh)); } + __forceinline vint8 min(const vint8& a, int b) { return min(a,vint8(b)); } + __forceinline vint8 min(int a, const vint8& b) { return min(vint8(a),b); } + + __forceinline vint8 max(const vint8& a, const vint8& b) { return vint8(_mm_max_epi32(a.vl, b.vl), _mm_max_epi32(a.vh, b.vh)); } + __forceinline vint8 max(const vint8& a, int b) { return max(a,vint8(b)); } + __forceinline vint8 max(int a, const vint8& b) { return max(vint8(a),b); } + + __forceinline vint8 umin(const vint8& a, const vint8& b) { return vint8(_mm_min_epu32(a.vl, b.vl), _mm_min_epu32(a.vh, b.vh)); } + __forceinline vint8 umin(const vint8& a, int b) { return umin(a,vint8(b)); } + __forceinline vint8 umin(int a, const vint8& b) { return umin(vint8(a),b); } + + __forceinline vint8 umax(const vint8& a, const vint8& b) { return vint8(_mm_max_epu32(a.vl, b.vl), _mm_max_epu32(a.vh, b.vh)); } + __forceinline vint8 umax(const vint8& a, int b) { return umax(a,vint8(b)); } + __forceinline vint8 umax(int a, const vint8& b) { return umax(vint8(a),b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint8& operator +=(vint8& a, const vint8& b) { return a = a + b; } + __forceinline vint8& operator +=(vint8& a, int b) { return a = a + b; } + + __forceinline vint8& operator -=(vint8& a, const vint8& b) { return a = a - b; } + __forceinline vint8& operator -=(vint8& a, int b) { return a = a - b; } + + __forceinline vint8& operator *=(vint8& a, const vint8& b) { return a = a * b; } + __forceinline vint8& operator *=(vint8& a, int b) { return a = a * b; } + + __forceinline vint8& operator &=(vint8& a, const vint8& b) { return a = a & b; } + __forceinline vint8& operator &=(vint8& a, int b) { return a = a & b; } + + __forceinline vint8& operator |=(vint8& a, const vint8& b) { return a = a | b; } + __forceinline vint8& operator |=(vint8& a, int b) { return a = a | b; } + + __forceinline vint8& operator <<=(vint8& a, int b) { return a = a << b; } + __forceinline vint8& operator >>=(vint8& a, int b) { return a = a >> b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf8 operator ==(const vint8& a, const vint8& b) { return vboolf8(_mm_castsi128_ps(_mm_cmpeq_epi32 (a.vl, b.vl)), + _mm_castsi128_ps(_mm_cmpeq_epi32 (a.vh, b.vh))); } + __forceinline vboolf8 operator ==(const vint8& a, int b) { return a == vint8(b); } + __forceinline vboolf8 operator ==(int a, const vint8& b) { return vint8(a) == b; } + + __forceinline vboolf8 operator !=(const vint8& a, const vint8& b) { return !(a == b); } + __forceinline vboolf8 operator !=(const vint8& a, int b) { return a != vint8(b); } + __forceinline vboolf8 operator !=(int a, const vint8& b) { return vint8(a) != b; } + + __forceinline vboolf8 operator < (const vint8& a, const vint8& b) { return vboolf8(_mm_castsi128_ps(_mm_cmplt_epi32 (a.vl, b.vl)), + _mm_castsi128_ps(_mm_cmplt_epi32 (a.vh, b.vh))); } + __forceinline vboolf8 operator < (const vint8& a, int b) { return a < vint8(b); } + __forceinline vboolf8 operator < (int a, const vint8& b) { return vint8(a) < b; } + + __forceinline vboolf8 operator >=(const vint8& a, const vint8& b) { return !(a < b); } + __forceinline vboolf8 operator >=(const vint8& a, int b) { return a >= vint8(b); } + __forceinline vboolf8 operator >=(int a, const vint8& b) { return vint8(a) >= b; } + + __forceinline vboolf8 operator > (const vint8& a, const vint8& b) { return vboolf8(_mm_castsi128_ps(_mm_cmpgt_epi32 (a.vl, b.vl)), + _mm_castsi128_ps(_mm_cmpgt_epi32 (a.vh, b.vh))); } + __forceinline vboolf8 operator > (const vint8& a, int b) { return a > vint8(b); } + __forceinline vboolf8 operator > (int a, const vint8& b) { return vint8(a) > b; } + + __forceinline vboolf8 operator <=(const vint8& a, const vint8& b) { return !(a > b); } + __forceinline vboolf8 operator <=(const vint8& a, int b) { return a <= vint8(b); } + __forceinline vboolf8 operator <=(int a, const vint8& b) { return vint8(a) <= b; } + + __forceinline vboolf8 eq(const vint8& a, const vint8& b) { return a == b; } + __forceinline vboolf8 ne(const vint8& a, const vint8& b) { return a != b; } + __forceinline vboolf8 lt(const vint8& a, const vint8& b) { return a < b; } + __forceinline vboolf8 ge(const vint8& a, const vint8& b) { return a >= b; } + __forceinline vboolf8 gt(const vint8& a, const vint8& b) { return a > b; } + __forceinline vboolf8 le(const vint8& a, const vint8& b) { return a <= b; } + + __forceinline vboolf8 eq(const vboolf8& mask, const vint8& a, const vint8& b) { return mask & (a == b); } + __forceinline vboolf8 ne(const vboolf8& mask, const vint8& a, const vint8& b) { return mask & (a != b); } + __forceinline vboolf8 lt(const vboolf8& mask, const vint8& a, const vint8& b) { return mask & (a < b); } + __forceinline vboolf8 ge(const vboolf8& mask, const vint8& a, const vint8& b) { return mask & (a >= b); } + __forceinline vboolf8 gt(const vboolf8& mask, const vint8& a, const vint8& b) { return mask & (a > b); } + __forceinline vboolf8 le(const vboolf8& mask, const vint8& a, const vint8& b) { return mask & (a <= b); } + + __forceinline vint8 select(const vboolf8& m, const vint8& t, const vint8& f) { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(f), _mm256_castsi256_ps(t), m)); + } + + __forceinline vint8 notand(const vboolf8& m, const vint8& f) { + return _mm256_castps_si256(_mm256_andnot_ps(m, _mm256_castsi256_ps(f))); + } + + + //////////////////////////////////////////////////////////////////////////////// + /// Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint8 unpacklo(const vint8& a, const vint8& b) { return _mm256_castps_si256(_mm256_unpacklo_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b))); } + __forceinline vint8 unpackhi(const vint8& a, const vint8& b) { return _mm256_castps_si256(_mm256_unpackhi_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b))); } + + template + __forceinline vint8 shuffle(const vint8& v) { + return _mm256_castps_si256(_mm256_permute_ps(_mm256_castsi256_ps(v), _MM_SHUFFLE(i, i, i, i))); + } + + template + __forceinline vint8 shuffle4(const vint8& v) { + return _mm256_permute2f128_si256(v, v, (i1 << 4) | (i0 << 0)); + } + + template + __forceinline vint8 shuffle4(const vint8& a, const vint8& b) { + return _mm256_permute2f128_si256(a, b, (i1 << 4) | (i0 << 0)); + } + + template + __forceinline vint8 shuffle(const vint8& v) { + return _mm256_castps_si256(_mm256_permute_ps(_mm256_castsi256_ps(v), _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template + __forceinline vint8 shuffle(const vint8& a, const vint8& b) { + return _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template<> __forceinline vint8 shuffle<0, 0, 2, 2>(const vint8& v) { return _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(v))); } + template<> __forceinline vint8 shuffle<1, 1, 3, 3>(const vint8& v) { return _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(v))); } + template<> __forceinline vint8 shuffle<0, 1, 0, 1>(const vint8& v) { return _mm256_castps_si256(_mm256_castpd_ps(_mm256_movedup_pd(_mm256_castps_pd(_mm256_castsi256_ps(v))))); } + + __forceinline vint8 broadcast(const int* ptr) { return _mm256_castps_si256(_mm256_broadcast_ss((const float*)ptr)); } + template __forceinline vint8 insert4(const vint8& a, const vint4& b) { return _mm256_insertf128_si256(a, b, i); } + template __forceinline vint4 extract4(const vint8& a) { return _mm256_extractf128_si256(a, i); } + template<> __forceinline vint4 extract4<0>(const vint8& a) { return _mm256_castsi256_si128(a); } + + __forceinline int toScalar(const vint8& v) { return _mm_cvtsi128_si32(_mm256_castsi256_si128(v)); } + + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint8 vreduce_min2(const vint8& v) { return min(v,shuffle<1,0,3,2>(v)); } + __forceinline vint8 vreduce_min4(const vint8& v) { vint8 v1 = vreduce_min2(v); return min(v1,shuffle<2,3,0,1>(v1)); } + __forceinline vint8 vreduce_min (const vint8& v) { vint8 v1 = vreduce_min4(v); return min(v1,shuffle4<1,0>(v1)); } + + __forceinline vint8 vreduce_max2(const vint8& v) { return max(v,shuffle<1,0,3,2>(v)); } + __forceinline vint8 vreduce_max4(const vint8& v) { vint8 v1 = vreduce_max2(v); return max(v1,shuffle<2,3,0,1>(v1)); } + __forceinline vint8 vreduce_max (const vint8& v) { vint8 v1 = vreduce_max4(v); return max(v1,shuffle4<1,0>(v1)); } + + __forceinline vint8 vreduce_add2(const vint8& v) { return v + shuffle<1,0,3,2>(v); } + __forceinline vint8 vreduce_add4(const vint8& v) { vint8 v1 = vreduce_add2(v); return v1 + shuffle<2,3,0,1>(v1); } + __forceinline vint8 vreduce_add (const vint8& v) { vint8 v1 = vreduce_add4(v); return v1 + shuffle4<1,0>(v1); } + + __forceinline int reduce_min(const vint8& v) { return toScalar(vreduce_min(v)); } + __forceinline int reduce_max(const vint8& v) { return toScalar(vreduce_max(v)); } + __forceinline int reduce_add(const vint8& v) { return toScalar(vreduce_add(v)); } + + __forceinline size_t select_min(const vint8& v) { return bsf(movemask(v == vreduce_min(v))); } + __forceinline size_t select_max(const vint8& v) { return bsf(movemask(v == vreduce_max(v))); } + + __forceinline size_t select_min(const vboolf8& valid, const vint8& v) { const vint8 a = select(valid,v,vint8(pos_inf)); return bsf(movemask(valid & (a == vreduce_min(a)))); } + __forceinline size_t select_max(const vboolf8& valid, const vint8& v) { const vint8 a = select(valid,v,vint8(neg_inf)); return bsf(movemask(valid & (a == vreduce_max(a)))); } + + //////////////////////////////////////////////////////////////////////////////// + /// Sorting networks + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint8 usort_ascending(const vint8& v) + { + const vint8 a0 = v; + const vint8 b0 = shuffle<1,0,3,2>(a0); + const vint8 c0 = umin(a0,b0); + const vint8 d0 = umax(a0,b0); + const vint8 a1 = select(0x99 /* 0b10011001 */,c0,d0); + const vint8 b1 = shuffle<2,3,0,1>(a1); + const vint8 c1 = umin(a1,b1); + const vint8 d1 = umax(a1,b1); + const vint8 a2 = select(0xc3 /* 0b11000011 */,c1,d1); + const vint8 b2 = shuffle<1,0,3,2>(a2); + const vint8 c2 = umin(a2,b2); + const vint8 d2 = umax(a2,b2); + const vint8 a3 = select(0xa5 /* 0b10100101 */,c2,d2); + const vint8 b3 = shuffle4<1,0>(a3); + const vint8 c3 = umin(a3,b3); + const vint8 d3 = umax(a3,b3); + const vint8 a4 = select(0xf /* 0b00001111 */,c3,d3); + const vint8 b4 = shuffle<2,3,0,1>(a4); + const vint8 c4 = umin(a4,b4); + const vint8 d4 = umax(a4,b4); + const vint8 a5 = select(0x33 /* 0b00110011 */,c4,d4); + const vint8 b5 = shuffle<1,0,3,2>(a5); + const vint8 c5 = umin(a5,b5); + const vint8 d5 = umax(a5,b5); + const vint8 a6 = select(0x55 /* 0b01010101 */,c5,d5); + return a6; + } + + __forceinline vint8 usort_descending(const vint8& v) + { + const vint8 a0 = v; + const vint8 b0 = shuffle<1,0,3,2>(a0); + const vint8 c0 = umax(a0,b0); + const vint8 d0 = umin(a0,b0); + const vint8 a1 = select(0x99 /* 0b10011001 */,c0,d0); + const vint8 b1 = shuffle<2,3,0,1>(a1); + const vint8 c1 = umax(a1,b1); + const vint8 d1 = umin(a1,b1); + const vint8 a2 = select(0xc3 /* 0b11000011 */,c1,d1); + const vint8 b2 = shuffle<1,0,3,2>(a2); + const vint8 c2 = umax(a2,b2); + const vint8 d2 = umin(a2,b2); + const vint8 a3 = select(0xa5 /* 0b10100101 */,c2,d2); + const vint8 b3 = shuffle4<1,0>(a3); + const vint8 c3 = umax(a3,b3); + const vint8 d3 = umin(a3,b3); + const vint8 a4 = select(0xf /* 0b00001111 */,c3,d3); + const vint8 b4 = shuffle<2,3,0,1>(a4); + const vint8 c4 = umax(a4,b4); + const vint8 d4 = umin(a4,b4); + const vint8 a5 = select(0x33 /* 0b00110011 */,c4,d4); + const vint8 b5 = shuffle<1,0,3,2>(a5); + const vint8 c5 = umax(a5,b5); + const vint8 d5 = umin(a5,b5); + const vint8 a6 = select(0x55 /* 0b01010101 */,c5,d5); + return a6; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vint8& a) { + return cout << "<" << a[0] << ", " << a[1] << ", " << a[2] << ", " << a[3] << ", " << a[4] << ", " << a[5] << ", " << a[6] << ", " << a[7] << ">"; + } +} diff --git a/thirdparty/embree/common/simd/vint8_avx2.h b/thirdparty/embree/common/simd/vint8_avx2.h new file mode 100644 index 000000000000..ea97d3eb346e --- /dev/null +++ b/thirdparty/embree/common/simd/vint8_avx2.h @@ -0,0 +1,505 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 8-wide AVX integer type */ + template<> + struct vint<8> + { + ALIGNED_STRUCT_(32); + + typedef vboolf8 Bool; + typedef vint8 Int; + typedef vfloat8 Float; + + enum { size = 8 }; // number of SIMD elements + union { // data + __m256i v; + int i[8]; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint() {} + __forceinline vint(const vint8& a) { v = a.v; } + __forceinline vint8& operator =(const vint8& a) { v = a.v; return *this; } + + __forceinline vint(__m256i a) : v(a) {} + __forceinline operator const __m256i&() const { return v; } + __forceinline operator __m256i&() { return v; } + + __forceinline explicit vint(const vint4& a) : v(_mm256_insertf128_si256(_mm256_castsi128_si256(a),a,1)) {} + __forceinline vint(const vint4& a, const vint4& b) : v(_mm256_insertf128_si256(_mm256_castsi128_si256(a),b,1)) {} + __forceinline vint(const __m128i& a, const __m128i& b) : v(_mm256_insertf128_si256(_mm256_castsi128_si256(a),b,1)) {} + + __forceinline explicit vint(const int* a) : v(_mm256_castps_si256(_mm256_loadu_ps((const float*)a))) {} + __forceinline vint(int a) : v(_mm256_set1_epi32(a)) {} + __forceinline vint(int a, int b) : v(_mm256_set_epi32(b, a, b, a, b, a, b, a)) {} + __forceinline vint(int a, int b, int c, int d) : v(_mm256_set_epi32(d, c, b, a, d, c, b, a)) {} + __forceinline vint(int a, int b, int c, int d, int e, int f, int g, int h) : v(_mm256_set_epi32(h, g, f, e, d, c, b, a)) {} + + __forceinline explicit vint(__m256 a) : v(_mm256_cvtps_epi32(a)) {} + +#if defined(__AVX512VL__) + __forceinline explicit vint(const vboolf8& a) : v(_mm256_movm_epi32(a)) {} +#else + __forceinline explicit vint(const vboolf8& a) : v(_mm256_castps_si256((__m256)a)) {} +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint(ZeroTy) : v(_mm256_setzero_si256()) {} + __forceinline vint(OneTy) : v(_mm256_set1_epi32(1)) {} + __forceinline vint(PosInfTy) : v(_mm256_set1_epi32(pos_inf)) {} + __forceinline vint(NegInfTy) : v(_mm256_set1_epi32(neg_inf)) {} + __forceinline vint(StepTy) : v(_mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0)) {} + __forceinline vint(ReverseStepTy) : v(_mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7)) {} + __forceinline vint(UndefinedTy) : v(_mm256_undefined_si256()) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline vint8 load(const unsigned char* ptr) { return _mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)ptr)); } + static __forceinline vint8 loadu(const unsigned char* ptr) { return _mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)ptr)); } + static __forceinline vint8 load(const unsigned short* ptr) { return _mm256_cvtepu16_epi32(_mm_load_si128((__m128i*)ptr)); } + static __forceinline vint8 loadu(const unsigned short* ptr) { return _mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i*)ptr)); } + + static __forceinline vint8 load(const void* ptr) { return _mm256_load_si256((__m256i*)ptr); } + static __forceinline vint8 loadu(const void* ptr) { return _mm256_loadu_si256((__m256i*)ptr); } + + static __forceinline void store (void* ptr, const vint8& v) { _mm256_store_si256((__m256i*)ptr,v); } + static __forceinline void storeu(void* ptr, const vint8& v) { _mm256_storeu_ps((float*)ptr,_mm256_castsi256_ps(v)); } + +#if defined(__AVX512VL__) + + static __forceinline vint8 compact(const vboolf8& mask, vint8 &v) { + return _mm256_mask_compress_epi32(v, mask, v); + } + static __forceinline vint8 compact(const vboolf8& mask, vint8 &a, const vint8& b) { + return _mm256_mask_compress_epi32(a, mask, b); + } + + static __forceinline vint8 load (const vboolf8& mask, const void* ptr) { return _mm256_mask_load_epi32 (_mm256_setzero_si256(),mask,ptr); } + static __forceinline vint8 loadu(const vboolf8& mask, const void* ptr) { return _mm256_mask_loadu_epi32(_mm256_setzero_si256(),mask,ptr); } + + static __forceinline void store (const vboolf8& mask, void* ptr, const vint8& v) { _mm256_mask_store_epi32 (ptr,mask,v); } + static __forceinline void storeu(const vboolf8& mask, void* ptr, const vint8& v) { _mm256_mask_storeu_epi32(ptr,mask,v); } +#else + static __forceinline vint8 load (const vboolf8& mask, const void* ptr) { return _mm256_castps_si256(_mm256_maskload_ps((float*)ptr,mask)); } + static __forceinline vint8 loadu(const vboolf8& mask, const void* ptr) { return _mm256_castps_si256(_mm256_maskload_ps((float*)ptr,mask)); } + + static __forceinline void store (const vboolf8& mask, void* ptr, const vint8& v) { _mm256_maskstore_epi32((int*)ptr,mask,v); } + static __forceinline void storeu(const vboolf8& mask, void* ptr, const vint8& v) { _mm256_maskstore_epi32((int*)ptr,mask,v); } +#endif + + static __forceinline vint8 load_nt(void* ptr) { + return _mm256_stream_load_si256((__m256i*)ptr); + } + + static __forceinline void store_nt(void* ptr, const vint8& v) { + _mm256_stream_ps((float*)ptr,_mm256_castsi256_ps(v)); + } + + static __forceinline void store(unsigned char* ptr, const vint8& i) + { + for (size_t j=0; j<8; j++) + ptr[j] = i[j]; + } + + static __forceinline void store(unsigned short* ptr, const vint8& v) { + for (size_t i=0;i<8;i++) + ptr[i] = (unsigned short)v[i]; + } + + template + static __forceinline vint8 gather(const int *const ptr, const vint8& index) { + return _mm256_i32gather_epi32(ptr, index, scale); + } + + template + static __forceinline vint8 gather(const vboolf8& mask, const int *const ptr, const vint8& index) { + vint8 r = zero; +#if defined(__AVX512VL__) + return _mm256_mmask_i32gather_epi32(r, mask, index, ptr, scale); +#else + return _mm256_mask_i32gather_epi32(r, ptr, index, mask, scale); +#endif + } + + template + static __forceinline void scatter(void* ptr, const vint8& ofs, const vint8& v) + { +#if defined(__AVX512VL__) + _mm256_i32scatter_epi32((int*)ptr, ofs, v, scale); +#else + *(int*)(((char*)ptr)+scale*ofs[0]) = v[0]; + *(int*)(((char*)ptr)+scale*ofs[1]) = v[1]; + *(int*)(((char*)ptr)+scale*ofs[2]) = v[2]; + *(int*)(((char*)ptr)+scale*ofs[3]) = v[3]; + *(int*)(((char*)ptr)+scale*ofs[4]) = v[4]; + *(int*)(((char*)ptr)+scale*ofs[5]) = v[5]; + *(int*)(((char*)ptr)+scale*ofs[6]) = v[6]; + *(int*)(((char*)ptr)+scale*ofs[7]) = v[7]; +#endif + } + + template + static __forceinline void scatter(const vboolf8& mask, void* ptr, const vint8& ofs, const vint8& v) + { +#if defined(__AVX512VL__) + _mm256_mask_i32scatter_epi32((int*)ptr, mask, ofs, v, scale); +#else + if (likely(mask[0])) *(int*)(((char*)ptr)+scale*ofs[0]) = v[0]; + if (likely(mask[1])) *(int*)(((char*)ptr)+scale*ofs[1]) = v[1]; + if (likely(mask[2])) *(int*)(((char*)ptr)+scale*ofs[2]) = v[2]; + if (likely(mask[3])) *(int*)(((char*)ptr)+scale*ofs[3]) = v[3]; + if (likely(mask[4])) *(int*)(((char*)ptr)+scale*ofs[4]) = v[4]; + if (likely(mask[5])) *(int*)(((char*)ptr)+scale*ofs[5]) = v[5]; + if (likely(mask[6])) *(int*)(((char*)ptr)+scale*ofs[6]) = v[6]; + if (likely(mask[7])) *(int*)(((char*)ptr)+scale*ofs[7]) = v[7]; +#endif + } + + static __forceinline vint8 broadcast64(const long long &a) { return _mm256_set1_epi64x(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const int& operator [](size_t index) const { assert(index < 8); return i[index]; } + __forceinline int& operator [](size_t index) { assert(index < 8); return i[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX512VL__) + static __forceinline vboolf8 asBool(const vint8& a) { return _mm256_movepi32_mask(a); } +#else + static __forceinline vboolf8 asBool(const vint8& a) { return _mm256_castsi256_ps(a); } +#endif + + __forceinline vint8 operator +(const vint8& a) { return a; } + __forceinline vint8 operator -(const vint8& a) { return _mm256_sub_epi32(_mm256_setzero_si256(), a); } + __forceinline vint8 abs (const vint8& a) { return _mm256_abs_epi32(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint8 operator +(const vint8& a, const vint8& b) { return _mm256_add_epi32(a, b); } + __forceinline vint8 operator +(const vint8& a, int b) { return a + vint8(b); } + __forceinline vint8 operator +(int a, const vint8& b) { return vint8(a) + b; } + + __forceinline vint8 operator -(const vint8& a, const vint8& b) { return _mm256_sub_epi32(a, b); } + __forceinline vint8 operator -(const vint8& a, int b) { return a - vint8(b); } + __forceinline vint8 operator -(int a, const vint8& b) { return vint8(a) - b; } + + __forceinline vint8 operator *(const vint8& a, const vint8& b) { return _mm256_mullo_epi32(a, b); } + __forceinline vint8 operator *(const vint8& a, int b) { return a * vint8(b); } + __forceinline vint8 operator *(int a, const vint8& b) { return vint8(a) * b; } + + __forceinline vint8 operator &(const vint8& a, const vint8& b) { return _mm256_and_si256(a, b); } + __forceinline vint8 operator &(const vint8& a, int b) { return a & vint8(b); } + __forceinline vint8 operator &(int a, const vint8& b) { return vint8(a) & b; } + + __forceinline vint8 operator |(const vint8& a, const vint8& b) { return _mm256_or_si256(a, b); } + __forceinline vint8 operator |(const vint8& a, int b) { return a | vint8(b); } + __forceinline vint8 operator |(int a, const vint8& b) { return vint8(a) | b; } + + __forceinline vint8 operator ^(const vint8& a, const vint8& b) { return _mm256_xor_si256(a, b); } + __forceinline vint8 operator ^(const vint8& a, int b) { return a ^ vint8(b); } + __forceinline vint8 operator ^(int a, const vint8& b) { return vint8(a) ^ b; } + + __forceinline vint8 operator <<(const vint8& a, int n) { return _mm256_slli_epi32(a, n); } + __forceinline vint8 operator >>(const vint8& a, int n) { return _mm256_srai_epi32(a, n); } + + __forceinline vint8 operator <<(const vint8& a, const vint8& n) { return _mm256_sllv_epi32(a, n); } + __forceinline vint8 operator >>(const vint8& a, const vint8& n) { return _mm256_srav_epi32(a, n); } + + __forceinline vint8 sll(const vint8& a, int b) { return _mm256_slli_epi32(a, b); } + __forceinline vint8 sra(const vint8& a, int b) { return _mm256_srai_epi32(a, b); } + __forceinline vint8 srl(const vint8& a, int b) { return _mm256_srli_epi32(a, b); } + + __forceinline vint8 sll(const vint8& a, const vint8& b) { return _mm256_sllv_epi32(a, b); } + __forceinline vint8 sra(const vint8& a, const vint8& b) { return _mm256_srav_epi32(a, b); } + __forceinline vint8 srl(const vint8& a, const vint8& b) { return _mm256_srlv_epi32(a, b); } + + __forceinline vint8 min(const vint8& a, const vint8& b) { return _mm256_min_epi32(a, b); } + __forceinline vint8 min(const vint8& a, int b) { return min(a,vint8(b)); } + __forceinline vint8 min(int a, const vint8& b) { return min(vint8(a),b); } + + __forceinline vint8 max(const vint8& a, const vint8& b) { return _mm256_max_epi32(a, b); } + __forceinline vint8 max(const vint8& a, int b) { return max(a,vint8(b)); } + __forceinline vint8 max(int a, const vint8& b) { return max(vint8(a),b); } + + __forceinline vint8 umin(const vint8& a, const vint8& b) { return _mm256_min_epu32(a, b); } + __forceinline vint8 umax(const vint8& a, const vint8& b) { return _mm256_max_epu32(a, b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint8& operator +=(vint8& a, const vint8& b) { return a = a + b; } + __forceinline vint8& operator +=(vint8& a, int b) { return a = a + b; } + + __forceinline vint8& operator -=(vint8& a, const vint8& b) { return a = a - b; } + __forceinline vint8& operator -=(vint8& a, int b) { return a = a - b; } + + __forceinline vint8& operator *=(vint8& a, const vint8& b) { return a = a * b; } + __forceinline vint8& operator *=(vint8& a, int b) { return a = a * b; } + + __forceinline vint8& operator &=(vint8& a, const vint8& b) { return a = a & b; } + __forceinline vint8& operator &=(vint8& a, int b) { return a = a & b; } + + __forceinline vint8& operator |=(vint8& a, const vint8& b) { return a = a | b; } + __forceinline vint8& operator |=(vint8& a, int b) { return a = a | b; } + + __forceinline vint8& operator <<=(vint8& a, const int b) { return a = a << b; } + __forceinline vint8& operator >>=(vint8& a, const int b) { return a = a >> b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX512VL__) + static __forceinline vboolf8 operator ==(const vint8& a, const vint8& b) { return _mm256_cmp_epi32_mask(a,b,_MM_CMPINT_EQ); } + static __forceinline vboolf8 operator !=(const vint8& a, const vint8& b) { return _mm256_cmp_epi32_mask(a,b,_MM_CMPINT_NE); } + static __forceinline vboolf8 operator < (const vint8& a, const vint8& b) { return _mm256_cmp_epi32_mask(a,b,_MM_CMPINT_LT); } + static __forceinline vboolf8 operator >=(const vint8& a, const vint8& b) { return _mm256_cmp_epi32_mask(a,b,_MM_CMPINT_GE); } + static __forceinline vboolf8 operator > (const vint8& a, const vint8& b) { return _mm256_cmp_epi32_mask(a,b,_MM_CMPINT_GT); } + static __forceinline vboolf8 operator <=(const vint8& a, const vint8& b) { return _mm256_cmp_epi32_mask(a,b,_MM_CMPINT_LE); } + + static __forceinline vint8 select(const vboolf8& m, const vint8& t, const vint8& f) { + return _mm256_mask_blend_epi32(m, (__m256i)f, (__m256i)t); + } +#else + static __forceinline vboolf8 operator ==(const vint8& a, const vint8& b) { return _mm256_castsi256_ps(_mm256_cmpeq_epi32(a, b)); } + static __forceinline vboolf8 operator !=(const vint8& a, const vint8& b) { return !(a == b); } + static __forceinline vboolf8 operator < (const vint8& a, const vint8& b) { return _mm256_castsi256_ps(_mm256_cmpgt_epi32(b, a)); } + static __forceinline vboolf8 operator >=(const vint8& a, const vint8& b) { return !(a < b); } + static __forceinline vboolf8 operator > (const vint8& a, const vint8& b) { return _mm256_castsi256_ps(_mm256_cmpgt_epi32(a, b)); } + static __forceinline vboolf8 operator <=(const vint8& a, const vint8& b) { return !(a > b); } + + static __forceinline vint8 select(const vboolf8& m, const vint8& t, const vint8& f) { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(f), _mm256_castsi256_ps(t), m)); + } +#endif + + template + __forceinline vint8 select(const vint8& t, const vint8& f) { + return _mm256_blend_epi32(f, t, mask); + } + + __forceinline vboolf8 operator ==(const vint8& a, int b) { return a == vint8(b); } + __forceinline vboolf8 operator ==(int a, const vint8& b) { return vint8(a) == b; } + + __forceinline vboolf8 operator !=(const vint8& a, int b) { return a != vint8(b); } + __forceinline vboolf8 operator !=(int a, const vint8& b) { return vint8(a) != b; } + + __forceinline vboolf8 operator < (const vint8& a, int b) { return a < vint8(b); } + __forceinline vboolf8 operator < (int a, const vint8& b) { return vint8(a) < b; } + + __forceinline vboolf8 operator >=(const vint8& a, int b) { return a >= vint8(b); } + __forceinline vboolf8 operator >=(int a, const vint8& b) { return vint8(a) >= b; } + + __forceinline vboolf8 operator > (const vint8& a, int b) { return a > vint8(b); } + __forceinline vboolf8 operator > (int a, const vint8& b) { return vint8(a) > b; } + + __forceinline vboolf8 operator <=(const vint8& a, int b) { return a <= vint8(b); } + __forceinline vboolf8 operator <=(int a, const vint8& b) { return vint8(a) <= b; } + + __forceinline vboolf8 eq(const vint8& a, const vint8& b) { return a == b; } + __forceinline vboolf8 ne(const vint8& a, const vint8& b) { return a != b; } + __forceinline vboolf8 lt(const vint8& a, const vint8& b) { return a < b; } + __forceinline vboolf8 ge(const vint8& a, const vint8& b) { return a >= b; } + __forceinline vboolf8 gt(const vint8& a, const vint8& b) { return a > b; } + __forceinline vboolf8 le(const vint8& a, const vint8& b) { return a <= b; } + +#if defined(__AVX512VL__) + static __forceinline vboolf8 eq(const vboolf8& mask, const vint8& a, const vint8& b) { return _mm256_mask_cmp_epi32_mask(mask, a, b, _MM_CMPINT_EQ); } + static __forceinline vboolf8 ne(const vboolf8& mask, const vint8& a, const vint8& b) { return _mm256_mask_cmp_epi32_mask(mask, a, b, _MM_CMPINT_NE); } + static __forceinline vboolf8 lt(const vboolf8& mask, const vint8& a, const vint8& b) { return _mm256_mask_cmp_epi32_mask(mask, a, b, _MM_CMPINT_LT); } + static __forceinline vboolf8 ge(const vboolf8& mask, const vint8& a, const vint8& b) { return _mm256_mask_cmp_epi32_mask(mask, a, b, _MM_CMPINT_GE); } + static __forceinline vboolf8 gt(const vboolf8& mask, const vint8& a, const vint8& b) { return _mm256_mask_cmp_epi32_mask(mask, a, b, _MM_CMPINT_GT); } + static __forceinline vboolf8 le(const vboolf8& mask, const vint8& a, const vint8& b) { return _mm256_mask_cmp_epi32_mask(mask, a, b, _MM_CMPINT_LE); } +#else + static __forceinline vboolf8 eq(const vboolf8& mask, const vint8& a, const vint8& b) { return mask & (a == b); } + static __forceinline vboolf8 ne(const vboolf8& mask, const vint8& a, const vint8& b) { return mask & (a != b); } + static __forceinline vboolf8 lt(const vboolf8& mask, const vint8& a, const vint8& b) { return mask & (a < b); } + static __forceinline vboolf8 ge(const vboolf8& mask, const vint8& a, const vint8& b) { return mask & (a >= b); } + static __forceinline vboolf8 gt(const vboolf8& mask, const vint8& a, const vint8& b) { return mask & (a > b); } + static __forceinline vboolf8 le(const vboolf8& mask, const vint8& a, const vint8& b) { return mask & (a <= b); } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint8 unpacklo(const vint8& a, const vint8& b) { return _mm256_unpacklo_epi32(a, b); } + __forceinline vint8 unpackhi(const vint8& a, const vint8& b) { return _mm256_unpackhi_epi32(a, b); } + + template + __forceinline vint8 shuffle(const vint8& v) { + return _mm256_castps_si256(_mm256_permute_ps(_mm256_castsi256_ps(v), _MM_SHUFFLE(i, i, i, i))); + } + + template + __forceinline vint8 shuffle4(const vint8& v) { + return _mm256_permute2f128_si256(v, v, (i1 << 4) | (i0 << 0)); + } + + template + __forceinline vint8 shuffle4(const vint8& a, const vint8& b) { + return _mm256_permute2f128_si256(a, b, (i1 << 4) | (i0 << 0)); + } + + template + __forceinline vint8 shuffle(const vint8& v) { + return _mm256_castps_si256(_mm256_permute_ps(_mm256_castsi256_ps(v), _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template + __forceinline vint8 shuffle(const vint8& a, const vint8& b) { + return _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template<> __forceinline vint8 shuffle<0, 0, 2, 2>(const vint8& v) { return _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(v))); } + template<> __forceinline vint8 shuffle<1, 1, 3, 3>(const vint8& v) { return _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(v))); } + template<> __forceinline vint8 shuffle<0, 1, 0, 1>(const vint8& v) { return _mm256_castps_si256(_mm256_castpd_ps(_mm256_movedup_pd(_mm256_castps_pd(_mm256_castsi256_ps(v))))); } + + __forceinline vint8 broadcast(const int* ptr) { return _mm256_castps_si256(_mm256_broadcast_ss((const float*)ptr)); } + + template __forceinline vint8 insert4(const vint8& a, const vint4& b) { return _mm256_insertf128_si256(a, b, i); } + template __forceinline vint4 extract4(const vint8& a) { return _mm256_extractf128_si256(a, i); } + template<> __forceinline vint4 extract4<0>(const vint8& a) { return _mm256_castsi256_si128(a); } + + __forceinline int toScalar(const vint8& v) { return _mm_cvtsi128_si32(_mm256_castsi256_si128(v)); } + + __forceinline vint8 permute(const vint8& v, const __m256i& index) { + return _mm256_permutevar8x32_epi32(v, index); + } + + __forceinline vint8 shuffle(const vint8& v, const __m256i& index) { + return _mm256_castps_si256(_mm256_permutevar_ps(_mm256_castsi256_ps(v), index)); + } + + template + static __forceinline vint8 align_shift_right(const vint8& a, const vint8& b) { +#if defined(__AVX512VL__) + return _mm256_alignr_epi32(a, b, i); +#else + return _mm256_alignr_epi8(a, b, 4*i); +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint8 vreduce_min2(const vint8& v) { return min(v,shuffle<1,0,3,2>(v)); } + __forceinline vint8 vreduce_min4(const vint8& v) { vint8 v1 = vreduce_min2(v); return min(v1,shuffle<2,3,0,1>(v1)); } + __forceinline vint8 vreduce_min (const vint8& v) { vint8 v1 = vreduce_min4(v); return min(v1,shuffle4<1,0>(v1)); } + + __forceinline vint8 vreduce_max2(const vint8& v) { return max(v,shuffle<1,0,3,2>(v)); } + __forceinline vint8 vreduce_max4(const vint8& v) { vint8 v1 = vreduce_max2(v); return max(v1,shuffle<2,3,0,1>(v1)); } + __forceinline vint8 vreduce_max (const vint8& v) { vint8 v1 = vreduce_max4(v); return max(v1,shuffle4<1,0>(v1)); } + + __forceinline vint8 vreduce_add2(const vint8& v) { return v + shuffle<1,0,3,2>(v); } + __forceinline vint8 vreduce_add4(const vint8& v) { vint8 v1 = vreduce_add2(v); return v1 + shuffle<2,3,0,1>(v1); } + __forceinline vint8 vreduce_add (const vint8& v) { vint8 v1 = vreduce_add4(v); return v1 + shuffle4<1,0>(v1); } + + __forceinline int reduce_min(const vint8& v) { return toScalar(vreduce_min(v)); } + __forceinline int reduce_max(const vint8& v) { return toScalar(vreduce_max(v)); } + __forceinline int reduce_add(const vint8& v) { return toScalar(vreduce_add(v)); } + + __forceinline size_t select_min(const vint8& v) { return bsf(movemask(v == vreduce_min(v))); } + __forceinline size_t select_max(const vint8& v) { return bsf(movemask(v == vreduce_max(v))); } + + __forceinline size_t select_min(const vboolf8& valid, const vint8& v) { const vint8 a = select(valid,v,vint8(pos_inf)); return bsf(movemask(valid & (a == vreduce_min(a)))); } + __forceinline size_t select_max(const vboolf8& valid, const vint8& v) { const vint8 a = select(valid,v,vint8(neg_inf)); return bsf(movemask(valid & (a == vreduce_max(a)))); } + + + __forceinline vint8 assign(const vint4& a) { return _mm256_castsi128_si256(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Sorting networks + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vint8 usort_ascending(const vint8& v) + { + const vint8 a0 = v; + const vint8 b0 = shuffle<1,0,3,2>(a0); + const vint8 c0 = umin(a0,b0); + const vint8 d0 = umax(a0,b0); + const vint8 a1 = select<0x99 /* 0b10011001 */>(c0,d0); + const vint8 b1 = shuffle<2,3,0,1>(a1); + const vint8 c1 = umin(a1,b1); + const vint8 d1 = umax(a1,b1); + const vint8 a2 = select<0xc3 /* 0b11000011 */>(c1,d1); + const vint8 b2 = shuffle<1,0,3,2>(a2); + const vint8 c2 = umin(a2,b2); + const vint8 d2 = umax(a2,b2); + const vint8 a3 = select<0xa5 /* 0b10100101 */>(c2,d2); + const vint8 b3 = shuffle4<1,0>(a3); + const vint8 c3 = umin(a3,b3); + const vint8 d3 = umax(a3,b3); + const vint8 a4 = select<0xf /* 0b00001111 */>(c3,d3); + const vint8 b4 = shuffle<2,3,0,1>(a4); + const vint8 c4 = umin(a4,b4); + const vint8 d4 = umax(a4,b4); + const vint8 a5 = select<0x33 /* 0b00110011 */>(c4,d4); + const vint8 b5 = shuffle<1,0,3,2>(a5); + const vint8 c5 = umin(a5,b5); + const vint8 d5 = umax(a5,b5); + const vint8 a6 = select<0x55 /* 0b01010101 */>(c5,d5); + return a6; + } + + __forceinline vint8 usort_descending(const vint8& v) + { + const vint8 a0 = v; + const vint8 b0 = shuffle<1,0,3,2>(a0); + const vint8 c0 = umax(a0,b0); + const vint8 d0 = umin(a0,b0); + const vint8 a1 = select<0x99 /* 0b10011001 */>(c0,d0); + const vint8 b1 = shuffle<2,3,0,1>(a1); + const vint8 c1 = umax(a1,b1); + const vint8 d1 = umin(a1,b1); + const vint8 a2 = select<0xc3 /* 0b11000011 */>(c1,d1); + const vint8 b2 = shuffle<1,0,3,2>(a2); + const vint8 c2 = umax(a2,b2); + const vint8 d2 = umin(a2,b2); + const vint8 a3 = select<0xa5 /* 0b10100101 */>(c2,d2); + const vint8 b3 = shuffle4<1,0>(a3); + const vint8 c3 = umax(a3,b3); + const vint8 d3 = umin(a3,b3); + const vint8 a4 = select<0xf /* 0b00001111 */>(c3,d3); + const vint8 b4 = shuffle<2,3,0,1>(a4); + const vint8 c4 = umax(a4,b4); + const vint8 d4 = umin(a4,b4); + const vint8 a5 = select<0x33 /* 0b00110011 */>(c4,d4); + const vint8 b5 = shuffle<1,0,3,2>(a5); + const vint8 c5 = umax(a5,b5); + const vint8 d5 = umin(a5,b5); + const vint8 a6 = select<0x55 /* 0b01010101 */>(c5,d5); + return a6; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vint8& a) { + return cout << "<" << a[0] << ", " << a[1] << ", " << a[2] << ", " << a[3] << ", " << a[4] << ", " << a[5] << ", " << a[6] << ", " << a[7] << ">"; + } +} diff --git a/thirdparty/embree/common/simd/vllong4_avx2.h b/thirdparty/embree/common/simd/vllong4_avx2.h new file mode 100644 index 000000000000..de3ebc16a728 --- /dev/null +++ b/thirdparty/embree/common/simd/vllong4_avx2.h @@ -0,0 +1,358 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 4-wide AVX2 64-bit long long type */ + template<> + struct vllong<4> + { + ALIGNED_STRUCT_(32); + + typedef vboold4 Bool; + + enum { size = 4 }; // number of SIMD elements + union { // data + __m256i v; + long long i[4]; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vllong() {} + __forceinline vllong(const vllong4& t) { v = t.v; } + __forceinline vllong4& operator =(const vllong4& f) { v = f.v; return *this; } + + __forceinline vllong(const __m256i& t) { v = t; } + __forceinline operator __m256i() const { return v; } + __forceinline operator __m256d() const { return _mm256_castsi256_pd(v); } + + + __forceinline vllong(long long i) { + v = _mm256_set1_epi64x(i); + } + + __forceinline vllong(long long a, long long b, long long c, long long d) { + v = _mm256_set_epi64x(d,c,b,a); + } + + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vllong(ZeroTy) : v(_mm256_setzero_si256()) {} + __forceinline vllong(OneTy) : v(_mm256_set1_epi64x(1)) {} + __forceinline vllong(StepTy) : v(_mm256_set_epi64x(3,2,1,0)) {} + __forceinline vllong(ReverseStepTy) : v(_mm256_set_epi64x(0,1,2,3)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline void store_nt(void* __restrict__ ptr, const vllong4& a) { + _mm256_stream_ps((float*)ptr,_mm256_castsi256_ps(a)); + } + + static __forceinline vllong4 loadu(const void* addr) + { + return _mm256_loadu_si256((__m256i*)addr); + } + + static __forceinline vllong4 load(const vllong4* addr) { + return _mm256_load_si256((__m256i*)addr); + } + + static __forceinline vllong4 load(const long long* addr) { + return _mm256_load_si256((__m256i*)addr); + } + + static __forceinline void store(void* ptr, const vllong4& v) { + _mm256_store_si256((__m256i*)ptr,v); + } + + static __forceinline void storeu(void* ptr, const vllong4& v) { + _mm256_storeu_si256((__m256i*)ptr,v); + } + + static __forceinline void storeu(const vboold4& mask, long long* ptr, const vllong4& f) { +#if defined(__AVX512VL__) + _mm256_mask_storeu_epi64(ptr,mask,f); +#else + _mm256_maskstore_pd((double*)ptr,mask,_mm256_castsi256_pd(f)); +#endif + } + + static __forceinline void store(const vboold4& mask, void* ptr, const vllong4& f) { +#if defined(__AVX512VL__) + _mm256_mask_store_epi64(ptr,mask,f); +#else + _mm256_maskstore_pd((double*)ptr,mask,_mm256_castsi256_pd(f)); +#endif + } + + static __forceinline vllong4 broadcast64bit(size_t v) { + return _mm256_set1_epi64x(v); + } + + static __forceinline size_t extract64bit(const vllong4& v) + { + return _mm_cvtsi128_si64(_mm256_castsi256_si128(v)); + } + + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline long long& operator [](size_t index) { assert(index < 4); return i[index]; } + __forceinline const long long& operator [](size_t index) const { assert(index < 4); return i[index]; } + + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vllong4 select(const vboold4& m, const vllong4& t, const vllong4& f) { + #if defined(__AVX512VL__) + return _mm256_mask_blend_epi64(m, f, t); + #else + return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(f), _mm256_castsi256_pd(t), m)); + #endif + } + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX512VL__) + __forceinline vboold4 asBool(const vllong4& a) { return _mm256_movepi64_mask(a); } +#else + __forceinline vboold4 asBool(const vllong4& a) { return _mm256_castsi256_pd(a); } +#endif + + __forceinline vllong4 operator +(const vllong4& a) { return a; } + __forceinline vllong4 operator -(const vllong4& a) { return _mm256_sub_epi64(_mm256_setzero_si256(), a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vllong4 operator +(const vllong4& a, const vllong4& b) { return _mm256_add_epi64(a, b); } + __forceinline vllong4 operator +(const vllong4& a, long long b) { return a + vllong4(b); } + __forceinline vllong4 operator +(long long a, const vllong4& b) { return vllong4(a) + b; } + + __forceinline vllong4 operator -(const vllong4& a, const vllong4& b) { return _mm256_sub_epi64(a, b); } + __forceinline vllong4 operator -(const vllong4& a, long long b) { return a - vllong4(b); } + __forceinline vllong4 operator -(long long a, const vllong4& b) { return vllong4(a) - b; } + + /* only low 32bit part */ + __forceinline vllong4 operator *(const vllong4& a, const vllong4& b) { return _mm256_mul_epi32(a, b); } + __forceinline vllong4 operator *(const vllong4& a, long long b) { return a * vllong4(b); } + __forceinline vllong4 operator *(long long a, const vllong4& b) { return vllong4(a) * b; } + + __forceinline vllong4 operator &(const vllong4& a, const vllong4& b) { return _mm256_and_si256(a, b); } + __forceinline vllong4 operator &(const vllong4& a, long long b) { return a & vllong4(b); } + __forceinline vllong4 operator &(long long a, const vllong4& b) { return vllong4(a) & b; } + + __forceinline vllong4 operator |(const vllong4& a, const vllong4& b) { return _mm256_or_si256(a, b); } + __forceinline vllong4 operator |(const vllong4& a, long long b) { return a | vllong4(b); } + __forceinline vllong4 operator |(long long a, const vllong4& b) { return vllong4(a) | b; } + + __forceinline vllong4 operator ^(const vllong4& a, const vllong4& b) { return _mm256_xor_si256(a, b); } + __forceinline vllong4 operator ^(const vllong4& a, long long b) { return a ^ vllong4(b); } + __forceinline vllong4 operator ^(long long a, const vllong4& b) { return vllong4(a) ^ b; } + + __forceinline vllong4 operator <<(const vllong4& a, long long n) { return _mm256_slli_epi64(a, (int)n); } + //__forceinline vllong4 operator >>(const vllong4& a, long long n) { return _mm256_srai_epi64(a, n); } + + __forceinline vllong4 operator <<(const vllong4& a, const vllong4& n) { return _mm256_sllv_epi64(a, n); } + //__forceinline vllong4 operator >>(const vllong4& a, const vllong4& n) { return _mm256_srav_epi64(a, n); } + //__forceinline vllong4 sra(const vllong4& a, long long b) { return _mm256_srai_epi64(a, b); } + + __forceinline vllong4 srl(const vllong4& a, long long b) { return _mm256_srli_epi64(a, (int)b); } + + //__forceinline vllong4 min(const vllong4& a, const vllong4& b) { return _mm256_min_epi64(a, b); } + //__forceinline vllong4 min(const vllong4& a, long long b) { return min(a,vllong4(b)); } + //__forceinline vllong4 min(long long a, const vllong4& b) { return min(vllong4(a),b); } + + //__forceinline vllong4 max(const vllong4& a, const vllong4& b) { return _mm256_max_epi64(a, b); } + //__forceinline vllong4 max(const vllong4& a, long long b) { return max(a,vllong4(b)); } + //__forceinline vllong4 max(long long a, const vllong4& b) { return max(vllong4(a),b); } + +#if defined(__AVX512VL__) + __forceinline vllong4 mask_and(const vboold4& m, const vllong4& c, const vllong4& a, const vllong4& b) { return _mm256_mask_and_epi64(c,m,a,b); } + __forceinline vllong4 mask_or (const vboold4& m, const vllong4& c, const vllong4& a, const vllong4& b) { return _mm256_mask_or_epi64(c,m,a,b); } +#else + __forceinline vllong4 mask_and(const vboold4& m, const vllong4& c, const vllong4& a, const vllong4& b) { return select(m, a & b, c); } + __forceinline vllong4 mask_or (const vboold4& m, const vllong4& c, const vllong4& a, const vllong4& b) { return select(m, a | b, c); } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vllong4& operator +=(vllong4& a, const vllong4& b) { return a = a + b; } + __forceinline vllong4& operator +=(vllong4& a, long long b) { return a = a + b; } + + __forceinline vllong4& operator -=(vllong4& a, const vllong4& b) { return a = a - b; } + __forceinline vllong4& operator -=(vllong4& a, long long b) { return a = a - b; } + + __forceinline vllong4& operator *=(vllong4& a, const vllong4& b) { return a = a * b; } + __forceinline vllong4& operator *=(vllong4& a, long long b) { return a = a * b; } + + __forceinline vllong4& operator &=(vllong4& a, const vllong4& b) { return a = a & b; } + __forceinline vllong4& operator &=(vllong4& a, long long b) { return a = a & b; } + + __forceinline vllong4& operator |=(vllong4& a, const vllong4& b) { return a = a | b; } + __forceinline vllong4& operator |=(vllong4& a, long long b) { return a = a | b; } + + __forceinline vllong4& operator <<=(vllong4& a, long long b) { return a = a << b; } + //__forceinline vllong4& operator >>=(vllong4& a, long long b) { return a = a >> b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX512VL__) + __forceinline vboold4 operator ==(const vllong4& a, const vllong4& b) { return _mm256_cmp_epi64_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboold4 operator !=(const vllong4& a, const vllong4& b) { return _mm256_cmp_epi64_mask(a,b,_MM_CMPINT_NE); } + __forceinline vboold4 operator < (const vllong4& a, const vllong4& b) { return _mm256_cmp_epi64_mask(a,b,_MM_CMPINT_LT); } + __forceinline vboold4 operator >=(const vllong4& a, const vllong4& b) { return _mm256_cmp_epi64_mask(a,b,_MM_CMPINT_GE); } + __forceinline vboold4 operator > (const vllong4& a, const vllong4& b) { return _mm256_cmp_epi64_mask(a,b,_MM_CMPINT_GT); } + __forceinline vboold4 operator <=(const vllong4& a, const vllong4& b) { return _mm256_cmp_epi64_mask(a,b,_MM_CMPINT_LE); } +#else + __forceinline vboold4 operator ==(const vllong4& a, const vllong4& b) { return _mm256_cmpeq_epi64(a,b); } + __forceinline vboold4 operator !=(const vllong4& a, const vllong4& b) { return !(a == b); } + __forceinline vboold4 operator > (const vllong4& a, const vllong4& b) { return _mm256_cmpgt_epi64(a,b); } + __forceinline vboold4 operator < (const vllong4& a, const vllong4& b) { return _mm256_cmpgt_epi64(b,a); } + __forceinline vboold4 operator >=(const vllong4& a, const vllong4& b) { return !(a < b); } + __forceinline vboold4 operator <=(const vllong4& a, const vllong4& b) { return !(a > b); } +#endif + + __forceinline vboold4 operator ==(const vllong4& a, long long b) { return a == vllong4(b); } + __forceinline vboold4 operator ==(long long a, const vllong4& b) { return vllong4(a) == b; } + + __forceinline vboold4 operator !=(const vllong4& a, long long b) { return a != vllong4(b); } + __forceinline vboold4 operator !=(long long a, const vllong4& b) { return vllong4(a) != b; } + + __forceinline vboold4 operator > (const vllong4& a, long long b) { return a > vllong4(b); } + __forceinline vboold4 operator > (long long a, const vllong4& b) { return vllong4(a) > b; } + + __forceinline vboold4 operator < (const vllong4& a, long long b) { return a < vllong4(b); } + __forceinline vboold4 operator < (long long a, const vllong4& b) { return vllong4(a) < b; } + + __forceinline vboold4 operator >=(const vllong4& a, long long b) { return a >= vllong4(b); } + __forceinline vboold4 operator >=(long long a, const vllong4& b) { return vllong4(a) >= b; } + + __forceinline vboold4 operator <=(const vllong4& a, long long b) { return a <= vllong4(b); } + __forceinline vboold4 operator <=(long long a, const vllong4& b) { return vllong4(a) <= b; } + + __forceinline vboold4 eq(const vllong4& a, const vllong4& b) { return a == b; } + __forceinline vboold4 ne(const vllong4& a, const vllong4& b) { return a != b; } + __forceinline vboold4 lt(const vllong4& a, const vllong4& b) { return a < b; } + __forceinline vboold4 ge(const vllong4& a, const vllong4& b) { return a >= b; } + __forceinline vboold4 gt(const vllong4& a, const vllong4& b) { return a > b; } + __forceinline vboold4 le(const vllong4& a, const vllong4& b) { return a <= b; } + +#if defined(__AVX512VL__) + __forceinline vboold4 eq(const vboold4& mask, const vllong4& a, const vllong4& b) { return _mm256_mask_cmp_epi64_mask(mask, a, b, _MM_CMPINT_EQ); } + __forceinline vboold4 ne(const vboold4& mask, const vllong4& a, const vllong4& b) { return _mm256_mask_cmp_epi64_mask(mask, a, b, _MM_CMPINT_NE); } + __forceinline vboold4 lt(const vboold4& mask, const vllong4& a, const vllong4& b) { return _mm256_mask_cmp_epi64_mask(mask, a, b, _MM_CMPINT_LT); } + __forceinline vboold4 ge(const vboold4& mask, const vllong4& a, const vllong4& b) { return _mm256_mask_cmp_epi64_mask(mask, a, b, _MM_CMPINT_GE); } + __forceinline vboold4 gt(const vboold4& mask, const vllong4& a, const vllong4& b) { return _mm256_mask_cmp_epi64_mask(mask, a, b, _MM_CMPINT_GT); } + __forceinline vboold4 le(const vboold4& mask, const vllong4& a, const vllong4& b) { return _mm256_mask_cmp_epi64_mask(mask, a, b, _MM_CMPINT_LE); } +#else + __forceinline vboold4 eq(const vboold4& mask, const vllong4& a, const vllong4& b) { return mask & (a == b); } + __forceinline vboold4 ne(const vboold4& mask, const vllong4& a, const vllong4& b) { return mask & (a != b); } + __forceinline vboold4 lt(const vboold4& mask, const vllong4& a, const vllong4& b) { return mask & (a < b); } + __forceinline vboold4 ge(const vboold4& mask, const vllong4& a, const vllong4& b) { return mask & (a >= b); } + __forceinline vboold4 gt(const vboold4& mask, const vllong4& a, const vllong4& b) { return mask & (a > b); } + __forceinline vboold4 le(const vboold4& mask, const vllong4& a, const vllong4& b) { return mask & (a <= b); } +#endif + + __forceinline void xchg(const vboold4& m, vllong4& a, vllong4& b) { + const vllong4 c = a; a = select(m,b,a); b = select(m,c,b); + } + + __forceinline vboold4 test(const vllong4& a, const vllong4& b) { +#if defined(__AVX512VL__) + return _mm256_test_epi64_mask(a,b); +#else + return _mm256_testz_si256(a,b); +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + // Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vllong4 shuffle(const vllong4& v) { + return _mm256_castpd_si256(_mm256_permute_pd(_mm256_castsi256_pd(v), (i1 << 3) | (i0 << 2) | (i1 << 1) | i0)); + } + + template + __forceinline vllong4 shuffle(const vllong4& v) { + return shuffle(v); + } + + template + __forceinline vllong4 shuffle2(const vllong4& v) { + return _mm256_castpd_si256(_mm256_permute2f128_pd(_mm256_castsi256_pd(v), _mm256_castsi256_pd(v), (i1 << 4) | i0)); + } + + __forceinline long long toScalar(const vllong4& v) { + return _mm_cvtsi128_si64(_mm256_castsi256_si128(v)); + } + +#if defined(__AVX512VL__) + __forceinline vllong4 permute(const vllong4& a, const __m256i& index) { + // workaround for GCC 7.x +#if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__) + return _mm256_permutex2var_epi64(a,index,a); +#else + return _mm256_permutexvar_epi64(index,a); +#endif + } + + __forceinline vllong4 permutex2var(const vllong4& index, const vllong4& a, const vllong4& b) { + return _mm256_permutex2var_epi64(a,index,b); + } + +#endif + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + + __forceinline vllong4 vreduce_and2(const vllong4& x) { return x & shuffle<1,0>(x); } + __forceinline vllong4 vreduce_and (const vllong4& y) { const vllong4 x = vreduce_and2(y); return x & shuffle2<1,0>(x); } + + __forceinline vllong4 vreduce_or2(const vllong4& x) { return x | shuffle<1,0>(x); } + __forceinline vllong4 vreduce_or (const vllong4& y) { const vllong4 x = vreduce_or2(y); return x | shuffle2<1,0>(x); } + + __forceinline vllong4 vreduce_add2(const vllong4& x) { return x + shuffle<1,0>(x); } + __forceinline vllong4 vreduce_add (const vllong4& y) { const vllong4 x = vreduce_add2(y); return x + shuffle2<1,0>(x); } + + __forceinline long long reduce_add(const vllong4& a) { return toScalar(vreduce_add(a)); } + __forceinline long long reduce_or (const vllong4& a) { return toScalar(vreduce_or(a)); } + __forceinline long long reduce_and(const vllong4& a) { return toScalar(vreduce_and(a)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vllong4& v) + { + cout << "<" << v[0]; + for (size_t i=1; i<4; i++) cout << ", " << v[i]; + cout << ">"; + return cout; + } +} diff --git a/thirdparty/embree/common/simd/vllong8_avx512.h b/thirdparty/embree/common/simd/vllong8_avx512.h new file mode 100644 index 000000000000..4a724de0621b --- /dev/null +++ b/thirdparty/embree/common/simd/vllong8_avx512.h @@ -0,0 +1,381 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 8-wide AVX-512 64-bit long long type */ + template<> + struct vllong<8> + { + ALIGNED_STRUCT_(64); + + typedef vboold8 Bool; + + enum { size = 8 }; // number of SIMD elements + union { // data + __m512i v; + long long i[8]; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vllong() {} + __forceinline vllong(const vllong8& t) { v = t.v; } + __forceinline vllong8& operator =(const vllong8& f) { v = f.v; return *this; } + + __forceinline vllong(const __m512i& t) { v = t; } + __forceinline operator __m512i() const { return v; } + __forceinline operator __m256i() const { return _mm512_castsi512_si256(v); } + + __forceinline vllong(long long i) { + v = _mm512_set1_epi64(i); + } + + __forceinline vllong(long long a, long long b, long long c, long long d) { + v = _mm512_set4_epi64(d,c,b,a); + } + + __forceinline vllong(long long a0, long long a1, long long a2, long long a3, + long long a4, long long a5, long long a6, long long a7) + { + v = _mm512_set_epi64(a7,a6,a5,a4,a3,a2,a1,a0); + } + + __forceinline vllong(const vllong<4>& i) { + v = _mm512_broadcast_i64x4(i); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vllong(ZeroTy) : v(_mm512_setzero_epi32()) {} + __forceinline vllong(OneTy) : v(_mm512_set1_epi64(1)) {} + __forceinline vllong(StepTy) : v(_mm512_set_epi64(7,6,5,4,3,2,1,0)) {} + __forceinline vllong(ReverseStepTy) : v(_mm512_setr_epi64(7,6,5,4,3,2,1,0)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline void store_nt(void* __restrict__ ptr, const vllong8& a) { + _mm512_stream_si512((__m512i*)ptr,a); + } + + static __forceinline vllong8 loadu(const void* addr) { + return _mm512_loadu_si512(addr); + } + + static __forceinline vllong8 load(const vllong8* addr) { + return _mm512_load_si512(addr); + } + + static __forceinline vllong8 load(const long long* addr) { + return _mm512_load_si512(addr); + } + + static __forceinline vllong8 load(const unsigned char* ptr) { + return _mm512_cvtepu8_epi64(*(__m128i*)ptr); + } + + static __forceinline void store(void* ptr, const vllong8& v) { + _mm512_store_si512(ptr,v); + } + + static __forceinline void storeu(void* ptr, const vllong8& v) { + _mm512_storeu_si512(ptr,v); + } + + static __forceinline void storeu(const vboold8& mask, long long* ptr, const vllong8& f) { + _mm512_mask_storeu_epi64(ptr,mask,f); + } + + static __forceinline void store(const vboold8& mask, void* addr, const vllong8& v2) { + _mm512_mask_store_epi64(addr,mask,v2); + } + + /* pass by value to avoid compiler generating inefficient code */ + static __forceinline void storeu_compact(const vboold8 mask, void* addr, const vllong8& reg) { + _mm512_mask_compressstoreu_epi64(addr,mask,reg); + } + + static __forceinline vllong8 compact64bit(const vboold8& mask, vllong8& v) { + return _mm512_mask_compress_epi64(v,mask,v); + } + + static __forceinline vllong8 compact64bit(const vboold8& mask, vllong8& dest, const vllong8& source) { + return _mm512_mask_compress_epi64(dest,mask,source); + } + + static __forceinline vllong8 compact(const vboold8& mask, vllong8& v) { + return _mm512_mask_compress_epi64(v,mask,v); + } + + static __forceinline vllong8 compact(const vboold8& mask, const vllong8& a, vllong8& b) { + return _mm512_mask_compress_epi64(a,mask,b); + } + + static __forceinline vllong8 expand(const vboold8& mask, const vllong8& a, vllong8& b) { + return _mm512_mask_expand_epi64(b,mask,a); + } + + static __forceinline vllong8 broadcast64bit(size_t v) { + return _mm512_set1_epi64(v); + } + + static __forceinline size_t extract64bit(const vllong8& v) + { + return _mm_cvtsi128_si64(_mm512_castsi512_si128(v)); + } + + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline long long& operator [](size_t index) { assert(index < 8); return i[index]; } + __forceinline const long long& operator [](size_t index) const { assert(index < 8); return i[index]; } + + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold8 asBool(const vllong8& a) { return _mm512_movepi64_mask(a); } + + __forceinline vllong8 operator +(const vllong8& a) { return a; } + __forceinline vllong8 operator -(const vllong8& a) { return _mm512_sub_epi64(_mm512_setzero_epi32(), a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vllong8 operator +(const vllong8& a, const vllong8& b) { return _mm512_add_epi64(a, b); } + __forceinline vllong8 operator +(const vllong8& a, long long b) { return a + vllong8(b); } + __forceinline vllong8 operator +(long long a, const vllong8& b) { return vllong8(a) + b; } + + __forceinline vllong8 operator -(const vllong8& a, const vllong8& b) { return _mm512_sub_epi64(a, b); } + __forceinline vllong8 operator -(const vllong8& a, long long b) { return a - vllong8(b); } + __forceinline vllong8 operator -(long long a, const vllong8& b) { return vllong8(a) - b; } + + __forceinline vllong8 operator *(const vllong8& a, const vllong8& b) { return _mm512_mullo_epi64(a, b); } + __forceinline vllong8 operator *(const vllong8& a, long long b) { return a * vllong8(b); } + __forceinline vllong8 operator *(long long a, const vllong8& b) { return vllong8(a) * b; } + + __forceinline vllong8 operator &(const vllong8& a, const vllong8& b) { return _mm512_and_epi64(a, b); } + __forceinline vllong8 operator &(const vllong8& a, long long b) { return a & vllong8(b); } + __forceinline vllong8 operator &(long long a, const vllong8& b) { return vllong8(a) & b; } + + __forceinline vllong8 operator |(const vllong8& a, const vllong8& b) { return _mm512_or_epi64(a, b); } + __forceinline vllong8 operator |(const vllong8& a, long long b) { return a | vllong8(b); } + __forceinline vllong8 operator |(long long a, const vllong8& b) { return vllong8(a) | b; } + + __forceinline vllong8 operator ^(const vllong8& a, const vllong8& b) { return _mm512_xor_epi64(a, b); } + __forceinline vllong8 operator ^(const vllong8& a, long long b) { return a ^ vllong8(b); } + __forceinline vllong8 operator ^(long long a, const vllong8& b) { return vllong8(a) ^ b; } + + __forceinline vllong8 operator <<(const vllong8& a, long long n) { return _mm512_slli_epi64(a, n); } + __forceinline vllong8 operator >>(const vllong8& a, long long n) { return _mm512_srai_epi64(a, n); } + + __forceinline vllong8 operator <<(const vllong8& a, const vllong8& n) { return _mm512_sllv_epi64(a, n); } + __forceinline vllong8 operator >>(const vllong8& a, const vllong8& n) { return _mm512_srav_epi64(a, n); } + + __forceinline vllong8 sll (const vllong8& a, long long b) { return _mm512_slli_epi64(a, b); } + __forceinline vllong8 sra (const vllong8& a, long long b) { return _mm512_srai_epi64(a, b); } + __forceinline vllong8 srl (const vllong8& a, long long b) { return _mm512_srli_epi64(a, b); } + + __forceinline vllong8 min(const vllong8& a, const vllong8& b) { return _mm512_min_epi64(a, b); } + __forceinline vllong8 min(const vllong8& a, long long b) { return min(a,vllong8(b)); } + __forceinline vllong8 min(long long a, const vllong8& b) { return min(vllong8(a),b); } + + __forceinline vllong8 max(const vllong8& a, const vllong8& b) { return _mm512_max_epi64(a, b); } + __forceinline vllong8 max(const vllong8& a, long long b) { return max(a,vllong8(b)); } + __forceinline vllong8 max(long long a, const vllong8& b) { return max(vllong8(a),b); } + + __forceinline vllong8 mask_add(const vboold8& m, const vllong8& c, const vllong8& a, const vllong8& b) { return _mm512_mask_add_epi64(c,m,a,b); } + __forceinline vllong8 mask_sub(const vboold8& m, const vllong8& c, const vllong8& a, const vllong8& b) { return _mm512_mask_sub_epi64(c,m,a,b); } + + __forceinline vllong8 mask_and(const vboold8& m, const vllong8& c, const vllong8& a, const vllong8& b) { return _mm512_mask_and_epi64(c,m,a,b); } + __forceinline vllong8 mask_or (const vboold8& m, const vllong8& c, const vllong8& a, const vllong8& b) { return _mm512_mask_or_epi64(c,m,a,b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vllong8& operator +=(vllong8& a, const vllong8& b) { return a = a + b; } + __forceinline vllong8& operator +=(vllong8& a, long long b) { return a = a + b; } + + __forceinline vllong8& operator -=(vllong8& a, const vllong8& b) { return a = a - b; } + __forceinline vllong8& operator -=(vllong8& a, long long b) { return a = a - b; } + + __forceinline vllong8& operator *=(vllong8& a, const vllong8& b) { return a = a * b; } + __forceinline vllong8& operator *=(vllong8& a, long long b) { return a = a * b; } + + __forceinline vllong8& operator &=(vllong8& a, const vllong8& b) { return a = a & b; } + __forceinline vllong8& operator &=(vllong8& a, long long b) { return a = a & b; } + + __forceinline vllong8& operator |=(vllong8& a, const vllong8& b) { return a = a | b; } + __forceinline vllong8& operator |=(vllong8& a, long long b) { return a = a | b; } + + __forceinline vllong8& operator <<=(vllong8& a, long long b) { return a = a << b; } + __forceinline vllong8& operator >>=(vllong8& a, long long b) { return a = a >> b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboold8 operator ==(const vllong8& a, const vllong8& b) { return _mm512_cmp_epi64_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboold8 operator ==(const vllong8& a, long long b) { return a == vllong8(b); } + __forceinline vboold8 operator ==(long long a, const vllong8& b) { return vllong8(a) == b; } + + __forceinline vboold8 operator !=(const vllong8& a, const vllong8& b) { return _mm512_cmp_epi64_mask(a,b,_MM_CMPINT_NE); } + __forceinline vboold8 operator !=(const vllong8& a, long long b) { return a != vllong8(b); } + __forceinline vboold8 operator !=(long long a, const vllong8& b) { return vllong8(a) != b; } + + __forceinline vboold8 operator < (const vllong8& a, const vllong8& b) { return _mm512_cmp_epi64_mask(a,b,_MM_CMPINT_LT); } + __forceinline vboold8 operator < (const vllong8& a, long long b) { return a < vllong8(b); } + __forceinline vboold8 operator < (long long a, const vllong8& b) { return vllong8(a) < b; } + + __forceinline vboold8 operator >=(const vllong8& a, const vllong8& b) { return _mm512_cmp_epi64_mask(a,b,_MM_CMPINT_GE); } + __forceinline vboold8 operator >=(const vllong8& a, long long b) { return a >= vllong8(b); } + __forceinline vboold8 operator >=(long long a, const vllong8& b) { return vllong8(a) >= b; } + + __forceinline vboold8 operator > (const vllong8& a, const vllong8& b) { return _mm512_cmp_epi64_mask(a,b,_MM_CMPINT_GT); } + __forceinline vboold8 operator > (const vllong8& a, long long b) { return a > vllong8(b); } + __forceinline vboold8 operator > (long long a, const vllong8& b) { return vllong8(a) > b; } + + __forceinline vboold8 operator <=(const vllong8& a, const vllong8& b) { return _mm512_cmp_epi64_mask(a,b,_MM_CMPINT_LE); } + __forceinline vboold8 operator <=(const vllong8& a, long long b) { return a <= vllong8(b); } + __forceinline vboold8 operator <=(long long a, const vllong8& b) { return vllong8(a) <= b; } + + __forceinline vboold8 eq(const vllong8& a, const vllong8& b) { return _mm512_cmp_epi64_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboold8 ne(const vllong8& a, const vllong8& b) { return _mm512_cmp_epi64_mask(a,b,_MM_CMPINT_NE); } + __forceinline vboold8 lt(const vllong8& a, const vllong8& b) { return _mm512_cmp_epi64_mask(a,b,_MM_CMPINT_LT); } + __forceinline vboold8 ge(const vllong8& a, const vllong8& b) { return _mm512_cmp_epi64_mask(a,b,_MM_CMPINT_GE); } + __forceinline vboold8 gt(const vllong8& a, const vllong8& b) { return _mm512_cmp_epi64_mask(a,b,_MM_CMPINT_GT); } + __forceinline vboold8 le(const vllong8& a, const vllong8& b) { return _mm512_cmp_epi64_mask(a,b,_MM_CMPINT_LE); } + + __forceinline vboold8 eq(const vboold8 mask, const vllong8& a, const vllong8& b) { return _mm512_mask_cmp_epi64_mask(mask,a,b,_MM_CMPINT_EQ); } + __forceinline vboold8 ne(const vboold8 mask, const vllong8& a, const vllong8& b) { return _mm512_mask_cmp_epi64_mask(mask,a,b,_MM_CMPINT_NE); } + __forceinline vboold8 lt(const vboold8 mask, const vllong8& a, const vllong8& b) { return _mm512_mask_cmp_epi64_mask(mask,a,b,_MM_CMPINT_LT); } + __forceinline vboold8 ge(const vboold8 mask, const vllong8& a, const vllong8& b) { return _mm512_mask_cmp_epi64_mask(mask,a,b,_MM_CMPINT_GE); } + __forceinline vboold8 gt(const vboold8 mask, const vllong8& a, const vllong8& b) { return _mm512_mask_cmp_epi64_mask(mask,a,b,_MM_CMPINT_GT); } + __forceinline vboold8 le(const vboold8 mask, const vllong8& a, const vllong8& b) { return _mm512_mask_cmp_epi64_mask(mask,a,b,_MM_CMPINT_LE); } + + __forceinline vllong8 select(const vboold8& m, const vllong8& t, const vllong8& f) { + return _mm512_mask_or_epi64(f,m,t,t); + } + + __forceinline void xchg(const vboold8& m, vllong8& a, vllong8& b) { + const vllong8 c = a; a = select(m,b,a); b = select(m,c,b); + } + + __forceinline vboold8 test(const vboold8& m, const vllong8& a, const vllong8& b) { + return _mm512_mask_test_epi64_mask(m,a,b); + } + + __forceinline vboold8 test(const vllong8& a, const vllong8& b) { + return _mm512_test_epi64_mask(a,b); + } + + //////////////////////////////////////////////////////////////////////////////// + // Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vllong8 shuffle(const vllong8& v) { + return _mm512_castpd_si512(_mm512_permute_pd(_mm512_castsi512_pd(v), (i1 << 7) | (i0 << 6) | (i1 << 5) | (i0 << 4) | (i1 << 3) | (i0 << 2) | (i1 << 1) | i0)); + } + + template + __forceinline vllong8 shuffle(const vllong8& v) { + return shuffle(v); + } + + template + __forceinline vllong8 shuffle(const vllong8& v) { + return _mm512_permutex_epi64(v, _MM_SHUFFLE(i3, i2, i1, i0)); + } + + template + __forceinline vllong8 shuffle4(const vllong8& v) { + return _mm512_shuffle_i64x2(v, v, _MM_SHUFFLE(i1*2+1, i1*2, i0*2+1, i0*2)); + } + + template + __forceinline vllong8 shuffle4(const vllong8& v) { + return shuffle4(v); + } + + template + __forceinline vllong8 align_shift_right(const vllong8& a, const vllong8& b) { + return _mm512_alignr_epi64(a, b, i); + }; + + __forceinline long long toScalar(const vllong8& v) { + return _mm_cvtsi128_si64(_mm512_castsi512_si128(v)); + } + + __forceinline vllong8 zeroExtend32Bit(const __m512i& a) { + return _mm512_cvtepu32_epi64(_mm512_castsi512_si256(a)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vllong8 vreduce_min2(vllong8 x) { return min(x, shuffle<1,0,3,2>(x)); } + __forceinline vllong8 vreduce_min4(vllong8 x) { x = vreduce_min2(x); return min(x, shuffle<2,3,0,1>(x)); } + __forceinline vllong8 vreduce_min (vllong8 x) { x = vreduce_min4(x); return min(x, shuffle4<1,0>(x)); } + + __forceinline vllong8 vreduce_max2(vllong8 x) { return max(x, shuffle<1,0,3,2>(x)); } + __forceinline vllong8 vreduce_max4(vllong8 x) { x = vreduce_max2(x); return max(x, shuffle<2,3,0,1>(x)); } + __forceinline vllong8 vreduce_max (vllong8 x) { x = vreduce_max4(x); return max(x, shuffle4<1,0>(x)); } + + __forceinline vllong8 vreduce_and2(vllong8 x) { return x & shuffle<1,0,3,2>(x); } + __forceinline vllong8 vreduce_and4(vllong8 x) { x = vreduce_and2(x); return x & shuffle<2,3,0,1>(x); } + __forceinline vllong8 vreduce_and (vllong8 x) { x = vreduce_and4(x); return x & shuffle4<1,0>(x); } + + __forceinline vllong8 vreduce_or2(vllong8 x) { return x | shuffle<1,0,3,2>(x); } + __forceinline vllong8 vreduce_or4(vllong8 x) { x = vreduce_or2(x); return x | shuffle<2,3,0,1>(x); } + __forceinline vllong8 vreduce_or (vllong8 x) { x = vreduce_or4(x); return x | shuffle4<1,0>(x); } + + __forceinline vllong8 vreduce_add2(vllong8 x) { return x + shuffle<1,0,3,2>(x); } + __forceinline vllong8 vreduce_add4(vllong8 x) { x = vreduce_add2(x); return x + shuffle<2,3,0,1>(x); } + __forceinline vllong8 vreduce_add (vllong8 x) { x = vreduce_add4(x); return x + shuffle4<1,0>(x); } + + __forceinline long long reduce_min(const vllong8& v) { return toScalar(vreduce_min(v)); } + __forceinline long long reduce_max(const vllong8& v) { return toScalar(vreduce_max(v)); } + __forceinline long long reduce_and(const vllong8& v) { return toScalar(vreduce_and(v)); } + __forceinline long long reduce_or (const vllong8& v) { return toScalar(vreduce_or (v)); } + __forceinline long long reduce_add(const vllong8& v) { return toScalar(vreduce_add(v)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Memory load and store operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vllong8 permute(const vllong8& v, const vllong8& index) { + return _mm512_permutexvar_epi64(index,v); + } + + __forceinline vllong8 reverse(const vllong8& a) { + return permute(a,vllong8(reverse_step)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vllong8& v) + { + cout << "<" << v[0]; + for (size_t i=1; i<8; i++) cout << ", " << v[i]; + cout << ">"; + return cout; + } +} diff --git a/thirdparty/embree/common/simd/vuint16_avx512.h b/thirdparty/embree/common/simd/vuint16_avx512.h new file mode 100644 index 000000000000..c5a2bb047840 --- /dev/null +++ b/thirdparty/embree/common/simd/vuint16_avx512.h @@ -0,0 +1,443 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 16-wide AVX-512 unsigned integer type */ + template<> + struct vuint<16> + { + ALIGNED_STRUCT_(64); + + typedef vboolf16 Bool; + typedef vuint16 UInt; + typedef vfloat16 Float; + + enum { size = 16 }; // number of SIMD elements + union { // data + __m512i v; + unsigned int i[16]; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint() {} + __forceinline vuint(const vuint16& t) { v = t.v; } + __forceinline vuint16& operator =(const vuint16& f) { v = f.v; return *this; } + + __forceinline vuint(const __m512i& t) { v = t; } + __forceinline operator __m512i() const { return v; } + __forceinline operator __m256i() const { return _mm512_castsi512_si256(v); } + + __forceinline vuint(unsigned int i) { + v = _mm512_set1_epi32(i); + } + + __forceinline vuint(const vuint4& i) { + v = _mm512_broadcast_i32x4(i); + } + + __forceinline vuint(const vuint8& i) { + v = _mm512_castps_si512(_mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_castsi256_pd(i)))); + } + + __forceinline vuint(unsigned int a, unsigned int b, unsigned int c, unsigned int d) { + v = _mm512_set4_epi32(d,c,b,a); + } + + __forceinline vuint(unsigned int a0 , unsigned int a1 , unsigned int a2 , unsigned int a3, + unsigned int a4 , unsigned int a5 , unsigned int a6 , unsigned int a7, + unsigned int a8 , unsigned int a9 , unsigned int a10, unsigned int a11, + unsigned int a12, unsigned int a13, unsigned int a14, unsigned int a15) + { + v = _mm512_set_epi32(a15,a14,a13,a12,a11,a10,a9,a8,a7,a6,a5,a4,a3,a2,a1,a0); + } + + __forceinline explicit vuint(const __m512& f) { + v = _mm512_cvtps_epu32(f); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint(ZeroTy) : v(_mm512_setzero_epi32()) {} + __forceinline vuint(OneTy) : v(_mm512_set1_epi32(1)) {} + __forceinline vuint(StepTy) : v(_mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)) {} + __forceinline vuint(ReverseStepTy) : v(_mm512_setr_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline void store_nt(void* __restrict__ ptr, const vuint16& a) { + _mm512_stream_si512((__m512i*)ptr,a); + } + + static __forceinline vuint16 loadu(const void* addr) + { + return _mm512_loadu_si512(addr); + } + + static __forceinline vuint16 loadu(const unsigned char* ptr) { return _mm512_cvtepu8_epi32(_mm_loadu_si128((__m128i*)ptr)); } + static __forceinline vuint16 loadu(const unsigned short* ptr) { return _mm512_cvtepu16_epi32(_mm256_loadu_si256((__m256i*)ptr)); } + + static __forceinline vuint16 load(const vuint16* addr) { + return _mm512_load_si512(addr); + } + + static __forceinline vuint16 load(const unsigned int* addr) { + return _mm512_load_si512(addr); + } + + static __forceinline vuint16 load(unsigned short* ptr) { return _mm512_cvtepu16_epi32(*(__m256i*)ptr); } + + + static __forceinline void store(void* ptr, const vuint16& v) { + _mm512_store_si512(ptr,v); + } + + static __forceinline void storeu(void* ptr, const vuint16& v) { + _mm512_storeu_si512(ptr,v); + } + + static __forceinline void storeu(const vboolf16& mask, void* ptr, const vuint16& f) { + _mm512_mask_storeu_epi32(ptr,mask,f); + } + + static __forceinline void store(const vboolf16& mask, void* addr, const vuint16& v2) { + _mm512_mask_store_epi32(addr,mask,v2); + } + + /* pass by value to avoid compiler generating inefficient code */ + static __forceinline void storeu_compact(const vboolf16 mask, void* addr, const vuint16 reg) { + _mm512_mask_compressstoreu_epi32(addr,mask,reg); + } + + static __forceinline void storeu_compact_single(const vboolf16 mask, void* addr, vuint16 reg) { + //_mm512_mask_compressstoreu_epi32(addr,mask,reg); + *(float*)addr = mm512_cvtss_f32(_mm512_mask_compress_ps(_mm512_castsi512_ps(reg),mask,_mm512_castsi512_ps(reg))); + } + + static __forceinline vuint16 compact64bit(const vboolf16& mask, vuint16& v) { + return _mm512_mask_compress_epi64(v,mask,v); + } + + static __forceinline vuint16 compact(const vboolf16& mask, vuint16& v) { + return _mm512_mask_compress_epi32(v,mask,v); + } + + static __forceinline vuint16 compact(const vboolf16& mask, const vuint16& a, vuint16& b) { + return _mm512_mask_compress_epi32(a,mask,b); + } + + static __forceinline vuint16 expand(const vboolf16& mask, const vuint16& a, vuint16& b) { + return _mm512_mask_expand_epi32(b,mask,a); + } + + template + static __forceinline vuint16 gather(const unsigned int* ptr, const vint16& index) { + return _mm512_i32gather_epi32(index,ptr,scale); + } + + template + static __forceinline vuint16 gather(const vboolf16& mask, const unsigned int* ptr, const vint16& index) { + return _mm512_mask_i32gather_epi32(_mm512_undefined_epi32(),mask,index,ptr,scale); + } + + template + static __forceinline vuint16 gather(const vboolf16& mask, vuint16& dest, const unsigned int* ptr, const vint16& index) { + return _mm512_mask_i32gather_epi32(dest,mask,index,ptr,scale); + } + + template + static __forceinline void scatter(unsigned int* ptr, const vint16& index, const vuint16& v) { + _mm512_i32scatter_epi32((int*)ptr,index,v,scale); + } + + template + static __forceinline void scatter(const vboolf16& mask, unsigned int* ptr, const vint16& index, const vuint16& v) { + _mm512_mask_i32scatter_epi32((int*)ptr,mask,index,v,scale); + } + + static __forceinline vuint16 broadcast64bit(size_t v) { + return _mm512_set1_epi64(v); + } + + static __forceinline size_t extract64bit(const vuint16& v) + { + return _mm_cvtsi128_si64(_mm512_castsi512_si128(v)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline unsigned int& operator [](size_t index) { assert(index < 16); return i[index]; } + __forceinline const unsigned int& operator [](size_t index) const { assert(index < 16); return i[index]; } + + __forceinline unsigned int uint (size_t index) const { assert(index < 16); return ((unsigned int*)i)[index]; } + __forceinline size_t& uint64_t(size_t index) const { assert(index < 8); return ((size_t*)i)[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf16 asBool(const vuint16& a) { return _mm512_movepi32_mask(a); } + + __forceinline vuint16 operator +(const vuint16& a) { return a; } + __forceinline vuint16 operator -(const vuint16& a) { return _mm512_sub_epi32(_mm512_setzero_epi32(), a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint16 operator +(const vuint16& a, const vuint16& b) { return _mm512_add_epi32(a, b); } + __forceinline vuint16 operator +(const vuint16& a, unsigned int b) { return a + vuint16(b); } + __forceinline vuint16 operator +(unsigned int a, const vuint16& b) { return vuint16(a) + b; } + + __forceinline vuint16 operator -(const vuint16& a, const vuint16& b) { return _mm512_sub_epi32(a, b); } + __forceinline vuint16 operator -(const vuint16& a, unsigned int b) { return a - vuint16(b); } + __forceinline vuint16 operator -(unsigned int a, const vuint16& b) { return vuint16(a) - b; } + + __forceinline vuint16 operator *(const vuint16& a, const vuint16& b) { return _mm512_mul_epu32(a, b); } + __forceinline vuint16 operator *(const vuint16& a, unsigned int b) { return a * vuint16(b); } + __forceinline vuint16 operator *(unsigned int a, const vuint16& b) { return vuint16(a) * b; } + + __forceinline vuint16 operator &(const vuint16& a, const vuint16& b) { return _mm512_and_epi32(a, b); } + __forceinline vuint16 operator &(const vuint16& a, unsigned int b) { return a & vuint16(b); } + __forceinline vuint16 operator &(unsigned int a, const vuint16& b) { return vuint16(a) & b; } + + __forceinline vuint16 operator |(const vuint16& a, const vuint16& b) { return _mm512_or_epi32(a, b); } + __forceinline vuint16 operator |(const vuint16& a, unsigned int b) { return a | vuint16(b); } + __forceinline vuint16 operator |(unsigned int a, const vuint16& b) { return vuint16(a) | b; } + + __forceinline vuint16 operator ^(const vuint16& a, const vuint16& b) { return _mm512_xor_epi32(a, b); } + __forceinline vuint16 operator ^(const vuint16& a, unsigned int b) { return a ^ vuint16(b); } + __forceinline vuint16 operator ^(unsigned int a, const vuint16& b) { return vuint16(a) ^ b; } + + __forceinline vuint16 operator <<(const vuint16& a, unsigned int n) { return _mm512_slli_epi32(a, n); } + __forceinline vuint16 operator >>(const vuint16& a, unsigned int n) { return _mm512_srli_epi32(a, n); } + + __forceinline vuint16 operator <<(const vuint16& a, const vuint16& n) { return _mm512_sllv_epi32(a, n); } + __forceinline vuint16 operator >>(const vuint16& a, const vuint16& n) { return _mm512_srlv_epi32(a, n); } + + __forceinline vuint16 sll (const vuint16& a, unsigned int b) { return _mm512_slli_epi32(a, b); } + __forceinline vuint16 sra (const vuint16& a, unsigned int b) { return _mm512_srai_epi32(a, b); } + __forceinline vuint16 srl (const vuint16& a, unsigned int b) { return _mm512_srli_epi32(a, b); } + + __forceinline vuint16 min(const vuint16& a, const vuint16& b) { return _mm512_min_epu32(a, b); } + __forceinline vuint16 min(const vuint16& a, unsigned int b) { return min(a,vuint16(b)); } + __forceinline vuint16 min(unsigned int a, const vuint16& b) { return min(vuint16(a),b); } + + __forceinline vuint16 max(const vuint16& a, const vuint16& b) { return _mm512_max_epu32(a, b); } + __forceinline vuint16 max(const vuint16& a, unsigned int b) { return max(a,vuint16(b)); } + __forceinline vuint16 max(unsigned int a, const vuint16& b) { return max(vuint16(a),b); } + + __forceinline vuint16 mask_add(const vboolf16& mask, vuint16& c, const vuint16& a, const vuint16& b) { return _mm512_mask_add_epi32(c,mask,a,b); } + __forceinline vuint16 mask_sub(const vboolf16& mask, vuint16& c, const vuint16& a, const vuint16& b) { return _mm512_mask_sub_epi32(c,mask,a,b); } + + __forceinline vuint16 mask_and(const vboolf16& m, vuint16& c, const vuint16& a, const vuint16& b) { return _mm512_mask_and_epi32(c,m,a,b); } + __forceinline vuint16 mask_or (const vboolf16& m, vuint16& c, const vuint16& a, const vuint16& b) { return _mm512_mask_or_epi32(c,m,a,b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint16& operator +=(vuint16& a, const vuint16& b) { return a = a + b; } + __forceinline vuint16& operator +=(vuint16& a, unsigned int b) { return a = a + b; } + + __forceinline vuint16& operator -=(vuint16& a, const vuint16& b) { return a = a - b; } + __forceinline vuint16& operator -=(vuint16& a, unsigned int b) { return a = a - b; } + + __forceinline vuint16& operator *=(vuint16& a, const vuint16& b) { return a = a * b; } + __forceinline vuint16& operator *=(vuint16& a, unsigned int b) { return a = a * b; } + + __forceinline vuint16& operator &=(vuint16& a, const vuint16& b) { return a = a & b; } + __forceinline vuint16& operator &=(vuint16& a, unsigned int b) { return a = a & b; } + + __forceinline vuint16& operator |=(vuint16& a, const vuint16& b) { return a = a | b; } + __forceinline vuint16& operator |=(vuint16& a, unsigned int b) { return a = a | b; } + + __forceinline vuint16& operator <<=(vuint16& a, unsigned int b) { return a = a << b; } + __forceinline vuint16& operator >>=(vuint16& a, unsigned int b) { return a = a >> b; } + + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf16 operator ==(const vuint16& a, const vuint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboolf16 operator ==(const vuint16& a, unsigned int b) { return a == vuint16(b); } + __forceinline vboolf16 operator ==(unsigned int a, const vuint16& b) { return vuint16(a) == b; } + + __forceinline vboolf16 operator !=(const vuint16& a, const vuint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_NE); } + __forceinline vboolf16 operator !=(const vuint16& a, unsigned int b) { return a != vuint16(b); } + __forceinline vboolf16 operator !=(unsigned int a, const vuint16& b) { return vuint16(a) != b; } + + __forceinline vboolf16 operator < (const vuint16& a, const vuint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_LT); } + __forceinline vboolf16 operator < (const vuint16& a, unsigned int b) { return a < vuint16(b); } + __forceinline vboolf16 operator < (unsigned int a, const vuint16& b) { return vuint16(a) < b; } + + __forceinline vboolf16 operator >=(const vuint16& a, const vuint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_GE); } + __forceinline vboolf16 operator >=(const vuint16& a, unsigned int b) { return a >= vuint16(b); } + __forceinline vboolf16 operator >=(unsigned int a, const vuint16& b) { return vuint16(a) >= b; } + + __forceinline vboolf16 operator > (const vuint16& a, const vuint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_GT); } + __forceinline vboolf16 operator > (const vuint16& a, unsigned int b) { return a > vuint16(b); } + __forceinline vboolf16 operator > (unsigned int a, const vuint16& b) { return vuint16(a) > b; } + + __forceinline vboolf16 operator <=(const vuint16& a, const vuint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_LE); } + __forceinline vboolf16 operator <=(const vuint16& a, unsigned int b) { return a <= vuint16(b); } + __forceinline vboolf16 operator <=(unsigned int a, const vuint16& b) { return vuint16(a) <= b; } + + __forceinline vboolf16 eq(const vuint16& a, const vuint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboolf16 ne(const vuint16& a, const vuint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_NE); } + __forceinline vboolf16 lt(const vuint16& a, const vuint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_LT); } + __forceinline vboolf16 ge(const vuint16& a, const vuint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_GE); } + __forceinline vboolf16 gt(const vuint16& a, const vuint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_GT); } + __forceinline vboolf16 le(const vuint16& a, const vuint16& b) { return _mm512_cmp_epu32_mask(a,b,_MM_CMPINT_LE); } + + __forceinline vboolf16 eq(const vboolf16 mask, const vuint16& a, const vuint16& b) { return _mm512_mask_cmp_epu32_mask(mask,a,b,_MM_CMPINT_EQ); } + __forceinline vboolf16 ne(const vboolf16 mask, const vuint16& a, const vuint16& b) { return _mm512_mask_cmp_epu32_mask(mask,a,b,_MM_CMPINT_NE); } + __forceinline vboolf16 lt(const vboolf16 mask, const vuint16& a, const vuint16& b) { return _mm512_mask_cmp_epu32_mask(mask,a,b,_MM_CMPINT_LT); } + __forceinline vboolf16 ge(const vboolf16 mask, const vuint16& a, const vuint16& b) { return _mm512_mask_cmp_epu32_mask(mask,a,b,_MM_CMPINT_GE); } + __forceinline vboolf16 gt(const vboolf16 mask, const vuint16& a, const vuint16& b) { return _mm512_mask_cmp_epu32_mask(mask,a,b,_MM_CMPINT_GT); } + __forceinline vboolf16 le(const vboolf16 mask, const vuint16& a, const vuint16& b) { return _mm512_mask_cmp_epu32_mask(mask,a,b,_MM_CMPINT_LE); } + + + __forceinline vuint16 select(const vboolf16& m, const vuint16& t, const vuint16& f) { + return _mm512_mask_or_epi32(f,m,t,t); + } + + __forceinline void xchg(const vboolf16& m, vuint16& a, vuint16& b) { + const vuint16 c = a; a = select(m,b,a); b = select(m,c,b); + } + + __forceinline vboolf16 test(const vboolf16& m, const vuint16& a, const vuint16& b) { + return _mm512_mask_test_epi32_mask(m,a,b); + } + + __forceinline vboolf16 test(const vuint16& a, const vuint16& b) { + return _mm512_test_epi32_mask(a,b); + } + + //////////////////////////////////////////////////////////////////////////////// + // Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vuint16 shuffle(const vuint16& v) { + return _mm512_castps_si512(_mm512_permute_ps(_mm512_castsi512_ps(v), _MM_SHUFFLE(i, i, i, i))); + } + + template + __forceinline vuint16 shuffle(const vuint16& v) { + return _mm512_castps_si512(_mm512_permute_ps(_mm512_castsi512_ps(v), _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template + __forceinline vuint16 shuffle4(const vuint16& v) { + return _mm512_castps_si512(_mm512_shuffle_f32x4(_mm512_castsi512_ps(v), _mm512_castsi512_ps(v) ,_MM_SHUFFLE(i, i, i, i))); + } + + template + __forceinline vuint16 shuffle4(const vuint16& v) { + return _mm512_castps_si512(_mm512_shuffle_f32x4(_mm512_castsi512_ps(v), _mm512_castsi512_ps(v), _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template + __forceinline vuint16 align_shift_right(const vuint16& a, const vuint16& b) { + return _mm512_alignr_epi32(a, b, i); + }; + + __forceinline unsigned int toScalar(const vuint16& v) { + return _mm_cvtsi128_si32(_mm512_castsi512_si128(v)); + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint16 vreduce_min2(vuint16 x) { return min(x, shuffle<1,0,3,2>(x)); } + __forceinline vuint16 vreduce_min4(vuint16 x) { x = vreduce_min2(x); return min(x, shuffle<2,3,0,1>(x)); } + __forceinline vuint16 vreduce_min8(vuint16 x) { x = vreduce_min4(x); return min(x, shuffle4<1,0,3,2>(x)); } + __forceinline vuint16 vreduce_min (vuint16 x) { x = vreduce_min8(x); return min(x, shuffle4<2,3,0,1>(x)); } + + __forceinline vuint16 vreduce_max2(vuint16 x) { return max(x, shuffle<1,0,3,2>(x)); } + __forceinline vuint16 vreduce_max4(vuint16 x) { x = vreduce_max2(x); return max(x, shuffle<2,3,0,1>(x)); } + __forceinline vuint16 vreduce_max8(vuint16 x) { x = vreduce_max4(x); return max(x, shuffle4<1,0,3,2>(x)); } + __forceinline vuint16 vreduce_max (vuint16 x) { x = vreduce_max8(x); return max(x, shuffle4<2,3,0,1>(x)); } + + __forceinline vuint16 vreduce_and2(vuint16 x) { return x & shuffle<1,0,3,2>(x); } + __forceinline vuint16 vreduce_and4(vuint16 x) { x = vreduce_and2(x); return x & shuffle<2,3,0,1>(x); } + __forceinline vuint16 vreduce_and8(vuint16 x) { x = vreduce_and4(x); return x & shuffle4<1,0,3,2>(x); } + __forceinline vuint16 vreduce_and (vuint16 x) { x = vreduce_and8(x); return x & shuffle4<2,3,0,1>(x); } + + __forceinline vuint16 vreduce_or2(vuint16 x) { return x | shuffle<1,0,3,2>(x); } + __forceinline vuint16 vreduce_or4(vuint16 x) { x = vreduce_or2(x); return x | shuffle<2,3,0,1>(x); } + __forceinline vuint16 vreduce_or8(vuint16 x) { x = vreduce_or4(x); return x | shuffle4<1,0,3,2>(x); } + __forceinline vuint16 vreduce_or (vuint16 x) { x = vreduce_or8(x); return x | shuffle4<2,3,0,1>(x); } + + __forceinline vuint16 vreduce_add2(vuint16 x) { return x + shuffle<1,0,3,2>(x); } + __forceinline vuint16 vreduce_add4(vuint16 x) { x = vreduce_add2(x); return x + shuffle<2,3,0,1>(x); } + __forceinline vuint16 vreduce_add8(vuint16 x) { x = vreduce_add4(x); return x + shuffle4<1,0,3,2>(x); } + __forceinline vuint16 vreduce_add (vuint16 x) { x = vreduce_add8(x); return x + shuffle4<2,3,0,1>(x); } + + __forceinline unsigned int reduce_min(const vuint16& v) { return toScalar(vreduce_min(v)); } + __forceinline unsigned int reduce_max(const vuint16& v) { return toScalar(vreduce_max(v)); } + __forceinline unsigned int reduce_and(const vuint16& v) { return toScalar(vreduce_and(v)); } + __forceinline unsigned int reduce_or (const vuint16& v) { return toScalar(vreduce_or (v)); } + __forceinline unsigned int reduce_add(const vuint16& v) { return toScalar(vreduce_add(v)); } + + //////////////////////////////////////////////////////////////////////////////// + /// Memory load and store operations + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint16 permute(vuint16 v, vuint16 index) { + return _mm512_permutexvar_epi32(index,v); + } + + __forceinline vuint16 reverse(const vuint16& a) { + return permute(a,vuint16(reverse_step)); + } + + __forceinline vuint16 prefix_sum(const vuint16& a) + { + const vuint16 z(zero); + vuint16 v = a; + v = v + align_shift_right<16-1>(v,z); + v = v + align_shift_right<16-2>(v,z); + v = v + align_shift_right<16-4>(v,z); + v = v + align_shift_right<16-8>(v,z); + return v; + } + + __forceinline vuint16 reverse_prefix_sum(const vuint16& a) + { + const vuint16 z(zero); + vuint16 v = a; + v = v + align_shift_right<1>(z,v); + v = v + align_shift_right<2>(z,v); + v = v + align_shift_right<4>(z,v); + v = v + align_shift_right<8>(z,v); + return v; + } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vuint16& v) + { + cout << "<" << v[0]; + for (int i=1; i<16; i++) cout << ", " << v[i]; + cout << ">"; + return cout; + } +} diff --git a/thirdparty/embree/common/simd/vuint4_sse2.h b/thirdparty/embree/common/simd/vuint4_sse2.h new file mode 100644 index 000000000000..396eb45d5dfe --- /dev/null +++ b/thirdparty/embree/common/simd/vuint4_sse2.h @@ -0,0 +1,428 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../math/math.h" + +namespace embree +{ + /* 4-wide SSE integer type */ + template<> + struct vuint<4> + { + ALIGNED_STRUCT_(16); + + typedef vboolf4 Bool; + typedef vuint4 Int; + typedef vfloat4 Float; + + enum { size = 4 }; // number of SIMD elements + union { __m128i v; unsigned int i[4]; }; // data + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint() {} + __forceinline vuint(const vuint4& a) { v = a.v; } + __forceinline vuint4& operator =(const vuint4& a) { v = a.v; return *this; } + + __forceinline vuint(const __m128i a) : v(a) {} + __forceinline operator const __m128i&() const { return v; } + __forceinline operator __m128i&() { return v; } + + + __forceinline vuint(unsigned int a) : v(_mm_set1_epi32(a)) {} + __forceinline vuint(unsigned int a, unsigned int b, unsigned int c, unsigned int d) : v(_mm_set_epi32(d, c, b, a)) {} + +#if defined(__AVX512VL__) + __forceinline explicit vuint(__m128 a) : v(_mm_cvtps_epu32(a)) {} +#endif + +#if defined(__AVX512VL__) + __forceinline explicit vuint(const vboolf4& a) : v(_mm_movm_epi32(a)) {} +#else + __forceinline explicit vuint(const vboolf4& a) : v(_mm_castps_si128((__m128)a)) {} +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint(ZeroTy) : v(_mm_setzero_si128()) {} + __forceinline vuint(OneTy) : v(_mm_set1_epi32(1)) {} + __forceinline vuint(PosInfTy) : v(_mm_set1_epi32(unsigned(pos_inf))) {} + __forceinline vuint(StepTy) : v(_mm_set_epi32(3, 2, 1, 0)) {} + __forceinline vuint(TrueTy) { v = _mm_cmpeq_epi32(v,v); } + __forceinline vuint(UndefinedTy) : v(_mm_castps_si128(_mm_undefined_ps())) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline vuint4 load (const void* a) { return _mm_load_si128((__m128i*)a); } + static __forceinline vuint4 loadu(const void* a) { return _mm_loadu_si128((__m128i*)a); } + + static __forceinline void store (void* ptr, const vuint4& v) { _mm_store_si128((__m128i*)ptr,v); } + static __forceinline void storeu(void* ptr, const vuint4& v) { _mm_storeu_si128((__m128i*)ptr,v); } + +#if defined(__AVX512VL__) + static __forceinline vuint4 load (const vboolf4& mask, const void* ptr) { return _mm_mask_load_epi32 (_mm_setzero_si128(),mask,ptr); } + static __forceinline vuint4 loadu(const vboolf4& mask, const void* ptr) { return _mm_mask_loadu_epi32(_mm_setzero_si128(),mask,ptr); } + + static __forceinline void store (const vboolf4& mask, void* ptr, const vuint4& v) { _mm_mask_store_epi32 (ptr,mask,v); } + static __forceinline void storeu(const vboolf4& mask, void* ptr, const vuint4& v) { _mm_mask_storeu_epi32(ptr,mask,v); } +#elif defined(__AVX__) + static __forceinline vuint4 load (const vbool4& mask, const void* a) { return _mm_castps_si128(_mm_maskload_ps((float*)a,mask)); } + static __forceinline vuint4 loadu(const vbool4& mask, const void* a) { return _mm_castps_si128(_mm_maskload_ps((float*)a,mask)); } + + static __forceinline void store (const vboolf4& mask, void* ptr, const vuint4& i) { _mm_maskstore_ps((float*)ptr,(__m128i)mask,_mm_castsi128_ps(i)); } + static __forceinline void storeu(const vboolf4& mask, void* ptr, const vuint4& i) { _mm_maskstore_ps((float*)ptr,(__m128i)mask,_mm_castsi128_ps(i)); } +#else + static __forceinline vuint4 load (const vbool4& mask, const void* a) { return _mm_and_si128(_mm_load_si128 ((__m128i*)a),mask); } + static __forceinline vuint4 loadu(const vbool4& mask, const void* a) { return _mm_and_si128(_mm_loadu_si128((__m128i*)a),mask); } + + static __forceinline void store (const vboolf4& mask, void* ptr, const vuint4& i) { store (ptr,select(mask,i,load (ptr))); } + static __forceinline void storeu(const vboolf4& mask, void* ptr, const vuint4& i) { storeu(ptr,select(mask,i,loadu(ptr))); } +#endif + +#if defined(__SSE4_1__) + static __forceinline vuint4 load(const unsigned char* ptr) { + return _mm_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)ptr)); + } + + static __forceinline vuint4 loadu(const unsigned char* ptr) { + return _mm_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)ptr)); + } + +#endif + + static __forceinline vuint4 load(const unsigned short* ptr) { +#if defined (__SSE4_1__) + return _mm_cvtepu16_epi32(_mm_loadu_si128((__m128i*)ptr)); +#else + return vuint4(ptr[0],ptr[1],ptr[2],ptr[3]); +#endif + } + + static __forceinline void store_uchar(unsigned char* ptr, const vuint4& v) { +#if defined(__SSE4_1__) + __m128i x = v; + x = _mm_packus_epi32(x, x); + x = _mm_packus_epi16(x, x); + *(unsigned*)ptr = _mm_cvtsi128_si32(x); +#else + for (size_t i=0;i<4;i++) + ptr[i] = (unsigned char)v[i]; +#endif + } + + static __forceinline void store_uchar(unsigned short* ptr, const vuint4& v) { + for (size_t i=0;i<4;i++) + ptr[i] = (unsigned short)v[i]; + } + + static __forceinline vuint4 load_nt(void* ptr) { +#if defined(__SSE4_1__) + return _mm_stream_load_si128((__m128i*)ptr); +#else + return _mm_load_si128((__m128i*)ptr); +#endif + } + + static __forceinline void store_nt(void* ptr, const vuint4& v) { +#if defined(__SSE4_1__) + _mm_stream_ps((float*)ptr,_mm_castsi128_ps(v)); +#else + _mm_store_si128((__m128i*)ptr,v); +#endif + } + + template + static __forceinline vuint4 gather(const unsigned int* ptr, const vint4& index) { +#if defined(__AVX2__) + return _mm_i32gather_epi32((const int*)ptr, index, scale); +#else + return vuint4( + *(unsigned int*)(((char*)ptr)+scale*index[0]), + *(unsigned int*)(((char*)ptr)+scale*index[1]), + *(unsigned int*)(((char*)ptr)+scale*index[2]), + *(unsigned int*)(((char*)ptr)+scale*index[3])); +#endif + } + + template + static __forceinline vuint4 gather(const vboolf4& mask, const unsigned int* ptr, const vint4& index) { + vuint4 r = zero; +#if defined(__AVX512VL__) + return _mm_mmask_i32gather_epi32(r, mask, index, ptr, scale); +#elif defined(__AVX2__) + return _mm_mask_i32gather_epi32(r, (const int*)ptr, index, mask, scale); +#else + if (likely(mask[0])) r[0] = *(unsigned int*)(((char*)ptr)+scale*index[0]); + if (likely(mask[1])) r[1] = *(unsigned int*)(((char*)ptr)+scale*index[1]); + if (likely(mask[2])) r[2] = *(unsigned int*)(((char*)ptr)+scale*index[2]); + if (likely(mask[3])) r[3] = *(unsigned int*)(((char*)ptr)+scale*index[3]); + return r; +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const unsigned int& operator [](size_t index) const { assert(index < 4); return i[index]; } + __forceinline unsigned int& operator [](size_t index) { assert(index < 4); return i[index]; } + + friend __forceinline vuint4 select(const vboolf4& m, const vuint4& t, const vuint4& f) { +#if defined(__AVX512VL__) + return _mm_mask_blend_epi32(m, (__m128i)f, (__m128i)t); +#elif defined(__SSE4_1__) + return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(f), _mm_castsi128_ps(t), m)); +#else + return _mm_or_si128(_mm_and_si128(m, t), _mm_andnot_si128(m, f)); +#endif + } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX512VL__) + __forceinline vboolf4 asBool(const vuint4& a) { return _mm_movepi32_mask(a); } +#else + __forceinline vboolf4 asBool(const vuint4& a) { return _mm_castsi128_ps(a); } +#endif + + __forceinline vuint4 operator +(const vuint4& a) { return a; } + __forceinline vuint4 operator -(const vuint4& a) { return _mm_sub_epi32(_mm_setzero_si128(), a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint4 operator +(const vuint4& a, const vuint4& b) { return _mm_add_epi32(a, b); } + __forceinline vuint4 operator +(const vuint4& a, unsigned int b) { return a + vuint4(b); } + __forceinline vuint4 operator +(unsigned int a, const vuint4& b) { return vuint4(a) + b; } + + __forceinline vuint4 operator -(const vuint4& a, const vuint4& b) { return _mm_sub_epi32(a, b); } + __forceinline vuint4 operator -(const vuint4& a, unsigned int b) { return a - vuint4(b); } + __forceinline vuint4 operator -(unsigned int a, const vuint4& b) { return vuint4(a) - b; } + +//#if defined(__SSE4_1__) +// __forceinline vuint4 operator *(const vuint4& a, const vuint4& b) { return _mm_mullo_epu32(a, b); } +//#else +// __forceinline vuint4 operator *(const vuint4& a, const vuint4& b) { return vuint4(a[0]*b[0],a[1]*b[1],a[2]*b[2],a[3]*b[3]); } +//#endif +// __forceinline vuint4 operator *(const vuint4& a, unsigned int b) { return a * vuint4(b); } +// __forceinline vuint4 operator *(unsigned int a, const vuint4& b) { return vuint4(a) * b; } + + __forceinline vuint4 operator &(const vuint4& a, const vuint4& b) { return _mm_and_si128(a, b); } + __forceinline vuint4 operator &(const vuint4& a, unsigned int b) { return a & vuint4(b); } + __forceinline vuint4 operator &(unsigned int a, const vuint4& b) { return vuint4(a) & b; } + + __forceinline vuint4 operator |(const vuint4& a, const vuint4& b) { return _mm_or_si128(a, b); } + __forceinline vuint4 operator |(const vuint4& a, unsigned int b) { return a | vuint4(b); } + __forceinline vuint4 operator |(unsigned int a, const vuint4& b) { return vuint4(a) | b; } + + __forceinline vuint4 operator ^(const vuint4& a, const vuint4& b) { return _mm_xor_si128(a, b); } + __forceinline vuint4 operator ^(const vuint4& a, unsigned int b) { return a ^ vuint4(b); } + __forceinline vuint4 operator ^(unsigned int a, const vuint4& b) { return vuint4(a) ^ b; } + + __forceinline vuint4 operator <<(const vuint4& a, unsigned int n) { return _mm_slli_epi32(a, n); } + __forceinline vuint4 operator >>(const vuint4& a, unsigned int n) { return _mm_srli_epi32(a, n); } + + __forceinline vuint4 sll (const vuint4& a, unsigned int b) { return _mm_slli_epi32(a, b); } + __forceinline vuint4 sra (const vuint4& a, unsigned int b) { return _mm_srai_epi32(a, b); } + __forceinline vuint4 srl (const vuint4& a, unsigned int b) { return _mm_srli_epi32(a, b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint4& operator +=(vuint4& a, const vuint4& b) { return a = a + b; } + __forceinline vuint4& operator +=(vuint4& a, unsigned int b) { return a = a + b; } + + __forceinline vuint4& operator -=(vuint4& a, const vuint4& b) { return a = a - b; } + __forceinline vuint4& operator -=(vuint4& a, unsigned int b) { return a = a - b; } + +//#if defined(__SSE4_1__) +// __forceinline vuint4& operator *=(vuint4& a, const vuint4& b) { return a = a * b; } +// __forceinline vuint4& operator *=(vuint4& a, unsigned int b) { return a = a * b; } +//#endif + + __forceinline vuint4& operator &=(vuint4& a, const vuint4& b) { return a = a & b; } + __forceinline vuint4& operator &=(vuint4& a, unsigned int b) { return a = a & b; } + + __forceinline vuint4& operator |=(vuint4& a, const vuint4& b) { return a = a | b; } + __forceinline vuint4& operator |=(vuint4& a, unsigned int b) { return a = a | b; } + + __forceinline vuint4& operator <<=(vuint4& a, unsigned int b) { return a = a << b; } + __forceinline vuint4& operator >>=(vuint4& a, unsigned int b) { return a = a >> b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX512VL__) + __forceinline vboolf4 operator ==(const vuint4& a, const vuint4& b) { return _mm_cmp_epu32_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboolf4 operator !=(const vuint4& a, const vuint4& b) { return _mm_cmp_epu32_mask(a,b,_MM_CMPINT_NE); } + //__forceinline vboolf4 operator < (const vuint4& a, const vuint4& b) { return _mm_cmp_epu32_mask(a,b,_MM_CMPINT_LT); } + //__forceinline vboolf4 operator >=(const vuint4& a, const vuint4& b) { return _mm_cmp_epu32_mask(a,b,_MM_CMPINT_GE); } + //__forceinline vboolf4 operator > (const vuint4& a, const vuint4& b) { return _mm_cmp_epu32_mask(a,b,_MM_CMPINT_GT); } + //__forceinline vboolf4 operator <=(const vuint4& a, const vuint4& b) { return _mm_cmp_epu32_mask(a,b,_MM_CMPINT_LE); } +#else + __forceinline vboolf4 operator ==(const vuint4& a, const vuint4& b) { return _mm_castsi128_ps(_mm_cmpeq_epi32(a, b)); } + __forceinline vboolf4 operator !=(const vuint4& a, const vuint4& b) { return !(a == b); } + //__forceinline vboolf4 operator < (const vuint4& a, const vuint4& b) { return _mm_castsi128_ps(_mm_cmplt_epu32(a, b)); } + //__forceinline vboolf4 operator >=(const vuint4& a, const vuint4& b) { return !(a < b); } + //__forceinline vboolf4 operator > (const vuint4& a, const vuint4& b) { return _mm_castsi128_ps(_mm_cmpgt_epu32(a, b)); } + //__forceinline vboolf4 operator <=(const vuint4& a, const vuint4& b) { return !(a > b); } +#endif + + __forceinline vboolf4 operator ==(const vuint4& a, unsigned int b) { return a == vuint4(b); } + __forceinline vboolf4 operator ==(unsigned int a, const vuint4& b) { return vuint4(a) == b; } + + __forceinline vboolf4 operator !=(const vuint4& a, unsigned int b) { return a != vuint4(b); } + __forceinline vboolf4 operator !=(unsigned int a, const vuint4& b) { return vuint4(a) != b; } + + //__forceinline vboolf4 operator < (const vuint4& a, unsigned int b) { return a < vuint4(b); } + //__forceinline vboolf4 operator < (unsigned int a, const vuint4& b) { return vuint4(a) < b; } + + //__forceinline vboolf4 operator >=(const vuint4& a, unsigned int b) { return a >= vuint4(b); } + //__forceinline vboolf4 operator >=(unsigned int a, const vuint4& b) { return vuint4(a) >= b; } + + //__forceinline vboolf4 operator > (const vuint4& a, unsigned int b) { return a > vuint4(b); } + //__forceinline vboolf4 operator > (unsigned int a, const vuint4& b) { return vuint4(a) > b; } + + //__forceinline vboolf4 operator <=(const vuint4& a, unsigned int b) { return a <= vuint4(b); } + //__forceinline vboolf4 operator <=(unsigned int a, const vuint4& b) { return vuint4(a) <= b; } + + __forceinline vboolf4 eq(const vuint4& a, const vuint4& b) { return a == b; } + __forceinline vboolf4 ne(const vuint4& a, const vuint4& b) { return a != b; } + //__forceinline vboolf4 lt(const vuint4& a, const vuint4& b) { return a < b; } + //__forceinline vboolf4 ge(const vuint4& a, const vuint4& b) { return a >= b; } + //__forceinline vboolf4 gt(const vuint4& a, const vuint4& b) { return a > b; } + //__forceinline vboolf4 le(const vuint4& a, const vuint4& b) { return a <= b; } + +#if defined(__AVX512VL__) + __forceinline vboolf4 eq(const vboolf4& mask, const vuint4& a, const vuint4& b) { return _mm_mask_cmp_epu32_mask(mask, a, b, _MM_CMPINT_EQ); } + __forceinline vboolf4 ne(const vboolf4& mask, const vuint4& a, const vuint4& b) { return _mm_mask_cmp_epu32_mask(mask, a, b, _MM_CMPINT_NE); } + //__forceinline vboolf4 lt(const vboolf4& mask, const vuint4& a, const vuint4& b) { return _mm_mask_cmp_epu32_mask(mask, a, b, _MM_CMPINT_LT); } + //__forceinline vboolf4 ge(const vboolf4& mask, const vuint4& a, const vuint4& b) { return _mm_mask_cmp_epu32_mask(mask, a, b, _MM_CMPINT_GE); } + //__forceinline vboolf4 gt(const vboolf4& mask, const vuint4& a, const vuint4& b) { return _mm_mask_cmp_epu32_mask(mask, a, b, _MM_CMPINT_GT); } + //__forceinline vboolf4 le(const vboolf4& mask, const vuint4& a, const vuint4& b) { return _mm_mask_cmp_epu32_mask(mask, a, b, _MM_CMPINT_LE); } +#else + __forceinline vboolf4 eq(const vboolf4& mask, const vuint4& a, const vuint4& b) { return mask & (a == b); } + __forceinline vboolf4 ne(const vboolf4& mask, const vuint4& a, const vuint4& b) { return mask & (a != b); } + //__forceinline vboolf4 lt(const vboolf4& mask, const vuint4& a, const vuint4& b) { return mask & (a < b); } + //__forceinline vboolf4 ge(const vboolf4& mask, const vuint4& a, const vuint4& b) { return mask & (a >= b); } + //__forceinline vboolf4 gt(const vboolf4& mask, const vuint4& a, const vuint4& b) { return mask & (a > b); } + //__forceinline vboolf4 le(const vboolf4& mask, const vuint4& a, const vuint4& b) { return mask & (a <= b); } +#endif + + template + __forceinline vuint4 select(const vuint4& t, const vuint4& f) { +#if defined(__SSE4_1__) + return _mm_castps_si128(_mm_blend_ps(_mm_castsi128_ps(f), _mm_castsi128_ps(t), mask)); +#else + return select(vboolf4(mask), t, f); +#endif + } + +/*#if defined(__SSE4_1__) + __forceinline vuint4 min(const vuint4& a, const vuint4& b) { return _mm_min_epu32(a, b); } + __forceinline vuint4 max(const vuint4& a, const vuint4& b) { return _mm_max_epu32(a, b); } + +#else + __forceinline vuint4 min(const vuint4& a, const vuint4& b) { return select(a < b,a,b); } + __forceinline vuint4 max(const vuint4& a, const vuint4& b) { return select(a < b,b,a); } +#endif + + __forceinline vuint4 min(const vuint4& a, unsigned int b) { return min(a,vuint4(b)); } + __forceinline vuint4 min(unsigned int a, const vuint4& b) { return min(vuint4(a),b); } + __forceinline vuint4 max(const vuint4& a, unsigned int b) { return max(a,vuint4(b)); } + __forceinline vuint4 max(unsigned int a, const vuint4& b) { return max(vuint4(a),b); }*/ + + //////////////////////////////////////////////////////////////////////////////// + // Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint4 unpacklo(const vuint4& a, const vuint4& b) { return _mm_castps_si128(_mm_unpacklo_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b))); } + __forceinline vuint4 unpackhi(const vuint4& a, const vuint4& b) { return _mm_castps_si128(_mm_unpackhi_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b))); } + + template + __forceinline vuint4 shuffle(const vuint4& v) { + return _mm_shuffle_epi32(v, _MM_SHUFFLE(i3, i2, i1, i0)); + } + + template + __forceinline vuint4 shuffle(const vuint4& a, const vuint4& b) { + return _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), _MM_SHUFFLE(i3, i2, i1, i0))); + } + +#if defined(__SSE3__) + template<> __forceinline vuint4 shuffle<0, 0, 2, 2>(const vuint4& v) { return _mm_castps_si128(_mm_moveldup_ps(_mm_castsi128_ps(v))); } + template<> __forceinline vuint4 shuffle<1, 1, 3, 3>(const vuint4& v) { return _mm_castps_si128(_mm_movehdup_ps(_mm_castsi128_ps(v))); } + template<> __forceinline vuint4 shuffle<0, 1, 0, 1>(const vuint4& v) { return _mm_castpd_si128(_mm_movedup_pd (_mm_castsi128_pd(v))); } +#endif + + template + __forceinline vuint4 shuffle(const vuint4& v) { + return shuffle(v); + } + +#if defined(__SSE4_1__) + template __forceinline unsigned int extract(const vuint4& b) { return _mm_extract_epi32(b, src); } + template __forceinline vuint4 insert(const vuint4& a, const unsigned b) { return _mm_insert_epi32(a, b, dst); } +#else + template __forceinline unsigned int extract(const vuint4& b) { return b[src&3]; } + template __forceinline vuint4 insert(const vuint4& a, const unsigned b) { vuint4 c = a; c[dst&3] = b; return c; } +#endif + + + template<> __forceinline unsigned int extract<0>(const vuint4& b) { return _mm_cvtsi128_si32(b); } + + __forceinline unsigned int toScalar(const vuint4& v) { return _mm_cvtsi128_si32(v); } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + +#if 0 +#if defined(__SSE4_1__) + + __forceinline vuint4 vreduce_min(const vuint4& v) { vuint4 h = min(shuffle<1,0,3,2>(v),v); return min(shuffle<2,3,0,1>(h),h); } + __forceinline vuint4 vreduce_max(const vuint4& v) { vuint4 h = max(shuffle<1,0,3,2>(v),v); return max(shuffle<2,3,0,1>(h),h); } + __forceinline vuint4 vreduce_add(const vuint4& v) { vuint4 h = shuffle<1,0,3,2>(v) + v ; return shuffle<2,3,0,1>(h) + h ; } + + __forceinline unsigned int reduce_min(const vuint4& v) { return toScalar(vreduce_min(v)); } + __forceinline unsigned int reduce_max(const vuint4& v) { return toScalar(vreduce_max(v)); } + __forceinline unsigned int reduce_add(const vuint4& v) { return toScalar(vreduce_add(v)); } + + __forceinline size_t select_min(const vuint4& v) { return bsf(movemask(v == vreduce_min(v))); } + __forceinline size_t select_max(const vuint4& v) { return bsf(movemask(v == vreduce_max(v))); } + + //__forceinline size_t select_min(const vboolf4& valid, const vuint4& v) { const vuint4 a = select(valid,v,vuint4(pos_inf)); return bsf(movemask(valid & (a == vreduce_min(a)))); } + //__forceinline size_t select_max(const vboolf4& valid, const vuint4& v) { const vuint4 a = select(valid,v,vuint4(neg_inf)); return bsf(movemask(valid & (a == vreduce_max(a)))); } + +#else + + __forceinline unsigned int reduce_min(const vuint4& v) { return min(v[0],v[1],v[2],v[3]); } + __forceinline unsigned int reduce_max(const vuint4& v) { return max(v[0],v[1],v[2],v[3]); } + __forceinline unsigned int reduce_add(const vuint4& v) { return v[0]+v[1]+v[2]+v[3]; } + +#endif +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vuint4& a) { + return cout << "<" << a[0] << ", " << a[1] << ", " << a[2] << ", " << a[3] << ">"; + } +} + diff --git a/thirdparty/embree/common/simd/vuint8_avx.h b/thirdparty/embree/common/simd/vuint8_avx.h new file mode 100644 index 000000000000..437e73c7fbf3 --- /dev/null +++ b/thirdparty/embree/common/simd/vuint8_avx.h @@ -0,0 +1,375 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 8-wide AVX integer type */ + template<> + struct vuint<8> + { + ALIGNED_STRUCT_(32); + + typedef vboolf8 Bool; + typedef vuint8 Int; + typedef vfloat8 Float; + + enum { size = 8 }; // number of SIMD elements + union { // data + __m256i v; + struct { __m128i vl,vh; }; + unsigned int i[8]; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint() {} + __forceinline vuint(const vuint8& a) { v = a.v; } + __forceinline vuint8& operator =(const vuint8& a) { v = a.v; return *this; } + + __forceinline vuint(__m256i a) : v(a) {} + __forceinline operator const __m256i&() const { return v; } + __forceinline operator __m256i&() { return v; } + + __forceinline explicit vuint(const vuint4& a) : v(_mm256_insertf128_si256(_mm256_castsi128_si256(a),a,1)) {} + __forceinline vuint(const vuint4& a, const vuint4& b) : v(_mm256_insertf128_si256(_mm256_castsi128_si256(a),b,1)) {} + __forceinline vuint(const __m128i& a, const __m128i& b) : vl(a), vh(b) {} + + __forceinline explicit vuint(const unsigned int* a) : v(_mm256_castps_si256(_mm256_loadu_ps((const float*)a))) {} + __forceinline vuint(unsigned int a) : v(_mm256_set1_epi32(a)) {} + __forceinline vuint(unsigned int a, unsigned int b) : v(_mm256_set_epi32(b, a, b, a, b, a, b, a)) {} + __forceinline vuint(unsigned int a, unsigned int b, unsigned int c, unsigned int d) : v(_mm256_set_epi32(d, c, b, a, d, c, b, a)) {} + __forceinline vuint(unsigned int a, unsigned int b, unsigned int c, unsigned int d, unsigned int e, unsigned int f, unsigned int g, unsigned int vh) : v(_mm256_set_epi32(vh, g, f, e, d, c, b, a)) {} + + __forceinline explicit vuint(__m256 a) : v(_mm256_cvtps_epi32(a)) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint(ZeroTy) : v(_mm256_setzero_si256()) {} + __forceinline vuint(OneTy) : v(_mm256_set1_epi32(1)) {} + __forceinline vuint(PosInfTy) : v(_mm256_set1_epi32(0xFFFFFFFF)) {} + __forceinline vuint(StepTy) : v(_mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0)) {} + __forceinline vuint(UndefinedTy) : v(_mm256_undefined_si256()) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline vuint8 load (const void* a) { return _mm256_castps_si256(_mm256_load_ps((float*)a)); } + static __forceinline vuint8 loadu(const void* a) { return _mm256_castps_si256(_mm256_loadu_ps((float*)a)); } + + static __forceinline vuint8 load (const vboolf8& mask, const void* a) { return _mm256_castps_si256(_mm256_maskload_ps((float*)a,mask)); } + static __forceinline vuint8 loadu(const vboolf8& mask, const void* a) { return _mm256_castps_si256(_mm256_maskload_ps((float*)a,mask)); } + + static __forceinline void store (void* ptr, const vuint8& f) { _mm256_store_ps((float*)ptr,_mm256_castsi256_ps(f)); } + static __forceinline void storeu(void* ptr, const vuint8& f) { _mm256_storeu_ps((float*)ptr,_mm256_castsi256_ps(f)); } + + static __forceinline void store (const vboolf8& mask, void* ptr, const vuint8& f) { _mm256_maskstore_ps((float*)ptr,(__m256i)mask,_mm256_castsi256_ps(f)); } + static __forceinline void storeu(const vboolf8& mask, void* ptr, const vuint8& f) { _mm256_maskstore_ps((float*)ptr,(__m256i)mask,_mm256_castsi256_ps(f)); } + + static __forceinline void store_nt(void* ptr, const vuint8& v) { + _mm256_stream_ps((float*)ptr,_mm256_castsi256_ps(v)); + } + + static __forceinline vuint8 load(const unsigned char* ptr) { + vuint4 il = vuint4::load(ptr+0); + vuint4 ih = vuint4::load(ptr+4); + return vuint8(il,ih); + } + + static __forceinline vuint8 loadu(const unsigned char* ptr) { + vuint4 il = vuint4::loadu(ptr+0); + vuint4 ih = vuint4::loadu(ptr+4); + return vuint8(il,ih); + } + + static __forceinline vuint8 load(const unsigned short* ptr) { + vuint4 il = vuint4::load(ptr+0); + vuint4 ih = vuint4::load(ptr+4); + return vuint8(il,ih); + } + + static __forceinline vuint8 loadu(const unsigned short* ptr) { + vuint4 il = vuint4::loadu(ptr+0); + vuint4 ih = vuint4::loadu(ptr+4); + return vuint8(il,ih); + } + + static __forceinline void store(unsigned char* ptr, const vuint8& i) { + vuint4 il(i.vl); + vuint4 ih(i.vh); + vuint4::store(ptr + 0,il); + vuint4::store(ptr + 4,ih); + } + + static __forceinline void store(unsigned short* ptr, const vuint8& v) { + for (size_t i=0;i<8;i++) + ptr[i] = (unsigned short)v[i]; + } + + template + static __forceinline vuint8 gather(const unsigned int* ptr, const vint8& index) { + return vuint8( + *(unsigned int*)(((char*)ptr)+scale*index[0]), + *(unsigned int*)(((char*)ptr)+scale*index[1]), + *(unsigned int*)(((char*)ptr)+scale*index[2]), + *(unsigned int*)(((char*)ptr)+scale*index[3]), + *(unsigned int*)(((char*)ptr)+scale*index[4]), + *(unsigned int*)(((char*)ptr)+scale*index[5]), + *(unsigned int*)(((char*)ptr)+scale*index[6]), + *(unsigned int*)(((char*)ptr)+scale*index[7])); + } + + template + static __forceinline vuint8 gather(const vboolf8& mask, const unsigned int* ptr, const vint8& index) { + vuint8 r = zero; + if (likely(mask[0])) r[0] = *(unsigned int*)(((char*)ptr)+scale*index[0]); + if (likely(mask[1])) r[1] = *(unsigned int*)(((char*)ptr)+scale*index[1]); + if (likely(mask[2])) r[2] = *(unsigned int*)(((char*)ptr)+scale*index[2]); + if (likely(mask[3])) r[3] = *(unsigned int*)(((char*)ptr)+scale*index[3]); + if (likely(mask[4])) r[4] = *(unsigned int*)(((char*)ptr)+scale*index[4]); + if (likely(mask[5])) r[5] = *(unsigned int*)(((char*)ptr)+scale*index[5]); + if (likely(mask[6])) r[6] = *(unsigned int*)(((char*)ptr)+scale*index[6]); + if (likely(mask[7])) r[7] = *(unsigned int*)(((char*)ptr)+scale*index[7]); + return r; + } + + template + static __forceinline void scatter(void* ptr, const vint8& ofs, const vuint8& v) + { + *(unsigned int*)(((char*)ptr)+scale*ofs[0]) = v[0]; + *(unsigned int*)(((char*)ptr)+scale*ofs[1]) = v[1]; + *(unsigned int*)(((char*)ptr)+scale*ofs[2]) = v[2]; + *(unsigned int*)(((char*)ptr)+scale*ofs[3]) = v[3]; + *(unsigned int*)(((char*)ptr)+scale*ofs[4]) = v[4]; + *(unsigned int*)(((char*)ptr)+scale*ofs[5]) = v[5]; + *(unsigned int*)(((char*)ptr)+scale*ofs[6]) = v[6]; + *(unsigned int*)(((char*)ptr)+scale*ofs[7]) = v[7]; + } + + template + static __forceinline void scatter(const vboolf8& mask, void* ptr, const vint8& ofs, const vuint8& v) + { + if (likely(mask[0])) *(unsigned int*)(((char*)ptr)+scale*ofs[0]) = v[0]; + if (likely(mask[1])) *(unsigned int*)(((char*)ptr)+scale*ofs[1]) = v[1]; + if (likely(mask[2])) *(unsigned int*)(((char*)ptr)+scale*ofs[2]) = v[2]; + if (likely(mask[3])) *(unsigned int*)(((char*)ptr)+scale*ofs[3]) = v[3]; + if (likely(mask[4])) *(unsigned int*)(((char*)ptr)+scale*ofs[4]) = v[4]; + if (likely(mask[5])) *(unsigned int*)(((char*)ptr)+scale*ofs[5]) = v[5]; + if (likely(mask[6])) *(unsigned int*)(((char*)ptr)+scale*ofs[6]) = v[6]; + if (likely(mask[7])) *(unsigned int*)(((char*)ptr)+scale*ofs[7]) = v[7]; + } + + + static __forceinline vuint8 broadcast64(const long long& a) { return _mm256_set1_epi64x(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const unsigned int& operator [](size_t index) const { assert(index < 8); return i[index]; } + __forceinline unsigned int& operator [](size_t index) { assert(index < 8); return i[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf8 asBool(const vuint8& a) { return _mm256_castsi256_ps(a); } + + __forceinline vuint8 operator +(const vuint8& a) { return a; } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint8 operator +(const vuint8& a, const vuint8& b) { return vuint8(_mm_add_epi32(a.vl, b.vl), _mm_add_epi32(a.vh, b.vh)); } + __forceinline vuint8 operator +(const vuint8& a, unsigned int b) { return a + vuint8(b); } + __forceinline vuint8 operator +(unsigned int a, const vuint8& b) { return vuint8(a) + b; } + + __forceinline vuint8 operator -(const vuint8& a, const vuint8& b) { return vuint8(_mm_sub_epi32(a.vl, b.vl), _mm_sub_epi32(a.vh, b.vh)); } + __forceinline vuint8 operator -(const vuint8& a, unsigned int b) { return a - vuint8(b); } + __forceinline vuint8 operator -(unsigned int a, const vuint8& b) { return vuint8(a) - b; } + + //__forceinline vuint8 operator *(const vuint8& a, const vuint8& b) { return vuint8(_mm_mullo_epu32(a.vl, b.vl), _mm_mullo_epu32(a.vh, b.vh)); } + //__forceinline vuint8 operator *(const vuint8& a, unsigned int b) { return a * vuint8(b); } + //__forceinline vuint8 operator *(unsigned int a, const vuint8& b) { return vuint8(a) * b; } + + __forceinline vuint8 operator &(const vuint8& a, const vuint8& b) { return _mm256_castps_si256(_mm256_and_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b))); } + __forceinline vuint8 operator &(const vuint8& a, unsigned int b) { return a & vuint8(b); } + __forceinline vuint8 operator &(unsigned int a, const vuint8& b) { return vuint8(a) & b; } + + __forceinline vuint8 operator |(const vuint8& a, const vuint8& b) { return _mm256_castps_si256(_mm256_or_ps (_mm256_castsi256_ps(a), _mm256_castsi256_ps(b))); } + __forceinline vuint8 operator |(const vuint8& a, unsigned int b) { return a | vuint8(b); } + __forceinline vuint8 operator |(unsigned int a, const vuint8& b) { return vuint8(a) | b; } + + __forceinline vuint8 operator ^(const vuint8& a, const vuint8& b) { return _mm256_castps_si256(_mm256_xor_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b))); } + __forceinline vuint8 operator ^(const vuint8& a, unsigned int b) { return a ^ vuint8(b); } + __forceinline vuint8 operator ^(unsigned int a, const vuint8& b) { return vuint8(a) ^ b; } + + __forceinline vuint8 operator <<(const vuint8& a, unsigned int n) { return vuint8(_mm_slli_epi32(a.vl, n), _mm_slli_epi32(a.vh, n)); } + __forceinline vuint8 operator >>(const vuint8& a, unsigned int n) { return vuint8(_mm_srai_epi32(a.vl, n), _mm_srli_epi32(a.vh, n)); } + + __forceinline vuint8 sll (const vuint8& a, unsigned int b) { return vuint8(_mm_slli_epi32(a.vl, b), _mm_slli_epi32(a.vh, b)); } + __forceinline vuint8 sra (const vuint8& a, unsigned int b) { return vuint8(_mm_srai_epi32(a.vl, b), _mm_srai_epi32(a.vh, b)); } + __forceinline vuint8 srl (const vuint8& a, unsigned int b) { return vuint8(_mm_srli_epi32(a.vl, b), _mm_srli_epi32(a.vh, b)); } + + __forceinline vuint8 min(const vuint8& a, const vuint8& b) { return vuint8(_mm_min_epu32(a.vl, b.vl), _mm_min_epu32(a.vh, b.vh)); } + __forceinline vuint8 min(const vuint8& a, unsigned int b) { return min(a,vuint8(b)); } + __forceinline vuint8 min(unsigned int a, const vuint8& b) { return min(vuint8(a),b); } + + __forceinline vuint8 max(const vuint8& a, const vuint8& b) { return vuint8(_mm_max_epu32(a.vl, b.vl), _mm_max_epu32(a.vh, b.vh)); } + __forceinline vuint8 max(const vuint8& a, unsigned int b) { return max(a,vuint8(b)); } + __forceinline vuint8 max(unsigned int a, const vuint8& b) { return max(vuint8(a),b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint8& operator +=(vuint8& a, const vuint8& b) { return a = a + b; } + __forceinline vuint8& operator +=(vuint8& a, unsigned int b) { return a = a + b; } + + __forceinline vuint8& operator -=(vuint8& a, const vuint8& b) { return a = a - b; } + __forceinline vuint8& operator -=(vuint8& a, unsigned int b) { return a = a - b; } + + //__forceinline vuint8& operator *=(vuint8& a, const vuint8& b) { return a = a * b; } + //__forceinline vuint8& operator *=(vuint8& a, unsigned int b) { return a = a * b; } + + __forceinline vuint8& operator &=(vuint8& a, const vuint8& b) { return a = a & b; } + __forceinline vuint8& operator &=(vuint8& a, unsigned int b) { return a = a & b; } + + __forceinline vuint8& operator |=(vuint8& a, const vuint8& b) { return a = a | b; } + __forceinline vuint8& operator |=(vuint8& a, unsigned int b) { return a = a | b; } + + __forceinline vuint8& operator <<=(vuint8& a, unsigned int b) { return a = a << b; } + __forceinline vuint8& operator >>=(vuint8& a, unsigned int b) { return a = a >> b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vboolf8 operator ==(const vuint8& a, const vuint8& b) { return vboolf8(_mm_castsi128_ps(_mm_cmpeq_epi32 (a.vl, b.vl)), + _mm_castsi128_ps(_mm_cmpeq_epi32 (a.vh, b.vh))); } + __forceinline vboolf8 operator ==(const vuint8& a, unsigned int b) { return a == vuint8(b); } + __forceinline vboolf8 operator ==(unsigned int a, const vuint8& b) { return vuint8(a) == b; } + + __forceinline vboolf8 operator !=(const vuint8& a, const vuint8& b) { return !(a == b); } + __forceinline vboolf8 operator !=(const vuint8& a, unsigned int b) { return a != vuint8(b); } + __forceinline vboolf8 operator !=(unsigned int a, const vuint8& b) { return vuint8(a) != b; } + + //__forceinline vboolf8 operator < (const vuint8& a, const vuint8& b) { return vboolf8(_mm_castsi128_ps(_mm_cmplt_epu32 (a.vl, b.vl)), + // _mm_castsi128_ps(_mm_cmplt_epu32 (a.vh, b.vh))); } + //__forceinline vboolf8 operator < (const vuint8& a, unsigned int b) { return a < vuint8(b); } + //__forceinline vboolf8 operator < (unsigned int a, const vuint8& b) { return vuint8(a) < b; } + + //__forceinline vboolf8 operator >=(const vuint8& a, const vuint8& b) { return !(a < b); } + //__forceinline vboolf8 operator >=(const vuint8& a, unsigned int b) { return a >= vuint8(b); } + //__forceinline vboolf8 operator >=(unsigned int a, const vuint8& b) { return vuint8(a) >= b; } + + //__forceinline vboolf8 operator > (const vuint8& a, const vuint8& b) { return vboolf8(_mm_castsi128_ps(_mm_cmpgt_epu32 (a.vl, b.vl)), + // _mm_castsi128_ps(_mm_cmpgt_epu32 (a.vh, b.vh))); } + //__forceinline vboolf8 operator > (const vuint8& a, unsigned int b) { return a > vuint8(b); } + //__forceinline vboolf8 operator > (unsigned int a, const vuint8& b) { return vuint8(a) > b; } + + //__forceinline vboolf8 operator <=(const vuint8& a, const vuint8& b) { return !(a > b); } + //__forceinline vboolf8 operator <=(const vuint8& a, unsigned int b) { return a <= vuint8(b); } + //__forceinline vboolf8 operator <=(unsigned int a, const vuint8& b) { return vuint8(a) <= b; } + + __forceinline vboolf8 eq(const vuint8& a, const vuint8& b) { return a == b; } + __forceinline vboolf8 ne(const vuint8& a, const vuint8& b) { return a != b; } + + __forceinline vboolf8 eq(const vboolf8& mask, const vuint8& a, const vuint8& b) { return mask & (a == b); } + __forceinline vboolf8 ne(const vboolf8& mask, const vuint8& a, const vuint8& b) { return mask & (a != b); } + + __forceinline vuint8 select(const vboolf8& m, const vuint8& t, const vuint8& f) { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(f), _mm256_castsi256_ps(t), m)); + } + + __forceinline vuint8 notand(const vboolf8& m, const vuint8& f) { + return _mm256_castps_si256(_mm256_andnot_ps(m, _mm256_castsi256_ps(f))); + } + + + //////////////////////////////////////////////////////////////////////////////// + /// Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint8 unpacklo(const vuint8& a, const vuint8& b) { return _mm256_castps_si256(_mm256_unpacklo_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b))); } + __forceinline vuint8 unpackhi(const vuint8& a, const vuint8& b) { return _mm256_castps_si256(_mm256_unpackhi_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b))); } + + template + __forceinline vuint8 shuffle(const vuint8& v) { + return _mm256_castps_si256(_mm256_permute_ps(_mm256_castsi256_ps(v), _MM_SHUFFLE(i, i, i, i))); + } + + template + __forceinline vuint8 shuffle4(const vuint8& v) { + return _mm256_permute2f128_si256(v, v, (i1 << 4) | (i0 << 0)); + } + + template + __forceinline vuint8 shuffle4(const vuint8& a, const vuint8& b) { + return _mm256_permute2f128_si256(a, b, (i1 << 4) | (i0 << 0)); + } + + template + __forceinline vuint8 shuffle(const vuint8& v) { + return _mm256_castps_si256(_mm256_permute_ps(_mm256_castsi256_ps(v), _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template + __forceinline vuint8 shuffle(const vuint8& a, const vuint8& b) { + return _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template<> __forceinline vuint8 shuffle<0, 0, 2, 2>(const vuint8& v) { return _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(v))); } + template<> __forceinline vuint8 shuffle<1, 1, 3, 3>(const vuint8& v) { return _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(v))); } + template<> __forceinline vuint8 shuffle<0, 1, 0, 1>(const vuint8& v) { return _mm256_castps_si256(_mm256_castpd_ps(_mm256_movedup_pd(_mm256_castps_pd(_mm256_castsi256_ps(v))))); } + + __forceinline vuint8 broadcast(const unsigned int* ptr) { return _mm256_castps_si256(_mm256_broadcast_ss((const float*)ptr)); } + template __forceinline vuint8 insert4(const vuint8& a, const vuint4& b) { return _mm256_insertf128_si256(a, b, i); } + template __forceinline vuint4 extract4(const vuint8& a) { return _mm256_extractf128_si256(a, i); } + template<> __forceinline vuint4 extract4<0>(const vuint8& a) { return _mm256_castsi256_si128(a); } + + __forceinline int toScalar(const vuint8& v) { return _mm_cvtsi128_si32(_mm256_castsi256_si128(v)); } + + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + //__forceinline vuint8 vreduce_min2(const vuint8& v) { return min(v,shuffle<1,0,3,2>(v)); } + //__forceinline vuint8 vreduce_min4(const vuint8& v) { vuint8 v1 = vreduce_min2(v); return min(v1,shuffle<2,3,0,1>(v1)); } + //__forceinline vuint8 vreduce_min (const vuint8& v) { vuint8 v1 = vreduce_min4(v); return min(v1,shuffle4<1,0>(v1)); } + + //__forceinline vuint8 vreduce_max2(const vuint8& v) { return max(v,shuffle<1,0,3,2>(v)); } + //__forceinline vuint8 vreduce_max4(const vuint8& v) { vuint8 v1 = vreduce_max2(v); return max(v1,shuffle<2,3,0,1>(v1)); } + //__forceinline vuint8 vreduce_max (const vuint8& v) { vuint8 v1 = vreduce_max4(v); return max(v1,shuffle4<1,0>(v1)); } + + __forceinline vuint8 vreduce_add2(const vuint8& v) { return v + shuffle<1,0,3,2>(v); } + __forceinline vuint8 vreduce_add4(const vuint8& v) { vuint8 v1 = vreduce_add2(v); return v1 + shuffle<2,3,0,1>(v1); } + __forceinline vuint8 vreduce_add (const vuint8& v) { vuint8 v1 = vreduce_add4(v); return v1 + shuffle4<1,0>(v1); } + + //__forceinline int reduce_min(const vuint8& v) { return toScalar(vreduce_min(v)); } + //__forceinline int reduce_max(const vuint8& v) { return toScalar(vreduce_max(v)); } + __forceinline int reduce_add(const vuint8& v) { return toScalar(vreduce_add(v)); } + + //__forceinline size_t select_min(const vuint8& v) { return bsf(movemask(v == vreduce_min(v))); } + //__forceinline size_t select_max(const vuint8& v) { return bsf(movemask(v == vreduce_max(v))); } + + //__forceinline size_t select_min(const vboolf8& valid, const vuint8& v) { const vuint8 a = select(valid,v,vuint8(pos_inf)); return bsf(movemask(valid & (a == vreduce_min(a)))); } + //__forceinline size_t select_max(const vboolf8& valid, const vuint8& v) { const vuint8 a = select(valid,v,vuint8(neg_inf)); return bsf(movemask(valid & (a == vreduce_max(a)))); } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vuint8& a) { + return cout << "<" << a[0] << ", " << a[1] << ", " << a[2] << ", " << a[3] << ", " << a[4] << ", " << a[5] << ", " << a[6] << ", " << a[7] << ">"; + } +} diff --git a/thirdparty/embree/common/simd/vuint8_avx2.h b/thirdparty/embree/common/simd/vuint8_avx2.h new file mode 100644 index 000000000000..ae243ddfb173 --- /dev/null +++ b/thirdparty/embree/common/simd/vuint8_avx2.h @@ -0,0 +1,434 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* 8-wide AVX integer type */ + template<> + struct vuint<8> + { + ALIGNED_STRUCT_(32); + + typedef vboolf8 Bool; + typedef vuint8 Int; + typedef vfloat8 Float; + + enum { size = 8 }; // number of SIMD elements + union { // data + __m256i v; + unsigned int i[8]; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint() {} + __forceinline vuint(const vuint8& a) { v = a.v; } + __forceinline vuint8& operator =(const vuint8& a) { v = a.v; return *this; } + + __forceinline vuint(__m256i a) : v(a) {} + __forceinline operator const __m256i&() const { return v; } + __forceinline operator __m256i&() { return v; } + + __forceinline explicit vuint(const vuint4& a) : v(_mm256_insertf128_si256(_mm256_castsi128_si256(a),a,1)) {} + __forceinline vuint(const vuint4& a, const vuint4& b) : v(_mm256_insertf128_si256(_mm256_castsi128_si256(a),b,1)) {} + __forceinline vuint(const __m128i& a, const __m128i& b) : v(_mm256_insertf128_si256(_mm256_castsi128_si256(a),b,1)) {} + + __forceinline explicit vuint(const unsigned int* a) : v(_mm256_castps_si256(_mm256_loadu_ps((const float*)a))) {} + __forceinline vuint(unsigned int a) : v(_mm256_set1_epi32(a)) {} + __forceinline vuint(unsigned int a, unsigned int b) : v(_mm256_set_epi32(b, a, b, a, b, a, b, a)) {} + __forceinline vuint(unsigned int a, unsigned int b, unsigned int c, unsigned int d) : v(_mm256_set_epi32(d, c, b, a, d, c, b, a)) {} + __forceinline vuint(unsigned int a, unsigned int b, unsigned int c, unsigned int d, unsigned int e, unsigned int f, unsigned int g, unsigned int h) : v(_mm256_set_epi32(h, g, f, e, d, c, b, a)) {} + + __forceinline explicit vuint(__m256 a) : v(_mm256_cvtps_epi32(a)) {} + +#if defined(__AVX512VL__) + __forceinline explicit vuint(const vboolf8& a) : v(_mm256_movm_epi32(a)) {} +#else + __forceinline explicit vuint(const vboolf8& a) : v(_mm256_castps_si256((__m256)a)) {} +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Constants + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint(ZeroTy) : v(_mm256_setzero_si256()) {} + __forceinline vuint(OneTy) : v(_mm256_set1_epi32(1)) {} + __forceinline vuint(PosInfTy) : v(_mm256_set1_epi32(pos_inf)) {} + __forceinline vuint(NegInfTy) : v(_mm256_set1_epi32(neg_inf)) {} + __forceinline vuint(StepTy) : v(_mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0)) {} + __forceinline vuint(UndefinedTy) : v(_mm256_undefined_si256()) {} + + //////////////////////////////////////////////////////////////////////////////// + /// Loads and Stores + //////////////////////////////////////////////////////////////////////////////// + + static __forceinline vuint8 load(const unsigned char* ptr) { return _mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)ptr)); } + static __forceinline vuint8 loadu(const unsigned char* ptr) { return _mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)ptr)); } + static __forceinline vuint8 load(const unsigned short* ptr) { return _mm256_cvtepu16_epi32(_mm_load_si128((__m128i*)ptr)); } + static __forceinline vuint8 loadu(const unsigned short* ptr) { return _mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i*)ptr)); } + + static __forceinline vuint8 load(const void* ptr) { return _mm256_load_si256((__m256i*)ptr); } + static __forceinline vuint8 loadu(const void* ptr) { return _mm256_loadu_si256((__m256i*)ptr); } + + static __forceinline void store (void* ptr, const vuint8& v) { _mm256_store_si256((__m256i*)ptr,v); } + static __forceinline void storeu(void* ptr, const vuint8& v) { _mm256_storeu_ps((float*)ptr,_mm256_castsi256_ps(v)); } + +#if defined(__AVX512VL__) + + static __forceinline vuint8 compact(const vboolf8& mask, vuint8 &v) { + return _mm256_mask_compress_epi32(v, mask, v); + } + static __forceinline vuint8 compact(const vboolf8& mask, vuint8 &a, const vuint8& b) { + return _mm256_mask_compress_epi32(a, mask, b); + } + + static __forceinline vuint8 load (const vboolf8& mask, const void* ptr) { return _mm256_mask_load_epi32 (_mm256_setzero_si256(),mask,ptr); } + static __forceinline vuint8 loadu(const vboolf8& mask, const void* ptr) { return _mm256_mask_loadu_epi32(_mm256_setzero_si256(),mask,ptr); } + + static __forceinline void store (const vboolf8& mask, void* ptr, const vuint8& v) { _mm256_mask_store_epi32 (ptr,mask,v); } + static __forceinline void storeu(const vboolf8& mask, void* ptr, const vuint8& v) { _mm256_mask_storeu_epi32(ptr,mask,v); } +#else + static __forceinline vuint8 load (const vboolf8& mask, const void* ptr) { return _mm256_castps_si256(_mm256_maskload_ps((float*)ptr,mask)); } + static __forceinline vuint8 loadu(const vboolf8& mask, const void* ptr) { return _mm256_castps_si256(_mm256_maskload_ps((float*)ptr,mask)); } + + static __forceinline void store (const vboolf8& mask, void* ptr, const vuint8& v) { _mm256_maskstore_epi32((int*)ptr,mask,v); } + static __forceinline void storeu(const vboolf8& mask, void* ptr, const vuint8& v) { _mm256_maskstore_epi32((int*)ptr,mask,v); } +#endif + + static __forceinline vuint8 load_nt(void* ptr) { + return _mm256_stream_load_si256((__m256i*)ptr); + } + + static __forceinline void store_nt(void* ptr, const vuint8& v) { + _mm256_stream_ps((float*)ptr,_mm256_castsi256_ps(v)); + } + + static __forceinline void store(unsigned char* ptr, const vuint8& i) + { + for (size_t j=0; j<8; j++) + ptr[j] = i[j]; + } + + static __forceinline void store(unsigned short* ptr, const vuint8& v) { + for (size_t i=0;i<8;i++) + ptr[i] = (unsigned short)v[i]; + } + + template + static __forceinline vuint8 gather(const unsigned int *const ptr, const vint8& index) { + return _mm256_i32gather_epi32((const int*) ptr, index, scale); + } + + template + static __forceinline vuint8 gather(const vboolf8& mask, const unsigned int *const ptr, const vint8& index) { + vuint8 r = zero; +#if defined(__AVX512VL__) + return _mm256_mmask_i32gather_epi32(r, mask, index, (const int*) ptr, scale); +#else + return _mm256_mask_i32gather_epi32(r, (const int*) ptr, index, mask, scale); +#endif + } + + template + static __forceinline void scatter(void* ptr, const vint8& ofs, const vuint8& v) + { +#if defined(__AVX512VL__) + _mm256_i32scatter_epi32((int*)ptr, ofs, v, scale); +#else + *(unsigned int*)(((char*)ptr)+scale*ofs[0]) = v[0]; + *(unsigned int*)(((char*)ptr)+scale*ofs[1]) = v[1]; + *(unsigned int*)(((char*)ptr)+scale*ofs[2]) = v[2]; + *(unsigned int*)(((char*)ptr)+scale*ofs[3]) = v[3]; + *(unsigned int*)(((char*)ptr)+scale*ofs[4]) = v[4]; + *(unsigned int*)(((char*)ptr)+scale*ofs[5]) = v[5]; + *(unsigned int*)(((char*)ptr)+scale*ofs[6]) = v[6]; + *(unsigned int*)(((char*)ptr)+scale*ofs[7]) = v[7]; +#endif + } + + template + static __forceinline void scatter(const vboolf8& mask, void* ptr, const vint8& ofs, const vuint8& v) + { +#if defined(__AVX512VL__) + _mm256_mask_i32scatter_epi32((int*)ptr, mask, ofs, v, scale); +#else + if (likely(mask[0])) *(unsigned int*)(((char*)ptr)+scale*ofs[0]) = v[0]; + if (likely(mask[1])) *(unsigned int*)(((char*)ptr)+scale*ofs[1]) = v[1]; + if (likely(mask[2])) *(unsigned int*)(((char*)ptr)+scale*ofs[2]) = v[2]; + if (likely(mask[3])) *(unsigned int*)(((char*)ptr)+scale*ofs[3]) = v[3]; + if (likely(mask[4])) *(unsigned int*)(((char*)ptr)+scale*ofs[4]) = v[4]; + if (likely(mask[5])) *(unsigned int*)(((char*)ptr)+scale*ofs[5]) = v[5]; + if (likely(mask[6])) *(unsigned int*)(((char*)ptr)+scale*ofs[6]) = v[6]; + if (likely(mask[7])) *(unsigned int*)(((char*)ptr)+scale*ofs[7]) = v[7]; +#endif + } + + static __forceinline vuint8 broadcast64(const long long &a) { return _mm256_set1_epi64x(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Array Access + //////////////////////////////////////////////////////////////////////////////// + + __forceinline const unsigned int& operator [](size_t index) const { assert(index < 8); return i[index]; } + __forceinline unsigned int& operator [](size_t index) { assert(index < 8); return i[index]; } + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Unary Operators + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX512VL__) + __forceinline vboolf8 asBool(const vuint8& a) { return _mm256_movepi32_mask(a); } +#else + __forceinline vboolf8 asBool(const vuint8& a) { return _mm256_castsi256_ps(a); } +#endif + + __forceinline vuint8 operator +(const vuint8& a) { return a; } + + //////////////////////////////////////////////////////////////////////////////// + /// Binary Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint8 operator +(const vuint8& a, const vuint8& b) { return _mm256_add_epi32(a, b); } + __forceinline vuint8 operator +(const vuint8& a, unsigned int b) { return a + vuint8(b); } + __forceinline vuint8 operator +(unsigned int a, const vuint8& b) { return vuint8(a) + b; } + + __forceinline vuint8 operator -(const vuint8& a, const vuint8& b) { return _mm256_sub_epi32(a, b); } + __forceinline vuint8 operator -(const vuint8& a, unsigned int b) { return a - vuint8(b); } + __forceinline vuint8 operator -(unsigned int a, const vuint8& b) { return vuint8(a) - b; } + + //__forceinline vuint8 operator *(const vuint8& a, const vuint8& b) { return _mm256_mullo_epu32(a, b); } + //__forceinline vuint8 operator *(const vuint8& a, unsigned int b) { return a * vuint8(b); } + //__forceinline vuint8 operator *(unsigned int a, const vuint8& b) { return vuint8(a) * b; } + + __forceinline vuint8 operator &(const vuint8& a, const vuint8& b) { return _mm256_and_si256(a, b); } + __forceinline vuint8 operator &(const vuint8& a, unsigned int b) { return a & vuint8(b); } + __forceinline vuint8 operator &(unsigned int a, const vuint8& b) { return vuint8(a) & b; } + + __forceinline vuint8 operator |(const vuint8& a, const vuint8& b) { return _mm256_or_si256(a, b); } + __forceinline vuint8 operator |(const vuint8& a, unsigned int b) { return a | vuint8(b); } + __forceinline vuint8 operator |(unsigned int a, const vuint8& b) { return vuint8(a) | b; } + + __forceinline vuint8 operator ^(const vuint8& a, const vuint8& b) { return _mm256_xor_si256(a, b); } + __forceinline vuint8 operator ^(const vuint8& a, unsigned int b) { return a ^ vuint8(b); } + __forceinline vuint8 operator ^(unsigned int a, const vuint8& b) { return vuint8(a) ^ b; } + + __forceinline vuint8 operator <<(const vuint8& a, unsigned int n) { return _mm256_slli_epi32(a, n); } + __forceinline vuint8 operator >>(const vuint8& a, unsigned int n) { return _mm256_srli_epi32(a, n); } + + __forceinline vuint8 operator <<(const vuint8& a, const vuint8& n) { return _mm256_sllv_epi32(a, n); } + __forceinline vuint8 operator >>(const vuint8& a, const vuint8& n) { return _mm256_srlv_epi32(a, n); } + + __forceinline vuint8 sll(const vuint8& a, unsigned int b) { return _mm256_slli_epi32(a, b); } + __forceinline vuint8 sra(const vuint8& a, unsigned int b) { return _mm256_srai_epi32(a, b); } + __forceinline vuint8 srl(const vuint8& a, unsigned int b) { return _mm256_srli_epi32(a, b); } + + __forceinline vuint8 sll(const vuint8& a, const vuint8& b) { return _mm256_sllv_epi32(a, b); } + __forceinline vuint8 sra(const vuint8& a, const vuint8& b) { return _mm256_srav_epi32(a, b); } + __forceinline vuint8 srl(const vuint8& a, const vuint8& b) { return _mm256_srlv_epi32(a, b); } + + __forceinline vuint8 min(const vuint8& a, const vuint8& b) { return _mm256_min_epu32(a, b); } + __forceinline vuint8 min(const vuint8& a, unsigned int b) { return min(a,vuint8(b)); } + __forceinline vuint8 min(unsigned int a, const vuint8& b) { return min(vuint8(a),b); } + + __forceinline vuint8 max(const vuint8& a, const vuint8& b) { return _mm256_max_epu32(a, b); } + __forceinline vuint8 max(const vuint8& a, unsigned int b) { return max(a,vuint8(b)); } + __forceinline vuint8 max(unsigned int a, const vuint8& b) { return max(vuint8(a),b); } + + //////////////////////////////////////////////////////////////////////////////// + /// Assignment Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint8& operator +=(vuint8& a, const vuint8& b) { return a = a + b; } + __forceinline vuint8& operator +=(vuint8& a, unsigned int b) { return a = a + b; } + + __forceinline vuint8& operator -=(vuint8& a, const vuint8& b) { return a = a - b; } + __forceinline vuint8& operator -=(vuint8& a, unsigned int b) { return a = a - b; } + + //__forceinline vuint8& operator *=(vuint8& a, const vuint8& b) { return a = a * b; } + //__forceinline vuint8& operator *=(vuint8& a, unsigned int b) { return a = a * b; } + + __forceinline vuint8& operator &=(vuint8& a, const vuint8& b) { return a = a & b; } + __forceinline vuint8& operator &=(vuint8& a, unsigned int b) { return a = a & b; } + + __forceinline vuint8& operator |=(vuint8& a, const vuint8& b) { return a = a | b; } + __forceinline vuint8& operator |=(vuint8& a, unsigned int b) { return a = a | b; } + + __forceinline vuint8& operator <<=(vuint8& a, const unsigned int b) { return a = a << b; } + __forceinline vuint8& operator >>=(vuint8& a, const unsigned int b) { return a = a >> b; } + + //////////////////////////////////////////////////////////////////////////////// + /// Comparison Operators + Select + //////////////////////////////////////////////////////////////////////////////// + +#if defined(__AVX512VL__) + __forceinline vboolf8 operator ==(const vuint8& a, const vuint8& b) { return _mm256_cmp_epu32_mask(a,b,_MM_CMPINT_EQ); } + __forceinline vboolf8 operator !=(const vuint8& a, const vuint8& b) { return _mm256_cmp_epu32_mask(a,b,_MM_CMPINT_NE); } + __forceinline vboolf8 operator < (const vuint8& a, const vuint8& b) { return _mm256_cmp_epu32_mask(a,b,_MM_CMPINT_LT); } + __forceinline vboolf8 operator >=(const vuint8& a, const vuint8& b) { return _mm256_cmp_epu32_mask(a,b,_MM_CMPINT_GE); } + __forceinline vboolf8 operator > (const vuint8& a, const vuint8& b) { return _mm256_cmp_epu32_mask(a,b,_MM_CMPINT_GT); } + __forceinline vboolf8 operator <=(const vuint8& a, const vuint8& b) { return _mm256_cmp_epu32_mask(a,b,_MM_CMPINT_LE); } + + __forceinline vuint8 select(const vboolf8& m, const vuint8& t, const vuint8& f) { + return _mm256_mask_blend_epi32(m, (__m256i)f, (__m256i)t); + } +#else + __forceinline vboolf8 operator ==(const vuint8& a, const vuint8& b) { return _mm256_castsi256_ps(_mm256_cmpeq_epi32(a, b)); } + __forceinline vboolf8 operator !=(const vuint8& a, const vuint8& b) { return !(a == b); } + //__forceinline vboolf8 operator < (const vuint8& a, const vuint8& b) { return _mm256_castsi256_ps(_mm256_cmpgt_epu32(b, a)); } + //__forceinline vboolf8 operator >=(const vuint8& a, const vuint8& b) { return !(a < b); } + //__forceinline vboolf8 operator > (const vuint8& a, const vuint8& b) { return _mm256_castsi256_ps(_mm256_cmpgt_epu32(a, b)); } + //__forceinline vboolf8 operator <=(const vuint8& a, const vuint8& b) { return !(a > b); } + + __forceinline vuint8 select(const vboolf8& m, const vuint8& t, const vuint8& f) { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(f), _mm256_castsi256_ps(t), m)); + } +#endif + + template + __forceinline vuint8 select(const vuint8& t, const vuint8& f) { + return _mm256_blend_epi32(f, t, mask); + } + + __forceinline vboolf8 operator ==(const vuint8& a, unsigned int b) { return a == vuint8(b); } + __forceinline vboolf8 operator ==(unsigned int a, const vuint8& b) { return vuint8(a) == b; } + + __forceinline vboolf8 operator !=(const vuint8& a, unsigned int b) { return a != vuint8(b); } + __forceinline vboolf8 operator !=(unsigned int a, const vuint8& b) { return vuint8(a) != b; } + + //__forceinline vboolf8 operator < (const vuint8& a, unsigned int b) { return a < vuint8(b); } + //__forceinline vboolf8 operator < (unsigned int a, const vuint8& b) { return vuint8(a) < b; } + + //__forceinline vboolf8 operator >=(const vuint8& a, unsigned int b) { return a >= vuint8(b); } + //__forceinline vboolf8 operator >=(unsigned int a, const vuint8& b) { return vuint8(a) >= b; } + + //__forceinline vboolf8 operator > (const vuint8& a, unsigned int b) { return a > vuint8(b); } + //__forceinline vboolf8 operator > (unsigned int a, const vuint8& b) { return vuint8(a) > b; } + + //__forceinline vboolf8 operator <=(const vuint8& a, unsigned int b) { return a <= vuint8(b); } + //__forceinline vboolf8 operator <=(unsigned int a, const vuint8& b) { return vuint8(a) <= b; } + + __forceinline vboolf8 eq(const vuint8& a, const vuint8& b) { return a == b; } + __forceinline vboolf8 ne(const vuint8& a, const vuint8& b) { return a != b; } + //__forceinline vboolf8 lt(const vuint8& a, const vuint8& b) { return a < b; } + //__forceinline vboolf8 ge(const vuint8& a, const vuint8& b) { return a >= b; } + //__forceinline vboolf8 gt(const vuint8& a, const vuint8& b) { return a > b; } + //__forceinline vboolf8 le(const vuint8& a, const vuint8& b) { return a <= b; } + +#if defined(__AVX512VL__) + __forceinline vboolf8 eq(const vboolf8& mask, const vuint8& a, const vuint8& b) { return _mm256_mask_cmp_epu32_mask(mask, a, b, _MM_CMPINT_EQ); } + __forceinline vboolf8 ne(const vboolf8& mask, const vuint8& a, const vuint8& b) { return _mm256_mask_cmp_epu32_mask(mask, a, b, _MM_CMPINT_NE); } + __forceinline vboolf8 lt(const vboolf8& mask, const vuint8& a, const vuint8& b) { return _mm256_mask_cmp_epu32_mask(mask, a, b, _MM_CMPINT_LT); } + __forceinline vboolf8 ge(const vboolf8& mask, const vuint8& a, const vuint8& b) { return _mm256_mask_cmp_epu32_mask(mask, a, b, _MM_CMPINT_GE); } + __forceinline vboolf8 gt(const vboolf8& mask, const vuint8& a, const vuint8& b) { return _mm256_mask_cmp_epu32_mask(mask, a, b, _MM_CMPINT_GT); } + __forceinline vboolf8 le(const vboolf8& mask, const vuint8& a, const vuint8& b) { return _mm256_mask_cmp_epu32_mask(mask, a, b, _MM_CMPINT_LE); } +#else + __forceinline vboolf8 eq(const vboolf8& mask, const vuint8& a, const vuint8& b) { return mask & (a == b); } + __forceinline vboolf8 ne(const vboolf8& mask, const vuint8& a, const vuint8& b) { return mask & (a != b); } + //__forceinline vboolf8 lt(const vboolf8& mask, const vuint8& a, const vuint8& b) { return mask & (a < b); } + //__forceinline vboolf8 ge(const vboolf8& mask, const vuint8& a, const vuint8& b) { return mask & (a >= b); } + //__forceinline vboolf8 gt(const vboolf8& mask, const vuint8& a, const vuint8& b) { return mask & (a > b); } + //__forceinline vboolf8 le(const vboolf8& mask, const vuint8& a, const vuint8& b) { return mask & (a <= b); } +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Movement/Shifting/Shuffling Functions + //////////////////////////////////////////////////////////////////////////////// + + __forceinline vuint8 unpacklo(const vuint8& a, const vuint8& b) { return _mm256_unpacklo_epi32(a, b); } + __forceinline vuint8 unpackhi(const vuint8& a, const vuint8& b) { return _mm256_unpackhi_epi32(a, b); } + + template + __forceinline vuint8 shuffle(const vuint8& v) { + return _mm256_castps_si256(_mm256_permute_ps(_mm256_castsi256_ps(v), _MM_SHUFFLE(i, i, i, i))); + } + + template + __forceinline vuint8 shuffle4(const vuint8& v) { + return _mm256_permute2f128_si256(v, v, (i1 << 4) | (i0 << 0)); + } + + template + __forceinline vuint8 shuffle4(const vuint8& a, const vuint8& b) { + return _mm256_permute2f128_si256(a, b, (i1 << 4) | (i0 << 0)); + } + + template + __forceinline vuint8 shuffle(const vuint8& v) { + return _mm256_castps_si256(_mm256_permute_ps(_mm256_castsi256_ps(v), _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template + __forceinline vuint8 shuffle(const vuint8& a, const vuint8& b) { + return _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), _MM_SHUFFLE(i3, i2, i1, i0))); + } + + template<> __forceinline vuint8 shuffle<0, 0, 2, 2>(const vuint8& v) { return _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(v))); } + template<> __forceinline vuint8 shuffle<1, 1, 3, 3>(const vuint8& v) { return _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(v))); } + template<> __forceinline vuint8 shuffle<0, 1, 0, 1>(const vuint8& v) { return _mm256_castps_si256(_mm256_castpd_ps(_mm256_movedup_pd(_mm256_castps_pd(_mm256_castsi256_ps(v))))); } + + __forceinline vuint8 broadcast(const unsigned int* ptr) { return _mm256_castps_si256(_mm256_broadcast_ss((const float*)ptr)); } + + template __forceinline vuint8 insert4(const vuint8& a, const vuint4& b) { return _mm256_insertf128_si256(a, b, i); } + template __forceinline vuint4 extract4(const vuint8& a) { return _mm256_extractf128_si256(a, i); } + template<> __forceinline vuint4 extract4<0>(const vuint8& a) { return _mm256_castsi256_si128(a); } + + __forceinline int toScalar(const vuint8& v) { return _mm_cvtsi128_si32(_mm256_castsi256_si128(v)); } + + __forceinline vuint8 permute(const vuint8& v, const __m256i& index) { + return _mm256_permutevar8x32_epi32(v, index); + } + + __forceinline vuint8 shuffle(const vuint8& v, const __m256i& index) { + return _mm256_castps_si256(_mm256_permutevar_ps(_mm256_castsi256_ps(v), index)); + } + + template + __forceinline vuint8 align_shift_right(const vuint8& a, const vuint8& b) { +#if defined(__AVX512VL__) + return _mm256_alignr_epi32(a, b, i); +#else + return _mm256_alignr_epi8(a, b, 4*i); +#endif + } + + //////////////////////////////////////////////////////////////////////////////// + /// Reductions + //////////////////////////////////////////////////////////////////////////////// + + //__forceinline vuint8 vreduce_min2(const vuint8& v) { return min(v,shuffle<1,0,3,2>(v)); } + //__forceinline vuint8 vreduce_min4(const vuint8& v) { vuint8 v1 = vreduce_min2(v); return min(v1,shuffle<2,3,0,1>(v1)); } + //__forceinline vuint8 vreduce_min (const vuint8& v) { vuint8 v1 = vreduce_min4(v); return min(v1,shuffle4<1,0>(v1)); } + + //__forceinline vuint8 vreduce_max2(const vuint8& v) { return max(v,shuffle<1,0,3,2>(v)); } + //__forceinline vuint8 vreduce_max4(const vuint8& v) { vuint8 v1 = vreduce_max2(v); return max(v1,shuffle<2,3,0,1>(v1)); } + //__forceinline vuint8 vreduce_max (const vuint8& v) { vuint8 v1 = vreduce_max4(v); return max(v1,shuffle4<1,0>(v1)); } + + __forceinline vuint8 vreduce_add2(const vuint8& v) { return v + shuffle<1,0,3,2>(v); } + __forceinline vuint8 vreduce_add4(const vuint8& v) { vuint8 v1 = vreduce_add2(v); return v1 + shuffle<2,3,0,1>(v1); } + __forceinline vuint8 vreduce_add (const vuint8& v) { vuint8 v1 = vreduce_add4(v); return v1 + shuffle4<1,0>(v1); } + + //__forceinline int reduce_min(const vuint8& v) { return toScalar(vreduce_min(v)); } + //__forceinline int reduce_max(const vuint8& v) { return toScalar(vreduce_max(v)); } + __forceinline int reduce_add(const vuint8& v) { return toScalar(vreduce_add(v)); } + + //__forceinline size_t select_min(const vuint8& v) { return bsf(movemask(v == vreduce_min(v))); } + //__forceinline size_t select_max(const vuint8& v) { return bsf(movemask(v == vreduce_max(v))); } + + //__forceinline size_t select_min(const vboolf8& valid, const vuint8& v) { const vuint8 a = select(valid,v,vuint8(pos_inf)); return bsf(movemask(valid & (a == vreduce_min(a)))); } + //__forceinline size_t select_max(const vboolf8& valid, const vuint8& v) { const vuint8 a = select(valid,v,vuint8(neg_inf)); return bsf(movemask(valid & (a == vreduce_max(a)))); } + + __forceinline vuint8 assign(const vuint4& a) { return _mm256_castsi128_si256(a); } + + //////////////////////////////////////////////////////////////////////////////// + /// Output Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline embree_ostream operator <<(embree_ostream cout, const vuint8& a) { + return cout << "<" << a[0] << ", " << a[1] << ", " << a[2] << ", " << a[3] << ", " << a[4] << ", " << a[5] << ", " << a[6] << ", " << a[7] << ">"; + } +} diff --git a/thirdparty/embree/common/sys/alloc.cpp b/thirdparty/embree/common/sys/alloc.cpp new file mode 100644 index 000000000000..4e8928242eb8 --- /dev/null +++ b/thirdparty/embree/common/sys/alloc.cpp @@ -0,0 +1,306 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "alloc.h" +#include "intrinsics.h" +#include "sysinfo.h" +#include "mutex.h" + +//////////////////////////////////////////////////////////////////////////////// +/// All Platforms +//////////////////////////////////////////////////////////////////////////////// + +namespace embree +{ + void* alignedMalloc(size_t size, size_t align) + { + if (size == 0) + return nullptr; + + assert((align & (align-1)) == 0); + void* ptr = _mm_malloc(size,align); + + if (size != 0 && ptr == nullptr) + throw std::bad_alloc(); + + return ptr; + } + + void alignedFree(void* ptr) + { + if (ptr) + _mm_free(ptr); + } + + static bool huge_pages_enabled = false; + static MutexSys os_init_mutex; + + __forceinline bool isHugePageCandidate(const size_t bytes) + { + if (!huge_pages_enabled) + return false; + + /* use huge pages only when memory overhead is low */ + const size_t hbytes = (bytes+PAGE_SIZE_2M-1) & ~size_t(PAGE_SIZE_2M-1); + return 66*(hbytes-bytes) < bytes; // at most 1.5% overhead + } +} + +//////////////////////////////////////////////////////////////////////////////// +/// Windows Platform +//////////////////////////////////////////////////////////////////////////////// + +#ifdef _WIN32 + +#define WIN32_LEAN_AND_MEAN +#include +#include + +namespace embree +{ + bool win_enable_selockmemoryprivilege (bool verbose) + { + HANDLE hToken; + if (!OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY | TOKEN_ADJUST_PRIVILEGES, &hToken)) { + if (verbose) std::cout << "WARNING: OpenProcessToken failed while trying to enable SeLockMemoryPrivilege: " << GetLastError() << std::endl; + return false; + } + + TOKEN_PRIVILEGES tp; + tp.PrivilegeCount = 1; + tp.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED; + + if (!LookupPrivilegeValueW(nullptr, L"SeLockMemoryPrivilege", &tp.Privileges[0].Luid)) { + if (verbose) std::cout << "WARNING: LookupPrivilegeValue failed while trying to enable SeLockMemoryPrivilege: " << GetLastError() << std::endl; + return false; + } + + SetLastError(ERROR_SUCCESS); + if (!AdjustTokenPrivileges(hToken, FALSE, &tp, sizeof(tp), nullptr, 0)) { + if (verbose) std::cout << "WARNING: AdjustTokenPrivileges failed while trying to enable SeLockMemoryPrivilege" << std::endl; + return false; + } + + if (GetLastError() == ERROR_NOT_ALL_ASSIGNED) { + if (verbose) std::cout << "WARNING: AdjustTokenPrivileges failed to enable SeLockMemoryPrivilege: Add SeLockMemoryPrivilege for current user and run process in elevated mode (Run as administrator)." << std::endl; + return false; + } + + return true; + } + + bool os_init(bool hugepages, bool verbose) + { + Lock lock(os_init_mutex); + + if (!hugepages) { + huge_pages_enabled = false; + return true; + } + + if (GetLargePageMinimum() != PAGE_SIZE_2M) { + huge_pages_enabled = false; + return false; + } + + huge_pages_enabled = true; + return true; + } + + void* os_malloc(size_t bytes, bool& hugepages) + { + if (bytes == 0) { + hugepages = false; + return nullptr; + } + + /* try direct huge page allocation first */ + if (isHugePageCandidate(bytes)) + { + int flags = MEM_COMMIT | MEM_RESERVE | MEM_LARGE_PAGES; + char* ptr = (char*) VirtualAlloc(nullptr,bytes,flags,PAGE_READWRITE); + if (ptr != nullptr) { + hugepages = true; + return ptr; + } + } + + /* fall back to 4k pages */ + int flags = MEM_COMMIT | MEM_RESERVE; + char* ptr = (char*) VirtualAlloc(nullptr,bytes,flags,PAGE_READWRITE); + if (ptr == nullptr) throw std::bad_alloc(); + hugepages = false; + return ptr; + } + + size_t os_shrink(void* ptr, size_t bytesNew, size_t bytesOld, bool hugepages) + { + if (hugepages) // decommitting huge pages seems not to work under Windows + return bytesOld; + + const size_t pageSize = hugepages ? PAGE_SIZE_2M : PAGE_SIZE_4K; + bytesNew = (bytesNew+pageSize-1) & ~(pageSize-1); + bytesOld = (bytesOld+pageSize-1) & ~(pageSize-1); + if (bytesNew >= bytesOld) + return bytesOld; + + if (!VirtualFree((char*)ptr+bytesNew,bytesOld-bytesNew,MEM_DECOMMIT)) + throw std::bad_alloc(); + + return bytesNew; + } + + void os_free(void* ptr, size_t bytes, bool hugepages) + { + if (bytes == 0) + return; + + if (!VirtualFree(ptr,0,MEM_RELEASE)) + throw std::bad_alloc(); + } + + void os_advise(void *ptr, size_t bytes) + { + } +} + +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// Unix Platform +//////////////////////////////////////////////////////////////////////////////// + +#if defined(__UNIX__) + +#include +#include +#include +#include +#include + +#if defined(__MACOSX__) +#include +#endif + +namespace embree +{ + bool os_init(bool hugepages, bool verbose) + { + Lock lock(os_init_mutex); + + if (!hugepages) { + huge_pages_enabled = false; + return true; + } + +#if defined(__LINUX__) + + int hugepagesize = 0; + + std::ifstream file; + file.open("/proc/meminfo",std::ios::in); + if (!file.is_open()) { + if (verbose) std::cout << "WARNING: Could not open /proc/meminfo. Huge page support cannot get enabled!" << std::endl; + huge_pages_enabled = false; + return false; + } + + std::string line; + while (getline(file,line)) + { + std::stringstream sline(line); + while (!sline.eof() && sline.peek() == ' ') sline.ignore(); + std::string tag; getline(sline,tag,' '); + while (!sline.eof() && sline.peek() == ' ') sline.ignore(); + std::string val; getline(sline,val,' '); + while (!sline.eof() && sline.peek() == ' ') sline.ignore(); + std::string unit; getline(sline,unit,' '); + if (tag == "Hugepagesize:" && unit == "kB") { + hugepagesize = std::stoi(val)*1024; + break; + } + } + + if (hugepagesize != PAGE_SIZE_2M) + { + if (verbose) std::cout << "WARNING: Only 2MB huge pages supported. Huge page support cannot get enabled!" << std::endl; + huge_pages_enabled = false; + return false; + } +#endif + + huge_pages_enabled = true; + return true; + } + + void* os_malloc(size_t bytes, bool& hugepages) + { + if (bytes == 0) { + hugepages = false; + return nullptr; + } + + /* try direct huge page allocation first */ + if (isHugePageCandidate(bytes)) + { +#if defined(__MACOSX__) + void* ptr = mmap(0, bytes, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, VM_FLAGS_SUPERPAGE_SIZE_2MB, 0); + if (ptr != MAP_FAILED) { + hugepages = true; + return ptr; + } +#elif defined(MAP_HUGETLB) + void* ptr = mmap(0, bytes, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON | MAP_HUGETLB, -1, 0); + if (ptr != MAP_FAILED) { + hugepages = true; + return ptr; + } +#endif + } + + /* fallback to 4k pages */ + void* ptr = (char*) mmap(0, bytes, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0); + if (ptr == MAP_FAILED) throw std::bad_alloc(); + hugepages = false; + + /* advise huge page hint for THP */ + os_advise(ptr,bytes); + return ptr; + } + + size_t os_shrink(void* ptr, size_t bytesNew, size_t bytesOld, bool hugepages) + { + const size_t pageSize = hugepages ? PAGE_SIZE_2M : PAGE_SIZE_4K; + bytesNew = (bytesNew+pageSize-1) & ~(pageSize-1); + bytesOld = (bytesOld+pageSize-1) & ~(pageSize-1); + if (bytesNew >= bytesOld) + return bytesOld; + + if (munmap((char*)ptr+bytesNew,bytesOld-bytesNew) == -1) + throw std::bad_alloc(); + + return bytesNew; + } + + void os_free(void* ptr, size_t bytes, bool hugepages) + { + if (bytes == 0) + return; + + /* for hugepages we need to also align the size */ + const size_t pageSize = hugepages ? PAGE_SIZE_2M : PAGE_SIZE_4K; + bytes = (bytes+pageSize-1) & ~(pageSize-1); + if (munmap(ptr,bytes) == -1) + throw std::bad_alloc(); + } + + /* hint for transparent huge pages (THP) */ + void os_advise(void* pptr, size_t bytes) + { +#if defined(MADV_HUGEPAGE) + madvise(pptr,bytes,MADV_HUGEPAGE); +#endif + } +} + +#endif diff --git a/thirdparty/embree/common/sys/alloc.h b/thirdparty/embree/common/sys/alloc.h new file mode 100644 index 000000000000..5898ecda7093 --- /dev/null +++ b/thirdparty/embree/common/sys/alloc.h @@ -0,0 +1,164 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "platform.h" +#include +#include + +namespace embree +{ +#define ALIGNED_STRUCT_(align) \ + void* operator new(size_t size) { return alignedMalloc(size,align); } \ + void operator delete(void* ptr) { alignedFree(ptr); } \ + void* operator new[](size_t size) { return alignedMalloc(size,align); } \ + void operator delete[](void* ptr) { alignedFree(ptr); } + +#define ALIGNED_CLASS_(align) \ + public: \ + ALIGNED_STRUCT_(align) \ + private: + + /*! aligned allocation */ + void* alignedMalloc(size_t size, size_t align); + void alignedFree(void* ptr); + + /*! allocator that performs aligned allocations */ + template + struct aligned_allocator + { + typedef T value_type; + typedef T* pointer; + typedef const T* const_pointer; + typedef T& reference; + typedef const T& const_reference; + typedef std::size_t size_type; + typedef std::ptrdiff_t difference_type; + + __forceinline pointer allocate( size_type n ) { + return (pointer) alignedMalloc(n*sizeof(value_type),alignment); + } + + __forceinline void deallocate( pointer p, size_type n ) { + return alignedFree(p); + } + + __forceinline void construct( pointer p, const_reference val ) { + new (p) T(val); + } + + __forceinline void destroy( pointer p ) { + p->~T(); + } + }; + + /*! allocates pages directly from OS */ + bool win_enable_selockmemoryprivilege(bool verbose); + bool os_init(bool hugepages, bool verbose); + void* os_malloc (size_t bytes, bool& hugepages); + size_t os_shrink (void* ptr, size_t bytesNew, size_t bytesOld, bool hugepages); + void os_free (void* ptr, size_t bytes, bool hugepages); + void os_advise (void* ptr, size_t bytes); + + /*! allocator that performs OS allocations */ + template + struct os_allocator + { + typedef T value_type; + typedef T* pointer; + typedef const T* const_pointer; + typedef T& reference; + typedef const T& const_reference; + typedef std::size_t size_type; + typedef std::ptrdiff_t difference_type; + + __forceinline os_allocator () + : hugepages(false) {} + + __forceinline pointer allocate( size_type n ) { + return (pointer) os_malloc(n*sizeof(value_type),hugepages); + } + + __forceinline void deallocate( pointer p, size_type n ) { + return os_free(p,n*sizeof(value_type),hugepages); + } + + __forceinline void construct( pointer p, const_reference val ) { + new (p) T(val); + } + + __forceinline void destroy( pointer p ) { + p->~T(); + } + + bool hugepages; + }; + + /*! allocator for IDs */ + template + struct IDPool + { + typedef T value_type; + + IDPool () + : nextID(0) {} + + T allocate() + { + /* return ID from list */ + if (!IDs.empty()) + { + T id = *IDs.begin(); + IDs.erase(IDs.begin()); + return id; + } + + /* allocate new ID */ + else + { + if (size_t(nextID)+1 > max_id) + return -1; + + return nextID++; + } + } + + /* adds an ID provided by the user */ + bool add(T id) + { + if (id > max_id) + return false; + + /* check if ID should be in IDs set */ + if (id < nextID) { + auto p = IDs.find(id); + if (p == IDs.end()) return false; + IDs.erase(p); + return true; + } + + /* otherwise increase ID set */ + else + { + for (T i=nextID; i IDs; //!< stores deallocated IDs to be reused + T nextID; //!< next ID to use when IDs vector is empty + }; +} + diff --git a/thirdparty/embree/common/sys/array.h b/thirdparty/embree/common/sys/array.h new file mode 100644 index 000000000000..6f6f98eac816 --- /dev/null +++ b/thirdparty/embree/common/sys/array.h @@ -0,0 +1,222 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "platform.h" +#include "alloc.h" + +namespace embree +{ + /*! static array with static size */ + template + class array_t + { + public: + + /********************** Iterators ****************************/ + + __forceinline T* begin() const { return items; }; + __forceinline T* end () const { return items+N; }; + + + /********************** Capacity ****************************/ + + __forceinline bool empty () const { return N == 0; } + __forceinline size_t size () const { return N; } + __forceinline size_t max_size () const { return N; } + + + /******************** Element access **************************/ + + __forceinline T& operator[](size_t i) { assert(i < N); return items[i]; } + __forceinline const T& operator[](size_t i) const { assert(i < N); return items[i]; } + + __forceinline T& at(size_t i) { assert(i < N); return items[i]; } + __forceinline const T& at(size_t i) const { assert(i < N); return items[i]; } + + __forceinline T& front() const { assert(N > 0); return items[0]; }; + __forceinline T& back () const { assert(N > 0); return items[N-1]; }; + + __forceinline T* data() { return items; }; + __forceinline const T* data() const { return items; }; + + private: + T items[N]; + }; + + /*! static array with dynamic size */ + template + class darray_t + { + public: + + __forceinline darray_t () : M(0) {} + + __forceinline darray_t (const T& v) : M(0) { + for (size_t i=0; i 0); return items[0]; }; + __forceinline T& back () const { assert(M > 0); return items[M-1]; }; + + __forceinline T* data() { return items; }; + __forceinline const T* data() const { return items; }; + + private: + size_t M; + T items[N]; + }; + + /*! dynamic sized array that is allocated on the stack */ +#define dynamic_large_stack_array(Ty,Name,N,max_stack_bytes) StackArray Name(N) + template + struct __aligned(64) StackArray + { + __forceinline StackArray (const size_t N) + : N(N) + { + if (N*sizeof(Ty) <= max_stack_bytes) + data = &arr[0]; + else + data = (Ty*) alignedMalloc(N*sizeof(Ty),64); + } + + __forceinline ~StackArray () { + if (data != &arr[0]) alignedFree(data); + } + + __forceinline operator Ty* () { return data; } + __forceinline operator const Ty* () const { return data; } + + __forceinline Ty& operator[](const int i) { assert(i>=0 && i=0 && i + struct __aligned(64) DynamicStackArray + { + __forceinline DynamicStackArray () + : data(&arr[0]) {} + + __forceinline ~DynamicStackArray () + { + if (!isStackAllocated()) + delete[] data; + } + + __forceinline bool isStackAllocated() const { + return data == &arr[0]; + } + + __forceinline size_t size() const + { + if (isStackAllocated()) return max_stack_elements; + else return max_total_elements; + } + + __forceinline void resize(size_t M) + { + assert(M <= max_total_elements); + if (likely(M <= max_stack_elements)) return; + if (likely(!isStackAllocated())) return; + + data = new Ty[max_total_elements]; + + for (size_t i=0; i=0 && ioperator[] (i) = other[i]; + } + + DynamicStackArray& operator= (const DynamicStackArray& other) + { + for (size_t i=0; ioperator[] (i) = other[i]; + + return *this; + } + + private: + Ty arr[max_stack_elements]; + Ty* data; + }; +} diff --git a/thirdparty/embree/common/sys/atomic.h b/thirdparty/embree/common/sys/atomic.h new file mode 100644 index 000000000000..ebfb8552c358 --- /dev/null +++ b/thirdparty/embree/common/sys/atomic.h @@ -0,0 +1,59 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "intrinsics.h" + +namespace embree +{ +/* compiler memory barriers */ +#if defined(__INTEL_COMPILER) +//#define __memory_barrier() __memory_barrier() +#elif defined(__GNUC__) || defined(__clang__) +# define __memory_barrier() asm volatile("" ::: "memory") +#elif defined(_MSC_VER) +# define __memory_barrier() _ReadWriteBarrier() +#endif + + template + struct atomic : public std::atomic + { + atomic () {} + + atomic (const T& a) + : std::atomic(a) {} + + atomic (const atomic& a) { + this->store(a.load()); + } + + atomic& operator=(const atomic& other) { + this->store(other.load()); + return *this; + } + }; + + template + __forceinline void atomic_min(std::atomic& aref, const T& bref) + { + const T b = bref.load(); + while (true) { + T a = aref.load(); + if (a <= b) break; + if (aref.compare_exchange_strong(a,b)) break; + } + } + + template + __forceinline void atomic_max(std::atomic& aref, const T& bref) + { + const T b = bref.load(); + while (true) { + T a = aref.load(); + if (a >= b) break; + if (aref.compare_exchange_strong(a,b)) break; + } + } +} diff --git a/thirdparty/embree/common/sys/barrier.cpp b/thirdparty/embree/common/sys/barrier.cpp new file mode 100644 index 000000000000..0061d18db255 --- /dev/null +++ b/thirdparty/embree/common/sys/barrier.cpp @@ -0,0 +1,289 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "barrier.h" +#include "condition.h" +#include "regression.h" +#include "thread.h" + +#if defined (__WIN32__) + +#define WIN32_LEAN_AND_MEAN +#include + +namespace embree +{ + struct BarrierSysImplementation + { + __forceinline BarrierSysImplementation (size_t N) + : i(0), enterCount(0), exitCount(0), barrierSize(0) + { + events[0] = CreateEvent(nullptr, TRUE, FALSE, nullptr); + events[1] = CreateEvent(nullptr, TRUE, FALSE, nullptr); + init(N); + } + + __forceinline ~BarrierSysImplementation () + { + CloseHandle(events[0]); + CloseHandle(events[1]); + } + + __forceinline void init(size_t N) + { + barrierSize = N; + enterCount.store(N); + exitCount.store(N); + } + + __forceinline void wait() + { + /* every thread entering the barrier decrements this count */ + size_t i0 = i; + size_t cnt0 = enterCount--; + + /* all threads except the last one are wait in the barrier */ + if (cnt0 > 1) + { + if (WaitForSingleObject(events[i0], INFINITE) != WAIT_OBJECT_0) + THROW_RUNTIME_ERROR("WaitForSingleObjects failed"); + } + + /* the last thread starts all threads waiting at the barrier */ + else + { + i = 1-i; + enterCount.store(barrierSize); + if (SetEvent(events[i0]) == 0) + THROW_RUNTIME_ERROR("SetEvent failed"); + } + + /* every thread leaving the barrier decrements this count */ + size_t cnt1 = exitCount--; + + /* the last thread that left the barrier resets the event again */ + if (cnt1 == 1) + { + exitCount.store(barrierSize); + if (ResetEvent(events[i0]) == 0) + THROW_RUNTIME_ERROR("ResetEvent failed"); + } + } + + public: + HANDLE events[2]; + atomic i; + atomic enterCount; + atomic exitCount; + size_t barrierSize; + }; +} + +#else + +namespace embree +{ + struct BarrierSysImplementation + { + __forceinline BarrierSysImplementation (size_t N) + : count(0), barrierSize(0) + { + init(N); + } + + __forceinline void init(size_t N) + { + assert(count == 0); + count = 0; + barrierSize = N; + } + + __forceinline void wait() + { + mutex.lock(); + count++; + + if (count == barrierSize) { + count = 0; + cond.notify_all(); + mutex.unlock(); + return; + } + + cond.wait(mutex); + mutex.unlock(); + return; + } + + public: + MutexSys mutex; + ConditionSys cond; + volatile size_t count; + volatile size_t barrierSize; + }; +} + +#endif + +namespace embree +{ + BarrierSys::BarrierSys (size_t N) { + opaque = new BarrierSysImplementation(N); + } + + BarrierSys::~BarrierSys () { + delete (BarrierSysImplementation*) opaque; + } + + void BarrierSys::init(size_t count) { + ((BarrierSysImplementation*) opaque)->init(count); + } + + void BarrierSys::wait() { + ((BarrierSysImplementation*) opaque)->wait(); + } + + LinearBarrierActive::LinearBarrierActive (size_t N) + : count0(nullptr), count1(nullptr), mode(0), flag0(0), flag1(0), threadCount(0) + { + if (N == 0) N = getNumberOfLogicalThreads(); + init(N); + } + + LinearBarrierActive::~LinearBarrierActive() + { + delete[] count0; + delete[] count1; + } + + void LinearBarrierActive::init(size_t N) + { + if (threadCount != N) { + threadCount = N; + if (count0) delete[] count0; count0 = new unsigned char[N]; + if (count1) delete[] count1; count1 = new unsigned char[N]; + } + mode = 0; + flag0 = 0; + flag1 = 0; + for (size_t i=0; i threadID; + std::atomic numFailed; + std::vector threadResults; + + barrier_sys_regression_test() + : RegressionTest("barrier_sys_regression_test"), threadID(0), numFailed(0) + { + registerRegressionTest(this); + } + + static void thread_alloc(barrier_sys_regression_test* This) + { + size_t tid = This->threadID++; + for (size_t j=0; j<1000; j++) + { + This->barrier.wait(); + This->threadResults[tid] = tid; + This->barrier.wait(); + } + } + + bool run () + { + threadID.store(0); + numFailed.store(0); + + size_t numThreads = getNumberOfLogicalThreads(); + threadResults.resize(numThreads); + barrier.init(numThreads+1); + + /* create threads */ + std::vector threads; + for (size_t i=0; i cntr; + }; + + /*! fast active barrier that does not require initialization to some number of threads */ + struct BarrierActiveAutoReset + { + public: + BarrierActiveAutoReset () + : cntr0(0), cntr1(0) {} + + void wait (size_t threadCount) + { + cntr0.fetch_add(1); + while (cntr0 != threadCount) pause_cpu(); + cntr1.fetch_add(1); + while (cntr1 != threadCount) pause_cpu(); + cntr0.fetch_add(-1); + while (cntr0 != 0) pause_cpu(); + cntr1.fetch_add(-1); + while (cntr1 != 0) pause_cpu(); + } + + private: + std::atomic cntr0; + std::atomic cntr1; + }; + + class LinearBarrierActive + { + public: + + /*! construction and destruction */ + LinearBarrierActive (size_t threadCount = 0); + ~LinearBarrierActive(); + + private: + /*! class in non-copyable */ + LinearBarrierActive (const LinearBarrierActive& other) DELETED; // do not implement + LinearBarrierActive& operator= (const LinearBarrierActive& other) DELETED; // do not implement + + public: + /*! intializes the barrier with some number of threads */ + void init(size_t threadCount); + + /*! thread with threadIndex waits in the barrier */ + void wait (const size_t threadIndex); + + private: + volatile unsigned char* count0; + volatile unsigned char* count1; + volatile unsigned int mode; + volatile unsigned int flag0; + volatile unsigned int flag1; + volatile size_t threadCount; + }; +} + diff --git a/thirdparty/embree/common/sys/condition.cpp b/thirdparty/embree/common/sys/condition.cpp new file mode 100644 index 000000000000..0e7ca7af399d --- /dev/null +++ b/thirdparty/embree/common/sys/condition.cpp @@ -0,0 +1,81 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "condition.h" + +#if defined(__WIN32__) && !defined(PTHREADS_WIN32) + +#define WIN32_LEAN_AND_MEAN +#include + +namespace embree +{ + struct ConditionImplementation + { + __forceinline ConditionImplementation () { + InitializeConditionVariable(&cond); + } + + __forceinline ~ConditionImplementation () { + } + + __forceinline void wait(MutexSys& mutex_in) { + SleepConditionVariableCS(&cond, (LPCRITICAL_SECTION)mutex_in.mutex, INFINITE); + } + + __forceinline void notify_all() { + WakeAllConditionVariable(&cond); + } + + public: + CONDITION_VARIABLE cond; + }; +} +#endif + +#if defined(__UNIX__) || defined(PTHREADS_WIN32) +#include +namespace embree +{ + struct ConditionImplementation + { + __forceinline ConditionImplementation () { + pthread_cond_init(&cond,nullptr); + } + + __forceinline ~ConditionImplementation() { + pthread_cond_destroy(&cond); + } + + __forceinline void wait(MutexSys& mutex) { + pthread_cond_wait(&cond, (pthread_mutex_t*)mutex.mutex); + } + + __forceinline void notify_all() { + pthread_cond_broadcast(&cond); + } + + public: + pthread_cond_t cond; + }; +} +#endif + +namespace embree +{ + ConditionSys::ConditionSys () { + cond = new ConditionImplementation; + } + + ConditionSys::~ConditionSys() { + delete (ConditionImplementation*) cond; + } + + void ConditionSys::wait(MutexSys& mutex) { + ((ConditionImplementation*) cond)->wait(mutex); + } + + void ConditionSys::notify_all() { + ((ConditionImplementation*) cond)->notify_all(); + } +} diff --git a/thirdparty/embree/common/sys/condition.h b/thirdparty/embree/common/sys/condition.h new file mode 100644 index 000000000000..7a3a05aa81af --- /dev/null +++ b/thirdparty/embree/common/sys/condition.h @@ -0,0 +1,31 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "mutex.h" + +namespace embree +{ + class ConditionSys + { + public: + ConditionSys(); + ~ConditionSys(); + void wait( class MutexSys& mutex ); + void notify_all(); + + template + __forceinline void wait( class MutexSys& mutex, const Predicate& pred ) + { + while (!pred()) wait(mutex); + } + + private: + ConditionSys (const ConditionSys& other) DELETED; // do not implement + ConditionSys& operator= (const ConditionSys& other) DELETED; // do not implement + + protected: + void* cond; + }; +} diff --git a/thirdparty/embree/common/sys/filename.cpp b/thirdparty/embree/common/sys/filename.cpp new file mode 100644 index 000000000000..86182c1afbbc --- /dev/null +++ b/thirdparty/embree/common/sys/filename.cpp @@ -0,0 +1,138 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "filename.h" +#include "sysinfo.h" + +namespace embree +{ +#ifdef __WIN32__ + const char path_sep = '\\'; +#else + const char path_sep = '/'; +#endif + + /*! create an empty filename */ + FileName::FileName () {} + + /*! create a valid filename from a string */ + FileName::FileName (const char* in) { + filename = in; + for (size_t i=0; i +#endif + +#include + +// -- GODOT start -- +#if defined(__WIN32__) && defined(__MINGW32__) +#include +#endif +// -- GODOT end -- + +#if defined(__BMI__) && defined(__GNUC__) && !defined(__INTEL_COMPILER) + #if !defined(_tzcnt_u32) + #define _tzcnt_u32 __tzcnt_u32 + #endif + #if !defined(_tzcnt_u64) + #define _tzcnt_u64 __tzcnt_u64 + #endif +#endif + +#if defined(__LZCNT__) + #if !defined(_lzcnt_u32) + #define _lzcnt_u32 __lzcnt32 + #endif + #if !defined(_lzcnt_u64) + #define _lzcnt_u64 __lzcnt64 + #endif +#endif + +#if defined(__WIN32__) +// -- GODOT start -- +#if !defined(NOMINMAX) +// -- GODOT end -- +#define NOMINMAX +// -- GODOT start -- +#endif +#include "windows.h" +// -- GODOT end -- +#endif + +/* normally defined in pmmintrin.h, but we always need this */ +#if !defined(_MM_SET_DENORMALS_ZERO_MODE) +#define _MM_DENORMALS_ZERO_ON (0x0040) +#define _MM_DENORMALS_ZERO_OFF (0x0000) +#define _MM_DENORMALS_ZERO_MASK (0x0040) +#define _MM_SET_DENORMALS_ZERO_MODE(x) (_mm_setcsr((_mm_getcsr() & ~_MM_DENORMALS_ZERO_MASK) | (x))) +#endif + +namespace embree +{ + +//////////////////////////////////////////////////////////////////////////////// +/// Windows Platform +//////////////////////////////////////////////////////////////////////////////// + +#if defined(__WIN32__) + + __forceinline size_t read_tsc() + { + LARGE_INTEGER li; + QueryPerformanceCounter(&li); + return (size_t)li.QuadPart; + } + + __forceinline int bsf(int v) { +#if defined(__AVX2__) + return _tzcnt_u32(v); +#else + unsigned long r = 0; _BitScanForward(&r,v); return r; +#endif + } + + __forceinline unsigned bsf(unsigned v) { +#if defined(__AVX2__) + return _tzcnt_u32(v); +#else + unsigned long r = 0; _BitScanForward(&r,v); return r; +#endif + } + +#if defined(__X86_64__) + __forceinline size_t bsf(size_t v) { +#if defined(__AVX2__) + return _tzcnt_u64(v); +#else + unsigned long r = 0; _BitScanForward64(&r,v); return r; +#endif + } +#endif + + __forceinline int bscf(int& v) + { + int i = bsf(v); + v &= v-1; + return i; + } + + __forceinline unsigned bscf(unsigned& v) + { + unsigned i = bsf(v); + v &= v-1; + return i; + } + +#if defined(__X86_64__) + __forceinline size_t bscf(size_t& v) + { + size_t i = bsf(v); + v &= v-1; + return i; + } +#endif + + __forceinline int bsr(int v) { +#if defined(__AVX2__) + return 31 - _lzcnt_u32(v); +#else + unsigned long r = 0; _BitScanReverse(&r,v); return r; +#endif + } + + __forceinline unsigned bsr(unsigned v) { +#if defined(__AVX2__) + return 31 - _lzcnt_u32(v); +#else + unsigned long r = 0; _BitScanReverse(&r,v); return r; +#endif + } + +#if defined(__X86_64__) + __forceinline size_t bsr(size_t v) { +#if defined(__AVX2__) + return 63 -_lzcnt_u64(v); +#else + unsigned long r = 0; _BitScanReverse64(&r, v); return r; +#endif + } +#endif + + __forceinline int lzcnt(const int x) + { +#if defined(__AVX2__) + return _lzcnt_u32(x); +#else + if (unlikely(x == 0)) return 32; + return 31 - bsr(x); +#endif + } + + __forceinline int btc(int v, int i) { + long r = v; _bittestandcomplement(&r,i); return r; + } + + __forceinline int bts(int v, int i) { + long r = v; _bittestandset(&r,i); return r; + } + + __forceinline int btr(int v, int i) { + long r = v; _bittestandreset(&r,i); return r; + } + +#if defined(__X86_64__) + + __forceinline size_t btc(size_t v, size_t i) { + size_t r = v; _bittestandcomplement64((__int64*)&r,i); return r; + } + + __forceinline size_t bts(size_t v, size_t i) { + __int64 r = v; _bittestandset64(&r,i); return r; + } + + __forceinline size_t btr(size_t v, size_t i) { + __int64 r = v; _bittestandreset64(&r,i); return r; + } + +#endif + + __forceinline int32_t atomic_cmpxchg(volatile int32_t* p, const int32_t c, const int32_t v) { + return _InterlockedCompareExchange((volatile long*)p,v,c); + } + +//////////////////////////////////////////////////////////////////////////////// +/// Unix Platform +//////////////////////////////////////////////////////////////////////////////// + +#else + +#if defined(__i386__) && defined(__PIC__) + + __forceinline void __cpuid(int out[4], int op) + { + asm volatile ("xchg{l}\t{%%}ebx, %1\n\t" + "cpuid\n\t" + "xchg{l}\t{%%}ebx, %1\n\t" + : "=a"(out[0]), "=r"(out[1]), "=c"(out[2]), "=d"(out[3]) + : "0"(op)); + } + + __forceinline void __cpuid_count(int out[4], int op1, int op2) + { + asm volatile ("xchg{l}\t{%%}ebx, %1\n\t" + "cpuid\n\t" + "xchg{l}\t{%%}ebx, %1\n\t" + : "=a" (out[0]), "=r" (out[1]), "=c" (out[2]), "=d" (out[3]) + : "0" (op1), "2" (op2)); + } + +#else + + __forceinline void __cpuid(int out[4], int op) { + asm volatile ("cpuid" : "=a"(out[0]), "=b"(out[1]), "=c"(out[2]), "=d"(out[3]) : "a"(op)); + } + + __forceinline void __cpuid_count(int out[4], int op1, int op2) { + asm volatile ("cpuid" : "=a"(out[0]), "=b"(out[1]), "=c"(out[2]), "=d"(out[3]) : "a"(op1), "c"(op2)); + } + +#endif + + __forceinline uint64_t read_tsc() { + uint32_t high,low; + asm volatile ("rdtsc" : "=d"(high), "=a"(low)); + return (((uint64_t)high) << 32) + (uint64_t)low; + } + + __forceinline int bsf(int v) { +#if defined(__AVX2__) + return _tzcnt_u32(v); +#else + int r = 0; asm ("bsf %1,%0" : "=r"(r) : "r"(v)); return r; +#endif + } + +#if defined(__X86_64__) + __forceinline unsigned bsf(unsigned v) + { +#if defined(__AVX2__) + return _tzcnt_u32(v); +#else + unsigned r = 0; asm ("bsf %1,%0" : "=r"(r) : "r"(v)); return r; +#endif + } +#endif + + __forceinline size_t bsf(size_t v) { +#if defined(__AVX2__) +#if defined(__X86_64__) + return _tzcnt_u64(v); +#else + return _tzcnt_u32(v); +#endif +#else + size_t r = 0; asm ("bsf %1,%0" : "=r"(r) : "r"(v)); return r; +#endif + } + + __forceinline int bscf(int& v) + { + int i = bsf(v); + v &= v-1; + return i; + } + +#if defined(__X86_64__) + __forceinline unsigned int bscf(unsigned int& v) + { + unsigned int i = bsf(v); + v &= v-1; + return i; + } +#endif + + __forceinline size_t bscf(size_t& v) + { + size_t i = bsf(v); + v &= v-1; + return i; + } + + __forceinline int bsr(int v) { +#if defined(__AVX2__) + return 31 - _lzcnt_u32(v); +#else + int r = 0; asm ("bsr %1,%0" : "=r"(r) : "r"(v)); return r; +#endif + } + +#if defined(__X86_64__) + __forceinline unsigned bsr(unsigned v) { +#if defined(__AVX2__) + return 31 - _lzcnt_u32(v); +#else + unsigned r = 0; asm ("bsr %1,%0" : "=r"(r) : "r"(v)); return r; +#endif + } +#endif + + __forceinline size_t bsr(size_t v) { +#if defined(__AVX2__) +#if defined(__X86_64__) + return 63 - _lzcnt_u64(v); +#else + return 31 - _lzcnt_u32(v); +#endif +#else + size_t r = 0; asm ("bsr %1,%0" : "=r"(r) : "r"(v)); return r; +#endif + } + + __forceinline int lzcnt(const int x) + { +#if defined(__AVX2__) + return _lzcnt_u32(x); +#else + if (unlikely(x == 0)) return 32; + return 31 - bsr(x); +#endif + } + + __forceinline size_t blsr(size_t v) { +#if defined(__AVX2__) +#if defined(__INTEL_COMPILER) + return _blsr_u64(v); +#else +#if defined(__X86_64__) + return __blsr_u64(v); +#else + return __blsr_u32(v); +#endif +#endif +#else + return v & (v-1); +#endif + } + + __forceinline int btc(int v, int i) { + int r = 0; asm ("btc %1,%0" : "=r"(r) : "r"(i), "0"(v) : "flags" ); return r; + } + + __forceinline int bts(int v, int i) { + int r = 0; asm ("bts %1,%0" : "=r"(r) : "r"(i), "0"(v) : "flags"); return r; + } + + __forceinline int btr(int v, int i) { + int r = 0; asm ("btr %1,%0" : "=r"(r) : "r"(i), "0"(v) : "flags"); return r; + } + + __forceinline size_t btc(size_t v, size_t i) { + size_t r = 0; asm ("btc %1,%0" : "=r"(r) : "r"(i), "0"(v) : "flags" ); return r; + } + + __forceinline size_t bts(size_t v, size_t i) { + size_t r = 0; asm ("bts %1,%0" : "=r"(r) : "r"(i), "0"(v) : "flags"); return r; + } + + __forceinline size_t btr(size_t v, size_t i) { + size_t r = 0; asm ("btr %1,%0" : "=r"(r) : "r"(i), "0"(v) : "flags"); return r; + } + + __forceinline int32_t atomic_cmpxchg(int32_t volatile* value, int32_t comparand, const int32_t input) { + return __sync_val_compare_and_swap(value, comparand, input); + } + +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// All Platforms +//////////////////////////////////////////////////////////////////////////////// + +#if defined(__clang__) || defined(__GNUC__) +#if !defined(_mm_undefined_ps) + __forceinline __m128 _mm_undefined_ps() { return _mm_setzero_ps(); } +#endif +#if !defined(_mm_undefined_si128) + __forceinline __m128i _mm_undefined_si128() { return _mm_setzero_si128(); } +#endif +#if !defined(_mm256_undefined_ps) && defined(__AVX__) + __forceinline __m256 _mm256_undefined_ps() { return _mm256_setzero_ps(); } +#endif +#if !defined(_mm256_undefined_si256) && defined(__AVX__) + __forceinline __m256i _mm256_undefined_si256() { return _mm256_setzero_si256(); } +#endif +#if !defined(_mm512_undefined_ps) && defined(__AVX512F__) + __forceinline __m512 _mm512_undefined_ps() { return _mm512_setzero_ps(); } +#endif +#if !defined(_mm512_undefined_epi32) && defined(__AVX512F__) + __forceinline __m512i _mm512_undefined_epi32() { return _mm512_setzero_si512(); } +#endif +#endif + +#if defined(__SSE4_2__) + + __forceinline int popcnt(int in) { + return _mm_popcnt_u32(in); + } + + __forceinline unsigned popcnt(unsigned in) { + return _mm_popcnt_u32(in); + } + +#if defined(__X86_64__) + __forceinline size_t popcnt(size_t in) { + return _mm_popcnt_u64(in); + } +#endif + +#endif + + __forceinline uint64_t rdtsc() + { + int dummy[4]; + __cpuid(dummy,0); + uint64_t clock = read_tsc(); + __cpuid(dummy,0); + return clock; + } + + __forceinline void pause_cpu(const size_t N = 8) + { +// -- GODOT start -- + for (size_t i=0; i + +namespace embree +{ + /* opens a shared library */ + lib_t openLibrary(const std::string& file) + { + std::string fullName = file+".dll"; + FileName executable = getExecutableFileName(); + HANDLE handle = LoadLibrary((executable.path() + fullName).c_str()); + return lib_t(handle); + } + + /* returns address of a symbol from the library */ + void* getSymbol(lib_t lib, const std::string& sym) { + // -- GODOT start -- + return (void*) GetProcAddress(HMODULE(lib),sym.c_str()); + // -- GODOT end -- + } + + /* closes the shared library */ + void closeLibrary(lib_t lib) { + FreeLibrary(HMODULE(lib)); + } +} +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// Unix Platform +//////////////////////////////////////////////////////////////////////////////// + +#if defined(__UNIX__) + +#include + +namespace embree +{ + /* opens a shared library */ + lib_t openLibrary(const std::string& file) + { +#if defined(__MACOSX__) + std::string fullName = "lib"+file+".dylib"; +#else + std::string fullName = "lib"+file+".so"; +#endif + void* lib = dlopen(fullName.c_str(), RTLD_NOW); + if (lib) return lib_t(lib); + FileName executable = getExecutableFileName(); + lib = dlopen((executable.path() + fullName).c_str(),RTLD_NOW); + if (lib == nullptr) { + const char* error = dlerror(); + if (error) { + THROW_RUNTIME_ERROR(error); + } else { + THROW_RUNTIME_ERROR("could not load library "+executable.str()); + } + } + return lib_t(lib); + } + + /* returns address of a symbol from the library */ + void* getSymbol(lib_t lib, const std::string& sym) { + return dlsym(lib,sym.c_str()); + } + + /* closes the shared library */ + void closeLibrary(lib_t lib) { + dlclose(lib); + } +} +#endif diff --git a/thirdparty/embree/common/sys/library.h b/thirdparty/embree/common/sys/library.h new file mode 100644 index 000000000000..c2164e9fbe02 --- /dev/null +++ b/thirdparty/embree/common/sys/library.h @@ -0,0 +1,21 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "platform.h" + +namespace embree +{ + /*! type for shared library */ + typedef struct opaque_lib_t* lib_t; + + /*! loads a shared library */ + lib_t openLibrary(const std::string& file); + + /*! returns address of a symbol from the library */ + void* getSymbol(lib_t lib, const std::string& sym); + + /*! unloads a shared library */ + void closeLibrary(lib_t lib); +} diff --git a/thirdparty/embree/common/sys/mutex.cpp b/thirdparty/embree/common/sys/mutex.cpp new file mode 100644 index 000000000000..57ef360981ac --- /dev/null +++ b/thirdparty/embree/common/sys/mutex.cpp @@ -0,0 +1,57 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "mutex.h" +#include "regression.h" + +#if defined(__WIN32__) && !defined(PTHREADS_WIN32) + +#define WIN32_LEAN_AND_MEAN +#include + +namespace embree +{ + MutexSys::MutexSys() { mutex = new CRITICAL_SECTION; InitializeCriticalSection((CRITICAL_SECTION*)mutex); } + MutexSys::~MutexSys() { DeleteCriticalSection((CRITICAL_SECTION*)mutex); delete (CRITICAL_SECTION*)mutex; } + void MutexSys::lock() { EnterCriticalSection((CRITICAL_SECTION*)mutex); } + bool MutexSys::try_lock() { return TryEnterCriticalSection((CRITICAL_SECTION*)mutex) != 0; } + void MutexSys::unlock() { LeaveCriticalSection((CRITICAL_SECTION*)mutex); } +} +#endif + +#if defined(__UNIX__) || defined(PTHREADS_WIN32) +#include +namespace embree +{ + /*! system mutex using pthreads */ + MutexSys::MutexSys() + { + mutex = new pthread_mutex_t; + if (pthread_mutex_init((pthread_mutex_t*)mutex, nullptr) != 0) + THROW_RUNTIME_ERROR("pthread_mutex_init failed"); + } + + MutexSys::~MutexSys() + { + MAYBE_UNUSED bool ok = pthread_mutex_destroy((pthread_mutex_t*)mutex) == 0; + assert(ok); + delete (pthread_mutex_t*)mutex; + } + + void MutexSys::lock() + { + if (pthread_mutex_lock((pthread_mutex_t*)mutex) != 0) + THROW_RUNTIME_ERROR("pthread_mutex_lock failed"); + } + + bool MutexSys::try_lock() { + return pthread_mutex_trylock((pthread_mutex_t*)mutex) == 0; + } + + void MutexSys::unlock() + { + if (pthread_mutex_unlock((pthread_mutex_t*)mutex) != 0) + THROW_RUNTIME_ERROR("pthread_mutex_unlock failed"); + } +}; +#endif diff --git a/thirdparty/embree/common/sys/mutex.h b/thirdparty/embree/common/sys/mutex.h new file mode 100644 index 000000000000..f0f55340a96a --- /dev/null +++ b/thirdparty/embree/common/sys/mutex.h @@ -0,0 +1,114 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "platform.h" +#include "intrinsics.h" +#include "atomic.h" + +namespace embree +{ + /*! system mutex */ + class MutexSys { + friend struct ConditionImplementation; + public: + MutexSys(); + ~MutexSys(); + + private: + MutexSys (const MutexSys& other) DELETED; // do not implement + MutexSys& operator= (const MutexSys& other) DELETED; // do not implement + + public: + void lock(); + bool try_lock(); + void unlock(); + + protected: + void* mutex; + }; + + /*! spinning mutex */ + class SpinLock + { + public: + + SpinLock () + : flag(false) {} + + __forceinline bool isLocked() { + return flag.load(); + } + + __forceinline void lock() + { + while (true) + { + while (flag.load()) + { +// -- GODOT start -- +#if !(defined (__WIN32__) && defined (__MINGW32__)) +// -- GODOT end -- + _mm_pause(); + _mm_pause(); +// -- GODOT start -- +#else + usleep(1); +#endif +// -- GODOT end -- + } + + bool expected = false; + if (flag.compare_exchange_strong(expected,true,std::memory_order_acquire)) + break; + } + } + + __forceinline bool try_lock() + { + bool expected = false; + if (flag.load() != expected) { + return false; + } + return flag.compare_exchange_strong(expected,true,std::memory_order_acquire); + } + + __forceinline void unlock() { + flag.store(false,std::memory_order_release); + } + + __forceinline void wait_until_unlocked() + { + while(flag.load()) + { +// -- GODOT start -- +#if !(defined (__WIN32__) && defined(__MINGW32__)) +// -- GODOT end -- + _mm_pause(); + _mm_pause(); +// -- GODOT start -- +#else + usleep(1); +#endif +// -- GODOT end -- + } + } + + public: + atomic flag; + }; + + /*! safe mutex lock and unlock helper */ + template class Lock { + public: + Lock (Mutex& mutex) : mutex(mutex), locked(true) { mutex.lock(); } + Lock (Mutex& mutex, bool locked) : mutex(mutex), locked(locked) {} + ~Lock() { if (locked) mutex.unlock(); } + __forceinline void lock() { assert(!locked); locked = true; mutex.lock(); } + __forceinline bool isLocked() const { return locked; } + protected: + Mutex& mutex; + bool locked; + }; +} diff --git a/thirdparty/embree/common/sys/platform.h b/thirdparty/embree/common/sys/platform.h new file mode 100644 index 000000000000..08617452f48a --- /dev/null +++ b/thirdparty/embree/common/sys/platform.h @@ -0,0 +1,374 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#define _CRT_SECURE_NO_WARNINGS + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +/// detect platform +//////////////////////////////////////////////////////////////////////////////// + +/* detect 32 or 64 platform */ +#if defined(__x86_64__) || defined(__ia64__) || defined(_M_X64) +#define __X86_64__ +#endif + +/* detect Linux platform */ +#if defined(linux) || defined(__linux__) || defined(__LINUX__) +# if !defined(__LINUX__) +# define __LINUX__ +# endif +# if !defined(__UNIX__) +# define __UNIX__ +# endif +#endif + +/* detect FreeBSD platform */ +#if defined(__FreeBSD__) || defined(__FREEBSD__) +# if !defined(__FREEBSD__) +# define __FREEBSD__ +# endif +# if !defined(__UNIX__) +# define __UNIX__ +# endif +#endif + +/* detect Windows 95/98/NT/2000/XP/Vista/7/8/10 platform */ +#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__)) && !defined(__CYGWIN__) +# if !defined(__WIN32__) +# define __WIN32__ +# endif +#endif + +/* detect Cygwin platform */ +#if defined(__CYGWIN__) +# if !defined(__UNIX__) +# define __UNIX__ +# endif +#endif + +/* detect MAC OS X platform */ +#if defined(__APPLE__) || defined(MACOSX) || defined(__MACOSX__) +# if !defined(__MACOSX__) +# define __MACOSX__ +# endif +# if !defined(__UNIX__) +# define __UNIX__ +# endif +#endif + +/* try to detect other Unix systems */ +#if defined(__unix__) || defined (unix) || defined(__unix) || defined(_unix) +# if !defined(__UNIX__) +# define __UNIX__ +# endif +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// Macros +//////////////////////////////////////////////////////////////////////////////// + +#ifdef __WIN32__ +#define dll_export __declspec(dllexport) +#define dll_import __declspec(dllimport) +#else +#define dll_export __attribute__ ((visibility ("default"))) +#define dll_import +#endif + +#ifdef __WIN32__ +#if !defined(__noinline) +#define __noinline __declspec(noinline) +#endif +//#define __forceinline __forceinline +//#define __restrict __restrict +#if defined(__INTEL_COMPILER) +#define __restrict__ __restrict +#else +#define __restrict__ //__restrict // causes issues with MSVC +#endif +#if !defined(__thread) +#define __thread __declspec(thread) +#endif +#if !defined(__aligned) +#define __aligned(...) __declspec(align(__VA_ARGS__)) +#endif +//#define __FUNCTION__ __FUNCTION__ +#define debugbreak() __debugbreak() + +#else +#if !defined(__noinline) +#define __noinline __attribute__((noinline)) +#endif +#if !defined(__forceinline) +#define __forceinline inline __attribute__((always_inline)) +#endif +//#define __restrict __restrict +//#define __thread __thread +#if !defined(__aligned) +#define __aligned(...) __attribute__((aligned(__VA_ARGS__))) +#endif +#if !defined(__FUNCTION__) +#define __FUNCTION__ __PRETTY_FUNCTION__ +#endif +#define debugbreak() asm ("int $3") +#endif + +#if defined(__clang__) || defined(__GNUC__) + #define MAYBE_UNUSED __attribute__((unused)) +#else + #define MAYBE_UNUSED +#endif + +#if defined(_MSC_VER) && (_MSC_VER < 1900) // before VS2015 deleted functions are not supported properly + #define DELETED +#else + #define DELETED = delete +#endif + +// -- GODOT start -- +#if !defined(likely) +// -- GODOT end -- +#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) +#define likely(expr) (expr) +#define unlikely(expr) (expr) +#else +#define likely(expr) __builtin_expect((bool)(expr),true ) +#define unlikely(expr) __builtin_expect((bool)(expr),false) +#endif +// -- GODOT start -- +#endif +// -- GODOT end -- + +//////////////////////////////////////////////////////////////////////////////// +/// Error handling and debugging +//////////////////////////////////////////////////////////////////////////////// + +/* debug printing macros */ +#define STRING(x) #x +#define TOSTRING(x) STRING(x) +#define PING embree_cout << __FILE__ << " (" << __LINE__ << "): " << __FUNCTION__ << embree_endl +#define PRINT(x) embree_cout << STRING(x) << " = " << (x) << embree_endl +#define PRINT2(x,y) embree_cout << STRING(x) << " = " << (x) << ", " << STRING(y) << " = " << (y) << embree_endl +#define PRINT3(x,y,z) embree_cout << STRING(x) << " = " << (x) << ", " << STRING(y) << " = " << (y) << ", " << STRING(z) << " = " << (z) << embree_endl +#define PRINT4(x,y,z,w) embree_cout << STRING(x) << " = " << (x) << ", " << STRING(y) << " = " << (y) << ", " << STRING(z) << " = " << (z) << ", " << STRING(w) << " = " << (w) << embree_endl + +#if defined(DEBUG) // only report file and line in debug mode + #define THROW_RUNTIME_ERROR(str) \ + throw std::runtime_error(std::string(__FILE__) + " (" + toString(__LINE__) + "): " + std::string(str)); +#else + #define THROW_RUNTIME_ERROR(str) \ + throw std::runtime_error(str); +#endif + +#define FATAL(x) THROW_RUNTIME_ERROR(x) +#define WARNING(x) { std::cerr << "Warning: " << x << embree_endl << std::flush; } + +#define NOT_IMPLEMENTED FATAL(std::string(__FUNCTION__) + " not implemented") + +//////////////////////////////////////////////////////////////////////////////// +/// Basic types +//////////////////////////////////////////////////////////////////////////////// + +/* default floating-point type */ +namespace embree { + typedef float real; +} + +/* windows does not have ssize_t */ +#if defined(__WIN32__) +#if defined(__X86_64__) +typedef int64_t ssize_t; +#else +typedef int32_t ssize_t; +#endif +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// Basic utility functions +//////////////////////////////////////////////////////////////////////////////// + +__forceinline std::string toString(long long value) { + return std::to_string(value); +} + +//////////////////////////////////////////////////////////////////////////////// +/// Disable some compiler warnings +//////////////////////////////////////////////////////////////////////////////// + +#if defined(__INTEL_COMPILER) +//#pragma warning(disable:265 ) // floating-point operation result is out of range +//#pragma warning(disable:383 ) // value copied to temporary, reference to temporary used +//#pragma warning(disable:869 ) // parameter was never referenced +//#pragma warning(disable:981 ) // operands are evaluated in unspecified order +//#pragma warning(disable:1418) // external function definition with no prior declaration +//#pragma warning(disable:1419) // external declaration in primary source file +//#pragma warning(disable:1572) // floating-point equality and inequality comparisons are unreliable +//#pragma warning(disable:94 ) // the size of an array must be greater than zero +//#pragma warning(disable:1599) // declaration hides parameter +//#pragma warning(disable:424 ) // extra ";" ignored +#pragma warning(disable:2196) // routine is both "inline" and "noinline" +//#pragma warning(disable:177 ) // label was declared but never referenced +//#pragma warning(disable:114 ) // function was referenced but not defined +//#pragma warning(disable:819 ) // template nesting depth does not match the previous declaration of function +#pragma warning(disable:15335) // was not vectorized: vectorization possible but seems inefficient +#endif + +#if defined(_MSC_VER) +//#pragma warning(disable:4200) // nonstandard extension used : zero-sized array in struct/union +#pragma warning(disable:4800) // forcing value to bool 'true' or 'false' (performance warning) +//#pragma warning(disable:4267) // '=' : conversion from 'size_t' to 'unsigned long', possible loss of data +#pragma warning(disable:4244) // 'argument' : conversion from 'ssize_t' to 'unsigned int', possible loss of data +//#pragma warning(disable:4355) // 'this' : used in base member initializer list +//#pragma warning(disable:391 ) // '<=' : signed / unsigned mismatch +//#pragma warning(disable:4018) // '<' : signed / unsigned mismatch +//#pragma warning(disable:4305) // 'initializing' : truncation from 'double' to 'float' +//#pragma warning(disable:4068) // unknown pragma +//#pragma warning(disable:4146) // unary minus operator applied to unsigned type, result still unsigned +//#pragma warning(disable:4838) // conversion from 'unsigned int' to 'const int' requires a narrowing conversion) +//#pragma warning(disable:4227) // anachronism used : qualifiers on reference are ignored +#pragma warning(disable:4503) // decorated name length exceeded, name was truncated +#pragma warning(disable:4180) // qualifier applied to function type has no meaning; ignored +#pragma warning(disable:4258) // definition from the for loop is ignored; the definition from the enclosing scope is used + +# if _MSC_VER < 1910 // prior to Visual studio 2017 (V141) +# pragma warning(disable:4101) // warning C4101: 'x': unreferenced local variable // a compiler bug issues wrong warnings +# pragma warning(disable:4789) // buffer '' of size 8 bytes will be overrun; 32 bytes will be written starting at offset 0 +# endif + +#endif + +#if defined(__clang__) && !defined(__INTEL_COMPILER) +//#pragma clang diagnostic ignored "-Wunknown-pragmas" +//#pragma clang diagnostic ignored "-Wunused-variable" +//#pragma clang diagnostic ignored "-Wreorder" +//#pragma clang diagnostic ignored "-Wmicrosoft" +//#pragma clang diagnostic ignored "-Wunused-private-field" +//#pragma clang diagnostic ignored "-Wunused-local-typedef" +//#pragma clang diagnostic ignored "-Wunused-function" +//#pragma clang diagnostic ignored "-Wnarrowing" +//#pragma clang diagnostic ignored "-Wc++11-narrowing" +//#pragma clang diagnostic ignored "-Wdeprecated-register" +//#pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif + +#if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__) +#pragma GCC diagnostic ignored "-Wpragmas" +//#pragma GCC diagnostic ignored "-Wnarrowing" +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +//#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +//#pragma GCC diagnostic ignored "-Warray-bounds" +#pragma GCC diagnostic ignored "-Wattributes" +#pragma GCC diagnostic ignored "-Wmisleading-indentation" +#pragma GCC diagnostic ignored "-Wsign-compare" +#pragma GCC diagnostic ignored "-Wparentheses" +#endif + +#if defined(__clang__) && defined(__WIN32__) +#pragma clang diagnostic ignored "-Wunused-parameter" +#pragma clang diagnostic ignored "-Wmicrosoft-cast" +#pragma clang diagnostic ignored "-Wmicrosoft-enum-value" +#pragma clang diagnostic ignored "-Wmicrosoft-include" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunknown-pragmas" +#endif + +/* disabling deprecated warning, please use only where use of deprecated Embree API functions is desired */ +#if defined(__WIN32__) && defined(__INTEL_COMPILER) +#define DISABLE_DEPRECATED_WARNING __pragma(warning (disable: 1478)) // warning: function was declared deprecated +#define ENABLE_DEPRECATED_WARNING __pragma(warning (enable: 1478)) // warning: function was declared deprecated +#elif defined(__INTEL_COMPILER) +#define DISABLE_DEPRECATED_WARNING _Pragma("warning (disable: 1478)") // warning: function was declared deprecated +#define ENABLE_DEPRECATED_WARNING _Pragma("warning (enable : 1478)") // warning: function was declared deprecated +#elif defined(__clang__) +#define DISABLE_DEPRECATED_WARNING _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"") // warning: xxx is deprecated +#define ENABLE_DEPRECATED_WARNING _Pragma("clang diagnostic warning \"-Wdeprecated-declarations\"") // warning: xxx is deprecated +#elif defined(__GNUC__) +#define DISABLE_DEPRECATED_WARNING _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") // warning: xxx is deprecated +#define ENABLE_DEPRECATED_WARNING _Pragma("GCC diagnostic warning \"-Wdeprecated-declarations\"") // warning: xxx is deprecated +#elif defined(_MSC_VER) +#define DISABLE_DEPRECATED_WARNING __pragma(warning (disable: 4996)) // warning: function was declared deprecated +#define ENABLE_DEPRECATED_WARNING __pragma(warning (enable : 4996)) // warning: function was declared deprecated +#endif + +/* embree output stream */ +#define embree_ostream std::ostream& +#define embree_cout std::cout +#define embree_cout_uniform std::cout +#define embree_endl std::endl + +//////////////////////////////////////////////////////////////////////////////// +/// Some macros for static profiling +//////////////////////////////////////////////////////////////////////////////// + +#if defined (__GNUC__) +#define IACA_SSC_MARK( MARK_ID ) \ +__asm__ __volatile__ ( \ + "\n\t movl $"#MARK_ID", %%ebx" \ + "\n\t .byte 0x64, 0x67, 0x90" \ + : : : "memory" ); + +#define IACA_UD_BYTES __asm__ __volatile__ ("\n\t .byte 0x0F, 0x0B"); + +#else +#define IACA_UD_BYTES {__asm _emit 0x0F \ + __asm _emit 0x0B} + +#define IACA_SSC_MARK(x) {__asm mov ebx, x\ + __asm _emit 0x64 \ + __asm _emit 0x67 \ + __asm _emit 0x90 } + +#define IACA_VC64_START __writegsbyte(111, 111); +#define IACA_VC64_END __writegsbyte(222, 222); + +#endif + +#define IACA_START {IACA_UD_BYTES \ + IACA_SSC_MARK(111)} +#define IACA_END {IACA_SSC_MARK(222) \ + IACA_UD_BYTES} + +namespace embree +{ + template + struct OnScopeExitHelper + { + OnScopeExitHelper (const Closure f) : active(true), f(f) {} + ~OnScopeExitHelper() { if (active) f(); } + void deactivate() { active = false; } + bool active; + const Closure f; + }; + + template + OnScopeExitHelper OnScopeExit(const Closure f) { + return OnScopeExitHelper(f); + } + +#define STRING_JOIN2(arg1, arg2) DO_STRING_JOIN2(arg1, arg2) +#define DO_STRING_JOIN2(arg1, arg2) arg1 ## arg2 +#define ON_SCOPE_EXIT(code) \ + auto STRING_JOIN2(on_scope_exit_, __LINE__) = OnScopeExit([&](){code;}) + + template + std::unique_ptr make_unique(Ty* ptr) { + return std::unique_ptr(ptr); + } + +} diff --git a/thirdparty/embree/common/sys/ref.h b/thirdparty/embree/common/sys/ref.h new file mode 100644 index 000000000000..24648e623444 --- /dev/null +++ b/thirdparty/embree/common/sys/ref.h @@ -0,0 +1,122 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "atomic.h" + +namespace embree +{ + struct NullTy { + }; + + extern MAYBE_UNUSED NullTy null; + + class RefCount + { + public: + RefCount(int val = 0) : refCounter(val) {} + virtual ~RefCount() {}; + + virtual RefCount* refInc() { refCounter.fetch_add(1); return this; } + virtual void refDec() { if (refCounter.fetch_add(-1) == 1) delete this; } + private: + std::atomic refCounter; + }; + + //////////////////////////////////////////////////////////////////////////////// + /// Reference to single object + //////////////////////////////////////////////////////////////////////////////// + + template + class Ref + { + public: + Type* ptr; + + //////////////////////////////////////////////////////////////////////////////// + /// Constructors, Assignment & Cast Operators + //////////////////////////////////////////////////////////////////////////////// + + __forceinline Ref() : ptr(nullptr) {} + __forceinline Ref(NullTy) : ptr(nullptr) {} + __forceinline Ref(const Ref& input) : ptr(input.ptr) { if (ptr) ptr->refInc(); } + __forceinline Ref(Ref&& input) : ptr(input.ptr) { input.ptr = nullptr; } + + __forceinline Ref(Type* const input) : ptr(input) + { + if (ptr) + ptr->refInc(); + } + + __forceinline ~Ref() + { + if (ptr) + ptr->refDec(); + } + + __forceinline Ref& operator =(const Ref& input) + { + if (input.ptr) + input.ptr->refInc(); + if (ptr) + ptr->refDec(); + ptr = input.ptr; + return *this; + } + + __forceinline Ref& operator =(Ref&& input) + { + if (ptr) + ptr->refDec(); + ptr = input.ptr; + input.ptr = nullptr; + return *this; + } + + __forceinline Ref& operator =(Type* const input) + { + if (input) + input->refInc(); + if (ptr) + ptr->refDec(); + ptr = input; + return *this; + } + + __forceinline Ref& operator =(NullTy) + { + if (ptr) + ptr->refDec(); + ptr = nullptr; + return *this; + } + + __forceinline operator bool() const { return ptr != nullptr; } + + __forceinline const Type& operator *() const { return *ptr; } + __forceinline Type& operator *() { return *ptr; } + __forceinline const Type* operator ->() const { return ptr; } + __forceinline Type* operator ->() { return ptr; } + + template + __forceinline Ref cast() { return Ref(static_cast(ptr)); } + template + __forceinline const Ref cast() const { return Ref(static_cast(ptr)); } + + template + __forceinline Ref dynamicCast() { return Ref(dynamic_cast(ptr)); } + template + __forceinline const Ref dynamicCast() const { return Ref(dynamic_cast(ptr)); } + }; + + template __forceinline bool operator < (const Ref& a, const Ref& b) { return a.ptr < b.ptr; } + + template __forceinline bool operator ==(const Ref& a, NullTy ) { return a.ptr == nullptr; } + template __forceinline bool operator ==(NullTy , const Ref& b) { return nullptr == b.ptr; } + template __forceinline bool operator ==(const Ref& a, const Ref& b) { return a.ptr == b.ptr; } + + template __forceinline bool operator !=(const Ref& a, NullTy ) { return a.ptr != nullptr; } + template __forceinline bool operator !=(NullTy , const Ref& b) { return nullptr != b.ptr; } + template __forceinline bool operator !=(const Ref& a, const Ref& b) { return a.ptr != b.ptr; } +} diff --git a/thirdparty/embree/common/sys/regression.cpp b/thirdparty/embree/common/sys/regression.cpp new file mode 100644 index 000000000000..d95ff8dfe067 --- /dev/null +++ b/thirdparty/embree/common/sys/regression.cpp @@ -0,0 +1,30 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "regression.h" + +namespace embree +{ + /* registerRegressionTest is invoked from static initializers, thus + * we cannot have the regression_tests variable as global static + * variable due to issues with static variable initialization + * order. */ + std::vector& get_regression_tests() + { + static std::vector regression_tests; + return regression_tests; + } + + void registerRegressionTest(RegressionTest* test) + { + get_regression_tests().push_back(test); + } + + RegressionTest* getRegressionTest(size_t index) + { + if (index >= get_regression_tests().size()) + return nullptr; + + return get_regression_tests()[index]; + } +} diff --git a/thirdparty/embree/common/sys/regression.h b/thirdparty/embree/common/sys/regression.h new file mode 100644 index 000000000000..632f8d92cf57 --- /dev/null +++ b/thirdparty/embree/common/sys/regression.h @@ -0,0 +1,25 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "platform.h" + +#include + +namespace embree +{ + /*! virtual interface for all regression tests */ + struct RegressionTest + { + RegressionTest (std::string name) : name(name) {} + virtual bool run() = 0; + std::string name; + }; + + /*! registers a regression test */ + void registerRegressionTest(RegressionTest* test); + + /*! run all regression tests */ + RegressionTest* getRegressionTest(size_t index); +} diff --git a/thirdparty/embree/common/sys/string.cpp b/thirdparty/embree/common/sys/string.cpp new file mode 100644 index 000000000000..931244383e05 --- /dev/null +++ b/thirdparty/embree/common/sys/string.cpp @@ -0,0 +1,42 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "string.h" + +#include +#include + +namespace embree +{ + char to_lower(char c) { return char(tolower(int(c))); } + char to_upper(char c) { return char(toupper(int(c))); } + std::string toLowerCase(const std::string& s) { std::string dst(s); std::transform(dst.begin(), dst.end(), dst.begin(), to_lower); return dst; } + std::string toUpperCase(const std::string& s) { std::string dst(s); std::transform(dst.begin(), dst.end(), dst.begin(), to_upper); return dst; } + + Vec2f string_to_Vec2f ( std::string str ) + { + size_t next = 0; + const float x = std::stof(str,&next); str = str.substr(next+1); + const float y = std::stof(str,&next); + return Vec2f(x,y); + } + + Vec3f string_to_Vec3f ( std::string str ) + { + size_t next = 0; + const float x = std::stof(str,&next); str = str.substr(next+1); + const float y = std::stof(str,&next); str = str.substr(next+1); + const float z = std::stof(str,&next); + return Vec3f(x,y,z); + } + + Vec4f string_to_Vec4f ( std::string str ) + { + size_t next = 0; + const float x = std::stof(str,&next); str = str.substr(next+1); + const float y = std::stof(str,&next); str = str.substr(next+1); + const float z = std::stof(str,&next); str = str.substr(next+1); + const float w = std::stof(str,&next); + return Vec4f(x,y,z,w); + } +} diff --git a/thirdparty/embree/common/sys/string.h b/thirdparty/embree/common/sys/string.h new file mode 100644 index 000000000000..2e9b0f88c325 --- /dev/null +++ b/thirdparty/embree/common/sys/string.h @@ -0,0 +1,37 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "platform.h" +#include "../math/vec2.h" +#include "../math/vec3.h" +#include "../math/vec4.h" + +namespace embree +{ + class IOStreamStateRestorer + { + public: + IOStreamStateRestorer(std::ostream& iostream) + : iostream(iostream), flags(iostream.flags()), precision(iostream.precision()) { + } + + ~IOStreamStateRestorer() { + iostream.flags(flags); + iostream.precision(precision); + } + + private: + std::ostream& iostream; + std::ios::fmtflags flags; + std::streamsize precision; + }; + + std::string toLowerCase(const std::string& s); + std::string toUpperCase(const std::string& s); + + Vec2f string_to_Vec2f ( std::string str ); + Vec3f string_to_Vec3f ( std::string str ); + Vec4f string_to_Vec4f ( std::string str ); +} diff --git a/thirdparty/embree/common/sys/sysinfo.cpp b/thirdparty/embree/common/sys/sysinfo.cpp new file mode 100644 index 000000000000..74438260dbe4 --- /dev/null +++ b/thirdparty/embree/common/sys/sysinfo.cpp @@ -0,0 +1,629 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "sysinfo.h" +#include "intrinsics.h" +#include "string.h" +#include "ref.h" +#if defined(__FREEBSD__) +#include +#include +typedef cpuset_t cpu_set_t; +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// All Platforms +//////////////////////////////////////////////////////////////////////////////// + +namespace embree +{ + NullTy null; + + std::string getPlatformName() + { +#if defined(__LINUX__) && !defined(__X86_64__) + return "Linux (32bit)"; +#elif defined(__LINUX__) && defined(__X86_64__) + return "Linux (64bit)"; +#elif defined(__FREEBSD__) && !defined(__X86_64__) + return "FreeBSD (32bit)"; +#elif defined(__FREEBSD__) && defined(__X86_64__) + return "FreeBSD (64bit)"; +#elif defined(__CYGWIN__) && !defined(__X86_64__) + return "Cygwin (32bit)"; +#elif defined(__CYGWIN__) && defined(__X86_64__) + return "Cygwin (64bit)"; +#elif defined(__WIN32__) && !defined(__X86_64__) + return "Windows (32bit)"; +#elif defined(__WIN32__) && defined(__X86_64__) + return "Windows (64bit)"; +#elif defined(__MACOSX__) && !defined(__X86_64__) + return "Mac OS X (32bit)"; +#elif defined(__MACOSX__) && defined(__X86_64__) + return "Mac OS X (64bit)"; +#elif defined(__UNIX__) && !defined(__X86_64__) + return "Unix (32bit)"; +#elif defined(__UNIX__) && defined(__X86_64__) + return "Unix (64bit)"; +#else + return "Unknown"; +#endif + } + + std::string getCompilerName() + { +#if defined(__INTEL_COMPILER) + int icc_mayor = __INTEL_COMPILER / 100 % 100; + int icc_minor = __INTEL_COMPILER % 100; + std::string version = "Intel Compiler "; + version += toString(icc_mayor); + version += "." + toString(icc_minor); +#if defined(__INTEL_COMPILER_UPDATE) + version += "." + toString(__INTEL_COMPILER_UPDATE); +#endif + return version; +#elif defined(__clang__) + return "CLANG " __clang_version__; +#elif defined (__GNUC__) + return "GCC " __VERSION__; +#elif defined(_MSC_VER) + std::string version = toString(_MSC_FULL_VER); + version.insert(4,"."); + version.insert(9,"."); + version.insert(2,"."); + return "Visual C++ Compiler " + version; +#else + return "Unknown Compiler"; +#endif + } + + std::string getCPUVendor() + { + int cpuinfo[4]; + __cpuid (cpuinfo, 0); + int name[4]; + name[0] = cpuinfo[1]; + name[1] = cpuinfo[3]; + name[2] = cpuinfo[2]; + name[3] = 0; + return (char*)name; + } + + CPU getCPUModel() + { + if (getCPUVendor() != "GenuineIntel") + return CPU::UNKNOWN; + + int out[4]; + __cpuid(out, 0); + if (out[0] < 1) return CPU::UNKNOWN; + __cpuid(out, 1); + + /* please see CPUID documentation for these formulas */ + uint32_t family_ID = (out[0] >> 8) & 0x0F; + uint32_t extended_family_ID = (out[0] >> 20) & 0xFF; + + uint32_t model_ID = (out[0] >> 4) & 0x0F; + uint32_t extended_model_ID = (out[0] >> 16) & 0x0F; + + uint32_t DisplayFamily = family_ID; + if (family_ID == 0x0F) + DisplayFamily += extended_family_ID; + + uint32_t DisplayModel = model_ID; + if (family_ID == 0x06 || family_ID == 0x0F) + DisplayModel += extended_model_ID << 4; + + uint32_t DisplayFamily_DisplayModel = (DisplayFamily << 8) + (DisplayModel << 0); + + // Data from Intel® 64 and IA-32 Architectures, Volume 4, Chapter 2, Table 2-1 (CPUID Signature Values of DisplayFamily_DisplayModel) + if (DisplayFamily_DisplayModel == 0x067D) return CPU::CORE_ICE_LAKE; + if (DisplayFamily_DisplayModel == 0x067E) return CPU::CORE_ICE_LAKE; + if (DisplayFamily_DisplayModel == 0x068C) return CPU::CORE_TIGER_LAKE; + if (DisplayFamily_DisplayModel == 0x06A5) return CPU::CORE_COMET_LAKE; + if (DisplayFamily_DisplayModel == 0x06A6) return CPU::CORE_COMET_LAKE; + if (DisplayFamily_DisplayModel == 0x0666) return CPU::CORE_CANNON_LAKE; + if (DisplayFamily_DisplayModel == 0x068E) return CPU::CORE_KABY_LAKE; + if (DisplayFamily_DisplayModel == 0x069E) return CPU::CORE_KABY_LAKE; + if (DisplayFamily_DisplayModel == 0x066A) return CPU::XEON_ICE_LAKE; + if (DisplayFamily_DisplayModel == 0x066C) return CPU::XEON_ICE_LAKE; + if (DisplayFamily_DisplayModel == 0x0655) return CPU::XEON_SKY_LAKE; + if (DisplayFamily_DisplayModel == 0x064E) return CPU::CORE_SKY_LAKE; + if (DisplayFamily_DisplayModel == 0x065E) return CPU::CORE_SKY_LAKE; + if (DisplayFamily_DisplayModel == 0x0656) return CPU::XEON_BROADWELL; + if (DisplayFamily_DisplayModel == 0x064F) return CPU::XEON_BROADWELL; + if (DisplayFamily_DisplayModel == 0x0647) return CPU::CORE_BROADWELL; + if (DisplayFamily_DisplayModel == 0x063D) return CPU::CORE_BROADWELL; + if (DisplayFamily_DisplayModel == 0x063F) return CPU::XEON_HASWELL; + if (DisplayFamily_DisplayModel == 0x063C) return CPU::CORE_HASWELL; + if (DisplayFamily_DisplayModel == 0x0645) return CPU::CORE_HASWELL; + if (DisplayFamily_DisplayModel == 0x0646) return CPU::CORE_HASWELL; + if (DisplayFamily_DisplayModel == 0x063E) return CPU::XEON_IVY_BRIDGE; + if (DisplayFamily_DisplayModel == 0x063A) return CPU::CORE_IVY_BRIDGE; + if (DisplayFamily_DisplayModel == 0x062D) return CPU::SANDY_BRIDGE; + if (DisplayFamily_DisplayModel == 0x062F) return CPU::SANDY_BRIDGE; + if (DisplayFamily_DisplayModel == 0x062A) return CPU::SANDY_BRIDGE; + if (DisplayFamily_DisplayModel == 0x062E) return CPU::NEHALEM; + if (DisplayFamily_DisplayModel == 0x0625) return CPU::NEHALEM; + if (DisplayFamily_DisplayModel == 0x062C) return CPU::NEHALEM; + if (DisplayFamily_DisplayModel == 0x061E) return CPU::NEHALEM; + if (DisplayFamily_DisplayModel == 0x061F) return CPU::NEHALEM; + if (DisplayFamily_DisplayModel == 0x061A) return CPU::NEHALEM; + if (DisplayFamily_DisplayModel == 0x061D) return CPU::NEHALEM; + if (DisplayFamily_DisplayModel == 0x0617) return CPU::CORE2; + if (DisplayFamily_DisplayModel == 0x060F) return CPU::CORE2; + if (DisplayFamily_DisplayModel == 0x060E) return CPU::CORE1; + + if (DisplayFamily_DisplayModel == 0x0685) return CPU::XEON_PHI_KNIGHTS_MILL; + if (DisplayFamily_DisplayModel == 0x0657) return CPU::XEON_PHI_KNIGHTS_LANDING; + + return CPU::UNKNOWN; + } + + std::string stringOfCPUModel(CPU model) + { + switch (model) { + case CPU::XEON_ICE_LAKE : return "Xeon Ice Lake"; + case CPU::CORE_ICE_LAKE : return "Core Ice Lake"; + case CPU::CORE_TIGER_LAKE : return "Core Tiger Lake"; + case CPU::CORE_COMET_LAKE : return "Core Comet Lake"; + case CPU::CORE_CANNON_LAKE : return "Core Cannon Lake"; + case CPU::CORE_KABY_LAKE : return "Core Kaby Lake"; + case CPU::XEON_SKY_LAKE : return "Xeon Sky Lake"; + case CPU::CORE_SKY_LAKE : return "Core Sky Lake"; + case CPU::XEON_PHI_KNIGHTS_MILL : return "Xeon Phi Knights Mill"; + case CPU::XEON_PHI_KNIGHTS_LANDING: return "Xeon Phi Knights Landing"; + case CPU::XEON_BROADWELL : return "Xeon Broadwell"; + case CPU::CORE_BROADWELL : return "Core Broadwell"; + case CPU::XEON_HASWELL : return "Xeon Haswell"; + case CPU::CORE_HASWELL : return "Core Haswell"; + case CPU::XEON_IVY_BRIDGE : return "Xeon Ivy Bridge"; + case CPU::CORE_IVY_BRIDGE : return "Core Ivy Bridge"; + case CPU::SANDY_BRIDGE : return "Sandy Bridge"; + case CPU::NEHALEM : return "Nehalem"; + case CPU::CORE2 : return "Core2"; + case CPU::CORE1 : return "Core"; + case CPU::UNKNOWN : return "Unknown CPU"; + } + return "Unknown CPU (error)"; + } + + /* constants to access destination registers of CPUID instruction */ + static const int EAX = 0; + static const int EBX = 1; + static const int ECX = 2; + static const int EDX = 3; + + /* cpuid[eax=1].ecx */ + static const int CPU_FEATURE_BIT_SSE3 = 1 << 0; + static const int CPU_FEATURE_BIT_SSSE3 = 1 << 9; + static const int CPU_FEATURE_BIT_FMA3 = 1 << 12; + static const int CPU_FEATURE_BIT_SSE4_1 = 1 << 19; + static const int CPU_FEATURE_BIT_SSE4_2 = 1 << 20; + //static const int CPU_FEATURE_BIT_MOVBE = 1 << 22; + static const int CPU_FEATURE_BIT_POPCNT = 1 << 23; + //static const int CPU_FEATURE_BIT_XSAVE = 1 << 26; + static const int CPU_FEATURE_BIT_OXSAVE = 1 << 27; + static const int CPU_FEATURE_BIT_AVX = 1 << 28; + static const int CPU_FEATURE_BIT_F16C = 1 << 29; + static const int CPU_FEATURE_BIT_RDRAND = 1 << 30; + + /* cpuid[eax=1].edx */ + static const int CPU_FEATURE_BIT_SSE = 1 << 25; + static const int CPU_FEATURE_BIT_SSE2 = 1 << 26; + + /* cpuid[eax=0x80000001].ecx */ + static const int CPU_FEATURE_BIT_LZCNT = 1 << 5; + + /* cpuid[eax=7,ecx=0].ebx */ + static const int CPU_FEATURE_BIT_BMI1 = 1 << 3; + static const int CPU_FEATURE_BIT_AVX2 = 1 << 5; + static const int CPU_FEATURE_BIT_BMI2 = 1 << 8; + static const int CPU_FEATURE_BIT_AVX512F = 1 << 16; // AVX512F (foundation) + static const int CPU_FEATURE_BIT_AVX512DQ = 1 << 17; // AVX512DQ (doubleword and quadword instructions) + static const int CPU_FEATURE_BIT_AVX512PF = 1 << 26; // AVX512PF (prefetch gather/scatter instructions) + static const int CPU_FEATURE_BIT_AVX512ER = 1 << 27; // AVX512ER (exponential and reciprocal instructions) + static const int CPU_FEATURE_BIT_AVX512CD = 1 << 28; // AVX512CD (conflict detection instructions) + static const int CPU_FEATURE_BIT_AVX512BW = 1 << 30; // AVX512BW (byte and word instructions) + static const int CPU_FEATURE_BIT_AVX512VL = 1 << 31; // AVX512VL (vector length extensions) + static const int CPU_FEATURE_BIT_AVX512IFMA = 1 << 21; // AVX512IFMA (integer fused multiple-add instructions) + + /* cpuid[eax=7,ecx=0].ecx */ + static const int CPU_FEATURE_BIT_AVX512VBMI = 1 << 1; // AVX512VBMI (vector bit manipulation instructions) + + __noinline int64_t get_xcr0() + { +#if defined (__WIN32__) /* -- GODOT start -- */ && !defined (__MINGW32__) /* -- GODOT end -- */ + int64_t xcr0 = 0; // int64_t is workaround for compiler bug under VS2013, Win32 + xcr0 = _xgetbv(0); + return xcr0; +#else + int xcr0 = 0; + __asm__ ("xgetbv" : "=a" (xcr0) : "c" (0) : "%edx" ); + return xcr0; +#endif + } + + int getCPUFeatures() + { + /* cache CPU features access */ + static int cpu_features = 0; + if (cpu_features) + return cpu_features; + + /* get number of CPUID leaves */ + int cpuid_leaf0[4]; + __cpuid(cpuid_leaf0, 0x00000000); + unsigned nIds = cpuid_leaf0[EAX]; + + /* get number of extended CPUID leaves */ + int cpuid_leafe[4]; + __cpuid(cpuid_leafe, 0x80000000); + unsigned nExIds = cpuid_leafe[EAX]; + + /* get CPUID leaves for EAX = 1,7, and 0x80000001 */ + int cpuid_leaf_1[4] = { 0,0,0,0 }; + int cpuid_leaf_7[4] = { 0,0,0,0 }; + int cpuid_leaf_e1[4] = { 0,0,0,0 }; + if (nIds >= 1) __cpuid (cpuid_leaf_1,0x00000001); +#if _WIN32 +#if _MSC_VER && (_MSC_FULL_VER < 160040219) +#else + if (nIds >= 7) __cpuidex(cpuid_leaf_7,0x00000007,0); +#endif +#else + if (nIds >= 7) __cpuid_count(cpuid_leaf_7,0x00000007,0); +#endif + if (nExIds >= 0x80000001) __cpuid(cpuid_leaf_e1,0x80000001); + + /* detect if OS saves XMM, YMM, and ZMM states */ + bool xmm_enabled = true; + bool ymm_enabled = false; + bool zmm_enabled = false; + if (cpuid_leaf_1[ECX] & CPU_FEATURE_BIT_OXSAVE) { + int64_t xcr0 = get_xcr0(); + xmm_enabled = ((xcr0 & 0x02) == 0x02); /* checks if xmm are enabled in XCR0 */ + ymm_enabled = xmm_enabled && ((xcr0 & 0x04) == 0x04); /* checks if ymm state are enabled in XCR0 */ + zmm_enabled = ymm_enabled && ((xcr0 & 0xE0) == 0xE0); /* checks if OPMASK state, upper 256-bit of ZMM0-ZMM15 and ZMM16-ZMM31 state are enabled in XCR0 */ + } + if (xmm_enabled) cpu_features |= CPU_FEATURE_XMM_ENABLED; + if (ymm_enabled) cpu_features |= CPU_FEATURE_YMM_ENABLED; + if (zmm_enabled) cpu_features |= CPU_FEATURE_ZMM_ENABLED; + + if (cpuid_leaf_1[EDX] & CPU_FEATURE_BIT_SSE ) cpu_features |= CPU_FEATURE_SSE; + if (cpuid_leaf_1[EDX] & CPU_FEATURE_BIT_SSE2 ) cpu_features |= CPU_FEATURE_SSE2; + if (cpuid_leaf_1[ECX] & CPU_FEATURE_BIT_SSE3 ) cpu_features |= CPU_FEATURE_SSE3; + if (cpuid_leaf_1[ECX] & CPU_FEATURE_BIT_SSSE3 ) cpu_features |= CPU_FEATURE_SSSE3; + if (cpuid_leaf_1[ECX] & CPU_FEATURE_BIT_SSE4_1) cpu_features |= CPU_FEATURE_SSE41; + if (cpuid_leaf_1[ECX] & CPU_FEATURE_BIT_SSE4_2) cpu_features |= CPU_FEATURE_SSE42; + if (cpuid_leaf_1[ECX] & CPU_FEATURE_BIT_POPCNT) cpu_features |= CPU_FEATURE_POPCNT; + + if (cpuid_leaf_1[ECX] & CPU_FEATURE_BIT_AVX ) cpu_features |= CPU_FEATURE_AVX; + if (cpuid_leaf_1[ECX] & CPU_FEATURE_BIT_F16C ) cpu_features |= CPU_FEATURE_F16C; + if (cpuid_leaf_1[ECX] & CPU_FEATURE_BIT_RDRAND) cpu_features |= CPU_FEATURE_RDRAND; + if (cpuid_leaf_7[EBX] & CPU_FEATURE_BIT_AVX2 ) cpu_features |= CPU_FEATURE_AVX2; + if (cpuid_leaf_1[ECX] & CPU_FEATURE_BIT_FMA3 ) cpu_features |= CPU_FEATURE_FMA3; + if (cpuid_leaf_e1[ECX] & CPU_FEATURE_BIT_LZCNT) cpu_features |= CPU_FEATURE_LZCNT; + if (cpuid_leaf_7 [EBX] & CPU_FEATURE_BIT_BMI1 ) cpu_features |= CPU_FEATURE_BMI1; + if (cpuid_leaf_7 [EBX] & CPU_FEATURE_BIT_BMI2 ) cpu_features |= CPU_FEATURE_BMI2; + + if (cpuid_leaf_7[EBX] & CPU_FEATURE_BIT_AVX512F ) cpu_features |= CPU_FEATURE_AVX512F; + if (cpuid_leaf_7[EBX] & CPU_FEATURE_BIT_AVX512DQ ) cpu_features |= CPU_FEATURE_AVX512DQ; + if (cpuid_leaf_7[EBX] & CPU_FEATURE_BIT_AVX512PF ) cpu_features |= CPU_FEATURE_AVX512PF; + if (cpuid_leaf_7[EBX] & CPU_FEATURE_BIT_AVX512ER ) cpu_features |= CPU_FEATURE_AVX512ER; + if (cpuid_leaf_7[EBX] & CPU_FEATURE_BIT_AVX512CD ) cpu_features |= CPU_FEATURE_AVX512CD; + if (cpuid_leaf_7[EBX] & CPU_FEATURE_BIT_AVX512BW ) cpu_features |= CPU_FEATURE_AVX512BW; + if (cpuid_leaf_7[EBX] & CPU_FEATURE_BIT_AVX512IFMA) cpu_features |= CPU_FEATURE_AVX512IFMA; + if (cpuid_leaf_7[EBX] & CPU_FEATURE_BIT_AVX512VL ) cpu_features |= CPU_FEATURE_AVX512VL; + if (cpuid_leaf_7[ECX] & CPU_FEATURE_BIT_AVX512VBMI) cpu_features |= CPU_FEATURE_AVX512VBMI; + + return cpu_features; + } + + std::string stringOfCPUFeatures(int features) + { + std::string str; + if (features & CPU_FEATURE_XMM_ENABLED) str += "XMM "; + if (features & CPU_FEATURE_YMM_ENABLED) str += "YMM "; + if (features & CPU_FEATURE_ZMM_ENABLED) str += "ZMM "; + if (features & CPU_FEATURE_SSE ) str += "SSE "; + if (features & CPU_FEATURE_SSE2 ) str += "SSE2 "; + if (features & CPU_FEATURE_SSE3 ) str += "SSE3 "; + if (features & CPU_FEATURE_SSSE3 ) str += "SSSE3 "; + if (features & CPU_FEATURE_SSE41 ) str += "SSE4.1 "; + if (features & CPU_FEATURE_SSE42 ) str += "SSE4.2 "; + if (features & CPU_FEATURE_POPCNT) str += "POPCNT "; + if (features & CPU_FEATURE_AVX ) str += "AVX "; + if (features & CPU_FEATURE_F16C ) str += "F16C "; + if (features & CPU_FEATURE_RDRAND) str += "RDRAND "; + if (features & CPU_FEATURE_AVX2 ) str += "AVX2 "; + if (features & CPU_FEATURE_FMA3 ) str += "FMA3 "; + if (features & CPU_FEATURE_LZCNT ) str += "LZCNT "; + if (features & CPU_FEATURE_BMI1 ) str += "BMI1 "; + if (features & CPU_FEATURE_BMI2 ) str += "BMI2 "; + if (features & CPU_FEATURE_AVX512F) str += "AVX512F "; + if (features & CPU_FEATURE_AVX512DQ) str += "AVX512DQ "; + if (features & CPU_FEATURE_AVX512PF) str += "AVX512PF "; + if (features & CPU_FEATURE_AVX512ER) str += "AVX512ER "; + if (features & CPU_FEATURE_AVX512CD) str += "AVX512CD "; + if (features & CPU_FEATURE_AVX512BW) str += "AVX512BW "; + if (features & CPU_FEATURE_AVX512VL) str += "AVX512VL "; + if (features & CPU_FEATURE_AVX512IFMA) str += "AVX512IFMA "; + if (features & CPU_FEATURE_AVX512VBMI) str += "AVX512VBMI "; + return str; + } + + std::string stringOfISA (int isa) + { + if (isa == SSE) return "SSE"; + if (isa == SSE2) return "SSE2"; + if (isa == SSE3) return "SSE3"; + if (isa == SSSE3) return "SSSE3"; + if (isa == SSE41) return "SSE4.1"; + if (isa == SSE42) return "SSE4.2"; + if (isa == AVX) return "AVX"; + if (isa == AVX2) return "AVX2"; + if (isa == AVX512KNL) return "AVX512KNL"; + if (isa == AVX512SKX) return "AVX512SKX"; + return "UNKNOWN"; + } + + bool hasISA(int features, int isa) { + return (features & isa) == isa; + } + + std::string supportedTargetList (int features) + { + std::string v; + if (hasISA(features,SSE)) v += "SSE "; + if (hasISA(features,SSE2)) v += "SSE2 "; + if (hasISA(features,SSE3)) v += "SSE3 "; + if (hasISA(features,SSSE3)) v += "SSSE3 "; + if (hasISA(features,SSE41)) v += "SSE4.1 "; + if (hasISA(features,SSE42)) v += "SSE4.2 "; + if (hasISA(features,AVX)) v += "AVX "; + if (hasISA(features,AVXI)) v += "AVXI "; + if (hasISA(features,AVX2)) v += "AVX2 "; + if (hasISA(features,AVX512KNL)) v += "AVX512KNL "; + if (hasISA(features,AVX512SKX)) v += "AVX512SKX "; + return v; + } +} + +//////////////////////////////////////////////////////////////////////////////// +/// Windows Platform +//////////////////////////////////////////////////////////////////////////////// + +#if defined(__WIN32__) + +#define WIN32_LEAN_AND_MEAN +#include +#include + +namespace embree +{ + std::string getExecutableFileName() { + char filename[1024]; + if (!GetModuleFileName(nullptr, filename, sizeof(filename))) + return std::string(); + return std::string(filename); + } + + unsigned int getNumberOfLogicalThreads() + { + static int nThreads = -1; + if (nThreads != -1) return nThreads; + + typedef WORD (WINAPI *GetActiveProcessorGroupCountFunc)(); + typedef DWORD (WINAPI *GetActiveProcessorCountFunc)(WORD); + HMODULE hlib = LoadLibrary("Kernel32"); + GetActiveProcessorGroupCountFunc pGetActiveProcessorGroupCount = (GetActiveProcessorGroupCountFunc)GetProcAddress(hlib, "GetActiveProcessorGroupCount"); + GetActiveProcessorCountFunc pGetActiveProcessorCount = (GetActiveProcessorCountFunc) GetProcAddress(hlib, "GetActiveProcessorCount"); + + if (pGetActiveProcessorGroupCount && pGetActiveProcessorCount) + { + int groups = pGetActiveProcessorGroupCount(); + int totalProcessors = 0; + for (int i = 0; i < groups; i++) + totalProcessors += pGetActiveProcessorCount(i); + nThreads = totalProcessors; + } + else + { + SYSTEM_INFO sysinfo; + GetSystemInfo(&sysinfo); + nThreads = sysinfo.dwNumberOfProcessors; + } + assert(nThreads); + return nThreads; + } + + int getTerminalWidth() + { + HANDLE handle = GetStdHandle(STD_OUTPUT_HANDLE); + if (handle == INVALID_HANDLE_VALUE) return 80; + CONSOLE_SCREEN_BUFFER_INFO info; + memset(&info,0,sizeof(info)); + GetConsoleScreenBufferInfo(handle, &info); + return info.dwSize.X; + } + + double getSeconds() + { + LARGE_INTEGER freq, val; + QueryPerformanceFrequency(&freq); + QueryPerformanceCounter(&val); + return (double)val.QuadPart / (double)freq.QuadPart; + } + + void sleepSeconds(double t) { + Sleep(DWORD(1000.0*t)); + } + + size_t getVirtualMemoryBytes() + { + PROCESS_MEMORY_COUNTERS info; + GetProcessMemoryInfo( GetCurrentProcess( ), &info, sizeof(info) ); + return (size_t)info.QuotaPeakPagedPoolUsage; + } + + size_t getResidentMemoryBytes() + { + PROCESS_MEMORY_COUNTERS info; + GetProcessMemoryInfo( GetCurrentProcess( ), &info, sizeof(info) ); + return (size_t)info.WorkingSetSize; + } +} +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// Linux Platform +//////////////////////////////////////////////////////////////////////////////// + +#if defined(__LINUX__) + +#include +#include + +namespace embree +{ + std::string getExecutableFileName() + { + std::string pid = "/proc/" + toString(getpid()) + "/exe"; + char buf[4096]; + memset(buf,0,sizeof(buf)); + if (readlink(pid.c_str(), buf, sizeof(buf)-1) == -1) + return std::string(); + return std::string(buf); + } + + size_t getVirtualMemoryBytes() + { + size_t virt, resident, shared; + std::ifstream buffer("/proc/self/statm"); + buffer >> virt >> resident >> shared; + return virt*sysconf(_SC_PAGE_SIZE); + } + + size_t getResidentMemoryBytes() + { + size_t virt, resident, shared; + std::ifstream buffer("/proc/self/statm"); + buffer >> virt >> resident >> shared; + return resident*sysconf(_SC_PAGE_SIZE); + } +} + +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// FreeBSD Platform +//////////////////////////////////////////////////////////////////////////////// + +#if defined (__FreeBSD__) + +#include + +namespace embree +{ + std::string getExecutableFileName() + { + const int mib[4] = { CTL_KERN, KERN_PROC, KERN_PROC_PATHNAME, -1 }; + char buf[4096]; + memset(buf,0,sizeof(buf)); + size_t len = sizeof(buf)-1; + if (sysctl(mib, 4, buf, &len, 0x0, 0) == -1) + return std::string(); + return std::string(buf); + } + + size_t getVirtualMemoryBytes() { + return 0; + } + + size_t getResidentMemoryBytes() { + return 0; + } +} + +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// Mac OS X Platform +//////////////////////////////////////////////////////////////////////////////// + +#if defined(__MACOSX__) + +#include + +namespace embree +{ + std::string getExecutableFileName() + { + char buf[4096]; + uint32_t size = sizeof(buf); + if (_NSGetExecutablePath(buf, &size) != 0) + return std::string(); + return std::string(buf); + } + + size_t getVirtualMemoryBytes() { + return 0; + } + + size_t getResidentMemoryBytes() { + return 0; + } +} + +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// Unix Platform +//////////////////////////////////////////////////////////////////////////////// + +#if defined(__UNIX__) + +#include +#include +#include +#include + +namespace embree +{ + unsigned int getNumberOfLogicalThreads() + { + static int nThreads = -1; + if (nThreads != -1) return nThreads; + +#if defined(__MACOSX__) + nThreads = sysconf(_SC_NPROCESSORS_ONLN); // does not work in Linux LXC container + assert(nThreads); +#else + cpu_set_t set; + if (pthread_getaffinity_np(pthread_self(), sizeof(set), &set) == 0) + nThreads = CPU_COUNT(&set); +#endif + + assert(nThreads); + return nThreads; + } + + int getTerminalWidth() + { + struct winsize info; + if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &info) < 0) return 80; + return info.ws_col; + } + + double getSeconds() { + struct timeval tp; gettimeofday(&tp,nullptr); + return double(tp.tv_sec) + double(tp.tv_usec)/1E6; + } + + void sleepSeconds(double t) { + usleep(1000000.0*t); + } +} +#endif + diff --git a/thirdparty/embree/common/sys/sysinfo.h b/thirdparty/embree/common/sys/sysinfo.h new file mode 100644 index 000000000000..cee1017ddecf --- /dev/null +++ b/thirdparty/embree/common/sys/sysinfo.h @@ -0,0 +1,182 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#define CACHELINE_SIZE 64 + +#if !defined(PAGE_SIZE) + #define PAGE_SIZE 4096 +#endif + +#define PAGE_SIZE_2M (2*1024*1024) +#define PAGE_SIZE_4K (4*1024) + +#include "platform.h" + +/* define isa namespace and ISA bitvector */ +#if defined (__AVX512VL__) +# define isa avx512skx +# define ISA AVX512SKX +# define ISA_STR "AVX512SKX" +#elif defined (__AVX512F__) +# define isa avx512knl +# define ISA AVX512KNL +# define ISA_STR "AVX512KNL" +#elif defined (__AVX2__) +# define isa avx2 +# define ISA AVX2 +# define ISA_STR "AVX2" +#elif defined(__AVXI__) +# define isa avxi +# define ISA AVXI +# define ISA_STR "AVXI" +#elif defined(__AVX__) +# define isa avx +# define ISA AVX +# define ISA_STR "AVX" +#elif defined (__SSE4_2__) +# define isa sse42 +# define ISA SSE42 +# define ISA_STR "SSE4.2" +//#elif defined (__SSE4_1__) // we demote this to SSE2, MacOSX code compiles with SSE41 by default with XCode 11 +//# define isa sse41 +//# define ISA SSE41 +//# define ISA_STR "SSE4.1" +//#elif defined(__SSSE3__) // we demote this to SSE2, MacOSX code compiles with SSSE3 by default with ICC +//# define isa ssse3 +//# define ISA SSSE3 +//# define ISA_STR "SSSE3" +//#elif defined(__SSE3__) // we demote this to SSE2, MacOSX code compiles with SSE3 by default with clang +//# define isa sse3 +//# define ISA SSE3 +//# define ISA_STR "SSE3" +#elif defined(__SSE2__) || defined(__SSE3__) || defined(__SSSE3__) +# define isa sse2 +# define ISA SSE2 +# define ISA_STR "SSE2" +#elif defined(__SSE__) +# define isa sse +# define ISA SSE +# define ISA_STR "SSE" +#else +#error Unknown ISA +#endif + +namespace embree +{ + enum class CPU + { + XEON_ICE_LAKE, + CORE_ICE_LAKE, + CORE_TIGER_LAKE, + CORE_COMET_LAKE, + CORE_CANNON_LAKE, + CORE_KABY_LAKE, + XEON_SKY_LAKE, + CORE_SKY_LAKE, + XEON_PHI_KNIGHTS_MILL, + XEON_PHI_KNIGHTS_LANDING, + XEON_BROADWELL, + CORE_BROADWELL, + XEON_HASWELL, + CORE_HASWELL, + XEON_IVY_BRIDGE, + CORE_IVY_BRIDGE, + SANDY_BRIDGE, + NEHALEM, + CORE2, + CORE1, + UNKNOWN, + }; + + /*! get the full path to the running executable */ + std::string getExecutableFileName(); + + /*! return platform name */ + std::string getPlatformName(); + + /*! get the full name of the compiler */ + std::string getCompilerName(); + + /*! return the name of the CPU */ + std::string getCPUVendor(); + + /*! get microprocessor model */ + CPU getCPUModel(); + + /*! converts CPU model into string */ + std::string stringOfCPUModel(CPU model); + + /*! CPU features */ + static const int CPU_FEATURE_SSE = 1 << 0; + static const int CPU_FEATURE_SSE2 = 1 << 1; + static const int CPU_FEATURE_SSE3 = 1 << 2; + static const int CPU_FEATURE_SSSE3 = 1 << 3; + static const int CPU_FEATURE_SSE41 = 1 << 4; + static const int CPU_FEATURE_SSE42 = 1 << 5; + static const int CPU_FEATURE_POPCNT = 1 << 6; + static const int CPU_FEATURE_AVX = 1 << 7; + static const int CPU_FEATURE_F16C = 1 << 8; + static const int CPU_FEATURE_RDRAND = 1 << 9; + static const int CPU_FEATURE_AVX2 = 1 << 10; + static const int CPU_FEATURE_FMA3 = 1 << 11; + static const int CPU_FEATURE_LZCNT = 1 << 12; + static const int CPU_FEATURE_BMI1 = 1 << 13; + static const int CPU_FEATURE_BMI2 = 1 << 14; + static const int CPU_FEATURE_AVX512F = 1 << 16; + static const int CPU_FEATURE_AVX512DQ = 1 << 17; + static const int CPU_FEATURE_AVX512PF = 1 << 18; + static const int CPU_FEATURE_AVX512ER = 1 << 19; + static const int CPU_FEATURE_AVX512CD = 1 << 20; + static const int CPU_FEATURE_AVX512BW = 1 << 21; + static const int CPU_FEATURE_AVX512VL = 1 << 22; + static const int CPU_FEATURE_AVX512IFMA = 1 << 23; + static const int CPU_FEATURE_AVX512VBMI = 1 << 24; + static const int CPU_FEATURE_XMM_ENABLED = 1 << 25; + static const int CPU_FEATURE_YMM_ENABLED = 1 << 26; + static const int CPU_FEATURE_ZMM_ENABLED = 1 << 27; + + /*! get CPU features */ + int getCPUFeatures(); + + /*! convert CPU features into a string */ + std::string stringOfCPUFeatures(int features); + + /*! creates a string of all supported targets that are supported */ + std::string supportedTargetList (int isa); + + /*! ISAs */ + static const int SSE = CPU_FEATURE_SSE | CPU_FEATURE_XMM_ENABLED; + static const int SSE2 = SSE | CPU_FEATURE_SSE2; + static const int SSE3 = SSE2 | CPU_FEATURE_SSE3; + static const int SSSE3 = SSE3 | CPU_FEATURE_SSSE3; + static const int SSE41 = SSSE3 | CPU_FEATURE_SSE41; + static const int SSE42 = SSE41 | CPU_FEATURE_SSE42 | CPU_FEATURE_POPCNT; + static const int AVX = SSE42 | CPU_FEATURE_AVX | CPU_FEATURE_YMM_ENABLED; + static const int AVXI = AVX | CPU_FEATURE_F16C | CPU_FEATURE_RDRAND; + static const int AVX2 = AVXI | CPU_FEATURE_AVX2 | CPU_FEATURE_FMA3 | CPU_FEATURE_BMI1 | CPU_FEATURE_BMI2 | CPU_FEATURE_LZCNT; + static const int AVX512KNL = AVX2 | CPU_FEATURE_AVX512F | CPU_FEATURE_AVX512PF | CPU_FEATURE_AVX512ER | CPU_FEATURE_AVX512CD | CPU_FEATURE_ZMM_ENABLED; + static const int AVX512SKX = AVX2 | CPU_FEATURE_AVX512F | CPU_FEATURE_AVX512DQ | CPU_FEATURE_AVX512CD | CPU_FEATURE_AVX512BW | CPU_FEATURE_AVX512VL | CPU_FEATURE_ZMM_ENABLED; + + /*! converts ISA bitvector into a string */ + std::string stringOfISA(int features); + + /*! return the number of logical threads of the system */ + unsigned int getNumberOfLogicalThreads(); + + /*! returns the size of the terminal window in characters */ + int getTerminalWidth(); + + /*! returns performance counter in seconds */ + double getSeconds(); + + /*! sleeps the specified number of seconds */ + void sleepSeconds(double t); + + /*! returns virtual address space occupied by process */ + size_t getVirtualMemoryBytes(); + + /*! returns resident memory required by process */ + size_t getResidentMemoryBytes(); +} diff --git a/thirdparty/embree/common/sys/thread.cpp b/thirdparty/embree/common/sys/thread.cpp new file mode 100644 index 000000000000..4d86853c47c2 --- /dev/null +++ b/thirdparty/embree/common/sys/thread.cpp @@ -0,0 +1,425 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "thread.h" +#include "sysinfo.h" +#include "string.h" + +#include +#include + +#if defined(PTHREADS_WIN32) +#pragma comment (lib, "pthreadVC.lib") +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// Windows Platform +//////////////////////////////////////////////////////////////////////////////// + +#if defined(__WIN32__) + +#define WIN32_LEAN_AND_MEAN +#include + +namespace embree +{ + /*! set the affinity of a given thread */ + void setAffinity(HANDLE thread, ssize_t affinity) + { + typedef WORD (WINAPI *GetActiveProcessorGroupCountFunc)(); + typedef DWORD (WINAPI *GetActiveProcessorCountFunc)(WORD); + typedef BOOL (WINAPI *SetThreadGroupAffinityFunc)(HANDLE, const GROUP_AFFINITY *, PGROUP_AFFINITY); + typedef BOOL (WINAPI *SetThreadIdealProcessorExFunc)(HANDLE, PPROCESSOR_NUMBER, PPROCESSOR_NUMBER); + HMODULE hlib = LoadLibrary("Kernel32"); + GetActiveProcessorGroupCountFunc pGetActiveProcessorGroupCount = (GetActiveProcessorGroupCountFunc)GetProcAddress(hlib, "GetActiveProcessorGroupCount"); + GetActiveProcessorCountFunc pGetActiveProcessorCount = (GetActiveProcessorCountFunc)GetProcAddress(hlib, "GetActiveProcessorCount"); + SetThreadGroupAffinityFunc pSetThreadGroupAffinity = (SetThreadGroupAffinityFunc)GetProcAddress(hlib, "SetThreadGroupAffinity"); + SetThreadIdealProcessorExFunc pSetThreadIdealProcessorEx = (SetThreadIdealProcessorExFunc)GetProcAddress(hlib, "SetThreadIdealProcessorEx"); + if (pGetActiveProcessorGroupCount && pGetActiveProcessorCount && pSetThreadGroupAffinity && pSetThreadIdealProcessorEx) + { + int groups = pGetActiveProcessorGroupCount(); + int totalProcessors = 0, group = 0, number = 0; + for (int i = 0; i affinity) { + group = i; + number = (int)affinity - totalProcessors; + break; + } + totalProcessors += processors; + } + + GROUP_AFFINITY groupAffinity; + groupAffinity.Group = (WORD)group; + groupAffinity.Mask = (KAFFINITY)(uint64_t(1) << number); + groupAffinity.Reserved[0] = 0; + groupAffinity.Reserved[1] = 0; + groupAffinity.Reserved[2] = 0; + if (!pSetThreadGroupAffinity(thread, &groupAffinity, nullptr)) + WARNING("SetThreadGroupAffinity failed"); // on purpose only a warning + + PROCESSOR_NUMBER processorNumber; + processorNumber.Group = group; + processorNumber.Number = number; + processorNumber.Reserved = 0; + if (!pSetThreadIdealProcessorEx(thread, &processorNumber, nullptr)) + WARNING("SetThreadIdealProcessorEx failed"); // on purpose only a warning + } + else + { + if (!SetThreadAffinityMask(thread, DWORD_PTR(uint64_t(1) << affinity))) + WARNING("SetThreadAffinityMask failed"); // on purpose only a warning + if (SetThreadIdealProcessor(thread, (DWORD)affinity) == (DWORD)-1) + WARNING("SetThreadIdealProcessor failed"); // on purpose only a warning + } + } + + /*! set affinity of the calling thread */ + void setAffinity(ssize_t affinity) { + setAffinity(GetCurrentThread(), affinity); + } + + struct ThreadStartupData + { + public: + ThreadStartupData (thread_func f, void* arg) + : f(f), arg(arg) {} + public: + thread_func f; + void* arg; + }; + + DWORD WINAPI threadStartup(LPVOID ptr) + { + ThreadStartupData* parg = (ThreadStartupData*) ptr; + _mm_setcsr(_mm_getcsr() | /*FTZ:*/ (1<<15) | /*DAZ:*/ (1<<6)); + parg->f(parg->arg); + delete parg; + return 0; + } + +#if !defined(PTHREADS_WIN32) + + /*! creates a hardware thread running on specific core */ + thread_t createThread(thread_func f, void* arg, size_t stack_size, ssize_t threadID) + { + HANDLE thread = CreateThread(nullptr, stack_size, threadStartup, new ThreadStartupData(f,arg), 0, nullptr); + if (thread == nullptr) FATAL("CreateThread failed"); + if (threadID >= 0) setAffinity(thread, threadID); + return thread_t(thread); + } + + /*! the thread calling this function gets yielded */ + void yield() { + SwitchToThread(); + } + + /*! waits until the given thread has terminated */ + void join(thread_t tid) { + WaitForSingleObject(HANDLE(tid), INFINITE); + CloseHandle(HANDLE(tid)); + } + + /*! destroy a hardware thread by its handle */ + void destroyThread(thread_t tid) { + TerminateThread(HANDLE(tid),0); + CloseHandle(HANDLE(tid)); + } + + /*! creates thread local storage */ + tls_t createTls() { + return tls_t(size_t(TlsAlloc())); + } + + /*! set the thread local storage pointer */ + void setTls(tls_t tls, void* const ptr) { + TlsSetValue(DWORD(size_t(tls)), ptr); + } + + /*! return the thread local storage pointer */ + void* getTls(tls_t tls) { + return TlsGetValue(DWORD(size_t(tls))); + } + + /*! destroys thread local storage identifier */ + void destroyTls(tls_t tls) { + TlsFree(DWORD(size_t(tls))); + } +#endif +} + +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// Linux Platform +//////////////////////////////////////////////////////////////////////////////// + +#if defined(__LINUX__) + +#include +#include +#include + +namespace embree +{ + static MutexSys mutex; + static std::vector threadIDs; + + /* changes thread ID mapping such that we first fill up all thread on one core */ + size_t mapThreadID(size_t threadID) + { + Lock lock(mutex); + + if (threadIDs.size() == 0) + { + /* parse thread/CPU topology */ + for (size_t cpuID=0;;cpuID++) + { + std::fstream fs; + std::string cpu = std::string("/sys/devices/system/cpu/cpu") + std::to_string((long long)cpuID) + std::string("/topology/thread_siblings_list"); + fs.open (cpu.c_str(), std::fstream::in); + if (fs.fail()) break; + + int i; + while (fs >> i) + { + if (std::none_of(threadIDs.begin(),threadIDs.end(),[&] (int id) { return id == i; })) + threadIDs.push_back(i); + if (fs.peek() == ',') + fs.ignore(); + } + fs.close(); + } + +#if 0 + for (size_t i=0;i " << threadIDs[i] << std::endl; +#endif + + /* verify the mapping and do not use it if the mapping has errors */ + for (size_t i=0;i + +namespace embree +{ + /*! set affinity of the calling thread */ + void setAffinity(ssize_t affinity) + { + cpuset_t cset; + CPU_ZERO(&cset); + CPU_SET(affinity, &cset); + + pthread_setaffinity_np(pthread_self(), sizeof(cset), &cset); + } +} +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// MacOSX Platform +//////////////////////////////////////////////////////////////////////////////// + +#if defined(__MACOSX__) + +#include +#include +#include + +namespace embree +{ + /*! set affinity of the calling thread */ + void setAffinity(ssize_t affinity) + { + thread_affinity_policy ap; + ap.affinity_tag = affinity; + if (thread_policy_set(mach_thread_self(),THREAD_AFFINITY_POLICY,(thread_policy_t)&ap,THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS) + WARNING("setting thread affinity failed"); // on purpose only a warning + } +} +#endif + +//////////////////////////////////////////////////////////////////////////////// +/// Unix Platform +//////////////////////////////////////////////////////////////////////////////// + +#if defined(__UNIX__) || defined(PTHREADS_WIN32) + +#include +#include + +#if defined(__USE_NUMA__) +#include +#endif + +namespace embree +{ + struct ThreadStartupData + { + public: + ThreadStartupData (thread_func f, void* arg, int affinity) + : f(f), arg(arg), affinity(affinity) {} + public: + thread_func f; + void* arg; + ssize_t affinity; + }; + + static void* threadStartup(ThreadStartupData* parg) + { + _mm_setcsr(_mm_getcsr() | /*FTZ:*/ (1<<15) | /*DAZ:*/ (1<<6)); + + /*! Mac OS X does not support setting affinity at thread creation time */ +#if defined(__MACOSX__) + if (parg->affinity >= 0) + setAffinity(parg->affinity); +#endif + + parg->f(parg->arg); + delete parg; + return nullptr; + } + + /*! creates a hardware thread running on specific core */ + thread_t createThread(thread_func f, void* arg, size_t stack_size, ssize_t threadID) + { + /* set stack size */ + pthread_attr_t attr; + pthread_attr_init(&attr); + if (stack_size > 0) pthread_attr_setstacksize (&attr, stack_size); + + /* create thread */ + pthread_t* tid = new pthread_t; + if (pthread_create(tid,&attr,(void*(*)(void*))threadStartup,new ThreadStartupData(f,arg,threadID)) != 0) { + pthread_attr_destroy(&attr); + delete tid; + FATAL("pthread_create failed"); + } + pthread_attr_destroy(&attr); + + /* set affinity */ +#if defined(__LINUX__) + if (threadID >= 0) { + cpu_set_t cset; + CPU_ZERO(&cset); + threadID = mapThreadID(threadID); + CPU_SET(threadID, &cset); + pthread_setaffinity_np(*tid, sizeof(cset), &cset); + } +#elif defined(__FreeBSD__) + if (threadID >= 0) { + cpuset_t cset; + CPU_ZERO(&cset); + CPU_SET(threadID, &cset); + pthread_setaffinity_np(*tid, sizeof(cset), &cset); + } +#endif + + return thread_t(tid); + } + + /*! the thread calling this function gets yielded */ + void yield() { + sched_yield(); + } + + /*! waits until the given thread has terminated */ + void join(thread_t tid) { + if (pthread_join(*(pthread_t*)tid, nullptr) != 0) + FATAL("pthread_join failed"); + delete (pthread_t*)tid; + } + + /*! destroy a hardware thread by its handle */ + void destroyThread(thread_t tid) { + pthread_cancel(*(pthread_t*)tid); + delete (pthread_t*)tid; + } + + /*! creates thread local storage */ + tls_t createTls() + { + pthread_key_t* key = new pthread_key_t; + if (pthread_key_create(key,nullptr) != 0) { + delete key; + FATAL("pthread_key_create failed"); + } + + return tls_t(key); + } + + /*! return the thread local storage pointer */ + void* getTls(tls_t tls) + { + assert(tls); + return pthread_getspecific(*(pthread_key_t*)tls); + } + + /*! set the thread local storage pointer */ + void setTls(tls_t tls, void* const ptr) + { + assert(tls); + if (pthread_setspecific(*(pthread_key_t*)tls, ptr) != 0) + FATAL("pthread_setspecific failed"); + } + + /*! destroys thread local storage identifier */ + void destroyTls(tls_t tls) + { + assert(tls); + if (pthread_key_delete(*(pthread_key_t*)tls) != 0) + FATAL("pthread_key_delete failed"); + delete (pthread_key_t*)tls; + } +} + +#endif diff --git a/thirdparty/embree/common/sys/thread.h b/thirdparty/embree/common/sys/thread.h new file mode 100644 index 000000000000..5261a985eefa --- /dev/null +++ b/thirdparty/embree/common/sys/thread.h @@ -0,0 +1,49 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "platform.h" +#include "mutex.h" +#include "alloc.h" +#include "vector.h" +#include + +namespace embree +{ + /*! type for thread */ + typedef struct opaque_thread_t* thread_t; + + /*! signature of thread start function */ + typedef void (*thread_func)(void*); + + /*! creates a hardware thread running on specific logical thread */ + thread_t createThread(thread_func f, void* arg, size_t stack_size = 0, ssize_t threadID = -1); + + /*! set affinity of the calling thread */ + void setAffinity(ssize_t affinity); + + /*! the thread calling this function gets yielded */ + void yield(); + + /*! waits until the given thread has terminated */ + void join(thread_t tid); + + /*! destroy handle of a thread */ + void destroyThread(thread_t tid); + + /*! type for handle to thread local storage */ + typedef struct opaque_tls_t* tls_t; + + /*! creates thread local storage */ + tls_t createTls(); + + /*! set the thread local storage pointer */ + void setTls(tls_t tls, void* const ptr); + + /*! return the thread local storage pointer */ + void* getTls(tls_t tls); + + /*! destroys thread local storage identifier */ + void destroyTls(tls_t tls); +} diff --git a/thirdparty/embree/common/sys/vector.h b/thirdparty/embree/common/sys/vector.h new file mode 100644 index 000000000000..e41794de7c53 --- /dev/null +++ b/thirdparty/embree/common/sys/vector.h @@ -0,0 +1,242 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "alloc.h" +#include + +namespace embree +{ + template + class vector_t + { + public: + typedef T value_type; + typedef T* iterator; + typedef const T* const_iterator; + + __forceinline vector_t () + : size_active(0), size_alloced(0), items(nullptr) {} + + __forceinline explicit vector_t (size_t sz) + : size_active(0), size_alloced(0), items(nullptr) { internal_resize_init(sz); } + + template + __forceinline explicit vector_t (M alloc, size_t sz) + : alloc(alloc), size_active(0), size_alloced(0), items(nullptr) { internal_resize_init(sz); } + + __forceinline ~vector_t() { + clear(); + } + + __forceinline vector_t (const vector_t& other) + { + size_active = other.size_active; + size_alloced = other.size_alloced; + items = alloc.allocate(size_alloced); + for (size_t i=0; i 0); return items[0]; }; + __forceinline T& back () const { assert(size_active > 0); return items[size_active-1]; }; + + __forceinline T* data() { return items; }; + __forceinline const T* data() const { return items; }; + + + /******************** Modifiers **************************/ + + __forceinline void push_back(const T& nt) + { + const T v = nt; // need local copy as input reference could point to this vector + internal_resize(size_active,internal_grow_size(size_active+1)); + ::new (&items[size_active++]) T(v); + } + + __forceinline void pop_back() + { + assert(!empty()); + size_active--; + alloc.destroy(&items[size_active]); + } + + __forceinline void clear() + { + /* destroy elements */ + for (size_t i=0; i + using vector = vector_t>; + + /*! vector class that performs aligned allocations */ + template + using avector = vector_t::value> >; + + /*! vector class that performs OS allocations */ + template + using ovector = vector_t >; +} diff --git a/thirdparty/embree/common/tasking/taskscheduler.h b/thirdparty/embree/common/tasking/taskscheduler.h new file mode 100644 index 000000000000..298d09255b2c --- /dev/null +++ b/thirdparty/embree/common/tasking/taskscheduler.h @@ -0,0 +1,15 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#if defined(TASKING_INTERNAL) +# include "taskschedulerinternal.h" +#elif defined(TASKING_TBB) +# include "taskschedulertbb.h" +#elif defined(TASKING_PPL) +# include "taskschedulerppl.h" +#else +# error "no tasking system enabled" +#endif + diff --git a/thirdparty/embree/common/tasking/taskschedulerinternal.cpp b/thirdparty/embree/common/tasking/taskschedulerinternal.cpp new file mode 100644 index 000000000000..923d62f83c86 --- /dev/null +++ b/thirdparty/embree/common/tasking/taskschedulerinternal.cpp @@ -0,0 +1,416 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "taskschedulerinternal.h" +#include "../math/math.h" +#include "../sys/sysinfo.h" +#include + +namespace embree +{ + RTC_NAMESPACE_BEGIN + + static MutexSys g_mutex; + size_t TaskScheduler::g_numThreads = 0; + __thread TaskScheduler* TaskScheduler::g_instance = nullptr; + std::vector> g_instance_vector; + __thread TaskScheduler::Thread* TaskScheduler::thread_local_thread = nullptr; + TaskScheduler::ThreadPool* TaskScheduler::threadPool = nullptr; + + template + __forceinline void TaskScheduler::steal_loop(Thread& thread, const Predicate& pred, const Body& body) + { + while (true) + { + /*! some rounds that yield */ + for (size_t i=0; i<32; i++) + { + /*! some spinning rounds */ + const size_t threadCount = thread.threadCount(); + for (size_t j=0; j<1024; j+=threadCount) + { + if (!pred()) return; + if (thread.scheduler->steal_from_other_threads(thread)) { + i=j=0; + body(); + } + } + yield(); + } + } + } + + /*! run this task */ + void TaskScheduler::Task::run_internal (Thread& thread) // FIXME: avoid as many dll_exports as possible + { + /* try to run if not already stolen */ + if (try_switch_state(INITIALIZED,DONE)) + { + Task* prevTask = thread.task; + thread.task = this; + try { + if (thread.scheduler->cancellingException == nullptr) + closure->execute(); + } catch (...) { + if (thread.scheduler->cancellingException == nullptr) + thread.scheduler->cancellingException = std::current_exception(); + } + thread.task = prevTask; + add_dependencies(-1); + } + + /* steal until all dependencies have completed */ + steal_loop(thread, + [&] () { return dependencies>0; }, + [&] () { while (thread.tasks.execute_local_internal(thread,this)); }); + + /* now signal our parent task that we are finished */ + if (parent) + parent->add_dependencies(-1); + } + + /*! run this task */ + dll_export void TaskScheduler::Task::run (Thread& thread) { + run_internal(thread); + } + + bool TaskScheduler::TaskQueue::execute_local_internal(Thread& thread, Task* parent) + { + /* stop if we run out of local tasks or reach the waiting task */ + if (right == 0 || &tasks[right-1] == parent) + return false; + + /* execute task */ + size_t oldRight = right; + tasks[right-1].run_internal(thread); + if (right != oldRight) { + THROW_RUNTIME_ERROR("you have to wait for spawned subtasks"); + } + + /* pop task and closure from stack */ + right--; + if (tasks[right].stackPtr != size_t(-1)) + stackPtr = tasks[right].stackPtr; + + /* also move left pointer */ + if (left >= right) left.store(right.load()); + + return right != 0; + } + + dll_export bool TaskScheduler::TaskQueue::execute_local(Thread& thread, Task* parent) { + return execute_local_internal(thread,parent); + } + + bool TaskScheduler::TaskQueue::steal(Thread& thread) + { + size_t l = left; + size_t r = right; + if (l < r) + { + l = left++; + if (l >= r) + return false; + } + else + return false; + + if (!tasks[l].try_steal(thread.tasks.tasks[thread.tasks.right])) + return false; + + thread.tasks.right++; + return true; + } + + /* we steal from the left */ + size_t TaskScheduler::TaskQueue::getTaskSizeAtLeft() + { + if (left >= right) return 0; + return tasks[left].N; + } + + void threadPoolFunction(std::pair* pair) + { + TaskScheduler::ThreadPool* pool = pair->first; + size_t threadIndex = pair->second; + delete pair; + pool->thread_loop(threadIndex); + } + + TaskScheduler::ThreadPool::ThreadPool(bool set_affinity) + : numThreads(0), numThreadsRunning(0), set_affinity(set_affinity), running(false) {} + + dll_export void TaskScheduler::ThreadPool::startThreads() + { + if (running) return; + setNumThreads(numThreads,true); + } + + void TaskScheduler::ThreadPool::setNumThreads(size_t newNumThreads, bool startThreads) + { + Lock lock(g_mutex); + assert(newNumThreads); + newNumThreads = min(newNumThreads, (size_t) getNumberOfLogicalThreads()); + + numThreads = newNumThreads; + if (!startThreads && !running) return; + running = true; + size_t numThreadsActive = numThreadsRunning; + + mutex.lock(); + numThreadsRunning = newNumThreads; + mutex.unlock(); + condition.notify_all(); + + /* start new threads */ + for (size_t t=numThreadsActive; t(this,t); + threads.push_back(createThread((thread_func)threadPoolFunction,pair,4*1024*1024,set_affinity ? t : -1)); + } + + /* stop some threads if we reduce the number of threads */ + for (ssize_t t=numThreadsActive-1; t>=ssize_t(numThreadsRunning); t--) { + if (t == 0) continue; + embree::join(threads.back()); + threads.pop_back(); + } + } + + TaskScheduler::ThreadPool::~ThreadPool() + { + /* leave all taskschedulers */ + mutex.lock(); + numThreadsRunning = 0; + mutex.unlock(); + condition.notify_all(); + + /* wait for threads to terminate */ + for (size_t i=0; i& scheduler) + { + mutex.lock(); + schedulers.push_back(scheduler); + mutex.unlock(); + condition.notify_all(); + } + + dll_export void TaskScheduler::ThreadPool::remove(const Ref& scheduler) + { + Lock lock(mutex); + for (std::list >::iterator it = schedulers.begin(); it != schedulers.end(); it++) { + if (scheduler == *it) { + schedulers.erase(it); + return; + } + } + } + + void TaskScheduler::ThreadPool::thread_loop(size_t globalThreadIndex) + { + while (globalThreadIndex < numThreadsRunning) + { + Ref scheduler = NULL; + ssize_t threadIndex = -1; + { + Lock lock(mutex); + condition.wait(mutex, [&] () { return globalThreadIndex >= numThreadsRunning || !schedulers.empty(); }); + if (globalThreadIndex >= numThreadsRunning) break; + scheduler = schedulers.front(); + threadIndex = scheduler->allocThreadIndex(); + } + scheduler->thread_loop(threadIndex); + } + } + + TaskScheduler::TaskScheduler() + : threadCounter(0), anyTasksRunning(0), hasRootTask(false) + { + threadLocal.resize(2*getNumberOfLogicalThreads()); // FIXME: this has to be 2x as in the compatibility join mode with rtcCommitScene the worker threads also join. When disallowing rtcCommitScene to join a build we can remove the 2x. + for (size_t i=0; ithreadIndex; + else return 0; + } + + dll_export size_t TaskScheduler::threadIndex() + { + Thread* thread = TaskScheduler::thread(); + if (thread) return thread->threadIndex; + else return 0; + } + + dll_export size_t TaskScheduler::threadCount() { + return threadPool->size(); + } + + dll_export TaskScheduler* TaskScheduler::instance() + { + if (g_instance == NULL) { + Lock lock(g_mutex); + g_instance = new TaskScheduler; + g_instance_vector.push_back(g_instance); + } + return g_instance; + } + + void TaskScheduler::create(size_t numThreads, bool set_affinity, bool start_threads) + { + if (!threadPool) threadPool = new TaskScheduler::ThreadPool(set_affinity); + threadPool->setNumThreads(numThreads,start_threads); + } + + void TaskScheduler::destroy() { + delete threadPool; threadPool = nullptr; + } + + dll_export ssize_t TaskScheduler::allocThreadIndex() + { + size_t threadIndex = threadCounter++; + assert(threadIndex < threadLocal.size()); + return threadIndex; + } + + void TaskScheduler::join() + { + mutex.lock(); + size_t threadIndex = allocThreadIndex(); + condition.wait(mutex, [&] () { return hasRootTask.load(); }); + mutex.unlock(); + std::exception_ptr except = thread_loop(threadIndex); + if (except != nullptr) std::rethrow_exception(except); + } + + void TaskScheduler::reset() { + hasRootTask = false; + } + + void TaskScheduler::wait_for_threads(size_t threadCount) + { + while (threadCounter < threadCount-1) + pause_cpu(); + } + + dll_export TaskScheduler::Thread* TaskScheduler::thread() { + return thread_local_thread; + } + + dll_export TaskScheduler::Thread* TaskScheduler::swapThread(Thread* thread) + { + Thread* old = thread_local_thread; + thread_local_thread = thread; + return old; + } + + dll_export bool TaskScheduler::wait() + { + Thread* thread = TaskScheduler::thread(); + if (thread == nullptr) return true; + while (thread->tasks.execute_local_internal(*thread,thread->task)) {}; + return thread->scheduler->cancellingException == nullptr; + } + + std::exception_ptr TaskScheduler::thread_loop(size_t threadIndex) + { + /* allocate thread structure */ + std::unique_ptr mthread(new Thread(threadIndex,this)); // too large for stack allocation + Thread& thread = *mthread; + threadLocal[threadIndex].store(&thread); + Thread* oldThread = swapThread(&thread); + + /* main thread loop */ + while (anyTasksRunning) + { + steal_loop(thread, + [&] () { return anyTasksRunning > 0; }, + [&] () { + anyTasksRunning++; + while (thread.tasks.execute_local_internal(thread,nullptr)); + anyTasksRunning--; + }); + } + threadLocal[threadIndex].store(nullptr); + swapThread(oldThread); + + /* remember exception to throw */ + std::exception_ptr except = nullptr; + if (cancellingException != nullptr) except = cancellingException; + + /* wait for all threads to terminate */ + threadCounter--; +#if defined(__WIN32__) + size_t loopIndex = 1; +#endif +#define LOOP_YIELD_THRESHOLD (4096) + while (threadCounter > 0) { +#if defined(__WIN32__) + if ((loopIndex % LOOP_YIELD_THRESHOLD) == 0) + yield(); + else +// -- GODOT start -- +#if !defined(__MINGW32__) +// -- GODOT end -- + _mm_pause(); +// -- GODOT start -- +#else + usleep(1); +#endif +// -- GODOT end -- + loopIndex++; +#else + yield(); +#endif + } + return except; + } + + bool TaskScheduler::steal_from_other_threads(Thread& thread) + { + const size_t threadIndex = thread.threadIndex; + const size_t threadCount = this->threadCounter; + + for (size_t i=1; i= threadCount) otherThreadIndex -= threadCount; + + Thread* othread = threadLocal[otherThreadIndex].load(); + if (!othread) + continue; + + if (othread->tasks.steal(thread)) + return true; + } + + return false; + } + + dll_export void TaskScheduler::startThreads() { + threadPool->startThreads(); + } + + dll_export void TaskScheduler::addScheduler(const Ref& scheduler) { + threadPool->add(scheduler); + } + + dll_export void TaskScheduler::removeScheduler(const Ref& scheduler) { + threadPool->remove(scheduler); + } + + RTC_NAMESPACE_END +} diff --git a/thirdparty/embree/common/tasking/taskschedulerinternal.h b/thirdparty/embree/common/tasking/taskschedulerinternal.h new file mode 100644 index 000000000000..ef4d65f6fd3f --- /dev/null +++ b/thirdparty/embree/common/tasking/taskschedulerinternal.h @@ -0,0 +1,376 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/platform.h" +#include "../sys/alloc.h" +#include "../sys/barrier.h" +#include "../sys/thread.h" +#include "../sys/mutex.h" +#include "../sys/condition.h" +#include "../sys/ref.h" +#include "../sys/atomic.h" +#include "../math/range.h" +#include "../../include/embree3/rtcore.h" + +#include + +namespace embree +{ + + /* The tasking system exports some symbols to be used by the tutorials. Thus we + hide is also in the API namespace when requested. */ + RTC_NAMESPACE_BEGIN + + struct TaskScheduler : public RefCount + { + ALIGNED_STRUCT_(64); + friend class Device; + + static const size_t TASK_STACK_SIZE = 4*1024; //!< task structure stack + static const size_t CLOSURE_STACK_SIZE = 512*1024; //!< stack for task closures + + struct Thread; + + /*! virtual interface for all tasks */ + struct TaskFunction { + virtual void execute() = 0; + }; + + /*! builds a task interface from a closure */ + template + struct ClosureTaskFunction : public TaskFunction + { + Closure closure; + __forceinline ClosureTaskFunction (const Closure& closure) : closure(closure) {} + void execute() { closure(); }; + }; + + struct __aligned(64) Task + { + /*! states a task can be in */ + enum { DONE, INITIALIZED }; + + /*! switch from one state to another */ + __forceinline void switch_state(int from, int to) + { + __memory_barrier(); + MAYBE_UNUSED bool success = state.compare_exchange_strong(from,to); + assert(success); + } + + /*! try to switch from one state to another */ + __forceinline bool try_switch_state(int from, int to) { + __memory_barrier(); + return state.compare_exchange_strong(from,to); + } + + /*! increment/decrement dependency counter */ + void add_dependencies(int n) { + dependencies+=n; + } + + /*! initialize all tasks to DONE state by default */ + __forceinline Task() + : state(DONE) {} + + /*! construction of new task */ + __forceinline Task (TaskFunction* closure, Task* parent, size_t stackPtr, size_t N) + : dependencies(1), stealable(true), closure(closure), parent(parent), stackPtr(stackPtr), N(N) + { + if (parent) parent->add_dependencies(+1); + switch_state(DONE,INITIALIZED); + } + + /*! construction of stolen task, stealing thread will decrement initial dependency */ + __forceinline Task (TaskFunction* closure, Task* parent) + : dependencies(1), stealable(false), closure(closure), parent(parent), stackPtr(-1), N(1) + { + switch_state(DONE,INITIALIZED); + } + + /*! try to steal this task */ + bool try_steal(Task& child) + { + if (!stealable) return false; + if (!try_switch_state(INITIALIZED,DONE)) return false; + new (&child) Task(closure, this); + return true; + } + + /*! run this task */ + dll_export void run(Thread& thread); + + void run_internal(Thread& thread); + + public: + std::atomic state; //!< state this task is in + std::atomic dependencies; //!< dependencies to wait for + std::atomic stealable; //!< true if task can be stolen + TaskFunction* closure; //!< the closure to execute + Task* parent; //!< parent task to signal when we are finished + size_t stackPtr; //!< stack location where closure is stored + size_t N; //!< approximative size of task + }; + + struct TaskQueue + { + TaskQueue () + : left(0), right(0), stackPtr(0) {} + + __forceinline void* alloc(size_t bytes, size_t align = 64) + { + size_t ofs = bytes + ((align - stackPtr) & (align-1)); + if (stackPtr + ofs > CLOSURE_STACK_SIZE) + throw std::runtime_error("closure stack overflow"); + stackPtr += ofs; + return &stack[stackPtr-bytes]; + } + + template + __forceinline void push_right(Thread& thread, const size_t size, const Closure& closure) + { + if (right >= TASK_STACK_SIZE) + throw std::runtime_error("task stack overflow"); + + /* allocate new task on right side of stack */ + size_t oldStackPtr = stackPtr; + TaskFunction* func = new (alloc(sizeof(ClosureTaskFunction))) ClosureTaskFunction(closure); + new (&tasks[right]) Task(func,thread.task,oldStackPtr,size); + right++; + + /* also move left pointer */ + if (left >= right-1) left = right-1; + } + + dll_export bool execute_local(Thread& thread, Task* parent); + bool execute_local_internal(Thread& thread, Task* parent); + bool steal(Thread& thread); + size_t getTaskSizeAtLeft(); + + bool empty() { return right == 0; } + + public: + + /* task stack */ + Task tasks[TASK_STACK_SIZE]; + __aligned(64) std::atomic left; //!< threads steal from left + __aligned(64) std::atomic right; //!< new tasks are added to the right + + /* closure stack */ + __aligned(64) char stack[CLOSURE_STACK_SIZE]; + size_t stackPtr; + }; + + /*! thread local structure for each thread */ + struct Thread + { + ALIGNED_STRUCT_(64); + + Thread (size_t threadIndex, const Ref& scheduler) + : threadIndex(threadIndex), task(nullptr), scheduler(scheduler) {} + + __forceinline size_t threadCount() { + return scheduler->threadCounter; + } + + size_t threadIndex; //!< ID of this thread + TaskQueue tasks; //!< local task queue + Task* task; //!< current active task + Ref scheduler; //!< pointer to task scheduler + }; + + /*! pool of worker threads */ + struct ThreadPool + { + ThreadPool (bool set_affinity); + ~ThreadPool (); + + /*! starts the threads */ + dll_export void startThreads(); + + /*! sets number of threads to use */ + void setNumThreads(size_t numThreads, bool startThreads = false); + + /*! adds a task scheduler object for scheduling */ + dll_export void add(const Ref& scheduler); + + /*! remove the task scheduler object again */ + dll_export void remove(const Ref& scheduler); + + /*! returns number of threads of the thread pool */ + size_t size() const { return numThreads; } + + /*! main loop for all threads */ + void thread_loop(size_t threadIndex); + + private: + std::atomic numThreads; + std::atomic numThreadsRunning; + bool set_affinity; + std::atomic running; + std::vector threads; + + private: + MutexSys mutex; + ConditionSys condition; + std::list > schedulers; + }; + + TaskScheduler (); + ~TaskScheduler (); + + /*! initializes the task scheduler */ + static void create(size_t numThreads, bool set_affinity, bool start_threads); + + /*! destroys the task scheduler again */ + static void destroy(); + + /*! lets new worker threads join the tasking system */ + void join(); + void reset(); + + /*! let a worker thread allocate a thread index */ + dll_export ssize_t allocThreadIndex(); + + /*! wait for some number of threads available (threadCount includes main thread) */ + void wait_for_threads(size_t threadCount); + + /*! thread loop for all worker threads */ + std::exception_ptr thread_loop(size_t threadIndex); + + /*! steals a task from a different thread */ + bool steal_from_other_threads(Thread& thread); + + template + static void steal_loop(Thread& thread, const Predicate& pred, const Body& body); + + /* spawn a new task at the top of the threads task stack */ + template + void spawn_root(const Closure& closure, size_t size = 1, bool useThreadPool = true) + { + if (useThreadPool) startThreads(); + + size_t threadIndex = allocThreadIndex(); + std::unique_ptr mthread(new Thread(threadIndex,this)); // too large for stack allocation + Thread& thread = *mthread; + assert(threadLocal[threadIndex].load() == nullptr); + threadLocal[threadIndex] = &thread; + Thread* oldThread = swapThread(&thread); + thread.tasks.push_right(thread,size,closure); + { + Lock lock(mutex); + anyTasksRunning++; + hasRootTask = true; + condition.notify_all(); + } + + if (useThreadPool) addScheduler(this); + + while (thread.tasks.execute_local(thread,nullptr)); + anyTasksRunning--; + if (useThreadPool) removeScheduler(this); + + threadLocal[threadIndex] = nullptr; + swapThread(oldThread); + + /* remember exception to throw */ + std::exception_ptr except = nullptr; + if (cancellingException != nullptr) except = cancellingException; + + /* wait for all threads to terminate */ + threadCounter--; + while (threadCounter > 0) yield(); + cancellingException = nullptr; + + /* re-throw proper exception */ + if (except != nullptr) + std::rethrow_exception(except); + } + + /* spawn a new task at the top of the threads task stack */ + template + static __forceinline void spawn(size_t size, const Closure& closure) + { + Thread* thread = TaskScheduler::thread(); + if (likely(thread != nullptr)) thread->tasks.push_right(*thread,size,closure); + else instance()->spawn_root(closure,size); + } + + /* spawn a new task at the top of the threads task stack */ + template + static __forceinline void spawn(const Closure& closure) { + spawn(1,closure); + } + + /* spawn a new task set */ + template + static void spawn(const Index begin, const Index end, const Index blockSize, const Closure& closure) + { + spawn(end-begin, [=]() + { + if (end-begin <= blockSize) { + return closure(range(begin,end)); + } + const Index center = (begin+end)/2; + spawn(begin,center,blockSize,closure); + spawn(center,end ,blockSize,closure); + wait(); + }); + } + + /* work on spawned subtasks and wait until all have finished */ + dll_export static bool wait(); + + /* returns the ID of the current thread */ + dll_export static size_t threadID(); + + /* returns the index (0..threadCount-1) of the current thread */ + dll_export static size_t threadIndex(); + + /* returns the total number of threads */ + dll_export static size_t threadCount(); + + private: + + /* returns the thread local task list of this worker thread */ + dll_export static Thread* thread(); + + /* sets the thread local task list of this worker thread */ + dll_export static Thread* swapThread(Thread* thread); + + /*! returns the taskscheduler object to be used by the master thread */ + dll_export static TaskScheduler* instance(); + + /*! starts the threads */ + dll_export static void startThreads(); + + /*! adds a task scheduler object for scheduling */ + dll_export static void addScheduler(const Ref& scheduler); + + /*! remove the task scheduler object again */ + dll_export static void removeScheduler(const Ref& scheduler); + + private: + std::vector> threadLocal; + std::atomic threadCounter; + std::atomic anyTasksRunning; + std::atomic hasRootTask; + std::exception_ptr cancellingException; + MutexSys mutex; + ConditionSys condition; + + private: + static size_t g_numThreads; + static __thread TaskScheduler* g_instance; + static __thread Thread* thread_local_thread; + static ThreadPool* threadPool; + }; + + RTC_NAMESPACE_END + +#if defined(RTC_NAMESPACE) + using RTC_NAMESPACE::TaskScheduler; +#endif +} diff --git a/thirdparty/embree/common/tasking/taskschedulerppl.h b/thirdparty/embree/common/tasking/taskschedulerppl.h new file mode 100644 index 000000000000..776f98cdacd8 --- /dev/null +++ b/thirdparty/embree/common/tasking/taskschedulerppl.h @@ -0,0 +1,46 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/platform.h" +#include "../sys/alloc.h" +#include "../sys/barrier.h" +#include "../sys/thread.h" +#include "../sys/mutex.h" +#include "../sys/condition.h" +#include "../sys/ref.h" + +#if !defined(__WIN32__) +#error PPL tasking system only available under windows +#endif + +#include + +namespace embree +{ + struct TaskScheduler + { + /*! initializes the task scheduler */ + static void create(size_t numThreads, bool set_affinity, bool start_threads); + + /*! destroys the task scheduler again */ + static void destroy(); + + /* returns the ID of the current thread */ + static __forceinline size_t threadID() { + return GetCurrentThreadId(); + } + + /* returns the index (0..threadCount-1) of the current thread */ + /* FIXME: threadIndex is NOT supported by PPL! */ + static __forceinline size_t threadIndex() { + return 0; + } + + /* returns the total number of threads */ + static __forceinline size_t threadCount() { + return GetMaximumProcessorCount(ALL_PROCESSOR_GROUPS) + 1; + } + }; +}; diff --git a/thirdparty/embree/common/tasking/taskschedulertbb.h b/thirdparty/embree/common/tasking/taskschedulertbb.h new file mode 100644 index 000000000000..369e5edf076b --- /dev/null +++ b/thirdparty/embree/common/tasking/taskschedulertbb.h @@ -0,0 +1,73 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../sys/platform.h" +#include "../sys/alloc.h" +#include "../sys/barrier.h" +#include "../sys/thread.h" +#include "../sys/mutex.h" +#include "../sys/condition.h" +#include "../sys/ref.h" + +#if defined(__WIN32__) +// -- GODOT start -- +#if !defined(NOMINMAX) +// -- GODOT end -- +# define NOMINMAX +// -- GODOT start -- +#endif +// -- GODOT end -- +#endif + +// We need to define these to avoid implicit linkage against +// tbb_debug.lib under Windows. When removing these lines debug build +// under Windows fails. +#define __TBB_NO_IMPLICIT_LINKAGE 1 +#define __TBBMALLOC_NO_IMPLICIT_LINKAGE 1 +#define TBB_SUPPRESS_DEPRECATED_MESSAGES 1 +#define TBB_PREVIEW_ISOLATED_TASK_GROUP 1 +#include "tbb/tbb.h" +#include "tbb/parallel_sort.h" + +namespace embree +{ + struct TaskScheduler + { + /*! initializes the task scheduler */ + static void create(size_t numThreads, bool set_affinity, bool start_threads); + + /*! destroys the task scheduler again */ + static void destroy(); + + /* returns the ID of the current thread */ + static __forceinline size_t threadID() + { + return threadIndex(); + } + + /* returns the index (0..threadCount-1) of the current thread */ + static __forceinline size_t threadIndex() + { +#if TBB_INTERFACE_VERSION >= 9100 + return tbb::this_task_arena::current_thread_index(); +#elif TBB_INTERFACE_VERSION >= 9000 + return tbb::task_arena::current_thread_index(); +#else + return 0; +#endif + } + + /* returns the total number of threads */ + static __forceinline size_t threadCount() { +#if TBB_INTERFACE_VERSION >= 9100 + return tbb::this_task_arena::max_concurrency(); +#else + return tbb::task_scheduler_init::default_num_threads(); +#endif + } + + }; + +}; diff --git a/thirdparty/embree/include/embree3/rtcore.h b/thirdparty/embree/include/embree3/rtcore.h new file mode 100644 index 000000000000..5830bb58801b --- /dev/null +++ b/thirdparty/embree/include/embree3/rtcore.h @@ -0,0 +1,14 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "rtcore_config.h" +#include "rtcore_common.h" +#include "rtcore_device.h" +#include "rtcore_buffer.h" +#include "rtcore_ray.h" +#include "rtcore_geometry.h" +#include "rtcore_scene.h" +#include "rtcore_builder.h" +#include "rtcore_quaternion.h" diff --git a/thirdparty/embree/include/embree3/rtcore_buffer.h b/thirdparty/embree/include/embree3/rtcore_buffer.h new file mode 100644 index 000000000000..400b604aa53f --- /dev/null +++ b/thirdparty/embree/include/embree3/rtcore_buffer.h @@ -0,0 +1,51 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "rtcore_device.h" + +RTC_NAMESPACE_BEGIN + +/* Types of buffers */ +enum RTCBufferType +{ + RTC_BUFFER_TYPE_INDEX = 0, + RTC_BUFFER_TYPE_VERTEX = 1, + RTC_BUFFER_TYPE_VERTEX_ATTRIBUTE = 2, + RTC_BUFFER_TYPE_NORMAL = 3, + RTC_BUFFER_TYPE_TANGENT = 4, + RTC_BUFFER_TYPE_NORMAL_DERIVATIVE = 5, + + RTC_BUFFER_TYPE_GRID = 8, + + RTC_BUFFER_TYPE_FACE = 16, + RTC_BUFFER_TYPE_LEVEL = 17, + RTC_BUFFER_TYPE_EDGE_CREASE_INDEX = 18, + RTC_BUFFER_TYPE_EDGE_CREASE_WEIGHT = 19, + RTC_BUFFER_TYPE_VERTEX_CREASE_INDEX = 20, + RTC_BUFFER_TYPE_VERTEX_CREASE_WEIGHT = 21, + RTC_BUFFER_TYPE_HOLE = 22, + + RTC_BUFFER_TYPE_FLAGS = 32 +}; + +/* Opaque buffer type */ +typedef struct RTCBufferTy* RTCBuffer; + +/* Creates a new buffer. */ +RTC_API RTCBuffer rtcNewBuffer(RTCDevice device, size_t byteSize); + +/* Creates a new shared buffer. */ +RTC_API RTCBuffer rtcNewSharedBuffer(RTCDevice device, void* ptr, size_t byteSize); + +/* Returns a pointer to the buffer data. */ +RTC_API void* rtcGetBufferData(RTCBuffer buffer); + +/* Retains the buffer (increments the reference count). */ +RTC_API void rtcRetainBuffer(RTCBuffer buffer); + +/* Releases the buffer (decrements the reference count). */ +RTC_API void rtcReleaseBuffer(RTCBuffer buffer); + +RTC_NAMESPACE_END diff --git a/thirdparty/embree/include/embree3/rtcore_builder.h b/thirdparty/embree/include/embree3/rtcore_builder.h new file mode 100644 index 000000000000..d62a7f72cc26 --- /dev/null +++ b/thirdparty/embree/include/embree3/rtcore_builder.h @@ -0,0 +1,125 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "rtcore_scene.h" + +RTC_NAMESPACE_BEGIN + +/* Opaque BVH type */ +typedef struct RTCBVHTy* RTCBVH; + +/* Input build primitives for the builder */ +struct RTC_ALIGN(32) RTCBuildPrimitive +{ + float lower_x, lower_y, lower_z; + unsigned int geomID; + float upper_x, upper_y, upper_z; + unsigned int primID; +}; + +/* Opaque thread local allocator type */ +typedef struct RTCThreadLocalAllocatorTy* RTCThreadLocalAllocator; + +/* Callback to create a node */ +typedef void* (*RTCCreateNodeFunction) (RTCThreadLocalAllocator allocator, unsigned int childCount, void* userPtr); + +/* Callback to set the pointer to all children */ +typedef void (*RTCSetNodeChildrenFunction) (void* nodePtr, void** children, unsigned int childCount, void* userPtr); + +/* Callback to set the bounds of all children */ +typedef void (*RTCSetNodeBoundsFunction) (void* nodePtr, const struct RTCBounds** bounds, unsigned int childCount, void* userPtr); + +/* Callback to create a leaf node */ +typedef void* (*RTCCreateLeafFunction) (RTCThreadLocalAllocator allocator, const struct RTCBuildPrimitive* primitives, size_t primitiveCount, void* userPtr); + +/* Callback to split a build primitive */ +typedef void (*RTCSplitPrimitiveFunction) (const struct RTCBuildPrimitive* primitive, unsigned int dimension, float position, struct RTCBounds* leftBounds, struct RTCBounds* rightBounds, void* userPtr); + +/* Build flags */ +enum RTCBuildFlags +{ + RTC_BUILD_FLAG_NONE = 0, + RTC_BUILD_FLAG_DYNAMIC = (1 << 0), +}; + +enum RTCBuildConstants +{ + RTC_BUILD_MAX_PRIMITIVES_PER_LEAF = 32 +}; + +/* Input for builders */ +struct RTCBuildArguments +{ + size_t byteSize; + + enum RTCBuildQuality buildQuality; + enum RTCBuildFlags buildFlags; + unsigned int maxBranchingFactor; + unsigned int maxDepth; + unsigned int sahBlockSize; + unsigned int minLeafSize; + unsigned int maxLeafSize; + float traversalCost; + float intersectionCost; + + RTCBVH bvh; + struct RTCBuildPrimitive* primitives; + size_t primitiveCount; + size_t primitiveArrayCapacity; + + RTCCreateNodeFunction createNode; + RTCSetNodeChildrenFunction setNodeChildren; + RTCSetNodeBoundsFunction setNodeBounds; + RTCCreateLeafFunction createLeaf; + RTCSplitPrimitiveFunction splitPrimitive; + RTCProgressMonitorFunction buildProgress; + void* userPtr; +}; + +/* Returns the default build settings. */ +RTC_FORCEINLINE struct RTCBuildArguments rtcDefaultBuildArguments() +{ + struct RTCBuildArguments args; + args.byteSize = sizeof(args); + args.buildQuality = RTC_BUILD_QUALITY_MEDIUM; + args.buildFlags = RTC_BUILD_FLAG_NONE; + args.maxBranchingFactor = 2; + args.maxDepth = 32; + args.sahBlockSize = 1; + args.minLeafSize = 1; + args.maxLeafSize = RTC_BUILD_MAX_PRIMITIVES_PER_LEAF; + args.traversalCost = 1.0f; + args.intersectionCost = 1.0f; + args.bvh = NULL; + args.primitives = NULL; + args.primitiveCount = 0; + args.primitiveArrayCapacity = 0; + args.createNode = NULL; + args.setNodeChildren = NULL; + args.setNodeBounds = NULL; + args.createLeaf = NULL; + args.splitPrimitive = NULL; + args.buildProgress = NULL; + args.userPtr = NULL; + return args; +} + +/* Creates a new BVH. */ +RTC_API RTCBVH rtcNewBVH(RTCDevice device); + +/* Builds a BVH. */ +RTC_API void* rtcBuildBVH(const struct RTCBuildArguments* args); + +/* Allocates memory using the thread local allocator. */ +RTC_API void* rtcThreadLocalAlloc(RTCThreadLocalAllocator allocator, size_t bytes, size_t align); + +/* Retains the BVH (increments reference count). */ +RTC_API void rtcRetainBVH(RTCBVH bvh); + +/* Releases the BVH (decrements reference count). */ +RTC_API void rtcReleaseBVH(RTCBVH bvh); + +RTC_NAMESPACE_END + diff --git a/thirdparty/embree/include/embree3/rtcore_common.h b/thirdparty/embree/include/embree3/rtcore_common.h new file mode 100644 index 000000000000..a516f6bdf1f8 --- /dev/null +++ b/thirdparty/embree/include/embree3/rtcore_common.h @@ -0,0 +1,326 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "rtcore_config.h" + +RTC_NAMESPACE_BEGIN + +#if defined(_WIN32) +#if defined(_M_X64) +typedef long long ssize_t; +#else +typedef int ssize_t; +#endif +#endif + +#ifdef _WIN32 +# define RTC_ALIGN(...) __declspec(align(__VA_ARGS__)) +#else +# define RTC_ALIGN(...) __attribute__((aligned(__VA_ARGS__))) +#endif + +#if !defined (RTC_DEPRECATED) +#ifdef __GNUC__ + #define RTC_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) + #define RTC_DEPRECATED __declspec(deprecated) +#else + #define RTC_DEPRECATED +#endif +#endif + +#if defined(_WIN32) +# define RTC_FORCEINLINE __forceinline +#else +# define RTC_FORCEINLINE inline __attribute__((always_inline)) +#endif + +/* Invalid geometry ID */ +#define RTC_INVALID_GEOMETRY_ID ((unsigned int)-1) + +/* Maximum number of time steps */ +#define RTC_MAX_TIME_STEP_COUNT 129 + +/* Formats of buffers and other data structures */ +enum RTCFormat +{ + RTC_FORMAT_UNDEFINED = 0, + + /* 8-bit unsigned integer */ + RTC_FORMAT_UCHAR = 0x1001, + RTC_FORMAT_UCHAR2, + RTC_FORMAT_UCHAR3, + RTC_FORMAT_UCHAR4, + + /* 8-bit signed integer */ + RTC_FORMAT_CHAR = 0x2001, + RTC_FORMAT_CHAR2, + RTC_FORMAT_CHAR3, + RTC_FORMAT_CHAR4, + + /* 16-bit unsigned integer */ + RTC_FORMAT_USHORT = 0x3001, + RTC_FORMAT_USHORT2, + RTC_FORMAT_USHORT3, + RTC_FORMAT_USHORT4, + + /* 16-bit signed integer */ + RTC_FORMAT_SHORT = 0x4001, + RTC_FORMAT_SHORT2, + RTC_FORMAT_SHORT3, + RTC_FORMAT_SHORT4, + + /* 32-bit unsigned integer */ + RTC_FORMAT_UINT = 0x5001, + RTC_FORMAT_UINT2, + RTC_FORMAT_UINT3, + RTC_FORMAT_UINT4, + + /* 32-bit signed integer */ + RTC_FORMAT_INT = 0x6001, + RTC_FORMAT_INT2, + RTC_FORMAT_INT3, + RTC_FORMAT_INT4, + + /* 64-bit unsigned integer */ + RTC_FORMAT_ULLONG = 0x7001, + RTC_FORMAT_ULLONG2, + RTC_FORMAT_ULLONG3, + RTC_FORMAT_ULLONG4, + + /* 64-bit signed integer */ + RTC_FORMAT_LLONG = 0x8001, + RTC_FORMAT_LLONG2, + RTC_FORMAT_LLONG3, + RTC_FORMAT_LLONG4, + + /* 32-bit float */ + RTC_FORMAT_FLOAT = 0x9001, + RTC_FORMAT_FLOAT2, + RTC_FORMAT_FLOAT3, + RTC_FORMAT_FLOAT4, + RTC_FORMAT_FLOAT5, + RTC_FORMAT_FLOAT6, + RTC_FORMAT_FLOAT7, + RTC_FORMAT_FLOAT8, + RTC_FORMAT_FLOAT9, + RTC_FORMAT_FLOAT10, + RTC_FORMAT_FLOAT11, + RTC_FORMAT_FLOAT12, + RTC_FORMAT_FLOAT13, + RTC_FORMAT_FLOAT14, + RTC_FORMAT_FLOAT15, + RTC_FORMAT_FLOAT16, + + /* 32-bit float matrix (row-major order) */ + RTC_FORMAT_FLOAT2X2_ROW_MAJOR = 0x9122, + RTC_FORMAT_FLOAT2X3_ROW_MAJOR = 0x9123, + RTC_FORMAT_FLOAT2X4_ROW_MAJOR = 0x9124, + RTC_FORMAT_FLOAT3X2_ROW_MAJOR = 0x9132, + RTC_FORMAT_FLOAT3X3_ROW_MAJOR = 0x9133, + RTC_FORMAT_FLOAT3X4_ROW_MAJOR = 0x9134, + RTC_FORMAT_FLOAT4X2_ROW_MAJOR = 0x9142, + RTC_FORMAT_FLOAT4X3_ROW_MAJOR = 0x9143, + RTC_FORMAT_FLOAT4X4_ROW_MAJOR = 0x9144, + + /* 32-bit float matrix (column-major order) */ + RTC_FORMAT_FLOAT2X2_COLUMN_MAJOR = 0x9222, + RTC_FORMAT_FLOAT2X3_COLUMN_MAJOR = 0x9223, + RTC_FORMAT_FLOAT2X4_COLUMN_MAJOR = 0x9224, + RTC_FORMAT_FLOAT3X2_COLUMN_MAJOR = 0x9232, + RTC_FORMAT_FLOAT3X3_COLUMN_MAJOR = 0x9233, + RTC_FORMAT_FLOAT3X4_COLUMN_MAJOR = 0x9234, + RTC_FORMAT_FLOAT4X2_COLUMN_MAJOR = 0x9242, + RTC_FORMAT_FLOAT4X3_COLUMN_MAJOR = 0x9243, + RTC_FORMAT_FLOAT4X4_COLUMN_MAJOR = 0x9244, + + /* special 12-byte format for grids */ + RTC_FORMAT_GRID = 0xA001 +}; + +/* Build quality levels */ +enum RTCBuildQuality +{ + RTC_BUILD_QUALITY_LOW = 0, + RTC_BUILD_QUALITY_MEDIUM = 1, + RTC_BUILD_QUALITY_HIGH = 2, + RTC_BUILD_QUALITY_REFIT = 3, +}; + +/* Axis-aligned bounding box representation */ +struct RTC_ALIGN(16) RTCBounds +{ + float lower_x, lower_y, lower_z, align0; + float upper_x, upper_y, upper_z, align1; +}; + +/* Linear axis-aligned bounding box representation */ +struct RTC_ALIGN(16) RTCLinearBounds +{ + struct RTCBounds bounds0; + struct RTCBounds bounds1; +}; + +/* Intersection context flags */ +enum RTCIntersectContextFlags +{ + RTC_INTERSECT_CONTEXT_FLAG_NONE = 0, + RTC_INTERSECT_CONTEXT_FLAG_INCOHERENT = (0 << 0), // optimize for incoherent rays + RTC_INTERSECT_CONTEXT_FLAG_COHERENT = (1 << 0) // optimize for coherent rays +}; + +/* Arguments for RTCFilterFunctionN */ +struct RTCFilterFunctionNArguments +{ + int* valid; + void* geometryUserPtr; + struct RTCIntersectContext* context; + struct RTCRayN* ray; + struct RTCHitN* hit; + unsigned int N; +}; + +/* Filter callback function */ +typedef void (*RTCFilterFunctionN)(const struct RTCFilterFunctionNArguments* args); + +/* Intersection context passed to intersect/occluded calls */ +struct RTCIntersectContext +{ + enum RTCIntersectContextFlags flags; // intersection flags + RTCFilterFunctionN filter; // filter function to execute + +#if RTC_MAX_INSTANCE_LEVEL_COUNT > 1 + unsigned int instStackSize; // Number of instances currently on the stack. +#endif + unsigned int instID[RTC_MAX_INSTANCE_LEVEL_COUNT]; // The current stack of instance ids. + +#if RTC_MIN_WIDTH + float minWidthDistanceFactor; // curve radius is set to this factor times distance to ray origin +#endif +}; + +/* Initializes an intersection context. */ +RTC_FORCEINLINE void rtcInitIntersectContext(struct RTCIntersectContext* context) +{ + unsigned l = 0; + context->flags = RTC_INTERSECT_CONTEXT_FLAG_INCOHERENT; + context->filter = NULL; + +#if RTC_MAX_INSTANCE_LEVEL_COUNT > 1 + context->instStackSize = 0; +#endif + for (; l < RTC_MAX_INSTANCE_LEVEL_COUNT; ++l) + context->instID[l] = RTC_INVALID_GEOMETRY_ID; + +#if RTC_MIN_WIDTH + context->minWidthDistanceFactor = 0.0f; +#endif +} + +/* Point query structure for closest point query */ +struct RTC_ALIGN(16) RTCPointQuery +{ + float x; // x coordinate of the query point + float y; // y coordinate of the query point + float z; // z coordinate of the query point + float time; // time of the point query + float radius; // radius of the point query +}; + +/* Structure of a packet of 4 query points */ +struct RTC_ALIGN(16) RTCPointQuery4 +{ + float x[4]; // x coordinate of the query point + float y[4]; // y coordinate of the query point + float z[4]; // z coordinate of the query point + float time[4]; // time of the point query + float radius[4]; // radius of the point query +}; + +/* Structure of a packet of 8 query points */ +struct RTC_ALIGN(32) RTCPointQuery8 +{ + float x[8]; // x coordinate of the query point + float y[8]; // y coordinate of the query point + float z[8]; // z coordinate of the query point + float time[8]; // time of the point query + float radius[8]; // radius ofr the point query +}; + +/* Structure of a packet of 16 query points */ +struct RTC_ALIGN(64) RTCPointQuery16 +{ + float x[16]; // x coordinate of the query point + float y[16]; // y coordinate of the query point + float z[16]; // z coordinate of the query point + float time[16]; // time of the point quey + float radius[16]; // radius of the point query +}; + +struct RTCPointQueryN; + +struct RTC_ALIGN(16) RTCPointQueryContext +{ + // accumulated 4x4 column major matrices from world space to instance space. + // undefined if size == 0. + float world2inst[RTC_MAX_INSTANCE_LEVEL_COUNT][16]; + + // accumulated 4x4 column major matrices from instance space to world space. + // undefined if size == 0. + float inst2world[RTC_MAX_INSTANCE_LEVEL_COUNT][16]; + + // instance ids. + unsigned int instID[RTC_MAX_INSTANCE_LEVEL_COUNT]; + + // number of instances currently on the stack. + unsigned int instStackSize; +}; + +/* Initializes an intersection context. */ +RTC_FORCEINLINE void rtcInitPointQueryContext(struct RTCPointQueryContext* context) +{ + context->instStackSize = 0; + context->instID[0] = RTC_INVALID_GEOMETRY_ID; +} + +struct RTC_ALIGN(16) RTCPointQueryFunctionArguments +{ + // The (world space) query object that was passed as an argument of rtcPointQuery. The + // radius of the query can be decreased inside the callback to shrink the + // search domain. Increasing the radius or modifying the time or position of + // the query results in undefined behaviour. + struct RTCPointQuery* query; + + // Used for user input/output data. Will not be read or modified internally. + void* userPtr; + + // primitive and geometry ID of primitive + unsigned int primID; + unsigned int geomID; + + // the context with transformation and instance ID stack + struct RTCPointQueryContext* context; + + // If the current instance transform M (= context->world2inst[context->instStackSize]) + // is a similarity matrix, i.e there is a constant factor similarityScale such that, + // for all x,y: dist(Mx, My) = similarityScale * dist(x, y), + // The similarity scale is 0, if the current instance transform is not a + // similarity transform and vice versa. The similarity scale allows to compute + // distance information in instance space and scale the distances into world + // space by dividing with the similarity scale, for example, to update the + // query radius. If the current instance transform is not a similarity + // transform (similarityScale = 0), the distance computation has to be + // performed in world space to ensure correctness. if there is no instance + // transform (context->instStackSize == 0), the similarity scale is 1. + float similarityScale; +}; + +typedef bool (*RTCPointQueryFunction)(struct RTCPointQueryFunctionArguments* args); + +RTC_NAMESPACE_END diff --git a/thirdparty/embree/include/embree3/rtcore_config.h b/thirdparty/embree/include/embree3/rtcore_config.h new file mode 100644 index 000000000000..337d4e948721 --- /dev/null +++ b/thirdparty/embree/include/embree3/rtcore_config.h @@ -0,0 +1,57 @@ + +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#define RTC_VERSION_MAJOR 3 +#define RTC_VERSION_MINOR 12 +#define RTC_VERSION_PATCH 1 +#define RTC_VERSION 31201 +#define RTC_VERSION_STRING "3.12.1" + +#define RTC_MAX_INSTANCE_LEVEL_COUNT 1 + +#define EMBREE_MIN_WIDTH 0 +#define RTC_MIN_WIDTH EMBREE_MIN_WIDTH + +#define EMBREE_STATIC_LIB +/* #undef EMBREE_API_NAMESPACE */ + +#if defined(EMBREE_API_NAMESPACE) +# define RTC_NAMESPACE +# define RTC_NAMESPACE_BEGIN namespace { +# define RTC_NAMESPACE_END } +# define RTC_NAMESPACE_USE using namespace ; +# define RTC_API_EXTERN_C +# undef EMBREE_API_NAMESPACE +#else +# define RTC_NAMESPACE_BEGIN +# define RTC_NAMESPACE_END +# define RTC_NAMESPACE_USE +# if defined(__cplusplus) +# define RTC_API_EXTERN_C extern "C" +# else +# define RTC_API_EXTERN_C +# endif +#endif + +#if defined(ISPC) +# define RTC_API_IMPORT extern "C" unmasked +# define RTC_API_EXPORT extern "C" unmasked +#elif defined(EMBREE_STATIC_LIB) +# define RTC_API_IMPORT RTC_API_EXTERN_C +# define RTC_API_EXPORT RTC_API_EXTERN_C +#elif defined(_WIN32) +# define RTC_API_IMPORT RTC_API_EXTERN_C __declspec(dllimport) +# define RTC_API_EXPORT RTC_API_EXTERN_C __declspec(dllexport) +#else +# define RTC_API_IMPORT RTC_API_EXTERN_C +# define RTC_API_EXPORT RTC_API_EXTERN_C __attribute__ ((visibility ("default"))) +#endif + +#if defined(RTC_EXPORT_API) +# define RTC_API RTC_API_EXPORT +#else +# define RTC_API RTC_API_IMPORT +#endif diff --git a/thirdparty/embree/include/embree3/rtcore_device.h b/thirdparty/embree/include/embree3/rtcore_device.h new file mode 100644 index 000000000000..594e2b755dcd --- /dev/null +++ b/thirdparty/embree/include/embree3/rtcore_device.h @@ -0,0 +1,87 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "rtcore_common.h" + +RTC_NAMESPACE_BEGIN + +/* Opaque device type */ +typedef struct RTCDeviceTy* RTCDevice; + +/* Creates a new Embree device. */ +RTC_API RTCDevice rtcNewDevice(const char* config); + +/* Retains the Embree device (increments the reference count). */ +RTC_API void rtcRetainDevice(RTCDevice device); + +/* Releases an Embree device (decrements the reference count). */ +RTC_API void rtcReleaseDevice(RTCDevice device); + +/* Device properties */ +enum RTCDeviceProperty +{ + RTC_DEVICE_PROPERTY_VERSION = 0, + RTC_DEVICE_PROPERTY_VERSION_MAJOR = 1, + RTC_DEVICE_PROPERTY_VERSION_MINOR = 2, + RTC_DEVICE_PROPERTY_VERSION_PATCH = 3, + + RTC_DEVICE_PROPERTY_NATIVE_RAY4_SUPPORTED = 32, + RTC_DEVICE_PROPERTY_NATIVE_RAY8_SUPPORTED = 33, + RTC_DEVICE_PROPERTY_NATIVE_RAY16_SUPPORTED = 34, + RTC_DEVICE_PROPERTY_RAY_STREAM_SUPPORTED = 35, + + RTC_DEVICE_PROPERTY_BACKFACE_CULLING_CURVES_ENABLED = 63, + RTC_DEVICE_PROPERTY_RAY_MASK_SUPPORTED = 64, + RTC_DEVICE_PROPERTY_BACKFACE_CULLING_ENABLED = 65, + RTC_DEVICE_PROPERTY_FILTER_FUNCTION_SUPPORTED = 66, + RTC_DEVICE_PROPERTY_IGNORE_INVALID_RAYS_ENABLED = 67, + RTC_DEVICE_PROPERTY_COMPACT_POLYS_ENABLED = 68, + + RTC_DEVICE_PROPERTY_TRIANGLE_GEOMETRY_SUPPORTED = 96, + RTC_DEVICE_PROPERTY_QUAD_GEOMETRY_SUPPORTED = 97, + RTC_DEVICE_PROPERTY_SUBDIVISION_GEOMETRY_SUPPORTED = 98, + RTC_DEVICE_PROPERTY_CURVE_GEOMETRY_SUPPORTED = 99, + RTC_DEVICE_PROPERTY_USER_GEOMETRY_SUPPORTED = 100, + RTC_DEVICE_PROPERTY_POINT_GEOMETRY_SUPPORTED = 101, + + RTC_DEVICE_PROPERTY_TASKING_SYSTEM = 128, + RTC_DEVICE_PROPERTY_JOIN_COMMIT_SUPPORTED = 129, + RTC_DEVICE_PROPERTY_PARALLEL_COMMIT_SUPPORTED = 130 +}; + +/* Gets a device property. */ +RTC_API ssize_t rtcGetDeviceProperty(RTCDevice device, enum RTCDeviceProperty prop); + +/* Sets a device property. */ +RTC_API void rtcSetDeviceProperty(RTCDevice device, const enum RTCDeviceProperty prop, ssize_t value); + +/* Error codes */ +enum RTCError +{ + RTC_ERROR_NONE = 0, + RTC_ERROR_UNKNOWN = 1, + RTC_ERROR_INVALID_ARGUMENT = 2, + RTC_ERROR_INVALID_OPERATION = 3, + RTC_ERROR_OUT_OF_MEMORY = 4, + RTC_ERROR_UNSUPPORTED_CPU = 5, + RTC_ERROR_CANCELLED = 6 +}; + +/* Returns the error code. */ +RTC_API enum RTCError rtcGetDeviceError(RTCDevice device); + +/* Error callback function */ +typedef void (*RTCErrorFunction)(void* userPtr, enum RTCError code, const char* str); + +/* Sets the error callback function. */ +RTC_API void rtcSetDeviceErrorFunction(RTCDevice device, RTCErrorFunction error, void* userPtr); + +/* Memory monitor callback function */ +typedef bool (*RTCMemoryMonitorFunction)(void* ptr, ssize_t bytes, bool post); + +/* Sets the memory monitor callback function. */ +RTC_API void rtcSetDeviceMemoryMonitorFunction(RTCDevice device, RTCMemoryMonitorFunction memoryMonitor, void* userPtr); + +RTC_NAMESPACE_END diff --git a/thirdparty/embree/include/embree3/rtcore_geometry.h b/thirdparty/embree/include/embree3/rtcore_geometry.h new file mode 100644 index 000000000000..c70f1b0e5cb3 --- /dev/null +++ b/thirdparty/embree/include/embree3/rtcore_geometry.h @@ -0,0 +1,383 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "rtcore_buffer.h" +#include "rtcore_quaternion.h" + +RTC_NAMESPACE_BEGIN + +/* Opaque scene type */ +typedef struct RTCSceneTy* RTCScene; + +/* Opaque geometry type */ +typedef struct RTCGeometryTy* RTCGeometry; + +/* Types of geometries */ +enum RTCGeometryType +{ + RTC_GEOMETRY_TYPE_TRIANGLE = 0, // triangle mesh + RTC_GEOMETRY_TYPE_QUAD = 1, // quad (triangle pair) mesh + RTC_GEOMETRY_TYPE_GRID = 2, // grid mesh + + RTC_GEOMETRY_TYPE_SUBDIVISION = 8, // Catmull-Clark subdivision surface + + RTC_GEOMETRY_TYPE_CONE_LINEAR_CURVE = 15, // Cone linear curves - discontinuous at edge boundaries + RTC_GEOMETRY_TYPE_ROUND_LINEAR_CURVE = 16, // Round (rounded cone like) linear curves + RTC_GEOMETRY_TYPE_FLAT_LINEAR_CURVE = 17, // flat (ribbon-like) linear curves + + RTC_GEOMETRY_TYPE_ROUND_BEZIER_CURVE = 24, // round (tube-like) Bezier curves + RTC_GEOMETRY_TYPE_FLAT_BEZIER_CURVE = 25, // flat (ribbon-like) Bezier curves + RTC_GEOMETRY_TYPE_NORMAL_ORIENTED_BEZIER_CURVE = 26, // flat normal-oriented Bezier curves + + RTC_GEOMETRY_TYPE_ROUND_BSPLINE_CURVE = 32, // round (tube-like) B-spline curves + RTC_GEOMETRY_TYPE_FLAT_BSPLINE_CURVE = 33, // flat (ribbon-like) B-spline curves + RTC_GEOMETRY_TYPE_NORMAL_ORIENTED_BSPLINE_CURVE = 34, // flat normal-oriented B-spline curves + + RTC_GEOMETRY_TYPE_ROUND_HERMITE_CURVE = 40, // round (tube-like) Hermite curves + RTC_GEOMETRY_TYPE_FLAT_HERMITE_CURVE = 41, // flat (ribbon-like) Hermite curves + RTC_GEOMETRY_TYPE_NORMAL_ORIENTED_HERMITE_CURVE = 42, // flat normal-oriented Hermite curves + + RTC_GEOMETRY_TYPE_SPHERE_POINT = 50, + RTC_GEOMETRY_TYPE_DISC_POINT = 51, + RTC_GEOMETRY_TYPE_ORIENTED_DISC_POINT = 52, + + RTC_GEOMETRY_TYPE_ROUND_CATMULL_ROM_CURVE = 58, // round (tube-like) Catmull-Rom curves + RTC_GEOMETRY_TYPE_FLAT_CATMULL_ROM_CURVE = 59, // flat (ribbon-like) Catmull-Rom curves + RTC_GEOMETRY_TYPE_NORMAL_ORIENTED_CATMULL_ROM_CURVE = 60, // flat normal-oriented Catmull-Rom curves + + RTC_GEOMETRY_TYPE_USER = 120, // user-defined geometry + RTC_GEOMETRY_TYPE_INSTANCE = 121 // scene instance +}; + +/* Interpolation modes for subdivision surfaces */ +enum RTCSubdivisionMode +{ + RTC_SUBDIVISION_MODE_NO_BOUNDARY = 0, + RTC_SUBDIVISION_MODE_SMOOTH_BOUNDARY = 1, + RTC_SUBDIVISION_MODE_PIN_CORNERS = 2, + RTC_SUBDIVISION_MODE_PIN_BOUNDARY = 3, + RTC_SUBDIVISION_MODE_PIN_ALL = 4, +}; + +/* Curve segment flags */ +enum RTCCurveFlags +{ + RTC_CURVE_FLAG_NEIGHBOR_LEFT = (1 << 0), // left segments exists + RTC_CURVE_FLAG_NEIGHBOR_RIGHT = (1 << 1) // right segment exists +}; + +/* Arguments for RTCBoundsFunction */ +struct RTCBoundsFunctionArguments +{ + void* geometryUserPtr; + unsigned int primID; + unsigned int timeStep; + struct RTCBounds* bounds_o; +}; + +/* Bounding callback function */ +typedef void (*RTCBoundsFunction)(const struct RTCBoundsFunctionArguments* args); + +/* Arguments for RTCIntersectFunctionN */ +struct RTCIntersectFunctionNArguments +{ + int* valid; + void* geometryUserPtr; + unsigned int primID; + struct RTCIntersectContext* context; + struct RTCRayHitN* rayhit; + unsigned int N; + unsigned int geomID; +}; + +/* Intersection callback function */ +typedef void (*RTCIntersectFunctionN)(const struct RTCIntersectFunctionNArguments* args); + +/* Arguments for RTCOccludedFunctionN */ +struct RTCOccludedFunctionNArguments +{ + int* valid; + void* geometryUserPtr; + unsigned int primID; + struct RTCIntersectContext* context; + struct RTCRayN* ray; + unsigned int N; + unsigned int geomID; +}; + +/* Occlusion callback function */ +typedef void (*RTCOccludedFunctionN)(const struct RTCOccludedFunctionNArguments* args); + +/* Arguments for RTCDisplacementFunctionN */ +struct RTCDisplacementFunctionNArguments +{ + void* geometryUserPtr; + RTCGeometry geometry; + unsigned int primID; + unsigned int timeStep; + const float* u; + const float* v; + const float* Ng_x; + const float* Ng_y; + const float* Ng_z; + float* P_x; + float* P_y; + float* P_z; + unsigned int N; +}; + +/* Displacement mapping callback function */ +typedef void (*RTCDisplacementFunctionN)(const struct RTCDisplacementFunctionNArguments* args); + +/* Creates a new geometry of specified type. */ +RTC_API RTCGeometry rtcNewGeometry(RTCDevice device, enum RTCGeometryType type); + +/* Retains the geometry (increments the reference count). */ +RTC_API void rtcRetainGeometry(RTCGeometry geometry); + +/* Releases the geometry (decrements the reference count) */ +RTC_API void rtcReleaseGeometry(RTCGeometry geometry); + +/* Commits the geometry. */ +RTC_API void rtcCommitGeometry(RTCGeometry geometry); + + +/* Enables the geometry. */ +RTC_API void rtcEnableGeometry(RTCGeometry geometry); + +/* Disables the geometry. */ +RTC_API void rtcDisableGeometry(RTCGeometry geometry); + + +/* Sets the number of motion blur time steps of the geometry. */ +RTC_API void rtcSetGeometryTimeStepCount(RTCGeometry geometry, unsigned int timeStepCount); + +/* Sets the motion blur time range of the geometry. */ +RTC_API void rtcSetGeometryTimeRange(RTCGeometry geometry, float startTime, float endTime); + +/* Sets the number of vertex attributes of the geometry. */ +RTC_API void rtcSetGeometryVertexAttributeCount(RTCGeometry geometry, unsigned int vertexAttributeCount); + +/* Sets the ray mask of the geometry. */ +RTC_API void rtcSetGeometryMask(RTCGeometry geometry, unsigned int mask); + +/* Sets the build quality of the geometry. */ +RTC_API void rtcSetGeometryBuildQuality(RTCGeometry geometry, enum RTCBuildQuality quality); + +/* Sets the maximal curve or point radius scale allowed by min-width feature. */ +RTC_API void rtcSetGeometryMaxRadiusScale(RTCGeometry geometry, float maxRadiusScale); + + +/* Sets a geometry buffer. */ +RTC_API void rtcSetGeometryBuffer(RTCGeometry geometry, enum RTCBufferType type, unsigned int slot, enum RTCFormat format, RTCBuffer buffer, size_t byteOffset, size_t byteStride, size_t itemCount); + +/* Sets a shared geometry buffer. */ +RTC_API void rtcSetSharedGeometryBuffer(RTCGeometry geometry, enum RTCBufferType type, unsigned int slot, enum RTCFormat format, const void* ptr, size_t byteOffset, size_t byteStride, size_t itemCount); + +/* Creates and sets a new geometry buffer. */ +RTC_API void* rtcSetNewGeometryBuffer(RTCGeometry geometry, enum RTCBufferType type, unsigned int slot, enum RTCFormat format, size_t byteStride, size_t itemCount); + +/* Returns the pointer to the data of a buffer. */ +RTC_API void* rtcGetGeometryBufferData(RTCGeometry geometry, enum RTCBufferType type, unsigned int slot); + +/* Updates a geometry buffer. */ +RTC_API void rtcUpdateGeometryBuffer(RTCGeometry geometry, enum RTCBufferType type, unsigned int slot); + + +/* Sets the intersection filter callback function of the geometry. */ +RTC_API void rtcSetGeometryIntersectFilterFunction(RTCGeometry geometry, RTCFilterFunctionN filter); + +/* Sets the occlusion filter callback function of the geometry. */ +RTC_API void rtcSetGeometryOccludedFilterFunction(RTCGeometry geometry, RTCFilterFunctionN filter); + +/* Sets the user-defined data pointer of the geometry. */ +RTC_API void rtcSetGeometryUserData(RTCGeometry geometry, void* ptr); + +/* Gets the user-defined data pointer of the geometry. */ +RTC_API void* rtcGetGeometryUserData(RTCGeometry geometry); + +/* Set the point query callback function of a geometry. */ +RTC_API void rtcSetGeometryPointQueryFunction(RTCGeometry geometry, RTCPointQueryFunction pointQuery); + +/* Sets the number of primitives of a user geometry. */ +RTC_API void rtcSetGeometryUserPrimitiveCount(RTCGeometry geometry, unsigned int userPrimitiveCount); + +/* Sets the bounding callback function to calculate bounding boxes for user primitives. */ +RTC_API void rtcSetGeometryBoundsFunction(RTCGeometry geometry, RTCBoundsFunction bounds, void* userPtr); + +/* Set the intersect callback function of a user geometry. */ +RTC_API void rtcSetGeometryIntersectFunction(RTCGeometry geometry, RTCIntersectFunctionN intersect); + +/* Set the occlusion callback function of a user geometry. */ +RTC_API void rtcSetGeometryOccludedFunction(RTCGeometry geometry, RTCOccludedFunctionN occluded); + +/* Invokes the intersection filter from the intersection callback function. */ +RTC_API void rtcFilterIntersection(const struct RTCIntersectFunctionNArguments* args, const struct RTCFilterFunctionNArguments* filterArgs); + +/* Invokes the occlusion filter from the occlusion callback function. */ +RTC_API void rtcFilterOcclusion(const struct RTCOccludedFunctionNArguments* args, const struct RTCFilterFunctionNArguments* filterArgs); + + +/* Sets the instanced scene of an instance geometry. */ +RTC_API void rtcSetGeometryInstancedScene(RTCGeometry geometry, RTCScene scene); + +/* Sets the transformation of an instance for the specified time step. */ +RTC_API void rtcSetGeometryTransform(RTCGeometry geometry, unsigned int timeStep, enum RTCFormat format, const void* xfm); + +/* Sets the transformation quaternion of an instance for the specified time step. */ +RTC_API void rtcSetGeometryTransformQuaternion(RTCGeometry geometry, unsigned int timeStep, const struct RTCQuaternionDecomposition* qd); + +/* Returns the interpolated transformation of an instance for the specified time. */ +RTC_API void rtcGetGeometryTransform(RTCGeometry geometry, float time, enum RTCFormat format, void* xfm); + + +/* Sets the uniform tessellation rate of the geometry. */ +RTC_API void rtcSetGeometryTessellationRate(RTCGeometry geometry, float tessellationRate); + +/* Sets the number of topologies of a subdivision surface. */ +RTC_API void rtcSetGeometryTopologyCount(RTCGeometry geometry, unsigned int topologyCount); + +/* Sets the subdivision interpolation mode. */ +RTC_API void rtcSetGeometrySubdivisionMode(RTCGeometry geometry, unsigned int topologyID, enum RTCSubdivisionMode mode); + +/* Binds a vertex attribute to a topology of the geometry. */ +RTC_API void rtcSetGeometryVertexAttributeTopology(RTCGeometry geometry, unsigned int vertexAttributeID, unsigned int topologyID); + +/* Sets the displacement callback function of a subdivision surface. */ +RTC_API void rtcSetGeometryDisplacementFunction(RTCGeometry geometry, RTCDisplacementFunctionN displacement); + +/* Returns the first half edge of a face. */ +RTC_API unsigned int rtcGetGeometryFirstHalfEdge(RTCGeometry geometry, unsigned int faceID); + +/* Returns the face the half edge belongs to. */ +RTC_API unsigned int rtcGetGeometryFace(RTCGeometry geometry, unsigned int edgeID); + +/* Returns next half edge. */ +RTC_API unsigned int rtcGetGeometryNextHalfEdge(RTCGeometry geometry, unsigned int edgeID); + +/* Returns previous half edge. */ +RTC_API unsigned int rtcGetGeometryPreviousHalfEdge(RTCGeometry geometry, unsigned int edgeID); + +/* Returns opposite half edge. */ +RTC_API unsigned int rtcGetGeometryOppositeHalfEdge(RTCGeometry geometry, unsigned int topologyID, unsigned int edgeID); + + +/* Arguments for rtcInterpolate */ +struct RTCInterpolateArguments +{ + RTCGeometry geometry; + unsigned int primID; + float u; + float v; + enum RTCBufferType bufferType; + unsigned int bufferSlot; + float* P; + float* dPdu; + float* dPdv; + float* ddPdudu; + float* ddPdvdv; + float* ddPdudv; + unsigned int valueCount; +}; + +/* Interpolates vertex data to some u/v location and optionally calculates all derivatives. */ +RTC_API void rtcInterpolate(const struct RTCInterpolateArguments* args); + +/* Interpolates vertex data to some u/v location. */ +RTC_FORCEINLINE void rtcInterpolate0(RTCGeometry geometry, unsigned int primID, float u, float v, enum RTCBufferType bufferType, unsigned int bufferSlot, float* P, unsigned int valueCount) +{ + struct RTCInterpolateArguments args; + args.geometry = geometry; + args.primID = primID; + args.u = u; + args.v = v; + args.bufferType = bufferType; + args.bufferSlot = bufferSlot; + args.P = P; + args.dPdu = NULL; + args.dPdv = NULL; + args.ddPdudu = NULL; + args.ddPdvdv = NULL; + args.ddPdudv = NULL; + args.valueCount = valueCount; + rtcInterpolate(&args); +} + +/* Interpolates vertex data to some u/v location and calculates first order derivatives. */ +RTC_FORCEINLINE void rtcInterpolate1(RTCGeometry geometry, unsigned int primID, float u, float v, enum RTCBufferType bufferType, unsigned int bufferSlot, + float* P, float* dPdu, float* dPdv, unsigned int valueCount) +{ + struct RTCInterpolateArguments args; + args.geometry = geometry; + args.primID = primID; + args.u = u; + args.v = v; + args.bufferType = bufferType; + args.bufferSlot = bufferSlot; + args.P = P; + args.dPdu = dPdu; + args.dPdv = dPdv; + args.ddPdudu = NULL; + args.ddPdvdv = NULL; + args.ddPdudv = NULL; + args.valueCount = valueCount; + rtcInterpolate(&args); +} + +/* Interpolates vertex data to some u/v location and calculates first and second order derivatives. */ +RTC_FORCEINLINE void rtcInterpolate2(RTCGeometry geometry, unsigned int primID, float u, float v, enum RTCBufferType bufferType, unsigned int bufferSlot, + float* P, float* dPdu, float* dPdv, float* ddPdudu, float* ddPdvdv, float* ddPdudv, unsigned int valueCount) +{ + struct RTCInterpolateArguments args; + args.geometry = geometry; + args.primID = primID; + args.u = u; + args.v = v; + args.bufferType = bufferType; + args.bufferSlot = bufferSlot; + args.P = P; + args.dPdu = dPdu; + args.dPdv = dPdv; + args.ddPdudu = ddPdudu; + args.ddPdvdv = ddPdvdv; + args.ddPdudv = ddPdudv; + args.valueCount = valueCount; + rtcInterpolate(&args); +} + +/* Arguments for rtcInterpolateN */ +struct RTCInterpolateNArguments +{ + RTCGeometry geometry; + const void* valid; + const unsigned int* primIDs; + const float* u; + const float* v; + unsigned int N; + enum RTCBufferType bufferType; + unsigned int bufferSlot; + float* P; + float* dPdu; + float* dPdv; + float* ddPdudu; + float* ddPdvdv; + float* ddPdudv; + unsigned int valueCount; +}; + +/* Interpolates vertex data to an array of u/v locations. */ +RTC_API void rtcInterpolateN(const struct RTCInterpolateNArguments* args); + +/* RTCGrid primitive for grid mesh */ +struct RTCGrid +{ + unsigned int startVertexID; + unsigned int stride; + unsigned short width,height; // max is a 32k x 32k grid +}; + +RTC_NAMESPACE_END + + diff --git a/thirdparty/embree/include/embree3/rtcore_quaternion.h b/thirdparty/embree/include/embree3/rtcore_quaternion.h new file mode 100644 index 000000000000..449cdedfdcf6 --- /dev/null +++ b/thirdparty/embree/include/embree3/rtcore_quaternion.h @@ -0,0 +1,101 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "rtcore_common.h" + +RTC_NAMESPACE_BEGIN + +/* + * Structure for transformation respresentation as a matrix decomposition using + * a quaternion + */ +struct RTC_ALIGN(16) RTCQuaternionDecomposition +{ + float scale_x; + float scale_y; + float scale_z; + float skew_xy; + float skew_xz; + float skew_yz; + float shift_x; + float shift_y; + float shift_z; + float quaternion_r; + float quaternion_i; + float quaternion_j; + float quaternion_k; + float translation_x; + float translation_y; + float translation_z; +}; + +RTC_FORCEINLINE void rtcInitQuaternionDecomposition(struct RTCQuaternionDecomposition* qdecomp) +{ + qdecomp->scale_x = 1.f; + qdecomp->scale_y = 1.f; + qdecomp->scale_z = 1.f; + qdecomp->skew_xy = 0.f; + qdecomp->skew_xz = 0.f; + qdecomp->skew_yz = 0.f; + qdecomp->shift_x = 0.f; + qdecomp->shift_y = 0.f; + qdecomp->shift_z = 0.f; + qdecomp->quaternion_r = 1.f; + qdecomp->quaternion_i = 0.f; + qdecomp->quaternion_j = 0.f; + qdecomp->quaternion_k = 0.f; + qdecomp->translation_x = 0.f; + qdecomp->translation_y = 0.f; + qdecomp->translation_z = 0.f; +} + +RTC_FORCEINLINE void rtcQuaternionDecompositionSetQuaternion( + struct RTCQuaternionDecomposition* qdecomp, + float r, float i, float j, float k) +{ + qdecomp->quaternion_r = r; + qdecomp->quaternion_i = i; + qdecomp->quaternion_j = j; + qdecomp->quaternion_k = k; +} + +RTC_FORCEINLINE void rtcQuaternionDecompositionSetScale( + struct RTCQuaternionDecomposition* qdecomp, + float scale_x, float scale_y, float scale_z) +{ + qdecomp->scale_x = scale_x; + qdecomp->scale_y = scale_y; + qdecomp->scale_z = scale_z; +} + +RTC_FORCEINLINE void rtcQuaternionDecompositionSetSkew( + struct RTCQuaternionDecomposition* qdecomp, + float skew_xy, float skew_xz, float skew_yz) +{ + qdecomp->skew_xy = skew_xy; + qdecomp->skew_xz = skew_xz; + qdecomp->skew_yz = skew_yz; +} + +RTC_FORCEINLINE void rtcQuaternionDecompositionSetShift( + struct RTCQuaternionDecomposition* qdecomp, + float shift_x, float shift_y, float shift_z) +{ + qdecomp->shift_x = shift_x; + qdecomp->shift_y = shift_y; + qdecomp->shift_z = shift_z; +} + +RTC_FORCEINLINE void rtcQuaternionDecompositionSetTranslation( + struct RTCQuaternionDecomposition* qdecomp, + float translation_x, float translation_y, float translation_z) +{ + qdecomp->translation_x = translation_x; + qdecomp->translation_y = translation_y; + qdecomp->translation_z = translation_z; +} + +RTC_NAMESPACE_END + diff --git a/thirdparty/embree/include/embree3/rtcore_ray.h b/thirdparty/embree/include/embree3/rtcore_ray.h new file mode 100644 index 000000000000..1ae3309ef1f8 --- /dev/null +++ b/thirdparty/embree/include/embree3/rtcore_ray.h @@ -0,0 +1,378 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "rtcore_common.h" + +RTC_NAMESPACE_BEGIN + +/* Ray structure for a single ray */ +struct RTC_ALIGN(16) RTCRay +{ + float org_x; // x coordinate of ray origin + float org_y; // y coordinate of ray origin + float org_z; // z coordinate of ray origin + float tnear; // start of ray segment + + float dir_x; // x coordinate of ray direction + float dir_y; // y coordinate of ray direction + float dir_z; // z coordinate of ray direction + float time; // time of this ray for motion blur + + float tfar; // end of ray segment (set to hit distance) + unsigned int mask; // ray mask + unsigned int id; // ray ID + unsigned int flags; // ray flags +}; + +/* Hit structure for a single ray */ +struct RTC_ALIGN(16) RTCHit +{ + float Ng_x; // x coordinate of geometry normal + float Ng_y; // y coordinate of geometry normal + float Ng_z; // z coordinate of geometry normal + + float u; // barycentric u coordinate of hit + float v; // barycentric v coordinate of hit + + unsigned int primID; // primitive ID + unsigned int geomID; // geometry ID + unsigned int instID[RTC_MAX_INSTANCE_LEVEL_COUNT]; // instance ID +}; + +/* Combined ray/hit structure for a single ray */ +struct RTCRayHit +{ + struct RTCRay ray; + struct RTCHit hit; +}; + +/* Ray structure for a packet of 4 rays */ +struct RTC_ALIGN(16) RTCRay4 +{ + float org_x[4]; + float org_y[4]; + float org_z[4]; + float tnear[4]; + + float dir_x[4]; + float dir_y[4]; + float dir_z[4]; + float time[4]; + + float tfar[4]; + unsigned int mask[4]; + unsigned int id[4]; + unsigned int flags[4]; +}; + +/* Hit structure for a packet of 4 rays */ +struct RTC_ALIGN(16) RTCHit4 +{ + float Ng_x[4]; + float Ng_y[4]; + float Ng_z[4]; + + float u[4]; + float v[4]; + + unsigned int primID[4]; + unsigned int geomID[4]; + unsigned int instID[RTC_MAX_INSTANCE_LEVEL_COUNT][4]; +}; + +/* Combined ray/hit structure for a packet of 4 rays */ +struct RTCRayHit4 +{ + struct RTCRay4 ray; + struct RTCHit4 hit; +}; + +/* Ray structure for a packet of 8 rays */ +struct RTC_ALIGN(32) RTCRay8 +{ + float org_x[8]; + float org_y[8]; + float org_z[8]; + float tnear[8]; + + float dir_x[8]; + float dir_y[8]; + float dir_z[8]; + float time[8]; + + float tfar[8]; + unsigned int mask[8]; + unsigned int id[8]; + unsigned int flags[8]; +}; + +/* Hit structure for a packet of 8 rays */ +struct RTC_ALIGN(32) RTCHit8 +{ + float Ng_x[8]; + float Ng_y[8]; + float Ng_z[8]; + + float u[8]; + float v[8]; + + unsigned int primID[8]; + unsigned int geomID[8]; + unsigned int instID[RTC_MAX_INSTANCE_LEVEL_COUNT][8]; +}; + +/* Combined ray/hit structure for a packet of 8 rays */ +struct RTCRayHit8 +{ + struct RTCRay8 ray; + struct RTCHit8 hit; +}; + +/* Ray structure for a packet of 16 rays */ +struct RTC_ALIGN(64) RTCRay16 +{ + float org_x[16]; + float org_y[16]; + float org_z[16]; + float tnear[16]; + + float dir_x[16]; + float dir_y[16]; + float dir_z[16]; + float time[16]; + + float tfar[16]; + unsigned int mask[16]; + unsigned int id[16]; + unsigned int flags[16]; +}; + +/* Hit structure for a packet of 16 rays */ +struct RTC_ALIGN(64) RTCHit16 +{ + float Ng_x[16]; + float Ng_y[16]; + float Ng_z[16]; + + float u[16]; + float v[16]; + + unsigned int primID[16]; + unsigned int geomID[16]; + unsigned int instID[RTC_MAX_INSTANCE_LEVEL_COUNT][16]; +}; + +/* Combined ray/hit structure for a packet of 16 rays */ +struct RTCRayHit16 +{ + struct RTCRay16 ray; + struct RTCHit16 hit; +}; + +/* Ray structure for a packet/stream of N rays in pointer SOA layout */ +struct RTCRayNp +{ + float* org_x; + float* org_y; + float* org_z; + float* tnear; + + float* dir_x; + float* dir_y; + float* dir_z; + float* time; + + float* tfar; + unsigned int* mask; + unsigned int* id; + unsigned int* flags; +}; + +/* Hit structure for a packet/stream of N rays in pointer SOA layout */ +struct RTCHitNp +{ + float* Ng_x; + float* Ng_y; + float* Ng_z; + + float* u; + float* v; + + unsigned int* primID; + unsigned int* geomID; + unsigned int* instID[RTC_MAX_INSTANCE_LEVEL_COUNT]; +}; + +/* Combined ray/hit structure for a packet/stream of N rays in pointer SOA layout */ +struct RTCRayHitNp +{ + struct RTCRayNp ray; + struct RTCHitNp hit; +}; + +struct RTCRayN; +struct RTCHitN; +struct RTCRayHitN; + +#if defined(__cplusplus) + +/* Helper functions to access ray packets of runtime size N */ +RTC_FORCEINLINE float& RTCRayN_org_x(RTCRayN* ray, unsigned int N, unsigned int i) { return ((float*)ray)[0*N+i]; } +RTC_FORCEINLINE float& RTCRayN_org_y(RTCRayN* ray, unsigned int N, unsigned int i) { return ((float*)ray)[1*N+i]; } +RTC_FORCEINLINE float& RTCRayN_org_z(RTCRayN* ray, unsigned int N, unsigned int i) { return ((float*)ray)[2*N+i]; } +RTC_FORCEINLINE float& RTCRayN_tnear(RTCRayN* ray, unsigned int N, unsigned int i) { return ((float*)ray)[3*N+i]; } + +RTC_FORCEINLINE float& RTCRayN_dir_x(RTCRayN* ray, unsigned int N, unsigned int i) { return ((float*)ray)[4*N+i]; } +RTC_FORCEINLINE float& RTCRayN_dir_y(RTCRayN* ray, unsigned int N, unsigned int i) { return ((float*)ray)[5*N+i]; } +RTC_FORCEINLINE float& RTCRayN_dir_z(RTCRayN* ray, unsigned int N, unsigned int i) { return ((float*)ray)[6*N+i]; } +RTC_FORCEINLINE float& RTCRayN_time (RTCRayN* ray, unsigned int N, unsigned int i) { return ((float*)ray)[7*N+i]; } + +RTC_FORCEINLINE float& RTCRayN_tfar (RTCRayN* ray, unsigned int N, unsigned int i) { return ((float*)ray)[8*N+i]; } +RTC_FORCEINLINE unsigned int& RTCRayN_mask (RTCRayN* ray, unsigned int N, unsigned int i) { return ((unsigned*)ray)[9*N+i]; } +RTC_FORCEINLINE unsigned int& RTCRayN_id (RTCRayN* ray, unsigned int N, unsigned int i) { return ((unsigned*)ray)[10*N+i]; } +RTC_FORCEINLINE unsigned int& RTCRayN_flags(RTCRayN* ray, unsigned int N, unsigned int i) { return ((unsigned*)ray)[11*N+i]; } + +/* Helper functions to access hit packets of runtime size N */ +RTC_FORCEINLINE float& RTCHitN_Ng_x(RTCHitN* hit, unsigned int N, unsigned int i) { return ((float*)hit)[0*N+i]; } +RTC_FORCEINLINE float& RTCHitN_Ng_y(RTCHitN* hit, unsigned int N, unsigned int i) { return ((float*)hit)[1*N+i]; } +RTC_FORCEINLINE float& RTCHitN_Ng_z(RTCHitN* hit, unsigned int N, unsigned int i) { return ((float*)hit)[2*N+i]; } + +RTC_FORCEINLINE float& RTCHitN_u(RTCHitN* hit, unsigned int N, unsigned int i) { return ((float*)hit)[3*N+i]; } +RTC_FORCEINLINE float& RTCHitN_v(RTCHitN* hit, unsigned int N, unsigned int i) { return ((float*)hit)[4*N+i]; } + +RTC_FORCEINLINE unsigned int& RTCHitN_primID(RTCHitN* hit, unsigned int N, unsigned int i) { return ((unsigned*)hit)[5*N+i]; } +RTC_FORCEINLINE unsigned int& RTCHitN_geomID(RTCHitN* hit, unsigned int N, unsigned int i) { return ((unsigned*)hit)[6*N+i]; } +RTC_FORCEINLINE unsigned int& RTCHitN_instID(RTCHitN* hit, unsigned int N, unsigned int i, unsigned int l) { return ((unsigned*)hit)[7*N+i+N*l]; } + +/* Helper functions to extract RTCRayN and RTCHitN from RTCRayHitN */ +RTC_FORCEINLINE RTCRayN* RTCRayHitN_RayN(RTCRayHitN* rayhit, unsigned int N) { return (RTCRayN*)&((float*)rayhit)[0*N]; } +RTC_FORCEINLINE RTCHitN* RTCRayHitN_HitN(RTCRayHitN* rayhit, unsigned int N) { return (RTCHitN*)&((float*)rayhit)[12*N]; } + +/* Helper structure for a ray packet of compile-time size N */ +template +struct RTCRayNt +{ + float org_x[N]; + float org_y[N]; + float org_z[N]; + float tnear[N]; + + float dir_x[N]; + float dir_y[N]; + float dir_z[N]; + float time[N]; + + float tfar[N]; + unsigned int mask[N]; + unsigned int id[N]; + unsigned int flags[N]; +}; + +/* Helper structure for a hit packet of compile-time size N */ +template +struct RTCHitNt +{ + float Ng_x[N]; + float Ng_y[N]; + float Ng_z[N]; + + float u[N]; + float v[N]; + + unsigned int primID[N]; + unsigned int geomID[N]; + unsigned int instID[RTC_MAX_INSTANCE_LEVEL_COUNT][N]; +}; + +/* Helper structure for a combined ray/hit packet of compile-time size N */ +template +struct RTCRayHitNt +{ + RTCRayNt ray; + RTCHitNt hit; +}; + +RTC_FORCEINLINE RTCRay rtcGetRayFromRayN(RTCRayN* rayN, unsigned int N, unsigned int i) +{ + RTCRay ray; + ray.org_x = RTCRayN_org_x(rayN,N,i); + ray.org_y = RTCRayN_org_y(rayN,N,i); + ray.org_z = RTCRayN_org_z(rayN,N,i); + ray.tnear = RTCRayN_tnear(rayN,N,i); + ray.dir_x = RTCRayN_dir_x(rayN,N,i); + ray.dir_y = RTCRayN_dir_y(rayN,N,i); + ray.dir_z = RTCRayN_dir_z(rayN,N,i); + ray.time = RTCRayN_time(rayN,N,i); + ray.tfar = RTCRayN_tfar(rayN,N,i); + ray.mask = RTCRayN_mask(rayN,N,i); + ray.id = RTCRayN_id(rayN,N,i); + ray.flags = RTCRayN_flags(rayN,N,i); + return ray; +} + +RTC_FORCEINLINE RTCHit rtcGetHitFromHitN(RTCHitN* hitN, unsigned int N, unsigned int i) +{ + RTCHit hit; + hit.Ng_x = RTCHitN_Ng_x(hitN,N,i); + hit.Ng_y = RTCHitN_Ng_y(hitN,N,i); + hit.Ng_z = RTCHitN_Ng_z(hitN,N,i); + hit.u = RTCHitN_u(hitN,N,i); + hit.v = RTCHitN_v(hitN,N,i); + hit.primID = RTCHitN_primID(hitN,N,i); + hit.geomID = RTCHitN_geomID(hitN,N,i); + for (unsigned int l = 0; l < RTC_MAX_INSTANCE_LEVEL_COUNT; l++) + hit.instID[l] = RTCHitN_instID(hitN,N,i,l); + return hit; +} + +RTC_FORCEINLINE void rtcCopyHitToHitN(RTCHitN* hitN, const RTCHit* hit, unsigned int N, unsigned int i) +{ + RTCHitN_Ng_x(hitN,N,i) = hit->Ng_x; + RTCHitN_Ng_y(hitN,N,i) = hit->Ng_y; + RTCHitN_Ng_z(hitN,N,i) = hit->Ng_z; + RTCHitN_u(hitN,N,i) = hit->u; + RTCHitN_v(hitN,N,i) = hit->v; + RTCHitN_primID(hitN,N,i) = hit->primID; + RTCHitN_geomID(hitN,N,i) = hit->geomID; + for (unsigned int l = 0; l < RTC_MAX_INSTANCE_LEVEL_COUNT; l++) + RTCHitN_instID(hitN,N,i,l) = hit->instID[l]; +} + +RTC_FORCEINLINE RTCRayHit rtcGetRayHitFromRayHitN(RTCRayHitN* rayhitN, unsigned int N, unsigned int i) +{ + RTCRayHit rh; + + RTCRayN* ray = RTCRayHitN_RayN(rayhitN,N); + rh.ray.org_x = RTCRayN_org_x(ray,N,i); + rh.ray.org_y = RTCRayN_org_y(ray,N,i); + rh.ray.org_z = RTCRayN_org_z(ray,N,i); + rh.ray.tnear = RTCRayN_tnear(ray,N,i); + rh.ray.dir_x = RTCRayN_dir_x(ray,N,i); + rh.ray.dir_y = RTCRayN_dir_y(ray,N,i); + rh.ray.dir_z = RTCRayN_dir_z(ray,N,i); + rh.ray.time = RTCRayN_time(ray,N,i); + rh.ray.tfar = RTCRayN_tfar(ray,N,i); + rh.ray.mask = RTCRayN_mask(ray,N,i); + rh.ray.id = RTCRayN_id(ray,N,i); + rh.ray.flags = RTCRayN_flags(ray,N,i); + + RTCHitN* hit = RTCRayHitN_HitN(rayhitN,N); + rh.hit.Ng_x = RTCHitN_Ng_x(hit,N,i); + rh.hit.Ng_y = RTCHitN_Ng_y(hit,N,i); + rh.hit.Ng_z = RTCHitN_Ng_z(hit,N,i); + rh.hit.u = RTCHitN_u(hit,N,i); + rh.hit.v = RTCHitN_v(hit,N,i); + rh.hit.primID = RTCHitN_primID(hit,N,i); + rh.hit.geomID = RTCHitN_geomID(hit,N,i); + for (unsigned int l = 0; l < RTC_MAX_INSTANCE_LEVEL_COUNT; l++) + rh.hit.instID[l] = RTCHitN_instID(hit,N,i,l); + + return rh; +} + +#endif + +RTC_NAMESPACE_END + diff --git a/thirdparty/embree/include/embree3/rtcore_scene.h b/thirdparty/embree/include/embree3/rtcore_scene.h new file mode 100644 index 000000000000..0cd64015937e --- /dev/null +++ b/thirdparty/embree/include/embree3/rtcore_scene.h @@ -0,0 +1,160 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "rtcore_device.h" + +RTC_NAMESPACE_BEGIN + +/* Forward declarations for ray structures */ +struct RTCRayHit; +struct RTCRayHit4; +struct RTCRayHit8; +struct RTCRayHit16; +struct RTCRayHitNp; + +/* Scene flags */ +enum RTCSceneFlags +{ + RTC_SCENE_FLAG_NONE = 0, + RTC_SCENE_FLAG_DYNAMIC = (1 << 0), + RTC_SCENE_FLAG_COMPACT = (1 << 1), + RTC_SCENE_FLAG_ROBUST = (1 << 2), + RTC_SCENE_FLAG_CONTEXT_FILTER_FUNCTION = (1 << 3) +}; + +/* Creates a new scene. */ +RTC_API RTCScene rtcNewScene(RTCDevice device); + +/* Returns the device the scene got created in. The reference count of + * the device is incremented by this function. */ +RTC_API RTCDevice rtcGetSceneDevice(RTCScene hscene); + +/* Retains the scene (increments the reference count). */ +RTC_API void rtcRetainScene(RTCScene scene); + +/* Releases the scene (decrements the reference count). */ +RTC_API void rtcReleaseScene(RTCScene scene); + + +/* Attaches the geometry to a scene. */ +RTC_API unsigned int rtcAttachGeometry(RTCScene scene, RTCGeometry geometry); + +/* Attaches the geometry to a scene using the specified geometry ID. */ +RTC_API void rtcAttachGeometryByID(RTCScene scene, RTCGeometry geometry, unsigned int geomID); + +/* Detaches the geometry from the scene. */ +RTC_API void rtcDetachGeometry(RTCScene scene, unsigned int geomID); + +/* Gets a geometry handle from the scene. */ +RTC_API RTCGeometry rtcGetGeometry(RTCScene scene, unsigned int geomID); + + +/* Commits the scene. */ +RTC_API void rtcCommitScene(RTCScene scene); + +/* Commits the scene from multiple threads. */ +RTC_API void rtcJoinCommitScene(RTCScene scene); + + +/* Progress monitor callback function */ +typedef bool (*RTCProgressMonitorFunction)(void* ptr, double n); + +/* Sets the progress monitor callback function of the scene. */ +RTC_API void rtcSetSceneProgressMonitorFunction(RTCScene scene, RTCProgressMonitorFunction progress, void* ptr); + +/* Sets the build quality of the scene. */ +RTC_API void rtcSetSceneBuildQuality(RTCScene scene, enum RTCBuildQuality quality); + +/* Sets the scene flags. */ +RTC_API void rtcSetSceneFlags(RTCScene scene, enum RTCSceneFlags flags); + +/* Returns the scene flags. */ +RTC_API enum RTCSceneFlags rtcGetSceneFlags(RTCScene scene); + +/* Returns the axis-aligned bounds of the scene. */ +RTC_API void rtcGetSceneBounds(RTCScene scene, struct RTCBounds* bounds_o); + +/* Returns the linear axis-aligned bounds of the scene. */ +RTC_API void rtcGetSceneLinearBounds(RTCScene scene, struct RTCLinearBounds* bounds_o); + + +/* Perform a closest point query of the scene. */ +RTC_API bool rtcPointQuery(RTCScene scene, struct RTCPointQuery* query, struct RTCPointQueryContext* context, RTCPointQueryFunction queryFunc, void* userPtr); + +/* Perform a closest point query with a packet of 4 points with the scene. */ +RTC_API bool rtcPointQuery4(const int* valid, RTCScene scene, struct RTCPointQuery4* query, struct RTCPointQueryContext* context, RTCPointQueryFunction queryFunc, void** userPtr); + +/* Perform a closest point query with a packet of 4 points with the scene. */ +RTC_API bool rtcPointQuery8(const int* valid, RTCScene scene, struct RTCPointQuery8* query, struct RTCPointQueryContext* context, RTCPointQueryFunction queryFunc, void** userPtr); + +/* Perform a closest point query with a packet of 4 points with the scene. */ +RTC_API bool rtcPointQuery16(const int* valid, RTCScene scene, struct RTCPointQuery16* query, struct RTCPointQueryContext* context, RTCPointQueryFunction queryFunc, void** userPtr); + +/* Intersects a single ray with the scene. */ +RTC_API void rtcIntersect1(RTCScene scene, struct RTCIntersectContext* context, struct RTCRayHit* rayhit); + +/* Intersects a packet of 4 rays with the scene. */ +RTC_API void rtcIntersect4(const int* valid, RTCScene scene, struct RTCIntersectContext* context, struct RTCRayHit4* rayhit); + +/* Intersects a packet of 8 rays with the scene. */ +RTC_API void rtcIntersect8(const int* valid, RTCScene scene, struct RTCIntersectContext* context, struct RTCRayHit8* rayhit); + +/* Intersects a packet of 16 rays with the scene. */ +RTC_API void rtcIntersect16(const int* valid, RTCScene scene, struct RTCIntersectContext* context, struct RTCRayHit16* rayhit); + +/* Intersects a stream of M rays with the scene. */ +RTC_API void rtcIntersect1M(RTCScene scene, struct RTCIntersectContext* context, struct RTCRayHit* rayhit, unsigned int M, size_t byteStride); + +/* Intersects a stream of pointers to M rays with the scene. */ +RTC_API void rtcIntersect1Mp(RTCScene scene, struct RTCIntersectContext* context, struct RTCRayHit** rayhit, unsigned int M); + +/* Intersects a stream of M ray packets of size N in SOA format with the scene. */ +RTC_API void rtcIntersectNM(RTCScene scene, struct RTCIntersectContext* context, struct RTCRayHitN* rayhit, unsigned int N, unsigned int M, size_t byteStride); + +/* Intersects a stream of M ray packets of size N in SOA format with the scene. */ +RTC_API void rtcIntersectNp(RTCScene scene, struct RTCIntersectContext* context, const struct RTCRayHitNp* rayhit, unsigned int N); + +/* Tests a single ray for occlusion with the scene. */ +RTC_API void rtcOccluded1(RTCScene scene, struct RTCIntersectContext* context, struct RTCRay* ray); + +/* Tests a packet of 4 rays for occlusion occluded with the scene. */ +RTC_API void rtcOccluded4(const int* valid, RTCScene scene, struct RTCIntersectContext* context, struct RTCRay4* ray); + +/* Tests a packet of 8 rays for occlusion with the scene. */ +RTC_API void rtcOccluded8(const int* valid, RTCScene scene, struct RTCIntersectContext* context, struct RTCRay8* ray); + +/* Tests a packet of 16 rays for occlusion with the scene. */ +RTC_API void rtcOccluded16(const int* valid, RTCScene scene, struct RTCIntersectContext* context, struct RTCRay16* ray); + +/* Tests a stream of M rays for occlusion with the scene. */ +RTC_API void rtcOccluded1M(RTCScene scene, struct RTCIntersectContext* context, struct RTCRay* ray, unsigned int M, size_t byteStride); + +/* Tests a stream of pointers to M rays for occlusion with the scene. */ +RTC_API void rtcOccluded1Mp(RTCScene scene, struct RTCIntersectContext* context, struct RTCRay** ray, unsigned int M); + +/* Tests a stream of M ray packets of size N in SOA format for occlusion with the scene. */ +RTC_API void rtcOccludedNM(RTCScene scene, struct RTCIntersectContext* context, struct RTCRayN* ray, unsigned int N, unsigned int M, size_t byteStride); + +/* Tests a stream of M ray packets of size N in SOA format for occlusion with the scene. */ +RTC_API void rtcOccludedNp(RTCScene scene, struct RTCIntersectContext* context, const struct RTCRayNp* ray, unsigned int N); + +/*! collision callback */ +struct RTCCollision { unsigned int geomID0; unsigned int primID0; unsigned int geomID1; unsigned int primID1; }; +typedef void (*RTCCollideFunc) (void* userPtr, struct RTCCollision* collisions, unsigned int num_collisions); + +/*! Performs collision detection of two scenes */ +RTC_API void rtcCollide (RTCScene scene0, RTCScene scene1, RTCCollideFunc callback, void* userPtr); + +#if defined(__cplusplus) + +/* Helper for easily combining scene flags */ +inline RTCSceneFlags operator|(RTCSceneFlags a, RTCSceneFlags b) { + return (RTCSceneFlags)((size_t)a | (size_t)b); +} + +#endif + +RTC_NAMESPACE_END + diff --git a/thirdparty/embree/kernels/builders/bvh_builder_hair.h b/thirdparty/embree/kernels/builders/bvh_builder_hair.h new file mode 100644 index 000000000000..755ce255fb47 --- /dev/null +++ b/thirdparty/embree/kernels/builders/bvh_builder_hair.h @@ -0,0 +1,411 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../bvh/bvh.h" +#include "../geometry/primitive.h" +#include "../builders/bvh_builder_sah.h" +#include "../builders/heuristic_binning_array_aligned.h" +#include "../builders/heuristic_binning_array_unaligned.h" +#include "../builders/heuristic_strand_array.h" + +#define NUM_HAIR_OBJECT_BINS 32 + +namespace embree +{ + namespace isa + { + struct BVHBuilderHair + { + /*! settings for builder */ + struct Settings + { + /*! default settings */ + Settings () + : branchingFactor(2), maxDepth(32), logBlockSize(0), minLeafSize(1), maxLeafSize(7), finished_range_threshold(inf) {} + + public: + size_t branchingFactor; //!< branching factor of BVH to build + size_t maxDepth; //!< maximum depth of BVH to build + size_t logBlockSize; //!< log2 of blocksize for SAH heuristic + size_t minLeafSize; //!< minimum size of a leaf + size_t maxLeafSize; //!< maximum size of a leaf + size_t finished_range_threshold; //!< finished range threshold + }; + + template + + class BuilderT + { + ALIGNED_CLASS_(16); + friend struct BVHBuilderHair; + + typedef FastAllocator::CachedAllocator Allocator; + typedef HeuristicArrayBinningSAH HeuristicBinningSAH; + typedef UnalignedHeuristicArrayBinningSAH UnalignedHeuristicBinningSAH; + typedef HeuristicStrandSplit HeuristicStrandSplitSAH; + + static const size_t MAX_BRANCHING_FACTOR = 8; //!< maximum supported BVH branching factor + static const size_t MIN_LARGE_LEAF_LEVELS = 8; //!< create balanced tree if we are that many levels before the maximum tree depth + static const size_t SINGLE_THREADED_THRESHOLD = 4096; //!< threshold to switch to single threaded build + + static const size_t travCostAligned = 1; + static const size_t travCostUnaligned = 5; + static const size_t intCost = 6; + + BuilderT (Scene* scene, + PrimRef* prims, + const CreateAllocFunc& createAlloc, + const CreateAABBNodeFunc& createAABBNode, + const SetAABBNodeFunc& setAABBNode, + const CreateOBBNodeFunc& createOBBNode, + const SetOBBNodeFunc& setOBBNode, + const CreateLeafFunc& createLeaf, + const ProgressMonitor& progressMonitor, + const ReportFinishedRangeFunc& reportFinishedRange, + const Settings settings) + + : cfg(settings), + prims(prims), + createAlloc(createAlloc), + createAABBNode(createAABBNode), + setAABBNode(setAABBNode), + createOBBNode(createOBBNode), + setOBBNode(setOBBNode), + createLeaf(createLeaf), + progressMonitor(progressMonitor), + reportFinishedRange(reportFinishedRange), + alignedHeuristic(prims), unalignedHeuristic(scene,prims), strandHeuristic(scene,prims) {} + + /*! checks if all primitives are from the same geometry */ + __forceinline bool sameGeometry(const PrimInfoRange& range) + { + if (range.size() == 0) return true; + unsigned int firstGeomID = prims[range.begin()].geomID(); + for (size_t i=range.begin()+1; i cfg.maxDepth) + throw_RTCError(RTC_ERROR_UNKNOWN,"depth limit reached"); + + /* create leaf for few primitives */ + if (pinfo.size() <= cfg.maxLeafSize && sameGeometry(pinfo)) + return createLeaf(prims,pinfo,alloc); + + /* fill all children by always splitting the largest one */ + PrimInfoRange children[MAX_BRANCHING_FACTOR]; + unsigned numChildren = 1; + children[0] = pinfo; + + do { + + /* find best child with largest bounding box area */ + int bestChild = -1; + size_t bestSize = 0; + for (unsigned i=0; i bestSize) { + bestSize = children[i].size(); + bestChild = i; + } + } + if (bestChild == -1) break; + + /*! split best child into left and right child */ + __aligned(64) PrimInfoRange left, right; + if (!sameGeometry(children[bestChild])) { + alignedHeuristic.splitByGeometry(children[bestChild],left,right); + } else { + alignedHeuristic.splitFallback(children[bestChild],left,right); + } + + /* add new children left and right */ + children[bestChild] = children[numChildren-1]; + children[numChildren-1] = left; + children[numChildren+0] = right; + numChildren++; + + } while (numChildren < cfg.branchingFactor); + + /* create node */ + auto node = createAABBNode(alloc); + + for (size_t i=0; i> cfg.logBlockSize; + const float leafSAH = intCost*float(blocks)*halfArea(pinfo.geomBounds); + + /* try standard binning in aligned space */ + float alignedObjectSAH = inf; + HeuristicBinningSAH::Split alignedObjectSplit; + if (aligned) { + alignedObjectSplit = alignedHeuristic.find(pinfo,cfg.logBlockSize); + alignedObjectSAH = travCostAligned*halfArea(pinfo.geomBounds) + intCost*alignedObjectSplit.splitSAH(); + bestSAH = min(alignedObjectSAH,bestSAH); + } + + /* try standard binning in unaligned space */ + UnalignedHeuristicBinningSAH::Split unalignedObjectSplit; + LinearSpace3fa uspace; + float unalignedObjectSAH = inf; + if (bestSAH > 0.7f*leafSAH) { + uspace = unalignedHeuristic.computeAlignedSpace(pinfo); + const PrimInfoRange sinfo = unalignedHeuristic.computePrimInfo(pinfo,uspace); + unalignedObjectSplit = unalignedHeuristic.find(sinfo,cfg.logBlockSize,uspace); + unalignedObjectSAH = travCostUnaligned*halfArea(pinfo.geomBounds) + intCost*unalignedObjectSplit.splitSAH(); + bestSAH = min(unalignedObjectSAH,bestSAH); + } + + /* try splitting into two strands */ + HeuristicStrandSplitSAH::Split strandSplit; + float strandSAH = inf; + if (bestSAH > 0.7f*leafSAH && pinfo.size() <= 256) { + strandSplit = strandHeuristic.find(pinfo,cfg.logBlockSize); + strandSAH = travCostUnaligned*halfArea(pinfo.geomBounds) + intCost*strandSplit.splitSAH(); + bestSAH = min(strandSAH,bestSAH); + } + + /* fallback if SAH heuristics failed */ + if (unlikely(!std::isfinite(bestSAH))) + { + alignedHeuristic.deterministic_order(pinfo); + alignedHeuristic.splitFallback(pinfo,linfo,rinfo); + } + + /* perform aligned split if this is best */ + else if (bestSAH == alignedObjectSAH) { + alignedHeuristic.split(alignedObjectSplit,pinfo,linfo,rinfo); + } + + /* perform unaligned split if this is best */ + else if (bestSAH == unalignedObjectSAH) { + unalignedHeuristic.split(unalignedObjectSplit,uspace,pinfo,linfo,rinfo); + aligned = false; + } + + /* perform strand split if this is best */ + else if (bestSAH == strandSAH) { + strandHeuristic.split(strandSplit,pinfo,linfo,rinfo); + aligned = false; + } + + /* can never happen */ + else + assert(false); + } + + /*! recursive build */ + NodeRef recurse(size_t depth, const PrimInfoRange& pinfo, Allocator alloc, bool toplevel, bool alloc_barrier) + { + /* get thread local allocator */ + if (!alloc) + alloc = createAlloc(); + + /* call memory monitor function to signal progress */ + if (toplevel && pinfo.size() <= SINGLE_THREADED_THRESHOLD) + progressMonitor(pinfo.size()); + + PrimInfoRange children[MAX_BRANCHING_FACTOR]; + + /* create leaf node */ + if (depth+MIN_LARGE_LEAF_LEVELS >= cfg.maxDepth || pinfo.size() <= cfg.minLeafSize) { + alignedHeuristic.deterministic_order(pinfo); + return createLargeLeaf(depth,pinfo,alloc); + } + + /* fill all children by always splitting the one with the largest surface area */ + size_t numChildren = 1; + children[0] = pinfo; + bool aligned = true; + + do { + + /* find best child with largest bounding box area */ + ssize_t bestChild = -1; + float bestArea = neg_inf; + for (size_t i=0; i bestArea) { + bestArea = area(children[i].geomBounds); + bestChild = i; + } + } + if (bestChild == -1) break; + + /*! split best child into left and right child */ + PrimInfoRange left, right; + split(children[bestChild],left,right,aligned); + + /* add new children left and right */ + children[bestChild] = children[numChildren-1]; + children[numChildren-1] = left; + children[numChildren+0] = right; + numChildren++; + + } while (numChildren < cfg.branchingFactor); + + NodeRef node; + + /* create aligned node */ + if (aligned) + { + node = createAABBNode(alloc); + + /* spawn tasks or ... */ + if (pinfo.size() > SINGLE_THREADED_THRESHOLD) + { + parallel_for(size_t(0), numChildren, [&] (const range& r) { + for (size_t i=r.begin(); i cfg.finished_range_threshold && children[i].size() <= cfg.finished_range_threshold; + setAABBNode(node,i,recurse(depth+1,children[i],nullptr,true,child_alloc_barrier),children[i].geomBounds); + _mm_mfence(); // to allow non-temporal stores during build + } + }); + } + /* ... continue sequentially */ + else { + for (size_t i=0; i cfg.finished_range_threshold && children[i].size() <= cfg.finished_range_threshold; + setAABBNode(node,i,recurse(depth+1,children[i],alloc,false,child_alloc_barrier),children[i].geomBounds); + } + } + } + + /* create unaligned node */ + else + { + node = createOBBNode(alloc); + + /* spawn tasks or ... */ + if (pinfo.size() > SINGLE_THREADED_THRESHOLD) + { + parallel_for(size_t(0), numChildren, [&] (const range& r) { + for (size_t i=r.begin(); i cfg.finished_range_threshold && children[i].size() <= cfg.finished_range_threshold; + setOBBNode(node,i,recurse(depth+1,children[i],nullptr,true,child_alloc_barrier),obounds); + _mm_mfence(); // to allow non-temporal stores during build + } + }); + } + /* ... continue sequentially */ + else + { + for (size_t i=0; i cfg.finished_range_threshold && children[i].size() <= cfg.finished_range_threshold; + setOBBNode(node,i,recurse(depth+1,children[i],alloc,false,child_alloc_barrier),obounds); + } + } + } + + /* reports a finished range of primrefs */ + if (unlikely(alloc_barrier)) + reportFinishedRange(pinfo); + + return node; + } + + private: + Settings cfg; + PrimRef* prims; + const CreateAllocFunc& createAlloc; + const CreateAABBNodeFunc& createAABBNode; + const SetAABBNodeFunc& setAABBNode; + const CreateOBBNodeFunc& createOBBNode; + const SetOBBNodeFunc& setOBBNode; + const CreateLeafFunc& createLeaf; + const ProgressMonitor& progressMonitor; + const ReportFinishedRangeFunc& reportFinishedRange; + + private: + HeuristicBinningSAH alignedHeuristic; + UnalignedHeuristicBinningSAH unalignedHeuristic; + HeuristicStrandSplitSAH strandHeuristic; + }; + + template + + static NodeRef build (const CreateAllocFunc& createAlloc, + const CreateAABBNodeFunc& createAABBNode, + const SetAABBNodeFunc& setAABBNode, + const CreateOBBNodeFunc& createOBBNode, + const SetOBBNodeFunc& setOBBNode, + const CreateLeafFunc& createLeaf, + const ProgressMonitor& progressMonitor, + const ReportFinishedRangeFunc& reportFinishedRange, + Scene* scene, + PrimRef* prims, + const PrimInfo& pinfo, + const Settings settings) + { + typedef BuilderT Builder; + + Builder builder(scene,prims,createAlloc, + createAABBNode,setAABBNode, + createOBBNode,setOBBNode, + createLeaf,progressMonitor,reportFinishedRange,settings); + + NodeRef root = builder.recurse(1,pinfo,nullptr,true,false); + _mm_mfence(); // to allow non-temporal stores during build + return root; + } + }; + } +} diff --git a/thirdparty/embree/kernels/builders/bvh_builder_morton.h b/thirdparty/embree/kernels/builders/bvh_builder_morton.h new file mode 100644 index 000000000000..92be2f7e65c2 --- /dev/null +++ b/thirdparty/embree/kernels/builders/bvh_builder_morton.h @@ -0,0 +1,501 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/builder.h" +#include "../../common/algorithms/parallel_reduce.h" + +namespace embree +{ + namespace isa + { + struct BVHBuilderMorton + { + static const size_t MAX_BRANCHING_FACTOR = 8; //!< maximum supported BVH branching factor + static const size_t MIN_LARGE_LEAF_LEVELS = 8; //!< create balanced tree of we are that many levels before the maximum tree depth + + /*! settings for morton builder */ + struct Settings + { + /*! default settings */ + Settings () + : branchingFactor(2), maxDepth(32), minLeafSize(1), maxLeafSize(7), singleThreadThreshold(1024) {} + + /*! initialize settings from API settings */ + Settings (const RTCBuildArguments& settings) + : branchingFactor(2), maxDepth(32), minLeafSize(1), maxLeafSize(7), singleThreadThreshold(1024) + { + if (RTC_BUILD_ARGUMENTS_HAS(settings,maxBranchingFactor)) branchingFactor = settings.maxBranchingFactor; + if (RTC_BUILD_ARGUMENTS_HAS(settings,maxDepth )) maxDepth = settings.maxDepth; + if (RTC_BUILD_ARGUMENTS_HAS(settings,minLeafSize )) minLeafSize = settings.minLeafSize; + if (RTC_BUILD_ARGUMENTS_HAS(settings,maxLeafSize )) maxLeafSize = settings.maxLeafSize; + + minLeafSize = min(minLeafSize,maxLeafSize); + } + + Settings (size_t branchingFactor, size_t maxDepth, size_t minLeafSize, size_t maxLeafSize, size_t singleThreadThreshold) + : branchingFactor(branchingFactor), maxDepth(maxDepth), minLeafSize(minLeafSize), maxLeafSize(maxLeafSize), singleThreadThreshold(singleThreadThreshold) + { + minLeafSize = min(minLeafSize,maxLeafSize); + } + + public: + size_t branchingFactor; //!< branching factor of BVH to build + size_t maxDepth; //!< maximum depth of BVH to build + size_t minLeafSize; //!< minimum size of a leaf + size_t maxLeafSize; //!< maximum size of a leaf + size_t singleThreadThreshold; //!< threshold when we switch to single threaded build + }; + + /*! Build primitive consisting of morton code and primitive ID. */ + struct __aligned(8) BuildPrim + { + union { + struct { + unsigned int code; //!< morton code + unsigned int index; //!< i'th primitive + }; + uint64_t t; + }; + + /*! interface for radix sort */ + __forceinline operator unsigned() const { return code; } + + /*! interface for standard sort */ + __forceinline bool operator<(const BuildPrim &m) const { return code < m.code; } + }; + + /*! maps bounding box to morton code */ + struct MortonCodeMapping + { + static const size_t LATTICE_BITS_PER_DIM = 10; + static const size_t LATTICE_SIZE_PER_DIM = size_t(1) << LATTICE_BITS_PER_DIM; + + vfloat4 base; + vfloat4 scale; + + __forceinline MortonCodeMapping(const BBox3fa& bounds) + { + base = (vfloat4)bounds.lower; + const vfloat4 diag = (vfloat4)bounds.upper - (vfloat4)bounds.lower; + scale = select(diag > vfloat4(1E-19f), rcp(diag) * vfloat4(LATTICE_SIZE_PER_DIM * 0.99f),vfloat4(0.0f)); + } + + __forceinline const vint4 bin (const BBox3fa& box) const + { + const vfloat4 lower = (vfloat4)box.lower; + const vfloat4 upper = (vfloat4)box.upper; + const vfloat4 centroid = lower+upper; + return vint4((centroid-base)*scale); + } + + __forceinline unsigned int code (const BBox3fa& box) const + { + const vint4 binID = bin(box); + const unsigned int x = extract<0>(binID); + const unsigned int y = extract<1>(binID); + const unsigned int z = extract<2>(binID); + const unsigned int xyz = bitInterleave(x,y,z); + return xyz; + } + }; + +#if defined (__AVX2__) + + /*! for AVX2 there is a fast scalar bitInterleave */ + struct MortonCodeGenerator + { + __forceinline MortonCodeGenerator(const MortonCodeMapping& mapping, BuildPrim* dest) + : mapping(mapping), dest(dest) {} + + __forceinline void operator() (const BBox3fa& b, const unsigned index) + { + dest->index = index; + dest->code = mapping.code(b); + dest++; + } + + public: + const MortonCodeMapping mapping; + BuildPrim* dest; + size_t currentID; + }; + +#else + + /*! before AVX2 is it better to use the SSE version of bitInterleave */ + struct MortonCodeGenerator + { + __forceinline MortonCodeGenerator(const MortonCodeMapping& mapping, BuildPrim* dest) + : mapping(mapping), dest(dest), currentID(0), slots(0), ax(0), ay(0), az(0), ai(0) {} + + __forceinline ~MortonCodeGenerator() + { + if (slots != 0) + { + const vint4 code = bitInterleave(ax,ay,az); + for (size_t i=0; i(binID); + ay[slots] = extract<1>(binID); + az[slots] = extract<2>(binID); + ai[slots] = index; + slots++; + currentID++; + + if (slots == 4) + { + const vint4 code = bitInterleave(ax,ay,az); + vint4::storeu(&dest[currentID-4],unpacklo(code,ai)); + vint4::storeu(&dest[currentID-2],unpackhi(code,ai)); + slots = 0; + } + } + + public: + const MortonCodeMapping mapping; + BuildPrim* dest; + size_t currentID; + size_t slots; + vint4 ax, ay, az, ai; + }; + +#endif + + template< + typename ReductionTy, + typename Allocator, + typename CreateAllocator, + typename CreateNodeFunc, + typename SetNodeBoundsFunc, + typename CreateLeafFunc, + typename CalculateBounds, + typename ProgressMonitor> + + class BuilderT : private Settings + { + ALIGNED_CLASS_(16); + + public: + + BuilderT (CreateAllocator& createAllocator, + CreateNodeFunc& createNode, + SetNodeBoundsFunc& setBounds, + CreateLeafFunc& createLeaf, + CalculateBounds& calculateBounds, + ProgressMonitor& progressMonitor, + const Settings& settings) + + : Settings(settings), + createAllocator(createAllocator), + createNode(createNode), + setBounds(setBounds), + createLeaf(createLeaf), + calculateBounds(calculateBounds), + progressMonitor(progressMonitor), + morton(nullptr) {} + + ReductionTy createLargeLeaf(size_t depth, const range& current, Allocator alloc) + { + /* this should never occur but is a fatal error */ + if (depth > maxDepth) + throw_RTCError(RTC_ERROR_UNKNOWN,"depth limit reached"); + + /* create leaf for few primitives */ + if (current.size() <= maxLeafSize) + return createLeaf(current,alloc); + + /* fill all children by always splitting the largest one */ + range children[MAX_BRANCHING_FACTOR]; + size_t numChildren = 1; + children[0] = current; + + do { + + /* find best child with largest number of primitives */ + size_t bestChild = -1; + size_t bestSize = 0; + for (size_t i=0; i bestSize) { + bestSize = children[i].size(); + bestChild = i; + } + } + if (bestChild == size_t(-1)) break; + + /*! split best child into left and right child */ + auto split = children[bestChild].split(); + + /* add new children left and right */ + children[bestChild] = children[numChildren-1]; + children[numChildren-1] = split.first; + children[numChildren+0] = split.second; + numChildren++; + + } while (numChildren < branchingFactor); + + /* create node */ + auto node = createNode(alloc,numChildren); + + /* recurse into each child */ + ReductionTy bounds[MAX_BRANCHING_FACTOR]; + for (size_t i=0; i& current) const + { + /* fast path for small ranges */ + if (likely(current.size() < 1024)) + { + /*! recalculate centroid bounds */ + BBox3fa centBounds(empty); + for (size_t i=current.begin(); i& r ) { + BBox3fa centBounds = empty; + for (size_t i=r.begin(); i& r ) { + for (size_t i=r.begin(); i& current, range& left, range& right) const + { + const unsigned int code_start = morton[current.begin()].code; + const unsigned int code_end = morton[current.end()-1].code; + unsigned int bitpos = lzcnt(code_start^code_end); + + /* if all items mapped to same morton code, then re-create new morton codes for the items */ + if (unlikely(bitpos == 32)) + { + recreateMortonCodes(current); + const unsigned int code_start = morton[current.begin()].code; + const unsigned int code_end = morton[current.end()-1].code; + bitpos = lzcnt(code_start^code_end); + + /* if the morton code is still the same, goto fall back split */ + if (unlikely(bitpos == 32)) { + current.split(left,right); + return; + } + } + + /* split the items at the topmost different morton code bit */ + const unsigned int bitpos_diff = 31-bitpos; + const unsigned int bitmask = 1 << bitpos_diff; + + /* find location where bit differs using binary search */ + unsigned begin = current.begin(); + unsigned end = current.end(); + while (begin + 1 != end) { + const unsigned mid = (begin+end)/2; + const unsigned bit = morton[mid].code & bitmask; + if (bit == 0) begin = mid; else end = mid; + } + unsigned center = end; +#if defined(DEBUG) + for (unsigned int i=begin; i& current, Allocator alloc, bool toplevel) + { + /* get thread local allocator */ + if (!alloc) + alloc = createAllocator(); + + /* call memory monitor function to signal progress */ + if (toplevel && current.size() <= singleThreadThreshold) + progressMonitor(current.size()); + + /* create leaf node */ + if (unlikely(depth+MIN_LARGE_LEAF_LEVELS >= maxDepth || current.size() <= minLeafSize)) + return createLargeLeaf(depth,current,alloc); + + /* fill all children by always splitting the one with the largest surface area */ + range children[MAX_BRANCHING_FACTOR]; + split(current,children[0],children[1]); + size_t numChildren = 2; + + while (numChildren < branchingFactor) + { + /* find best child with largest number of primitives */ + int bestChild = -1; + unsigned bestItems = 0; + for (unsigned int i=0; i bestItems) { + bestItems = children[i].size(); + bestChild = i; + } + } + if (bestChild == -1) break; + + /*! split best child into left and right child */ + range left, right; + split(children[bestChild],left,right); + + /* add new children left and right */ + children[bestChild] = children[numChildren-1]; + children[numChildren-1] = left; + children[numChildren+0] = right; + numChildren++; + } + + /* create leaf node if no split is possible */ + if (unlikely(numChildren == 1)) + return createLeaf(current,alloc); + + /* allocate node */ + auto node = createNode(alloc,numChildren); + + /* process top parts of tree parallel */ + ReductionTy bounds[MAX_BRANCHING_FACTOR]; + if (current.size() > singleThreadThreshold) + { + /*! parallel_for is faster than spawing sub-tasks */ + parallel_for(size_t(0), numChildren, [&] (const range& r) { + for (size_t i=r.begin(); i(0,(unsigned)numPrimitives), nullptr, true); + _mm_mfence(); // to allow non-temporal stores during build + return root; + } + + public: + CreateAllocator& createAllocator; + CreateNodeFunc& createNode; + SetNodeBoundsFunc& setBounds; + CreateLeafFunc& createLeaf; + CalculateBounds& calculateBounds; + ProgressMonitor& progressMonitor; + + public: + BuildPrim* morton; + }; + + + template< + typename ReductionTy, + typename CreateAllocFunc, + typename CreateNodeFunc, + typename SetBoundsFunc, + typename CreateLeafFunc, + typename CalculateBoundsFunc, + typename ProgressMonitor> + + static ReductionTy build(CreateAllocFunc createAllocator, + CreateNodeFunc createNode, + SetBoundsFunc setBounds, + CreateLeafFunc createLeaf, + CalculateBoundsFunc calculateBounds, + ProgressMonitor progressMonitor, + BuildPrim* src, + BuildPrim* tmp, + size_t numPrimitives, + const Settings& settings) + { + typedef BuilderT< + ReductionTy, + decltype(createAllocator()), + CreateAllocFunc, + CreateNodeFunc, + SetBoundsFunc, + CreateLeafFunc, + CalculateBoundsFunc, + ProgressMonitor> Builder; + + Builder builder(createAllocator, + createNode, + setBounds, + createLeaf, + calculateBounds, + progressMonitor, + settings); + + return builder.build(src,tmp,numPrimitives); + } + }; + } +} diff --git a/thirdparty/embree/kernels/builders/bvh_builder_msmblur.h b/thirdparty/embree/kernels/builders/bvh_builder_msmblur.h new file mode 100644 index 000000000000..4c138dacdbd4 --- /dev/null +++ b/thirdparty/embree/kernels/builders/bvh_builder_msmblur.h @@ -0,0 +1,692 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#define MBLUR_NUM_TEMPORAL_BINS 2 +#define MBLUR_NUM_OBJECT_BINS 32 + +#include "../bvh/bvh.h" +#include "../common/primref_mb.h" +#include "heuristic_binning_array_aligned.h" +#include "heuristic_timesplit_array.h" + +namespace embree +{ + namespace isa + { + template + struct SharedVector + { + __forceinline SharedVector() {} + + __forceinline SharedVector(T* ptr, size_t refCount = 1) + : prims(ptr), refCount(refCount) {} + + __forceinline void incRef() { + refCount++; + } + + __forceinline void decRef() + { + if (--refCount == 0) + delete prims; + } + + T* prims; + size_t refCount; + }; + + template + struct LocalChildListT + { + typedef SharedVector> SharedPrimRefVector; + + __forceinline LocalChildListT (const BuildRecord& record) + : numChildren(1), numSharedPrimVecs(1) + { + /* the local root will be freed in the ancestor where it was created (thus refCount is 2) */ + children[0] = record; + primvecs[0] = new (&sharedPrimVecs[0]) SharedPrimRefVector(record.prims.prims, 2); + } + + __forceinline ~LocalChildListT() + { + for (size_t i = 0; i < numChildren; i++) + primvecs[i]->decRef(); + } + + __forceinline BuildRecord& operator[] ( const size_t i ) { + return children[i]; + } + + __forceinline size_t size() const { + return numChildren; + } + + __forceinline void split(ssize_t bestChild, const BuildRecord& lrecord, const BuildRecord& rrecord, std::unique_ptr> new_vector) + { + SharedPrimRefVector* bsharedPrimVec = primvecs[bestChild]; + if (lrecord.prims.prims == bsharedPrimVec->prims) { + primvecs[bestChild] = bsharedPrimVec; + bsharedPrimVec->incRef(); + } + else { + primvecs[bestChild] = new (&sharedPrimVecs[numSharedPrimVecs++]) SharedPrimRefVector(lrecord.prims.prims); + } + + if (rrecord.prims.prims == bsharedPrimVec->prims) { + primvecs[numChildren] = bsharedPrimVec; + bsharedPrimVec->incRef(); + } + else { + primvecs[numChildren] = new (&sharedPrimVecs[numSharedPrimVecs++]) SharedPrimRefVector(rrecord.prims.prims); + } + bsharedPrimVec->decRef(); + new_vector.release(); + + children[bestChild] = lrecord; + children[numChildren] = rrecord; + numChildren++; + } + + public: + array_t children; + array_t primvecs; + size_t numChildren; + + array_t sharedPrimVecs; + size_t numSharedPrimVecs; + }; + + template + struct RecalculatePrimRef + { + Scene* scene; + + __forceinline RecalculatePrimRef (Scene* scene) + : scene(scene) {} + + __forceinline PrimRefMB operator() (const PrimRefMB& prim, const BBox1f time_range) const + { + const unsigned geomID = prim.geomID(); + const unsigned primID = prim.primID(); + const Mesh* mesh = scene->get(geomID); + const LBBox3fa lbounds = mesh->linearBounds(primID, time_range); + const range tbounds = mesh->timeSegmentRange(time_range); + return PrimRefMB (lbounds, tbounds.size(), mesh->time_range, mesh->numTimeSegments(), geomID, primID); + } + + // __noinline is workaround for ICC16 bug under MacOSX + __noinline PrimRefMB operator() (const PrimRefMB& prim, const BBox1f time_range, const LinearSpace3fa& space) const + { + const unsigned geomID = prim.geomID(); + const unsigned primID = prim.primID(); + const Mesh* mesh = scene->get(geomID); + const LBBox3fa lbounds = mesh->linearBounds(space, primID, time_range); + const range tbounds = mesh->timeSegmentRange(time_range); + return PrimRefMB (lbounds, tbounds.size(), mesh->time_range, mesh->numTimeSegments(), geomID, primID); + } + + __forceinline LBBox3fa linearBounds(const PrimRefMB& prim, const BBox1f time_range) const { + return scene->get(prim.geomID())->linearBounds(prim.primID(), time_range); + } + + // __noinline is workaround for ICC16 bug under MacOSX + __noinline LBBox3fa linearBounds(const PrimRefMB& prim, const BBox1f time_range, const LinearSpace3fa& space) const { + return scene->get(prim.geomID())->linearBounds(space, prim.primID(), time_range); + } + }; + + struct VirtualRecalculatePrimRef + { + Scene* scene; + + __forceinline VirtualRecalculatePrimRef (Scene* scene) + : scene(scene) {} + + __forceinline PrimRefMB operator() (const PrimRefMB& prim, const BBox1f time_range) const + { + const unsigned geomID = prim.geomID(); + const unsigned primID = prim.primID(); + const Geometry* mesh = scene->get(geomID); + const LBBox3fa lbounds = mesh->vlinearBounds(primID, time_range); + const range tbounds = mesh->timeSegmentRange(time_range); + return PrimRefMB (lbounds, tbounds.size(), mesh->time_range, mesh->numTimeSegments(), geomID, primID); + } + + __forceinline PrimRefMB operator() (const PrimRefMB& prim, const BBox1f time_range, const LinearSpace3fa& space) const + { + const unsigned geomID = prim.geomID(); + const unsigned primID = prim.primID(); + const Geometry* mesh = scene->get(geomID); + const LBBox3fa lbounds = mesh->vlinearBounds(space, primID, time_range); + const range tbounds = mesh->timeSegmentRange(time_range); + return PrimRefMB (lbounds, tbounds.size(), mesh->time_range, mesh->numTimeSegments(), geomID, primID); + } + + __forceinline LBBox3fa linearBounds(const PrimRefMB& prim, const BBox1f time_range) const { + return scene->get(prim.geomID())->vlinearBounds(prim.primID(), time_range); + } + + __forceinline LBBox3fa linearBounds(const PrimRefMB& prim, const BBox1f time_range, const LinearSpace3fa& space) const { + return scene->get(prim.geomID())->vlinearBounds(space, prim.primID(), time_range); + } + }; + + struct BVHBuilderMSMBlur + { + /*! settings for msmblur builder */ + struct Settings + { + /*! default settings */ + Settings () + : branchingFactor(2), maxDepth(32), logBlockSize(0), minLeafSize(1), maxLeafSize(8), + travCost(1.0f), intCost(1.0f), singleLeafTimeSegment(false), + singleThreadThreshold(1024) {} + + + Settings (size_t sahBlockSize, size_t minLeafSize, size_t maxLeafSize, float travCost, float intCost, size_t singleThreadThreshold) + : branchingFactor(2), maxDepth(32), logBlockSize(bsr(sahBlockSize)), minLeafSize(minLeafSize), maxLeafSize(maxLeafSize), + travCost(travCost), intCost(intCost), singleThreadThreshold(singleThreadThreshold) + { + minLeafSize = min(minLeafSize,maxLeafSize); + } + + public: + size_t branchingFactor; //!< branching factor of BVH to build + size_t maxDepth; //!< maximum depth of BVH to build + size_t logBlockSize; //!< log2 of blocksize for SAH heuristic + size_t minLeafSize; //!< minimum size of a leaf + size_t maxLeafSize; //!< maximum size of a leaf + float travCost; //!< estimated cost of one traversal step + float intCost; //!< estimated cost of one primitive intersection + bool singleLeafTimeSegment; //!< split time to single time range + size_t singleThreadThreshold; //!< threshold when we switch to single threaded build + }; + + struct BuildRecord + { + public: + __forceinline BuildRecord () {} + + __forceinline BuildRecord (size_t depth) + : depth(depth) {} + + __forceinline BuildRecord (const SetMB& prims, size_t depth) + : depth(depth), prims(prims) {} + + __forceinline friend bool operator< (const BuildRecord& a, const BuildRecord& b) { + return a.prims.size() < b.prims.size(); + } + + __forceinline size_t size() const { + return prims.size(); + } + + public: + size_t depth; //!< Depth of the root of this subtree. + SetMB prims; //!< The list of primitives. + }; + + struct BuildRecordSplit : public BuildRecord + { + __forceinline BuildRecordSplit () {} + + __forceinline BuildRecordSplit (size_t depth) + : BuildRecord(depth) {} + + __forceinline BuildRecordSplit (const BuildRecord& record, const BinSplit& split) + : BuildRecord(record), split(split) {} + + BinSplit split; + }; + + template< + typename NodeRef, + typename RecalculatePrimRef, + typename Allocator, + typename CreateAllocFunc, + typename CreateNodeFunc, + typename SetNodeFunc, + typename CreateLeafFunc, + typename ProgressMonitor> + + class BuilderT + { + ALIGNED_CLASS_(16); + static const size_t MAX_BRANCHING_FACTOR = 16; //!< maximum supported BVH branching factor + static const size_t MIN_LARGE_LEAF_LEVELS = 8; //!< create balanced tree if we are that many levels before the maximum tree depth + + typedef BVHNodeRecordMB4D NodeRecordMB4D; + typedef BinSplit Split; + typedef mvector* PrimRefVector; + typedef SharedVector> SharedPrimRefVector; + typedef LocalChildListT LocalChildList; + typedef LocalChildListT LocalChildListSplit; + + public: + + BuilderT (MemoryMonitorInterface* device, + const RecalculatePrimRef recalculatePrimRef, + const CreateAllocFunc createAlloc, + const CreateNodeFunc createNode, + const SetNodeFunc setNode, + const CreateLeafFunc createLeaf, + const ProgressMonitor progressMonitor, + const Settings& settings) + : cfg(settings), + heuristicObjectSplit(), + heuristicTemporalSplit(device, recalculatePrimRef), + recalculatePrimRef(recalculatePrimRef), createAlloc(createAlloc), createNode(createNode), setNode(setNode), createLeaf(createLeaf), + progressMonitor(progressMonitor) + { + if (cfg.branchingFactor > MAX_BRANCHING_FACTOR) + throw_RTCError(RTC_ERROR_UNKNOWN,"bvh_builder: branching factor too large"); + } + + /*! finds the best split */ + const Split find(const SetMB& set) + { + /* first try standard object split */ + const Split object_split = heuristicObjectSplit.find(set,cfg.logBlockSize); + const float object_split_sah = object_split.splitSAH(); + + /* test temporal splits only when object split was bad */ + const float leaf_sah = set.leafSAH(cfg.logBlockSize); + if (object_split_sah < 0.50f*leaf_sah) + return object_split; + + /* do temporal splits only if the the time range is big enough */ + if (set.time_range.size() > 1.01f/float(set.max_num_time_segments)) + { + const Split temporal_split = heuristicTemporalSplit.find(set,cfg.logBlockSize); + const float temporal_split_sah = temporal_split.splitSAH(); + + /* take temporal split if it improved SAH */ + if (temporal_split_sah < object_split_sah) + return temporal_split; + } + + return object_split; + } + + /*! array partitioning */ + __forceinline std::unique_ptr> split(const Split& split, const SetMB& set, SetMB& lset, SetMB& rset) + { + /* perform object split */ + if (likely(split.data == Split::SPLIT_OBJECT)) { + heuristicObjectSplit.split(split,set,lset,rset); + } + /* perform temporal split */ + else if (likely(split.data == Split::SPLIT_TEMPORAL)) { + return heuristicTemporalSplit.split(split,set,lset,rset); + } + /* perform fallback split */ + else if (unlikely(split.data == Split::SPLIT_FALLBACK)) { + set.deterministic_order(); + splitFallback(set,lset,rset); + } + /* split by geometry */ + else if (unlikely(split.data == Split::SPLIT_GEOMID)) { + set.deterministic_order(); + splitByGeometry(set,lset,rset); + } + else + assert(false); + + return std::unique_ptr>(); + } + + /*! finds the best fallback split */ + __noinline Split findFallback(const SetMB& set) + { + /* split if primitives are not from same geometry */ + if (!sameGeometry(set)) + return Split(0.0f,Split::SPLIT_GEOMID); + + /* if a leaf can only hold a single time-segment, we might have to do additional temporal splits */ + if (cfg.singleLeafTimeSegment) + { + /* test if one primitive has more than one time segment in time range, if so split time */ + for (size_t i=set.begin(); i itime_range = prim.timeSegmentRange(set.time_range); + const int localTimeSegments = itime_range.size(); + assert(localTimeSegments > 0); + if (localTimeSegments > 1) { + const int icenter = (itime_range.begin() + itime_range.end())/2; + const float splitTime = prim.timeStep(icenter); + return Split(0.0f,(unsigned)Split::SPLIT_TEMPORAL,0,splitTime); + } + } + } + + /* otherwise return fallback split */ + return Split(0.0f,Split::SPLIT_FALLBACK); + } + + /*! performs fallback split */ + void splitFallback(const SetMB& set, SetMB& lset, SetMB& rset) + { + mvector& prims = *set.prims; + + const size_t begin = set.begin(); + const size_t end = set.end(); + const size_t center = (begin + end)/2; + + PrimInfoMB linfo = empty; + for (size_t i=begin; i(begin,center),set.time_range); + new (&rset) SetMB(rinfo,set.prims,range(center,end ),set.time_range); + } + + /*! checks if all primitives are from the same geometry */ + __forceinline bool sameGeometry(const SetMB& set) + { + if (set.size() == 0) return true; + mvector& prims = *set.prims; + const size_t begin = set.begin(); + const size_t end = set.end(); + unsigned int firstGeomID = prims[begin].geomID(); + for (size_t i=begin+1; i 1); + + mvector& prims = *set.prims; + const size_t begin = set.begin(); + const size_t end = set.end(); + + PrimInfoMB left(empty); + PrimInfoMB right(empty); + unsigned int geomID = prims[begin].geomID(); + size_t center = serial_partitioning(prims.data(),begin,end,left,right, + [&] ( const PrimRefMB& prim ) { return prim.geomID() == geomID; }, + [ ] ( PrimInfoMB& dst, const PrimRefMB& prim ) { dst.add_primref(prim); }); + + new (&lset) SetMB(left, set.prims,range(begin,center),set.time_range); + new (&rset) SetMB(right,set.prims,range(center,end ),set.time_range); + } + + const NodeRecordMB4D createLargeLeaf(const BuildRecord& in, Allocator alloc) + { + /* this should never occur but is a fatal error */ + if (in.depth > cfg.maxDepth) + throw_RTCError(RTC_ERROR_UNKNOWN,"depth limit reached"); + + /* replace already found split by fallback split */ + const BuildRecordSplit current(BuildRecord(in.prims,in.depth),findFallback(in.prims)); + + /* special case when directly creating leaf without any splits that could shrink time_range */ + bool force_split = false; + if (current.depth == 1 && current.size() > 0) + { + BBox1f c = empty; + BBox1f p = current.prims.time_range; + for (size_t i=current.prims.begin(); i& prims = *current.prims.prims; + c.extend(prims[i].time_range); + } + + force_split = c.lower > p.lower || c.upper < p.upper; + } + + /* create leaf for few primitives */ + if (current.size() <= cfg.maxLeafSize && current.split.data < Split::SPLIT_ENFORCE && !force_split) + return createLeaf(current,alloc); + + /* fill all children by always splitting the largest one */ + bool hasTimeSplits = false; + NodeRecordMB4D values[MAX_BRANCHING_FACTOR]; + LocalChildListSplit children(current); + + do { + /* find best child with largest bounding box area */ + size_t bestChild = -1; + size_t bestSize = 0; + for (size_t i=0; i bestSize) { + bestSize = children[i].size(); + bestChild = i; + } + } + if (bestChild == -1) break; + + /* perform best found split */ + BuildRecordSplit& brecord = children[bestChild]; + BuildRecordSplit lrecord(current.depth+1); + BuildRecordSplit rrecord(current.depth+1); + std::unique_ptr> new_vector = split(brecord.split,brecord.prims,lrecord.prims,rrecord.prims); + hasTimeSplits |= new_vector != nullptr; + + /* find new splits */ + lrecord.split = findFallback(lrecord.prims); + rrecord.split = findFallback(rrecord.prims); + children.split(bestChild,lrecord,rrecord,std::move(new_vector)); + + } while (children.size() < cfg.branchingFactor); + + /* detect time_ranges that have shrunken */ + for (size_t i=0; i p.lower || c.upper < p.upper; + } + + /* create node */ + auto node = createNode(children.children.data(),children.numChildren,alloc,hasTimeSplits); + + /* recurse into each child and perform reduction */ + LBBox3fa gbounds = empty; + for (size_t i=0; i= 0) && (splitSAH >= 0))); + + /*! create a leaf node when threshold reached or SAH tells us to stop */ + if (current.size() <= cfg.minLeafSize || current.depth+MIN_LARGE_LEAF_LEVELS >= cfg.maxDepth || (current.size() <= cfg.maxLeafSize && leafSAH <= splitSAH)) { + current.prims.deterministic_order(); + return createLargeLeaf(current,alloc); + } + + /*! perform initial split */ + SetMB lprims,rprims; + std::unique_ptr> new_vector = split(csplit,current.prims,lprims,rprims); + bool hasTimeSplits = new_vector != nullptr; + NodeRecordMB4D values[MAX_BRANCHING_FACTOR]; + LocalChildList children(current); + { + BuildRecord lrecord(lprims,current.depth+1); + BuildRecord rrecord(rprims,current.depth+1); + children.split(0,lrecord,rrecord,std::move(new_vector)); + } + + /*! split until node is full or SAH tells us to stop */ + while (children.size() < cfg.branchingFactor) + { + /*! find best child to split */ + float bestArea = neg_inf; + ssize_t bestChild = -1; + for (size_t i=0; i bestArea) { + bestChild = i; bestArea = expectedApproxHalfArea(children[i].prims.geomBounds); + } + } + if (bestChild == -1) break; + + /* perform split */ + BuildRecord& brecord = children[bestChild]; + BuildRecord lrecord(current.depth+1); + BuildRecord rrecord(current.depth+1); + Split csplit = find(brecord.prims); + std::unique_ptr> new_vector = split(csplit,brecord.prims,lrecord.prims,rrecord.prims); + hasTimeSplits |= new_vector != nullptr; + children.split(bestChild,lrecord,rrecord,std::move(new_vector)); + } + + /* detect time_ranges that have shrunken */ + for (size_t i=0; i p.lower || c.upper < p.upper; + } + + /* sort buildrecords for simpler shadow ray traversal */ + //std::sort(&children[0],&children[children.size()],std::greater()); // FIXME: reduces traversal performance of bvh8.triangle4 (need to verified) !! + + /*! create an inner node */ + auto node = createNode(children.children.data(), children.numChildren, alloc, hasTimeSplits); + LBBox3fa gbounds = empty; + + /* spawn tasks */ + if (unlikely(current.size() > cfg.singleThreadThreshold)) + { + /*! parallel_for is faster than spawing sub-tasks */ + parallel_for(size_t(0), children.size(), [&] (const range& r) { + for (size_t i=r.begin(); i=0; i--) { + values[i] = recurse(children[i],alloc,false); + gbounds.extend(values[i].lbounds); + } + } + + setNode(current,children.children.data(),node,values,children.numChildren); + + /* calculate geometry bounds of this node */ + if (unlikely(hasTimeSplits)) + return NodeRecordMB4D(node,current.prims.linearBounds(recalculatePrimRef),current.prims.time_range); + else + return NodeRecordMB4D(node,gbounds,current.prims.time_range); + } + + /*! builder entry function */ + __forceinline const NodeRecordMB4D operator() (mvector& prims, const PrimInfoMB& pinfo) + { + const SetMB set(pinfo,&prims); + auto ret = recurse(BuildRecord(set,1),nullptr,true); + _mm_mfence(); // to allow non-temporal stores during build + return ret; + } + + private: + Settings cfg; + HeuristicArrayBinningMB heuristicObjectSplit; + HeuristicMBlurTemporalSplit heuristicTemporalSplit; + const RecalculatePrimRef recalculatePrimRef; + const CreateAllocFunc createAlloc; + const CreateNodeFunc createNode; + const SetNodeFunc setNode; + const CreateLeafFunc createLeaf; + const ProgressMonitor progressMonitor; + }; + + template + + static const BVHNodeRecordMB4D build(mvector& prims, + const PrimInfoMB& pinfo, + MemoryMonitorInterface* device, + const RecalculatePrimRef recalculatePrimRef, + const CreateAllocFunc createAlloc, + const CreateNodeFunc createNode, + const SetNodeFunc setNode, + const CreateLeafFunc createLeaf, + const ProgressMonitorFunc progressMonitor, + const Settings& settings) + { + typedef BuilderT< + NodeRef, + RecalculatePrimRef, + decltype(createAlloc()), + CreateAllocFunc, + CreateNodeFunc, + SetNodeFunc, + CreateLeafFunc, + ProgressMonitorFunc> Builder; + + Builder builder(device, + recalculatePrimRef, + createAlloc, + createNode, + setNode, + createLeaf, + progressMonitor, + settings); + + + return builder(prims,pinfo); + } + }; + } +} diff --git a/thirdparty/embree/kernels/builders/bvh_builder_msmblur_hair.h b/thirdparty/embree/kernels/builders/bvh_builder_msmblur_hair.h new file mode 100644 index 000000000000..e477c313a381 --- /dev/null +++ b/thirdparty/embree/kernels/builders/bvh_builder_msmblur_hair.h @@ -0,0 +1,526 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../bvh/bvh.h" +#include "../geometry/primitive.h" +#include "../builders/bvh_builder_msmblur.h" +#include "../builders/heuristic_binning_array_aligned.h" +#include "../builders/heuristic_binning_array_unaligned.h" +#include "../builders/heuristic_timesplit_array.h" + +namespace embree +{ + namespace isa + { + struct BVHBuilderHairMSMBlur + { + /*! settings for msmblur builder */ + struct Settings + { + /*! default settings */ + Settings () + : branchingFactor(2), maxDepth(32), logBlockSize(0), minLeafSize(1), maxLeafSize(8) {} + + public: + size_t branchingFactor; //!< branching factor of BVH to build + size_t maxDepth; //!< maximum depth of BVH to build + size_t logBlockSize; //!< log2 of blocksize for SAH heuristic + size_t minLeafSize; //!< minimum size of a leaf + size_t maxLeafSize; //!< maximum size of a leaf + }; + + struct BuildRecord + { + public: + __forceinline BuildRecord () {} + + __forceinline BuildRecord (size_t depth) + : depth(depth) {} + + __forceinline BuildRecord (const SetMB& prims, size_t depth) + : depth(depth), prims(prims) {} + + __forceinline size_t size() const { + return prims.size(); + } + + public: + size_t depth; //!< depth of the root of this subtree + SetMB prims; //!< the list of primitives + }; + + template + + class BuilderT + { + ALIGNED_CLASS_(16); + + static const size_t MAX_BRANCHING_FACTOR = 8; //!< maximum supported BVH branching factor + static const size_t MIN_LARGE_LEAF_LEVELS = 8; //!< create balanced tree if we are that many levels before the maximum tree depth + static const size_t SINGLE_THREADED_THRESHOLD = 4096; //!< threshold to switch to single threaded build + + typedef BVHNodeRecordMB NodeRecordMB; + typedef BVHNodeRecordMB4D NodeRecordMB4D; + + typedef FastAllocator::CachedAllocator Allocator; + typedef LocalChildListT LocalChildList; + + typedef HeuristicMBlurTemporalSplit HeuristicTemporal; + typedef HeuristicArrayBinningMB HeuristicBinning; + typedef UnalignedHeuristicArrayBinningMB UnalignedHeuristicBinning; + + public: + + BuilderT (Scene* scene, + const RecalculatePrimRef& recalculatePrimRef, + const CreateAllocFunc& createAlloc, + const CreateAABBNodeMBFunc& createAABBNodeMB, + const SetAABBNodeMBFunc& setAABBNodeMB, + const CreateOBBNodeMBFunc& createOBBNodeMB, + const SetOBBNodeMBFunc& setOBBNodeMB, + const CreateLeafFunc& createLeaf, + const ProgressMonitor& progressMonitor, + const Settings settings) + + : cfg(settings), + scene(scene), + recalculatePrimRef(recalculatePrimRef), + createAlloc(createAlloc), + createAABBNodeMB(createAABBNodeMB), setAABBNodeMB(setAABBNodeMB), + createOBBNodeMB(createOBBNodeMB), setOBBNodeMB(setOBBNodeMB), + createLeaf(createLeaf), + progressMonitor(progressMonitor), + unalignedHeuristic(scene), + temporalSplitHeuristic(scene->device,recalculatePrimRef) {} + + private: + + /*! checks if all primitives are from the same geometry */ + __forceinline bool sameGeometry(const SetMB& set) + { + mvector& prims = *set.prims; + unsigned int firstGeomID = prims[set.begin()].geomID(); + for (size_t i=set.begin()+1; i& prims = *set.prims; + + const size_t begin = set.begin(); + const size_t end = set.end(); + const size_t center = (begin + end)/2; + + PrimInfoMB linfo = empty; + for (size_t i=begin; i(begin,center),set.time_range); + new (&rset) SetMB(rinfo,set.prims,range(center,end ),set.time_range); + } + + void splitByGeometry(const SetMB& set, SetMB& lset, SetMB& rset) + { + assert(set.size() > 1); + const size_t begin = set.begin(); + const size_t end = set.end(); + PrimInfoMB linfo(empty); + PrimInfoMB rinfo(empty); + unsigned int geomID = (*set.prims)[begin].geomID(); + size_t center = serial_partitioning(set.prims->data(),begin,end,linfo,rinfo, + [&] ( const PrimRefMB& prim ) { return prim.geomID() == geomID; }, + [ ] ( PrimInfoMB& a, const PrimRefMB& ref ) { a.add_primref(ref); }); + + new (&lset) SetMB(linfo,set.prims,range(begin,center),set.time_range); + new (&rset) SetMB(rinfo,set.prims,range(center,end ),set.time_range); + } + + /*! creates a large leaf that could be larger than supported by the BVH */ + NodeRecordMB4D createLargeLeaf(BuildRecord& current, Allocator alloc) + { + /* this should never occur but is a fatal error */ + if (current.depth > cfg.maxDepth) + throw_RTCError(RTC_ERROR_UNKNOWN,"depth limit reached"); + + /* special case when directly creating leaf without any splits that could shrink time_range */ + bool force_split = false; + if (current.depth == 1 && current.size() > 0) + { + BBox1f c = empty; + BBox1f p = current.prims.time_range; + for (size_t i=current.prims.begin(); i& prims = *current.prims.prims; + c.extend(prims[i].time_range); + } + + force_split = c.lower > p.lower || c.upper < p.upper; + } + + /* create leaf for few primitives */ + if (current.size() <= cfg.maxLeafSize && sameGeometry(current.prims) && !force_split) + return createLeaf(current.prims,alloc); + + /* fill all children by always splitting the largest one */ + LocalChildList children(current); + NodeRecordMB4D values[MAX_BRANCHING_FACTOR]; + + do { + + /* find best child with largest bounding box area */ + int bestChild = -1; + size_t bestSize = 0; + for (unsigned i=0; i bestSize) { + bestSize = children[i].size(); + bestChild = i; + } + } + if (bestChild == -1) break; + + /*! split best child into left and right child */ + BuildRecord left(current.depth+1); + BuildRecord right(current.depth+1); + if (!sameGeometry(children[bestChild].prims)) { + splitByGeometry(children[bestChild].prims,left.prims,right.prims); + } else { + splitFallback(children[bestChild].prims,left.prims,right.prims); + } + children.split(bestChild,left,right,std::unique_ptr>()); + + } while (children.size() < cfg.branchingFactor); + + + /* detect time_ranges that have shrunken */ + bool timesplit = false; + for (size_t i=0; i p.lower || c.upper < p.upper; + } + + /* create node */ + NodeRef node = createAABBNodeMB(children.children.data(),children.numChildren,alloc,timesplit); + + LBBox3fa bounds = empty; + for (size_t i=0; i> split(const BuildRecord& current, BuildRecord& lrecord, BuildRecord& rrecord, bool& aligned, bool& timesplit) + { + /* variable to track the SAH of the best splitting approach */ + float bestSAH = inf; + const float leafSAH = current.prims.leafSAH(cfg.logBlockSize); + + /* perform standard binning in aligned space */ + HeuristicBinning::Split alignedObjectSplit = alignedHeuristic.find(current.prims,cfg.logBlockSize); + float alignedObjectSAH = alignedObjectSplit.splitSAH(); + bestSAH = min(alignedObjectSAH,bestSAH); + + /* perform standard binning in unaligned space */ + UnalignedHeuristicBinning::Split unalignedObjectSplit; + LinearSpace3fa uspace; + float unalignedObjectSAH = inf; + if (alignedObjectSAH > 0.7f*leafSAH) { + uspace = unalignedHeuristic.computeAlignedSpaceMB(scene,current.prims); + const SetMB sset = current.prims.primInfo(recalculatePrimRef,uspace); + unalignedObjectSplit = unalignedHeuristic.find(sset,cfg.logBlockSize,uspace); + unalignedObjectSAH = 1.3f*unalignedObjectSplit.splitSAH(); // makes unaligned splits more expensive + bestSAH = min(unalignedObjectSAH,bestSAH); + } + + /* do temporal splits only if previous approaches failed to produce good SAH and the the time range is large enough */ + float temporal_split_sah = inf; + typename HeuristicTemporal::Split temporal_split; + if (bestSAH > 0.5f*leafSAH) { + if (current.prims.time_range.size() > 1.01f/float(current.prims.max_num_time_segments)) { + temporal_split = temporalSplitHeuristic.find(current.prims,cfg.logBlockSize); + temporal_split_sah = temporal_split.splitSAH(); + bestSAH = min(temporal_split_sah,bestSAH); + } + } + + /* perform fallback split if SAH heuristics failed */ + if (unlikely(!std::isfinite(bestSAH))) { + current.prims.deterministic_order(); + splitFallback(current.prims,lrecord.prims,rrecord.prims); + } + /* perform aligned split if this is best */ + else if (likely(bestSAH == alignedObjectSAH)) { + alignedHeuristic.split(alignedObjectSplit,current.prims,lrecord.prims,rrecord.prims); + } + /* perform unaligned split if this is best */ + else if (likely(bestSAH == unalignedObjectSAH)) { + unalignedHeuristic.split(unalignedObjectSplit,uspace,current.prims,lrecord.prims,rrecord.prims); + aligned = false; + } + /* perform temporal split if this is best */ + else if (likely(bestSAH == temporal_split_sah)) { + timesplit = true; + return temporalSplitHeuristic.split(temporal_split,current.prims,lrecord.prims,rrecord.prims); + } + else + assert(false); + + return std::unique_ptr>(); + } + + /*! recursive build */ + NodeRecordMB4D recurse(BuildRecord& current, Allocator alloc, bool toplevel) + { + /* get thread local allocator */ + if (!alloc) + alloc = createAlloc(); + + /* call memory monitor function to signal progress */ + if (toplevel && current.size() <= SINGLE_THREADED_THRESHOLD) + progressMonitor(current.size()); + + /* create leaf node */ + if (current.depth+MIN_LARGE_LEAF_LEVELS >= cfg.maxDepth || current.size() <= cfg.minLeafSize) { + current.prims.deterministic_order(); + return createLargeLeaf(current,alloc); + } + + /* fill all children by always splitting the one with the largest surface area */ + NodeRecordMB4D values[MAX_BRANCHING_FACTOR]; + LocalChildList children(current); + bool aligned = true; + bool timesplit = false; + + do { + + /* find best child with largest bounding box area */ + ssize_t bestChild = -1; + float bestArea = neg_inf; + for (size_t i=0; i bestArea) { + bestArea = children[i].prims.halfArea(); + bestChild = i; + } + } + if (bestChild == -1) break; + + /*! split best child into left and right child */ + BuildRecord left(current.depth+1); + BuildRecord right(current.depth+1); + std::unique_ptr> new_vector = split(children[bestChild],left,right,aligned,timesplit); + children.split(bestChild,left,right,std::move(new_vector)); + + } while (children.size() < cfg.branchingFactor); + + /* detect time_ranges that have shrunken */ + for (size_t i=0; i p.lower || c.upper < p.upper; + } + + /* create time split node */ + if (timesplit) + { + const NodeRef node = createAABBNodeMB(children.children.data(),children.numChildren,alloc,true); + + /* spawn tasks or ... */ + if (current.size() > SINGLE_THREADED_THRESHOLD) + { + parallel_for(size_t(0), children.size(), [&] (const range& r) { + for (size_t i=r.begin(); i SINGLE_THREADED_THRESHOLD) + { + LBBox3fa cbounds[MAX_BRANCHING_FACTOR]; + parallel_for(size_t(0), children.size(), [&] (const range& r) { + for (size_t i=r.begin(); i SINGLE_THREADED_THRESHOLD) + { + parallel_for(size_t(0), children.size(), [&] (const range& r) { + for (size_t i=r.begin(); i& prims, const PrimInfoMB& pinfo) + { + BuildRecord record(SetMB(pinfo,&prims),1); + auto root = recurse(record,nullptr,true); + _mm_mfence(); // to allow non-temporal stores during build + return root; + } + + private: + Settings cfg; + Scene* scene; + const RecalculatePrimRef& recalculatePrimRef; + const CreateAllocFunc& createAlloc; + const CreateAABBNodeMBFunc& createAABBNodeMB; + const SetAABBNodeMBFunc& setAABBNodeMB; + const CreateOBBNodeMBFunc& createOBBNodeMB; + const SetOBBNodeMBFunc& setOBBNodeMB; + const CreateLeafFunc& createLeaf; + const ProgressMonitor& progressMonitor; + + private: + HeuristicBinning alignedHeuristic; + UnalignedHeuristicBinning unalignedHeuristic; + HeuristicTemporal temporalSplitHeuristic; + }; + + template + + static BVHNodeRecordMB4D build (Scene* scene, mvector& prims, const PrimInfoMB& pinfo, + const RecalculatePrimRef& recalculatePrimRef, + const CreateAllocFunc& createAlloc, + const CreateAABBNodeMBFunc& createAABBNodeMB, + const SetAABBNodeMBFunc& setAABBNodeMB, + const CreateOBBNodeMBFunc& createOBBNodeMB, + const SetOBBNodeMBFunc& setOBBNodeMB, + const CreateLeafFunc& createLeaf, + const ProgressMonitor& progressMonitor, + const Settings settings) + { + typedef BuilderT Builder; + + Builder builder(scene,recalculatePrimRef,createAlloc, + createAABBNodeMB,setAABBNodeMB, + createOBBNodeMB,setOBBNodeMB, + createLeaf,progressMonitor,settings); + + return builder(prims,pinfo); + } + }; + } +} diff --git a/thirdparty/embree/kernels/builders/bvh_builder_sah.h b/thirdparty/embree/kernels/builders/bvh_builder_sah.h new file mode 100644 index 000000000000..79ccdf946f5e --- /dev/null +++ b/thirdparty/embree/kernels/builders/bvh_builder_sah.h @@ -0,0 +1,669 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "heuristic_binning_array_aligned.h" +#include "heuristic_spatial_array.h" +#include "heuristic_openmerge_array.h" + +#if defined(__AVX512F__) && !defined(__AVX512VL__) // KNL +# define NUM_OBJECT_BINS 16 +# define NUM_SPATIAL_BINS 16 +#else +# define NUM_OBJECT_BINS 32 +# define NUM_SPATIAL_BINS 16 +#endif + +namespace embree +{ + namespace isa + { + MAYBE_UNUSED static const float travCost = 1.0f; + MAYBE_UNUSED static const size_t DEFAULT_SINGLE_THREAD_THRESHOLD = 1024; + + struct GeneralBVHBuilder + { + static const size_t MAX_BRANCHING_FACTOR = 16; //!< maximum supported BVH branching factor + static const size_t MIN_LARGE_LEAF_LEVELS = 8; //!< create balanced tree of we are that many levels before the maximum tree depth + + + /*! settings for SAH builder */ + struct Settings + { + /*! default settings */ + Settings () + : branchingFactor(2), maxDepth(32), logBlockSize(0), minLeafSize(1), maxLeafSize(7), + travCost(1.0f), intCost(1.0f), singleThreadThreshold(1024), primrefarrayalloc(inf) {} + + /*! initialize settings from API settings */ + Settings (const RTCBuildArguments& settings) + : branchingFactor(2), maxDepth(32), logBlockSize(0), minLeafSize(1), maxLeafSize(7), + travCost(1.0f), intCost(1.0f), singleThreadThreshold(1024), primrefarrayalloc(inf) + { + if (RTC_BUILD_ARGUMENTS_HAS(settings,maxBranchingFactor)) branchingFactor = settings.maxBranchingFactor; + if (RTC_BUILD_ARGUMENTS_HAS(settings,maxDepth )) maxDepth = settings.maxDepth; + if (RTC_BUILD_ARGUMENTS_HAS(settings,sahBlockSize )) logBlockSize = bsr(settings.sahBlockSize); + if (RTC_BUILD_ARGUMENTS_HAS(settings,minLeafSize )) minLeafSize = settings.minLeafSize; + if (RTC_BUILD_ARGUMENTS_HAS(settings,maxLeafSize )) maxLeafSize = settings.maxLeafSize; + if (RTC_BUILD_ARGUMENTS_HAS(settings,traversalCost )) travCost = settings.traversalCost; + if (RTC_BUILD_ARGUMENTS_HAS(settings,intersectionCost )) intCost = settings.intersectionCost; + + minLeafSize = min(minLeafSize,maxLeafSize); + } + + Settings (size_t sahBlockSize, size_t minLeafSize, size_t maxLeafSize, float travCost, float intCost, size_t singleThreadThreshold, size_t primrefarrayalloc = inf) + : branchingFactor(2), maxDepth(32), logBlockSize(bsr(sahBlockSize)), minLeafSize(minLeafSize), maxLeafSize(maxLeafSize), + travCost(travCost), intCost(intCost), singleThreadThreshold(singleThreadThreshold), primrefarrayalloc(primrefarrayalloc) + { + minLeafSize = min(minLeafSize,maxLeafSize); + } + + public: + size_t branchingFactor; //!< branching factor of BVH to build + size_t maxDepth; //!< maximum depth of BVH to build + size_t logBlockSize; //!< log2 of blocksize for SAH heuristic + size_t minLeafSize; //!< minimum size of a leaf + size_t maxLeafSize; //!< maximum size of a leaf + float travCost; //!< estimated cost of one traversal step + float intCost; //!< estimated cost of one primitive intersection + size_t singleThreadThreshold; //!< threshold when we switch to single threaded build + size_t primrefarrayalloc; //!< builder uses prim ref array to allocate nodes and leaves when a subtree of that size is finished + }; + + /*! recursive state of builder */ + template + struct BuildRecordT + { + public: + __forceinline BuildRecordT () {} + + __forceinline BuildRecordT (size_t depth) + : depth(depth), alloc_barrier(false), prims(empty) {} + + __forceinline BuildRecordT (size_t depth, const Set& prims) + : depth(depth), alloc_barrier(false), prims(prims) {} + + __forceinline BBox3fa bounds() const { return prims.geomBounds; } + + __forceinline friend bool operator< (const BuildRecordT& a, const BuildRecordT& b) { return a.prims.size() < b.prims.size(); } + __forceinline friend bool operator> (const BuildRecordT& a, const BuildRecordT& b) { return a.prims.size() > b.prims.size(); } + + __forceinline size_t size() const { return prims.size(); } + + public: + size_t depth; //!< Depth of the root of this subtree. + bool alloc_barrier; //!< barrier used to reuse primref-array blocks to allocate nodes + Set prims; //!< The list of primitives. + }; + + template + struct DefaultCanCreateLeafFunc + { + __forceinline bool operator()(const PrimRef*, const Set&) const { return true; } + }; + + template + struct DefaultCanCreateLeafSplitFunc + { + __forceinline void operator()(PrimRef*, const Set&, Set&, Set&) const { } + }; + + template + + class BuilderT + { + friend struct GeneralBVHBuilder; + + BuilderT (PrimRef* prims, + Heuristic& heuristic, + const CreateAllocFunc& createAlloc, + const CreateNodeFunc& createNode, + const UpdateNodeFunc& updateNode, + const CreateLeafFunc& createLeaf, + const CanCreateLeafFunc& canCreateLeaf, + const CanCreateLeafSplitFunc& canCreateLeafSplit, + const ProgressMonitor& progressMonitor, + const Settings& settings) : + cfg(settings), + prims(prims), + heuristic(heuristic), + createAlloc(createAlloc), + createNode(createNode), + updateNode(updateNode), + createLeaf(createLeaf), + canCreateLeaf(canCreateLeaf), + canCreateLeafSplit(canCreateLeafSplit), + progressMonitor(progressMonitor) + { + if (cfg.branchingFactor > MAX_BRANCHING_FACTOR) + throw_RTCError(RTC_ERROR_UNKNOWN,"bvh_builder: branching factor too large"); + } + + const ReductionTy createLargeLeaf(const BuildRecord& current, Allocator alloc) + { + /* this should never occur but is a fatal error */ + if (current.depth > cfg.maxDepth) + throw_RTCError(RTC_ERROR_UNKNOWN,"depth limit reached"); + + /* create leaf for few primitives */ + if (current.prims.size() <= cfg.maxLeafSize && canCreateLeaf(prims,current.prims)) + return createLeaf(prims,current.prims,alloc); + + /* fill all children by always splitting the largest one */ + ReductionTy values[MAX_BRANCHING_FACTOR]; + BuildRecord children[MAX_BRANCHING_FACTOR]; + size_t numChildren = 1; + children[0] = current; + do { + + /* find best child with largest bounding box area */ + size_t bestChild = -1; + size_t bestSize = 0; + for (size_t i=0; i bestSize) { + bestSize = children[i].prims.size(); + bestChild = i; + } + } + if (bestChild == (size_t)-1) break; + + /*! split best child into left and right child */ + BuildRecord left(current.depth+1); + BuildRecord right(current.depth+1); + if (!canCreateLeaf(prims,children[bestChild].prims)) { + canCreateLeafSplit(prims,children[bestChild].prims,left.prims,right.prims); + } else { + heuristic.splitFallback(children[bestChild].prims,left.prims,right.prims); + } + + /* add new children left and right */ + children[bestChild] = children[numChildren-1]; + children[numChildren-1] = left; + children[numChildren+0] = right; + numChildren++; + + } while (numChildren < cfg.branchingFactor); + + /* set barrier for primrefarrayalloc */ + if (unlikely(current.size() > cfg.primrefarrayalloc)) + for (size_t i=0; i= 0) && (splitSAH >= 0))); + + /*! create a leaf node when threshold reached or SAH tells us to stop */ + if (current.prims.size() <= cfg.minLeafSize || current.depth+MIN_LARGE_LEAF_LEVELS >= cfg.maxDepth || (current.prims.size() <= cfg.maxLeafSize && leafSAH <= splitSAH)) { + heuristic.deterministic_order(current.prims); + return createLargeLeaf(current,alloc); + } + + /*! perform initial split */ + Set lprims,rprims; + heuristic.split(split,current.prims,lprims,rprims); + + /*! initialize child list with initial split */ + ReductionTy values[MAX_BRANCHING_FACTOR]; + BuildRecord children[MAX_BRANCHING_FACTOR]; + children[0] = BuildRecord(current.depth+1,lprims); + children[1] = BuildRecord(current.depth+1,rprims); + size_t numChildren = 2; + + /*! split until node is full or SAH tells us to stop */ + while (numChildren < cfg.branchingFactor) + { + /*! find best child to split */ + float bestArea = neg_inf; + ssize_t bestChild = -1; + for (size_t i=0; i bestArea) { + bestChild = i; + bestArea = halfArea(children[i].prims.geomBounds); + } + } + if (bestChild == -1) break; + + /* perform best found split */ + BuildRecord& brecord = children[bestChild]; + BuildRecord lrecord(current.depth+1); + BuildRecord rrecord(current.depth+1); + auto split = heuristic.find(brecord.prims,cfg.logBlockSize); + heuristic.split(split,brecord.prims,lrecord.prims,rrecord.prims); + children[bestChild ] = lrecord; + children[numChildren] = rrecord; + numChildren++; + } + + /* set barrier for primrefarrayalloc */ + if (unlikely(current.size() > cfg.primrefarrayalloc)) + for (size_t i=0; i()); + + /*! create an inner node */ + auto node = createNode(children,numChildren,alloc); + + /* spawn tasks */ + if (current.size() > cfg.singleThreadThreshold) + { + /*! parallel_for is faster than spawing sub-tasks */ + parallel_for(size_t(0), numChildren, [&] (const range& r) { // FIXME: no range here + for (size_t i=r.begin(); i + + __noinline static ReductionTy build(Heuristic& heuristic, + PrimRef* prims, + const Set& set, + CreateAllocFunc createAlloc, + CreateNodeFunc createNode, UpdateNodeFunc updateNode, + const CreateLeafFunc& createLeaf, + const ProgressMonitor& progressMonitor, + const Settings& settings) + { + typedef BuildRecordT BuildRecord; + + typedef BuilderT< + BuildRecord, + Heuristic, + Set, + PrimRef, + ReductionTy, + decltype(createAlloc()), + CreateAllocFunc, + CreateNodeFunc, + UpdateNodeFunc, + CreateLeafFunc, + DefaultCanCreateLeafFunc, + DefaultCanCreateLeafSplitFunc, + ProgressMonitor> Builder; + + /* instantiate builder */ + Builder builder(prims, + heuristic, + createAlloc, + createNode, + updateNode, + createLeaf, + DefaultCanCreateLeafFunc(), + DefaultCanCreateLeafSplitFunc(), + progressMonitor, + settings); + + /* build hierarchy */ + BuildRecord record(1,set); + const ReductionTy root = builder.recurse(record,nullptr,true); + _mm_mfence(); // to allow non-temporal stores during build + return root; + } + + template< + typename ReductionTy, + typename Heuristic, + typename Set, + typename PrimRef, + typename CreateAllocFunc, + typename CreateNodeFunc, + typename UpdateNodeFunc, + typename CreateLeafFunc, + typename CanCreateLeafFunc, + typename CanCreateLeafSplitFunc, + typename ProgressMonitor> + + __noinline static ReductionTy build(Heuristic& heuristic, + PrimRef* prims, + const Set& set, + CreateAllocFunc createAlloc, + CreateNodeFunc createNode, UpdateNodeFunc updateNode, + const CreateLeafFunc& createLeaf, + const CanCreateLeafFunc& canCreateLeaf, + const CanCreateLeafSplitFunc& canCreateLeafSplit, + const ProgressMonitor& progressMonitor, + const Settings& settings) + { + typedef BuildRecordT BuildRecord; + + typedef BuilderT< + BuildRecord, + Heuristic, + Set, + PrimRef, + ReductionTy, + decltype(createAlloc()), + CreateAllocFunc, + CreateNodeFunc, + UpdateNodeFunc, + CreateLeafFunc, + CanCreateLeafFunc, + CanCreateLeafSplitFunc, + ProgressMonitor> Builder; + + /* instantiate builder */ + Builder builder(prims, + heuristic, + createAlloc, + createNode, + updateNode, + createLeaf, + canCreateLeaf, + canCreateLeafSplit, + progressMonitor, + settings); + + /* build hierarchy */ + BuildRecord record(1,set); + const ReductionTy root = builder.recurse(record,nullptr,true); + _mm_mfence(); // to allow non-temporal stores during build + return root; + } + }; + + /* SAH builder that operates on an array of BuildRecords */ + struct BVHBuilderBinnedSAH + { + typedef PrimInfoRange Set; + typedef HeuristicArrayBinningSAH Heuristic; + typedef GeneralBVHBuilder::BuildRecordT BuildRecord; + typedef GeneralBVHBuilder::Settings Settings; + + /*! special builder that propagates reduction over the tree */ + template< + typename ReductionTy, + typename CreateAllocFunc, + typename CreateNodeFunc, + typename UpdateNodeFunc, + typename CreateLeafFunc, + typename ProgressMonitor> + + static ReductionTy build(CreateAllocFunc createAlloc, + CreateNodeFunc createNode, UpdateNodeFunc updateNode, + const CreateLeafFunc& createLeaf, + const ProgressMonitor& progressMonitor, + PrimRef* prims, const PrimInfo& pinfo, + const Settings& settings) + { + Heuristic heuristic(prims); + return GeneralBVHBuilder::build( + heuristic, + prims, + PrimInfoRange(0,pinfo.size(),pinfo), + createAlloc, + createNode, + updateNode, + createLeaf, + progressMonitor, + settings); + } + + /*! special builder that propagates reduction over the tree */ + template< + typename ReductionTy, + typename CreateAllocFunc, + typename CreateNodeFunc, + typename UpdateNodeFunc, + typename CreateLeafFunc, + typename CanCreateLeafFunc, + typename CanCreateLeafSplitFunc, + typename ProgressMonitor> + + static ReductionTy build(CreateAllocFunc createAlloc, + CreateNodeFunc createNode, UpdateNodeFunc updateNode, + const CreateLeafFunc& createLeaf, + const CanCreateLeafFunc& canCreateLeaf, + const CanCreateLeafSplitFunc& canCreateLeafSplit, + const ProgressMonitor& progressMonitor, + PrimRef* prims, const PrimInfo& pinfo, + const Settings& settings) + { + Heuristic heuristic(prims); + return GeneralBVHBuilder::build( + heuristic, + prims, + PrimInfoRange(0,pinfo.size(),pinfo), + createAlloc, + createNode, + updateNode, + createLeaf, + canCreateLeaf, + canCreateLeafSplit, + progressMonitor, + settings); + } + }; + + /* Spatial SAH builder that operates on an double-buffered array of BuildRecords */ + struct BVHBuilderBinnedFastSpatialSAH + { + typedef PrimInfoExtRange Set; + typedef Split2,SpatialBinSplit > Split; + typedef GeneralBVHBuilder::BuildRecordT BuildRecord; + typedef GeneralBVHBuilder::Settings Settings; + + static const unsigned int GEOMID_MASK = 0xFFFFFFFF >> RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS; + static const unsigned int SPLITS_MASK = 0xFFFFFFFF << (32-RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS); + + template + struct CreateLeafExt + { + __forceinline CreateLeafExt (const UserCreateLeaf userCreateLeaf) + : userCreateLeaf(userCreateLeaf) {} + + // __noinline is workaround for ICC2016 compiler bug + template + __noinline ReductionTy operator() (PrimRef* prims, const range& range, Allocator alloc) const + { + for (size_t i=range.begin(); i + + static ReductionTy build(CreateAllocFunc createAlloc, + CreateNodeFunc createNode, + UpdateNodeFunc updateNode, + const CreateLeafFunc& createLeaf, + SplitPrimitiveFunc splitPrimitive, + ProgressMonitor progressMonitor, + PrimRef* prims, + const size_t extSize, + const PrimInfo& pinfo, + const Settings& settings) + { + typedef HeuristicArraySpatialSAH Heuristic; + Heuristic heuristic(splitPrimitive,prims,pinfo); + + /* calculate total surface area */ // FIXME: this sum is not deterministic + const float A = (float) parallel_reduce(size_t(0),pinfo.size(),0.0, [&] (const range& r) -> double { + + double A = 0.0f; + for (size_t i=r.begin(); i()); + + + /* calculate maximum number of spatial splits per primitive */ + const unsigned int maxSplits = ((size_t)1 << RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS)-1; + const float f = 10.0f; + + const float invA = 1.0f / A; + parallel_for( size_t(0), pinfo.size(), [&](const range& r) { + + for (size_t i=r.begin(); i( + heuristic, + prims, + PrimInfoExtRange(0,pinfo.size(),extSize,pinfo), + createAlloc, + createNode, + updateNode, + CreateLeafExt(createLeaf), + progressMonitor, + settings); + } + }; + + /* Open/Merge SAH builder that operates on an array of BuildRecords */ + struct BVHBuilderBinnedOpenMergeSAH + { + static const size_t NUM_OBJECT_BINS_HQ = 32; + typedef PrimInfoExtRange Set; + typedef BinSplit Split; + typedef GeneralBVHBuilder::BuildRecordT BuildRecord; + typedef GeneralBVHBuilder::Settings Settings; + + /*! special builder that propagates reduction over the tree */ + template< + typename ReductionTy, + typename BuildRef, + typename CreateAllocFunc, + typename CreateNodeFunc, + typename UpdateNodeFunc, + typename CreateLeafFunc, + typename NodeOpenerFunc, + typename ProgressMonitor> + + static ReductionTy build(CreateAllocFunc createAlloc, + CreateNodeFunc createNode, + UpdateNodeFunc updateNode, + const CreateLeafFunc& createLeaf, + NodeOpenerFunc nodeOpenerFunc, + ProgressMonitor progressMonitor, + BuildRef* prims, + const size_t extSize, + const PrimInfo& pinfo, + const Settings& settings) + { + typedef HeuristicArrayOpenMergeSAH Heuristic; + Heuristic heuristic(nodeOpenerFunc,prims,settings.branchingFactor); + + return GeneralBVHBuilder::build( + heuristic, + prims, + PrimInfoExtRange(0,pinfo.size(),extSize,pinfo), + createAlloc, + createNode, + updateNode, + createLeaf, + progressMonitor, + settings); + } + }; + } +} diff --git a/thirdparty/embree/kernels/builders/heuristic_binning.h b/thirdparty/embree/kernels/builders/heuristic_binning.h new file mode 100644 index 000000000000..a4d3b68e46a6 --- /dev/null +++ b/thirdparty/embree/kernels/builders/heuristic_binning.h @@ -0,0 +1,972 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "priminfo.h" +#include "../../common/algorithms/parallel_reduce.h" +#include "../../common/algorithms/parallel_partition.h" + +namespace embree +{ + namespace isa + { + /*! mapping into bins */ + template + struct BinMapping + { + public: + __forceinline BinMapping() {} + + /*! calculates the mapping */ + __forceinline BinMapping(size_t N, const BBox3fa& centBounds) + { + num = min(BINS,size_t(4.0f + 0.05f*N)); + assert(num >= 1); + const vfloat4 eps = 1E-34f; + const vfloat4 diag = max(eps, (vfloat4) centBounds.size()); + scale = select(diag > eps,vfloat4(0.99f*num)/diag,vfloat4(0.0f)); + ofs = (vfloat4) centBounds.lower; + } + + /*! calculates the mapping */ + __forceinline BinMapping(const BBox3fa& centBounds) + { + num = BINS; + const vfloat4 eps = 1E-34f; + const vfloat4 diag = max(eps, (vfloat4) centBounds.size()); + scale = select(diag > eps,vfloat4(0.99f*num)/diag,vfloat4(0.0f)); + ofs = (vfloat4) centBounds.lower; + } + + /*! calculates the mapping */ + template + __forceinline BinMapping(const PrimInfo& pinfo) + { + const vfloat4 eps = 1E-34f; + num = min(BINS,size_t(4.0f + 0.05f*pinfo.size())); + const vfloat4 diag = max(eps,(vfloat4) pinfo.centBounds.size()); + scale = select(diag > eps,vfloat4(0.99f*num)/diag,vfloat4(0.0f)); + ofs = (vfloat4) pinfo.centBounds.lower; + } + + /*! returns number of bins */ + __forceinline size_t size() const { return num; } + + /*! slower but safe binning */ + __forceinline Vec3ia bin(const Vec3fa& p) const + { + const vint4 i = floori((vfloat4(p)-ofs)*scale); +#if 1 + assert(i[0] >= 0 && (size_t)i[0] < num); + assert(i[1] >= 0 && (size_t)i[1] < num); + assert(i[2] >= 0 && (size_t)i[2] < num); + return Vec3ia(i); +#else + return Vec3ia(clamp(i,vint4(0),vint4(num-1))); +#endif + } + + /*! faster but unsafe binning */ + __forceinline Vec3ia bin_unsafe(const Vec3fa& p) const { + return Vec3ia(floori((vfloat4(p)-ofs)*scale)); + } + + /*! faster but unsafe binning */ + template + __forceinline Vec3ia bin_unsafe(const PrimRef& p) const { + return bin_unsafe(p.binCenter()); + } + + /*! faster but unsafe binning */ + template + __forceinline Vec3ia bin_unsafe(const PrimRef& p, const BinBoundsAndCenter& binBoundsAndCenter) const { + return bin_unsafe(binBoundsAndCenter.binCenter(p)); + } + + template + __forceinline bool bin_unsafe(const PrimRef& ref, + const vint4& vSplitPos, + const vbool4& splitDimMask) const // FIXME: rename to isLeft + { + return any(((vint4)bin_unsafe(center2(ref.bounds())) < vSplitPos) & splitDimMask); + } + /*! calculates left spatial position of bin */ + __forceinline float pos(const size_t bin, const size_t dim) const { + return madd(float(bin),1.0f / scale[dim],ofs[dim]); + } + + /*! returns true if the mapping is invalid in some dimension */ + __forceinline bool invalid(const size_t dim) const { + return scale[dim] == 0.0f; + } + + /*! stream output */ + friend embree_ostream operator<<(embree_ostream cout, const BinMapping& mapping) { + return cout << "BinMapping { num = " << mapping.num << ", ofs = " << mapping.ofs << ", scale = " << mapping.scale << "}"; + } + + public: + size_t num; + vfloat4 ofs,scale; //!< linear function that maps to bin ID + }; + + /*! stores all information to perform some split */ + template + struct BinSplit + { + enum + { + SPLIT_OBJECT = 0, + SPLIT_FALLBACK = 1, + SPLIT_ENFORCE = 2, // splits with larger ID are enforced in createLargeLeaf even if we could create a leaf already + SPLIT_TEMPORAL = 2, + SPLIT_GEOMID = 3, + }; + + /*! construct an invalid split by default */ + __forceinline BinSplit() + : sah(inf), dim(-1), pos(0), data(0) {} + + __forceinline BinSplit(float sah, unsigned data, int dim = 0, float fpos = 0) + : sah(sah), dim(dim), fpos(fpos), data(data) {} + + /*! constructs specified split */ + __forceinline BinSplit(float sah, int dim, int pos, const BinMapping& mapping) + : sah(sah), dim(dim), pos(pos), data(0), mapping(mapping) {} + + /*! tests if this split is valid */ + __forceinline bool valid() const { return dim != -1; } + + /*! calculates surface area heuristic for performing the split */ + __forceinline float splitSAH() const { return sah; } + + /*! stream output */ + friend embree_ostream operator<<(embree_ostream cout, const BinSplit& split) { + return cout << "BinSplit { sah = " << split.sah << ", dim = " << split.dim << ", pos = " << split.pos << "}"; + } + + public: + float sah; //!< SAH cost of the split + int dim; //!< split dimension + union { int pos; float fpos; }; //!< bin index for splitting + unsigned int data; //!< extra optional split data + BinMapping mapping; //!< mapping into bins + }; + + /*! stores extended information about the split */ + template + struct SplitInfoT + { + + __forceinline SplitInfoT () {} + + __forceinline SplitInfoT (size_t leftCount, const BBox& leftBounds, size_t rightCount, const BBox& rightBounds) + : leftCount(leftCount), rightCount(rightCount), leftBounds(leftBounds), rightBounds(rightBounds) {} + + public: + size_t leftCount,rightCount; + BBox leftBounds,rightBounds; + }; + + typedef SplitInfoT SplitInfo; + typedef SplitInfoT SplitInfo2; + + /*! stores all binning information */ + template + struct __aligned(64) BinInfoT + { + typedef BinSplit Split; + typedef vbool4 vbool; + typedef vint4 vint; + typedef vfloat4 vfloat; + + __forceinline BinInfoT() { + } + + __forceinline BinInfoT(EmptyTy) { + clear(); + } + + /*! bin access function */ + __forceinline BBox &bounds(const size_t binID, const size_t dimID) { return _bounds[binID][dimID]; } + __forceinline const BBox &bounds(const size_t binID, const size_t dimID) const { return _bounds[binID][dimID]; } + + __forceinline unsigned int &counts(const size_t binID, const size_t dimID) { return _counts[binID][dimID]; } + __forceinline const unsigned int &counts(const size_t binID, const size_t dimID) const { return _counts[binID][dimID]; } + + __forceinline vuint4 &counts(const size_t binID) { return _counts[binID]; } + __forceinline const vuint4 &counts(const size_t binID) const { return _counts[binID]; } + + /*! clears the bin info */ + __forceinline void clear() + { + for (size_t i=0; i& mapping) + { + if (unlikely(N == 0)) return; + size_t i; + for (i=0; i(bin0); bounds(b00,0).extend(prim0); + const unsigned int b01 = extract<1>(bin0); bounds(b01,1).extend(prim0); + const unsigned int b02 = extract<2>(bin0); bounds(b02,2).extend(prim0); + const unsigned int s0 = (unsigned int)prims[i+0].size(); + counts(b00,0)+=s0; + counts(b01,1)+=s0; + counts(b02,2)+=s0; + + /*! increase bounds of bins for odd primitive */ + const unsigned int b10 = extract<0>(bin1); bounds(b10,0).extend(prim1); + const unsigned int b11 = extract<1>(bin1); bounds(b11,1).extend(prim1); + const unsigned int b12 = extract<2>(bin1); bounds(b12,2).extend(prim1); + const unsigned int s1 = (unsigned int)prims[i+1].size(); + counts(b10,0)+=s1; + counts(b11,1)+=s1; + counts(b12,2)+=s1; + } + /*! for uneven number of primitives */ + if (i < N) + { + /*! map primitive to bin */ + BBox prim0; Vec3fa center0; + prims[i].binBoundsAndCenter(prim0,center0); + const vint4 bin0 = (vint4)mapping.bin(center0); + + /*! increase bounds of bins */ + const unsigned int s0 = (unsigned int)prims[i].size(); + const int b00 = extract<0>(bin0); counts(b00,0)+=s0; bounds(b00,0).extend(prim0); + const int b01 = extract<1>(bin0); counts(b01,1)+=s0; bounds(b01,1).extend(prim0); + const int b02 = extract<2>(bin0); counts(b02,2)+=s0; bounds(b02,2).extend(prim0); + } + } + + /*! bins an array of primitives */ + template + __forceinline void bin (const PrimRef* prims, size_t N, const BinMapping& mapping, const BinBoundsAndCenter& binBoundsAndCenter) + { + if (N == 0) return; + + size_t i; + for (i=0; i(bin0); counts(b00,0)+=s0; bounds(b00,0).extend(prim0); + const int b01 = extract<1>(bin0); counts(b01,1)+=s0; bounds(b01,1).extend(prim0); + const int b02 = extract<2>(bin0); counts(b02,2)+=s0; bounds(b02,2).extend(prim0); + + /*! increase bounds of bins for odd primitive */ + const unsigned int s1 = prims[i+1].size(); + const int b10 = extract<0>(bin1); counts(b10,0)+=s1; bounds(b10,0).extend(prim1); + const int b11 = extract<1>(bin1); counts(b11,1)+=s1; bounds(b11,1).extend(prim1); + const int b12 = extract<2>(bin1); counts(b12,2)+=s1; bounds(b12,2).extend(prim1); + } + + /*! for uneven number of primitives */ + if (i < N) + { + /*! map primitive to bin */ + BBox prim0; Vec3fa center0; binBoundsAndCenter.binBoundsAndCenter(prims[i+0],prim0,center0); + const vint4 bin0 = (vint4)mapping.bin(center0); + + /*! increase bounds of bins */ + const unsigned int s0 = prims[i+0].size(); + const int b00 = extract<0>(bin0); counts(b00,0)+=s0; bounds(b00,0).extend(prim0); + const int b01 = extract<1>(bin0); counts(b01,1)+=s0; bounds(b01,1).extend(prim0); + const int b02 = extract<2>(bin0); counts(b02,2)+=s0; bounds(b02,2).extend(prim0); + } + } + + __forceinline void bin(const PrimRef* prims, size_t begin, size_t end, const BinMapping& mapping) { + bin(prims+begin,end-begin,mapping); + } + + template + __forceinline void bin(const PrimRef* prims, size_t begin, size_t end, const BinMapping& mapping, const BinBoundsAndCenter& binBoundsAndCenter) { + bin(prims+begin,end-begin,mapping,binBoundsAndCenter); + } + + /*! merges in other binning information */ + __forceinline void merge (const BinInfoT& other, size_t numBins) + { + + for (size_t i=0; i& mapping, const size_t blocks_shift) const + { + /* sweep from right to left and compute parallel prefix of merged bounds */ + vfloat4 rAreas[BINS]; + vuint4 rCounts[BINS]; + vuint4 count = 0; BBox bx = empty; BBox by = empty; BBox bz = empty; + for (size_t i=mapping.size()-1; i>0; i--) + { + count += counts(i); + rCounts[i] = count; + bx.extend(bounds(i,0)); rAreas[i][0] = expectedApproxHalfArea(bx); + by.extend(bounds(i,1)); rAreas[i][1] = expectedApproxHalfArea(by); + bz.extend(bounds(i,2)); rAreas[i][2] = expectedApproxHalfArea(bz); + rAreas[i][3] = 0.0f; + } + /* sweep from left to right and compute SAH */ + vuint4 blocks_add = (1 << blocks_shift)-1; + vuint4 ii = 1; vfloat4 vbestSAH = pos_inf; vuint4 vbestPos = 0; + count = 0; bx = empty; by = empty; bz = empty; + for (size_t i=1; i> (unsigned int)(blocks_shift); // if blocks_shift >=1 then lCount < 4B and could be represented with an vint4, which would allow for faster vfloat4 conversions. + const vuint4 rCount = (rCounts[i]+blocks_add) >> (unsigned int)(blocks_shift); + const vfloat4 sah = madd(lArea,vfloat4(lCount),rArea*vfloat4(rCount)); + //const vfloat4 sah = madd(lArea,vfloat4(vint4(lCount)),rArea*vfloat4(vint4(rCount))); + + vbestPos = select(sah < vbestSAH,ii ,vbestPos); + vbestSAH = select(sah < vbestSAH,sah,vbestSAH); + } + + /* find best dimension */ + float bestSAH = inf; + int bestDim = -1; + int bestPos = 0; + for (int dim=0; dim<3; dim++) + { + /* ignore zero sized dimensions */ + if (unlikely(mapping.invalid(dim))) + continue; + + /* test if this is a better dimension */ + if (vbestSAH[dim] < bestSAH && vbestPos[dim] != 0) { + bestDim = dim; + bestPos = vbestPos[dim]; + bestSAH = vbestSAH[dim]; + } + } + return Split(bestSAH,bestDim,bestPos,mapping); + } + + /*! calculates extended split information */ + __forceinline void getSplitInfo(const BinMapping& mapping, const Split& split, SplitInfoT& info) const + { + if (split.dim == -1) { + new (&info) SplitInfoT(0,empty,0,empty); + return; + } + + size_t leftCount = 0; + BBox leftBounds = empty; + for (size_t i=0; i<(size_t)split.pos; i++) { + leftCount += counts(i,split.dim); + leftBounds.extend(bounds(i,split.dim)); + } + size_t rightCount = 0; + BBox rightBounds = empty; + for (size_t i=split.pos; i(leftCount,leftBounds,rightCount,rightBounds); + } + + /*! gets the number of primitives left of the split */ + __forceinline size_t getLeftCount(const BinMapping& mapping, const Split& split) const + { + if (unlikely(split.dim == -1)) return -1; + + size_t leftCount = 0; + for (size_t i = 0; i < (size_t)split.pos; i++) { + leftCount += counts(i, split.dim); + } + return leftCount; + } + + /*! gets the number of primitives right of the split */ + __forceinline size_t getRightCount(const BinMapping& mapping, const Split& split) const + { + if (unlikely(split.dim == -1)) return -1; + + size_t rightCount = 0; + for (size_t i = (size_t)split.pos; i + struct BinMapping<16> + { + public: + __forceinline BinMapping() {} + + /*! calculates the mapping */ + template + __forceinline BinMapping(const PrimInfo& pinfo) + { + num = 16; + const vfloat4 eps = 1E-34f; + const vfloat4 diag = max(eps,(vfloat4) pinfo.centBounds.size()); + scale = select(diag > eps,vfloat4(0.99f*num)/diag,vfloat4(0.0f)); + ofs = (vfloat4) pinfo.centBounds.lower; + scale16 = scale; + ofs16 = ofs; + } + + /*! returns number of bins */ + __forceinline size_t size() const { return num; } + + __forceinline vint16 bin16(const Vec3fa& p) const { + return vint16(vint4(floori((vfloat4(p)-ofs)*scale))); + } + + __forceinline vint16 bin16(const vfloat16& p) const { + return floori((p-ofs16)*scale16); + } + + __forceinline int bin_unsafe(const PrimRef& ref, + const vint16& vSplitPos, + const vbool16& splitDimMask) const // FIXME: rename to isLeft + { + const vfloat16 lower(*(vfloat4*)&ref.lower); + const vfloat16 upper(*(vfloat4*)&ref.upper); + const vfloat16 p = lower + upper; + const vint16 i = floori((p-ofs16)*scale16); + return lt(splitDimMask,i,vSplitPos); + } + + /*! returns true if the mapping is invalid in some dimension */ + __forceinline bool invalid(const size_t dim) const { + return scale[dim] == 0.0f; + } + + public: + size_t num; + vfloat4 ofs,scale; //!< linear function that maps to bin ID + vfloat16 ofs16,scale16; //!< linear function that maps to bin ID + }; + + /* 16 bins in-register binner */ + template + struct __aligned(64) BinInfoT<16,PrimRef,BBox3fa> + { + typedef BinSplit<16> Split; + typedef vbool16 vbool; + typedef vint16 vint; + typedef vfloat16 vfloat; + + __forceinline BinInfoT() { + } + + __forceinline BinInfoT(EmptyTy) { + clear(); + } + + /*! clears the bin info */ + __forceinline void clear() + { + lower[0] = lower[1] = lower[2] = pos_inf; + upper[0] = upper[1] = upper[2] = neg_inf; + count[0] = count[1] = count[2] = 0; + } + + + static __forceinline vfloat16 prefix_area_rl(const vfloat16 min_x, + const vfloat16 min_y, + const vfloat16 min_z, + const vfloat16 max_x, + const vfloat16 max_y, + const vfloat16 max_z) + { + const vfloat16 r_min_x = reverse_prefix_min(min_x); + const vfloat16 r_min_y = reverse_prefix_min(min_y); + const vfloat16 r_min_z = reverse_prefix_min(min_z); + const vfloat16 r_max_x = reverse_prefix_max(max_x); + const vfloat16 r_max_y = reverse_prefix_max(max_y); + const vfloat16 r_max_z = reverse_prefix_max(max_z); + const vfloat16 dx = r_max_x - r_min_x; + const vfloat16 dy = r_max_y - r_min_y; + const vfloat16 dz = r_max_z - r_min_z; + const vfloat16 area_rl = madd(dx,dy,madd(dx,dz,dy*dz)); + return area_rl; + } + + static __forceinline vfloat16 prefix_area_lr(const vfloat16 min_x, + const vfloat16 min_y, + const vfloat16 min_z, + const vfloat16 max_x, + const vfloat16 max_y, + const vfloat16 max_z) + { + const vfloat16 r_min_x = prefix_min(min_x); + const vfloat16 r_min_y = prefix_min(min_y); + const vfloat16 r_min_z = prefix_min(min_z); + const vfloat16 r_max_x = prefix_max(max_x); + const vfloat16 r_max_y = prefix_max(max_y); + const vfloat16 r_max_z = prefix_max(max_z); + const vfloat16 dx = r_max_x - r_min_x; + const vfloat16 dy = r_max_y - r_min_y; + const vfloat16 dz = r_max_z - r_min_z; + const vfloat16 area_lr = madd(dx,dy,madd(dx,dz,dy*dz)); + return area_lr; + } + + + /*! bins an array of primitives */ + __forceinline void bin (const PrimRef* prims, size_t N, const BinMapping<16>& mapping) + { + if (unlikely(N == 0)) return; + + const vfloat16 init_min(pos_inf); + const vfloat16 init_max(neg_inf); + + vfloat16 min_x0,min_x1,min_x2; + vfloat16 min_y0,min_y1,min_y2; + vfloat16 min_z0,min_z1,min_z2; + vfloat16 max_x0,max_x1,max_x2; + vfloat16 max_y0,max_y1,max_y2; + vfloat16 max_z0,max_z1,max_z2; + vuint16 count0,count1,count2; + + min_x0 = init_min; + min_x1 = init_min; + min_x2 = init_min; + min_y0 = init_min; + min_y1 = init_min; + min_y2 = init_min; + min_z0 = init_min; + min_z1 = init_min; + min_z2 = init_min; + + max_x0 = init_max; + max_x1 = init_max; + max_x2 = init_max; + max_y0 = init_max; + max_y1 = init_max; + max_y2 = init_max; + max_z0 = init_max; + max_z1 = init_max; + max_z2 = init_max; + + count0 = zero; + count1 = zero; + count2 = zero; + + const vint16 step16(step); + size_t i; + for (i=0; i(binA); + const vint16 bin1 = shuffle<1>(binA); + const vint16 bin2 = shuffle<2>(binA); + + const vbool16 m_update_x = step16 == bin0; + const vbool16 m_update_y = step16 == bin1; + const vbool16 m_update_z = step16 == bin2; + + assert(popcnt((size_t)m_update_x) == 1); + assert(popcnt((size_t)m_update_y) == 1); + assert(popcnt((size_t)m_update_z) == 1); + + min_x0 = mask_min(m_update_x,min_x0,min_x0,b_min_x); + min_y0 = mask_min(m_update_x,min_y0,min_y0,b_min_y); + min_z0 = mask_min(m_update_x,min_z0,min_z0,b_min_z); + // ------------------------------------------------------------------------ + max_x0 = mask_max(m_update_x,max_x0,max_x0,b_max_x); + max_y0 = mask_max(m_update_x,max_y0,max_y0,b_max_y); + max_z0 = mask_max(m_update_x,max_z0,max_z0,b_max_z); + // ------------------------------------------------------------------------ + min_x1 = mask_min(m_update_y,min_x1,min_x1,b_min_x); + min_y1 = mask_min(m_update_y,min_y1,min_y1,b_min_y); + min_z1 = mask_min(m_update_y,min_z1,min_z1,b_min_z); + // ------------------------------------------------------------------------ + max_x1 = mask_max(m_update_y,max_x1,max_x1,b_max_x); + max_y1 = mask_max(m_update_y,max_y1,max_y1,b_max_y); + max_z1 = mask_max(m_update_y,max_z1,max_z1,b_max_z); + // ------------------------------------------------------------------------ + min_x2 = mask_min(m_update_z,min_x2,min_x2,b_min_x); + min_y2 = mask_min(m_update_z,min_y2,min_y2,b_min_y); + min_z2 = mask_min(m_update_z,min_z2,min_z2,b_min_z); + // ------------------------------------------------------------------------ + max_x2 = mask_max(m_update_z,max_x2,max_x2,b_max_x); + max_y2 = mask_max(m_update_z,max_y2,max_y2,b_max_y); + max_z2 = mask_max(m_update_z,max_z2,max_z2,b_max_z); + // ------------------------------------------------------------------------ + count0 = mask_add(m_update_x,count0,count0,vuint16(1)); + count1 = mask_add(m_update_y,count1,count1,vuint16(1)); + count2 = mask_add(m_update_z,count2,count2,vuint16(1)); + } + + + /* B */ + { + const vfloat16 b_min_x = prims[i+1].lower.x; + const vfloat16 b_min_y = prims[i+1].lower.y; + const vfloat16 b_min_z = prims[i+1].lower.z; + const vfloat16 b_max_x = prims[i+1].upper.x; + const vfloat16 b_max_y = prims[i+1].upper.y; + const vfloat16 b_max_z = prims[i+1].upper.z; + + const vint16 bin0 = shuffle<0>(binB); + const vint16 bin1 = shuffle<1>(binB); + const vint16 bin2 = shuffle<2>(binB); + + const vbool16 m_update_x = step16 == bin0; + const vbool16 m_update_y = step16 == bin1; + const vbool16 m_update_z = step16 == bin2; + + assert(popcnt((size_t)m_update_x) == 1); + assert(popcnt((size_t)m_update_y) == 1); + assert(popcnt((size_t)m_update_z) == 1); + + min_x0 = mask_min(m_update_x,min_x0,min_x0,b_min_x); + min_y0 = mask_min(m_update_x,min_y0,min_y0,b_min_y); + min_z0 = mask_min(m_update_x,min_z0,min_z0,b_min_z); + // ------------------------------------------------------------------------ + max_x0 = mask_max(m_update_x,max_x0,max_x0,b_max_x); + max_y0 = mask_max(m_update_x,max_y0,max_y0,b_max_y); + max_z0 = mask_max(m_update_x,max_z0,max_z0,b_max_z); + // ------------------------------------------------------------------------ + min_x1 = mask_min(m_update_y,min_x1,min_x1,b_min_x); + min_y1 = mask_min(m_update_y,min_y1,min_y1,b_min_y); + min_z1 = mask_min(m_update_y,min_z1,min_z1,b_min_z); + // ------------------------------------------------------------------------ + max_x1 = mask_max(m_update_y,max_x1,max_x1,b_max_x); + max_y1 = mask_max(m_update_y,max_y1,max_y1,b_max_y); + max_z1 = mask_max(m_update_y,max_z1,max_z1,b_max_z); + // ------------------------------------------------------------------------ + min_x2 = mask_min(m_update_z,min_x2,min_x2,b_min_x); + min_y2 = mask_min(m_update_z,min_y2,min_y2,b_min_y); + min_z2 = mask_min(m_update_z,min_z2,min_z2,b_min_z); + // ------------------------------------------------------------------------ + max_x2 = mask_max(m_update_z,max_x2,max_x2,b_max_x); + max_y2 = mask_max(m_update_z,max_y2,max_y2,b_max_y); + max_z2 = mask_max(m_update_z,max_z2,max_z2,b_max_z); + // ------------------------------------------------------------------------ + count0 = mask_add(m_update_x,count0,count0,vuint16(1)); + count1 = mask_add(m_update_y,count1,count1,vuint16(1)); + count2 = mask_add(m_update_z,count2,count2,vuint16(1)); + } + + } + + if (i < N) + { + const BBox3fa prim0 = prims[i].bounds(); + const vfloat16 center0 = vfloat16((vfloat4)prim0.lower) + vfloat16((vfloat4)prim0.upper); + const vint16 bin = mapping.bin16(center0); + + const vfloat16 b_min_x = prims[i].lower.x; + const vfloat16 b_min_y = prims[i].lower.y; + const vfloat16 b_min_z = prims[i].lower.z; + const vfloat16 b_max_x = prims[i].upper.x; + const vfloat16 b_max_y = prims[i].upper.y; + const vfloat16 b_max_z = prims[i].upper.z; + + const vint16 bin0 = shuffle<0>(bin); + const vint16 bin1 = shuffle<1>(bin); + const vint16 bin2 = shuffle<2>(bin); + + const vbool16 m_update_x = step16 == bin0; + const vbool16 m_update_y = step16 == bin1; + const vbool16 m_update_z = step16 == bin2; + + assert(popcnt((size_t)m_update_x) == 1); + assert(popcnt((size_t)m_update_y) == 1); + assert(popcnt((size_t)m_update_z) == 1); + + min_x0 = mask_min(m_update_x,min_x0,min_x0,b_min_x); + min_y0 = mask_min(m_update_x,min_y0,min_y0,b_min_y); + min_z0 = mask_min(m_update_x,min_z0,min_z0,b_min_z); + // ------------------------------------------------------------------------ + max_x0 = mask_max(m_update_x,max_x0,max_x0,b_max_x); + max_y0 = mask_max(m_update_x,max_y0,max_y0,b_max_y); + max_z0 = mask_max(m_update_x,max_z0,max_z0,b_max_z); + // ------------------------------------------------------------------------ + min_x1 = mask_min(m_update_y,min_x1,min_x1,b_min_x); + min_y1 = mask_min(m_update_y,min_y1,min_y1,b_min_y); + min_z1 = mask_min(m_update_y,min_z1,min_z1,b_min_z); + // ------------------------------------------------------------------------ + max_x1 = mask_max(m_update_y,max_x1,max_x1,b_max_x); + max_y1 = mask_max(m_update_y,max_y1,max_y1,b_max_y); + max_z1 = mask_max(m_update_y,max_z1,max_z1,b_max_z); + // ------------------------------------------------------------------------ + min_x2 = mask_min(m_update_z,min_x2,min_x2,b_min_x); + min_y2 = mask_min(m_update_z,min_y2,min_y2,b_min_y); + min_z2 = mask_min(m_update_z,min_z2,min_z2,b_min_z); + // ------------------------------------------------------------------------ + max_x2 = mask_max(m_update_z,max_x2,max_x2,b_max_x); + max_y2 = mask_max(m_update_z,max_y2,max_y2,b_max_y); + max_z2 = mask_max(m_update_z,max_z2,max_z2,b_max_z); + // ------------------------------------------------------------------------ + count0 = mask_add(m_update_x,count0,count0,vuint16(1)); + count1 = mask_add(m_update_y,count1,count1,vuint16(1)); + count2 = mask_add(m_update_z,count2,count2,vuint16(1)); + } + + lower[0] = Vec3vf16( min_x0, min_y0, min_z0 ); + lower[1] = Vec3vf16( min_x1, min_y1, min_z1 ); + lower[2] = Vec3vf16( min_x2, min_y2, min_z2 ); + + upper[0] = Vec3vf16( max_x0, max_y0, max_z0 ); + upper[1] = Vec3vf16( max_x1, max_y1, max_z1 ); + upper[2] = Vec3vf16( max_x2, max_y2, max_z2 ); + + count[0] = count0; + count[1] = count1; + count[2] = count2; + } + + __forceinline void bin(const PrimRef* prims, size_t begin, size_t end, const BinMapping<16>& mapping) { + bin(prims+begin,end-begin,mapping); + } + + /*! merges in other binning information */ + __forceinline void merge (const BinInfoT& other, size_t numBins) + { + for (size_t i=0; i<3; i++) + { + lower[i] = min(lower[i],other.lower[i]); + upper[i] = max(upper[i],other.upper[i]); + count[i] += other.count[i]; + } + } + + /*! reducesr binning information */ + static __forceinline const BinInfoT reduce (const BinInfoT& a, const BinInfoT& b) + { + BinInfoT c; + for (size_t i=0; i<3; i++) + { + c.counts[i] = a.counts[i] + b.counts[i]; + c.lower[i] = min(a.lower[i],b.lower[i]); + c.upper[i] = max(a.upper[i],b.upper[i]); + } + return c; + } + + /*! finds the best split by scanning binning information */ + __forceinline Split best(const BinMapping<16>& mapping, const size_t blocks_shift) const + { + /* find best dimension */ + float bestSAH = inf; + int bestDim = -1; + int bestPos = 0; + const vuint16 blocks_add = (1 << blocks_shift)-1; + const vfloat16 inf(pos_inf); + for (size_t dim=0; dim<3; dim++) + { + /* ignore zero sized dimensions */ + if (unlikely(mapping.invalid(dim))) + continue; + + const vfloat16 rArea16 = prefix_area_rl(lower[dim].x,lower[dim].y,lower[dim].z, upper[dim].x,upper[dim].y,upper[dim].z); + const vfloat16 lArea16 = prefix_area_lr(lower[dim].x,lower[dim].y,lower[dim].z, upper[dim].x,upper[dim].y,upper[dim].z); + const vuint16 lCount16 = prefix_sum(count[dim]); + const vuint16 rCount16 = reverse_prefix_sum(count[dim]); + + /* compute best split in this dimension */ + const vfloat16 leftArea = lArea16; + const vfloat16 rightArea = align_shift_right<1>(zero,rArea16); + const vuint16 lC = lCount16; + const vuint16 rC = align_shift_right<1>(zero,rCount16); + const vuint16 leftCount = ( lC + blocks_add) >> blocks_shift; + const vuint16 rightCount = ( rC + blocks_add) >> blocks_shift; + const vbool16 valid = (leftArea < inf) & (rightArea < inf) & vbool16(0x7fff); // handles inf entries + const vfloat16 sah = select(valid,madd(leftArea,vfloat16(leftCount),rightArea*vfloat16(rightCount)),vfloat16(pos_inf)); + /* test if this is a better dimension */ + if (any(sah < vfloat16(bestSAH))) + { + const size_t index = select_min(sah); + assert(index < 15); + assert(sah[index] < bestSAH); + bestDim = dim; + bestPos = index+1; + bestSAH = sah[index]; + } + } + + return Split(bestSAH,bestDim,bestPos,mapping); + + } + + /*! calculates extended split information */ + __forceinline void getSplitInfo(const BinMapping<16>& mapping, const Split& split, SplitInfo& info) const + { + if (split.dim == -1) { + new (&info) SplitInfo(0,empty,0,empty); + return; + } + // FIXME: horizontal reduction! + + size_t leftCount = 0; + BBox3fa leftBounds = empty; + for (size_t i=0; i<(size_t)split.pos; i++) { + leftCount += count[split.dim][i]; + Vec3fa bounds_lower(lower[split.dim].x[i],lower[split.dim].y[i],lower[split.dim].z[i]); + Vec3fa bounds_upper(upper[split.dim].x[i],upper[split.dim].y[i],upper[split.dim].z[i]); + leftBounds.extend(BBox3fa(bounds_lower,bounds_upper)); + } + size_t rightCount = 0; + BBox3fa rightBounds = empty; + for (size_t i=split.pos; i& mapping, const Split& split) const + { + if (unlikely(split.dim == -1)) return -1; + + size_t leftCount = 0; + for (size_t i = 0; i < (size_t)split.pos; i++) { + leftCount += count[split.dim][i]; + } + return leftCount; + } + + /*! gets the number of primitives right of the split */ + __forceinline size_t getRightCount(const BinMapping<16>& mapping, const Split& split) const + { + if (unlikely(split.dim == -1)) return -1; + + size_t rightCount = 0; + for (size_t i = (size_t)split.pos; i + __forceinline void bin_parallel(BinInfoT& binner, const PrimRef* prims, size_t begin, size_t end, size_t blockSize, size_t parallelThreshold, const BinMapping& mapping) + { + if (likely(end-begin < parallelThreshold)) { + binner.bin(prims,begin,end,mapping); + } else { + binner = parallel_reduce(begin,end,blockSize,binner, + [&](const range& r) -> BinInfoT { BinInfoT binner(empty); binner.bin(prims + r.begin(), r.size(), mapping); return binner; }, + [&](const BinInfoT& b0, const BinInfoT& b1) -> BinInfoT { BinInfoT r = b0; r.merge(b1, mapping.size()); return r; }); + } + } + + template + __forceinline void bin_parallel(BinInfoT& binner, const PrimRef* prims, size_t begin, size_t end, size_t blockSize, size_t parallelThreshold, const BinMapping& mapping, const BinBoundsAndCenter& binBoundsAndCenter) + { + if (likely(end-begin < parallelThreshold)) { + binner.bin(prims,begin,end,mapping,binBoundsAndCenter); + } else { + binner = parallel_reduce(begin,end,blockSize,binner, + [&](const range& r) -> BinInfoT { BinInfoT binner(empty); binner.bin(prims + r.begin(), r.size(), mapping, binBoundsAndCenter); return binner; }, + [&](const BinInfoT& b0, const BinInfoT& b1) -> BinInfoT { BinInfoT r = b0; r.merge(b1, mapping.size()); return r; }); + } + } + + template + __forceinline void bin_serial_or_parallel(BinInfoT& binner, const PrimRef* prims, size_t begin, size_t end, size_t blockSize, const BinMapping& mapping) + { + if (!parallel) { + binner.bin(prims,begin,end,mapping); + } else { + binner = parallel_reduce(begin,end,blockSize,binner, + [&](const range& r) -> BinInfoT { BinInfoT binner(empty); binner.bin(prims + r.begin(), r.size(), mapping); return binner; }, + [&](const BinInfoT& b0, const BinInfoT& b1) -> BinInfoT { BinInfoT r = b0; r.merge(b1, mapping.size()); return r; }); + } + } + + template + __forceinline void bin_serial_or_parallel(BinInfoT& binner, const PrimRef* prims, size_t begin, size_t end, size_t blockSize, const BinMapping& mapping, const BinBoundsAndCenter& binBoundsAndCenter) + { + if (!parallel) { + binner.bin(prims,begin,end,mapping,binBoundsAndCenter); + } else { + binner = parallel_reduce(begin,end,blockSize,binner, + [&](const range& r) -> BinInfoT { BinInfoT binner(empty); binner.bin(prims + r.begin(), r.size(), mapping, binBoundsAndCenter); return binner; }, + [&](const BinInfoT& b0, const BinInfoT& b1) -> BinInfoT { BinInfoT r = b0; r.merge(b1, mapping.size()); return r; }); + } + } +} diff --git a/thirdparty/embree/kernels/builders/heuristic_binning_array_aligned.h b/thirdparty/embree/kernels/builders/heuristic_binning_array_aligned.h new file mode 100644 index 000000000000..a4c272f01511 --- /dev/null +++ b/thirdparty/embree/kernels/builders/heuristic_binning_array_aligned.h @@ -0,0 +1,205 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "heuristic_binning.h" + +namespace embree +{ + namespace isa + { + struct PrimInfoRange : public CentGeomBBox3fa, public range + { + __forceinline PrimInfoRange () { + } + + __forceinline PrimInfoRange(const PrimInfo& pinfo) + : CentGeomBBox3fa(pinfo), range(pinfo.begin,pinfo.end) {} + + __forceinline PrimInfoRange(EmptyTy) + : CentGeomBBox3fa(EmptyTy()), range(0,0) {} + + __forceinline PrimInfoRange (size_t begin, size_t end, const CentGeomBBox3fa& centGeomBounds) + : CentGeomBBox3fa(centGeomBounds), range(begin,end) {} + + __forceinline float leafSAH() const { + return expectedApproxHalfArea(geomBounds)*float(size()); + } + + __forceinline float leafSAH(size_t block_shift) const { + return expectedApproxHalfArea(geomBounds)*float((size()+(size_t(1)<> block_shift); + } + }; + + /*! Performs standard object binning */ + template + struct HeuristicArrayBinningSAH + { + typedef BinSplit Split; + typedef BinInfoT Binner; + typedef range Set; + +#if defined(__AVX512ER__) // KNL + static const size_t PARALLEL_THRESHOLD = 4*768; + static const size_t PARALLEL_FIND_BLOCK_SIZE = 768; + static const size_t PARALLEL_PARTITION_BLOCK_SIZE = 768; +#else + static const size_t PARALLEL_THRESHOLD = 3 * 1024; + static const size_t PARALLEL_FIND_BLOCK_SIZE = 1024; + static const size_t PARALLEL_PARTITION_BLOCK_SIZE = 128; +#endif + __forceinline HeuristicArrayBinningSAH () + : prims(nullptr) {} + + /*! remember prim array */ + __forceinline HeuristicArrayBinningSAH (PrimRef* prims) + : prims(prims) {} + + /*! finds the best split */ + __noinline const Split find(const PrimInfoRange& pinfo, const size_t logBlockSize) + { + if (likely(pinfo.size() < PARALLEL_THRESHOLD)) + return find_template(pinfo,logBlockSize); + else + return find_template(pinfo,logBlockSize); + } + + template + __forceinline const Split find_template(const PrimInfoRange& pinfo, const size_t logBlockSize) + { + Binner binner(empty); + const BinMapping mapping(pinfo); + bin_serial_or_parallel(binner,prims,pinfo.begin(),pinfo.end(),PARALLEL_FIND_BLOCK_SIZE,mapping); + return binner.best(mapping,logBlockSize); + } + + /*! array partitioning */ + __forceinline void split(const Split& split, const PrimInfoRange& pinfo, PrimInfoRange& linfo, PrimInfoRange& rinfo) + { + if (likely(pinfo.size() < PARALLEL_THRESHOLD)) + split_template(split,pinfo,linfo,rinfo); + else + split_template(split,pinfo,linfo,rinfo); + } + + template + __forceinline void split_template(const Split& split, const PrimInfoRange& set, PrimInfoRange& lset, PrimInfoRange& rset) + { + if (!split.valid()) { + deterministic_order(set); + return splitFallback(set,lset,rset); + } + + const size_t begin = set.begin(); + const size_t end = set.end(); + CentGeomBBox3fa local_left(empty); + CentGeomBBox3fa local_right(empty); + const unsigned int splitPos = split.pos; + const unsigned int splitDim = split.dim; + const unsigned int splitDimMask = (unsigned int)1 << splitDim; + + const typename Binner::vint vSplitPos(splitPos); + const typename Binner::vbool vSplitMask(splitDimMask); + auto isLeft = [&] (const PrimRef &ref) { return split.mapping.bin_unsafe(ref,vSplitPos,vSplitMask); }; + + size_t center = 0; + if (!parallel) + center = serial_partitioning(prims,begin,end,local_left,local_right,isLeft, + [] (CentGeomBBox3fa& pinfo,const PrimRef& ref) { pinfo.extend_center2(ref); }); + else + center = parallel_partitioning( + prims,begin,end,EmptyTy(),local_left,local_right,isLeft, + [] (CentGeomBBox3fa& pinfo,const PrimRef& ref) { pinfo.extend_center2(ref); }, + [] (CentGeomBBox3fa& pinfo0,const CentGeomBBox3fa& pinfo1) { pinfo0.merge(pinfo1); }, + PARALLEL_PARTITION_BLOCK_SIZE); + + new (&lset) PrimInfoRange(begin,center,local_left); + new (&rset) PrimInfoRange(center,end,local_right); + assert(area(lset.geomBounds) >= 0.0f); + assert(area(rset.geomBounds) >= 0.0f); + } + + void deterministic_order(const PrimInfoRange& pinfo) + { + /* required as parallel partition destroys original primitive order */ + std::sort(&prims[pinfo.begin()],&prims[pinfo.end()]); + } + + void splitFallback(const PrimInfoRange& pinfo, PrimInfoRange& linfo, PrimInfoRange& rinfo) + { + const size_t begin = pinfo.begin(); + const size_t end = pinfo.end(); + const size_t center = (begin + end)/2; + + CentGeomBBox3fa left(empty); + for (size_t i=begin; i& range, PrimInfoRange& linfo, PrimInfoRange& rinfo) + { + assert(range.size() > 1); + CentGeomBBox3fa left(empty); + CentGeomBBox3fa right(empty); + unsigned int geomID = prims[range.begin()].geomID(); + size_t center = serial_partitioning(prims,range.begin(),range.end(),left,right, + [&] ( const PrimRef& prim ) { return prim.geomID() == geomID; }, + [ ] ( CentGeomBBox3fa& a, const PrimRef& ref ) { a.extend_center2(ref); }); + + new (&linfo) PrimInfoRange(range.begin(),center,left); + new (&rinfo) PrimInfoRange(center,range.end(),right); + } + + private: + PrimRef* const prims; + }; + + /*! Performs standard object binning */ + template + struct HeuristicArrayBinningMB + { + typedef BinSplit Split; + typedef typename PrimRefMB::BBox BBox; + typedef BinInfoT ObjectBinner; + static const size_t PARALLEL_THRESHOLD = 3 * 1024; + static const size_t PARALLEL_FIND_BLOCK_SIZE = 1024; + static const size_t PARALLEL_PARTITION_BLOCK_SIZE = 128; + + /*! finds the best split */ + const Split find(const SetMB& set, const size_t logBlockSize) + { + ObjectBinner binner(empty); + const BinMapping mapping(set.size(),set.centBounds); + bin_parallel(binner,set.prims->data(),set.begin(),set.end(),PARALLEL_FIND_BLOCK_SIZE,PARALLEL_THRESHOLD,mapping); + Split osplit = binner.best(mapping,logBlockSize); + osplit.sah *= set.time_range.size(); + if (!osplit.valid()) osplit.data = Split::SPLIT_FALLBACK; // use fallback split + return osplit; + } + + /*! array partitioning */ + __forceinline void split(const Split& split, const SetMB& set, SetMB& lset, SetMB& rset) + { + const size_t begin = set.begin(); + const size_t end = set.end(); + PrimInfoMB left = empty; + PrimInfoMB right = empty; + const vint4 vSplitPos(split.pos); + const vbool4 vSplitMask(1 << split.dim); + auto isLeft = [&] (const PrimRefMB &ref) { return any(((vint4)split.mapping.bin_unsafe(ref) < vSplitPos) & vSplitMask); }; + auto reduction = [] (PrimInfoMB& pinfo, const PrimRefMB& ref) { pinfo.add_primref(ref); }; + auto reduction2 = [] (PrimInfoMB& pinfo0,const PrimInfoMB& pinfo1) { pinfo0.merge(pinfo1); }; + size_t center = parallel_partitioning(set.prims->data(),begin,end,EmptyTy(),left,right,isLeft,reduction,reduction2,PARALLEL_PARTITION_BLOCK_SIZE,PARALLEL_THRESHOLD); + new (&lset) SetMB(left, set.prims,range(begin,center),set.time_range); + new (&rset) SetMB(right,set.prims,range(center,end ),set.time_range); + } + }; + } +} diff --git a/thirdparty/embree/kernels/builders/heuristic_binning_array_unaligned.h b/thirdparty/embree/kernels/builders/heuristic_binning_array_unaligned.h new file mode 100644 index 000000000000..137024458621 --- /dev/null +++ b/thirdparty/embree/kernels/builders/heuristic_binning_array_unaligned.h @@ -0,0 +1,302 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "heuristic_binning.h" + +namespace embree +{ + namespace isa + { + /*! Performs standard object binning */ + template + struct UnalignedHeuristicArrayBinningSAH + { + typedef BinSplit Split; + typedef BinInfoT Binner; + typedef range Set; + + __forceinline UnalignedHeuristicArrayBinningSAH () // FIXME: required? + : scene(nullptr), prims(nullptr) {} + + /*! remember prim array */ + __forceinline UnalignedHeuristicArrayBinningSAH (Scene* scene, PrimRef* prims) + : scene(scene), prims(prims) {} + + const LinearSpace3fa computeAlignedSpace(const range& set) + { + Vec3fa axis(0,0,1); + uint64_t bestGeomPrimID = -1; + + /*! find curve with minimum ID that defines valid direction */ + for (size_t i=set.begin(); i= bestGeomPrimID) continue; + const Vec3fa axis1 = scene->get(geomID)->computeDirection(primID); + if (sqr_length(axis1) > 1E-18f) { + axis = normalize(axis1); + bestGeomPrimID = geomprimID; + } + } + return frame(axis).transposed(); + } + + const PrimInfo computePrimInfo(const range& set, const LinearSpace3fa& space) + { + auto computeBounds = [&](const range& r) -> CentGeomBBox3fa + { + CentGeomBBox3fa bounds(empty); + for (size_t i=r.begin(); iget(prims[i].geomID()); + bounds.extend(mesh->vbounds(space,prims[i].primID())); + } + return bounds; + }; + + const CentGeomBBox3fa bounds = parallel_reduce(set.begin(), set.end(), size_t(1024), size_t(4096), + CentGeomBBox3fa(empty), computeBounds, CentGeomBBox3fa::merge2); + + return PrimInfo(set.begin(),set.end(),bounds); + } + + struct BinBoundsAndCenter + { + __forceinline BinBoundsAndCenter(Scene* scene, const LinearSpace3fa& space) + : scene(scene), space(space) {} + + /*! returns center for binning */ + __forceinline Vec3fa binCenter(const PrimRef& ref) const + { + Geometry* mesh = (Geometry*) scene->get(ref.geomID()); + BBox3fa bounds = mesh->vbounds(space,ref.primID()); + return embree::center2(bounds); + } + + /*! returns bounds and centroid used for binning */ + __forceinline void binBoundsAndCenter(const PrimRef& ref, BBox3fa& bounds_o, Vec3fa& center_o) const + { + Geometry* mesh = (Geometry*) scene->get(ref.geomID()); + BBox3fa bounds = mesh->vbounds(space,ref.primID()); + bounds_o = bounds; + center_o = embree::center2(bounds); + } + + private: + Scene* scene; + const LinearSpace3fa space; + }; + + /*! finds the best split */ + __forceinline const Split find(const PrimInfoRange& pinfo, const size_t logBlockSize, const LinearSpace3fa& space) + { + if (likely(pinfo.size() < 10000)) + return find_template(pinfo,logBlockSize,space); + else + return find_template(pinfo,logBlockSize,space); + } + + /*! finds the best split */ + template + const Split find_template(const PrimInfoRange& set, const size_t logBlockSize, const LinearSpace3fa& space) + { + Binner binner(empty); + const BinMapping mapping(set); + BinBoundsAndCenter binBoundsAndCenter(scene,space); + bin_serial_or_parallel(binner,prims,set.begin(),set.end(),size_t(4096),mapping,binBoundsAndCenter); + return binner.best(mapping,logBlockSize); + } + + /*! array partitioning */ + __forceinline void split(const Split& split, const LinearSpace3fa& space, const Set& set, PrimInfoRange& lset, PrimInfoRange& rset) + { + if (likely(set.size() < 10000)) + split_template(split,space,set,lset,rset); + else + split_template(split,space,set,lset,rset); + } + + /*! array partitioning */ + template + __forceinline void split_template(const Split& split, const LinearSpace3fa& space, const Set& set, PrimInfoRange& lset, PrimInfoRange& rset) + { + if (!split.valid()) { + deterministic_order(set); + return splitFallback(set,lset,rset); + } + + const size_t begin = set.begin(); + const size_t end = set.end(); + CentGeomBBox3fa local_left(empty); + CentGeomBBox3fa local_right(empty); + const int splitPos = split.pos; + const int splitDim = split.dim; + BinBoundsAndCenter binBoundsAndCenter(scene,space); + + size_t center = 0; + if (likely(set.size() < 10000)) + center = serial_partitioning(prims,begin,end,local_left,local_right, + [&] (const PrimRef& ref) { return split.mapping.bin_unsafe(ref,binBoundsAndCenter)[splitDim] < splitPos; }, + [] (CentGeomBBox3fa& pinfo,const PrimRef& ref) { pinfo.extend_center2(ref); }); + else + center = parallel_partitioning(prims,begin,end,EmptyTy(),local_left,local_right, + [&] (const PrimRef& ref) { return split.mapping.bin_unsafe(ref,binBoundsAndCenter)[splitDim] < splitPos; }, + [] (CentGeomBBox3fa& pinfo,const PrimRef& ref) { pinfo.extend_center2(ref); }, + [] (CentGeomBBox3fa& pinfo0,const CentGeomBBox3fa& pinfo1) { pinfo0.merge(pinfo1); }, + 128); + + new (&lset) PrimInfoRange(begin,center,local_left); + new (&rset) PrimInfoRange(center,end,local_right); + assert(area(lset.geomBounds) >= 0.0f); + assert(area(rset.geomBounds) >= 0.0f); + } + + void deterministic_order(const range& set) + { + /* required as parallel partition destroys original primitive order */ + std::sort(&prims[set.begin()],&prims[set.end()]); + } + + void splitFallback(const range& set, PrimInfoRange& lset, PrimInfoRange& rset) + { + const size_t begin = set.begin(); + const size_t end = set.end(); + const size_t center = (begin + end)/2; + + CentGeomBBox3fa left(empty); + for (size_t i=begin; i + struct UnalignedHeuristicArrayBinningMB + { + typedef BinSplit Split; + typedef typename PrimRefMB::BBox BBox; + typedef BinInfoT ObjectBinner; + + static const size_t PARALLEL_THRESHOLD = 3 * 1024; + static const size_t PARALLEL_FIND_BLOCK_SIZE = 1024; + static const size_t PARALLEL_PARTITION_BLOCK_SIZE = 128; + + UnalignedHeuristicArrayBinningMB(Scene* scene) + : scene(scene) {} + + const LinearSpace3fa computeAlignedSpaceMB(Scene* scene, const SetMB& set) + { + Vec3fa axis0(0,0,1); + uint64_t bestGeomPrimID = -1; + + /*! find curve with minimum ID that defines valid direction */ + for (size_t i=set.begin(); i= bestGeomPrimID) continue; + + const Geometry* mesh = scene->get(geomID); + const range tbounds = mesh->timeSegmentRange(set.time_range); + if (tbounds.size() == 0) continue; + + const size_t t = (tbounds.begin()+tbounds.end())/2; + const Vec3fa axis1 = mesh->computeDirection(primID,t); + if (sqr_length(axis1) > 1E-18f) { + axis0 = normalize(axis1); + bestGeomPrimID = geomprimID; + } + } + + return frame(axis0).transposed(); + } + + struct BinBoundsAndCenter + { + __forceinline BinBoundsAndCenter(Scene* scene, BBox1f time_range, const LinearSpace3fa& space) + : scene(scene), time_range(time_range), space(space) {} + + /*! returns center for binning */ + template + __forceinline Vec3fa binCenter(const PrimRef& ref) const + { + Geometry* mesh = scene->get(ref.geomID()); + LBBox3fa lbounds = mesh->vlinearBounds(space,ref.primID(),time_range); + return center2(lbounds.interpolate(0.5f)); + } + + /*! returns bounds and centroid used for binning */ + __noinline void binBoundsAndCenter (const PrimRefMB& ref, BBox3fa& bounds_o, Vec3fa& center_o) const // __noinline is workaround for ICC16 bug under MacOSX + { + Geometry* mesh = scene->get(ref.geomID()); + LBBox3fa lbounds = mesh->vlinearBounds(space,ref.primID(),time_range); + bounds_o = lbounds.interpolate(0.5f); + center_o = center2(bounds_o); + } + + /*! returns bounds and centroid used for binning */ + __noinline void binBoundsAndCenter (const PrimRefMB& ref, LBBox3fa& bounds_o, Vec3fa& center_o) const // __noinline is workaround for ICC16 bug under MacOSX + { + Geometry* mesh = scene->get(ref.geomID()); + LBBox3fa lbounds = mesh->vlinearBounds(space,ref.primID(),time_range); + bounds_o = lbounds; + center_o = center2(lbounds.interpolate(0.5f)); + } + + private: + Scene* scene; + BBox1f time_range; + const LinearSpace3fa space; + }; + + /*! finds the best split */ + const Split find(const SetMB& set, const size_t logBlockSize, const LinearSpace3fa& space) + { + BinBoundsAndCenter binBoundsAndCenter(scene,set.time_range,space); + ObjectBinner binner(empty); + const BinMapping mapping(set.size(),set.centBounds); + bin_parallel(binner,set.prims->data(),set.begin(),set.end(),PARALLEL_FIND_BLOCK_SIZE,PARALLEL_THRESHOLD,mapping,binBoundsAndCenter); + Split osplit = binner.best(mapping,logBlockSize); + osplit.sah *= set.time_range.size(); + if (!osplit.valid()) osplit.data = Split::SPLIT_FALLBACK; // use fallback split + return osplit; + } + + /*! array partitioning */ + __forceinline void split(const Split& split, const LinearSpace3fa& space, const SetMB& set, SetMB& lset, SetMB& rset) + { + BinBoundsAndCenter binBoundsAndCenter(scene,set.time_range,space); + const size_t begin = set.begin(); + const size_t end = set.end(); + PrimInfoMB left = empty; + PrimInfoMB right = empty; + const vint4 vSplitPos(split.pos); + const vbool4 vSplitMask(1 << split.dim); + auto isLeft = [&] (const PrimRefMB &ref) { return any(((vint4)split.mapping.bin_unsafe(ref,binBoundsAndCenter) < vSplitPos) & vSplitMask); }; + auto reduction = [] (PrimInfoMB& pinfo, const PrimRefMB& ref) { pinfo.add_primref(ref); }; + auto reduction2 = [] (PrimInfoMB& pinfo0,const PrimInfoMB& pinfo1) { pinfo0.merge(pinfo1); }; + size_t center = parallel_partitioning(set.prims->data(),begin,end,EmptyTy(),left,right,isLeft,reduction,reduction2,PARALLEL_PARTITION_BLOCK_SIZE,PARALLEL_THRESHOLD); + new (&lset) SetMB(left,set.prims,range(begin,center),set.time_range); + new (&rset) SetMB(right,set.prims,range(center,end ),set.time_range); + } + + private: + Scene* scene; + }; + } +} diff --git a/thirdparty/embree/kernels/builders/heuristic_openmerge_array.h b/thirdparty/embree/kernels/builders/heuristic_openmerge_array.h new file mode 100644 index 000000000000..21f18c0208f5 --- /dev/null +++ b/thirdparty/embree/kernels/builders/heuristic_openmerge_array.h @@ -0,0 +1,443 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +// TODO: +// - adjust parallel build thresholds +// - openNodesBasedOnExtend should consider max extended size + +#pragma once + +#include "heuristic_binning.h" +#include "heuristic_spatial.h" + +/* stop opening of all bref.geomIDs are the same */ +#define EQUAL_GEOMID_STOP_CRITERIA 1 + +/* 10% spatial extend threshold */ +#define MAX_EXTEND_THRESHOLD 0.1f + +/* maximum is 8 children */ +#define MAX_OPENED_CHILD_NODES 8 + +/* open until all build refs are below threshold size in one step */ +#define USE_LOOP_OPENING 0 + +namespace embree +{ + namespace isa + { + /*! Performs standard object binning */ + template + struct HeuristicArrayOpenMergeSAH + { + typedef BinSplit Split; + typedef BinInfoT Binner; + + static const size_t PARALLEL_THRESHOLD = 1024; + static const size_t PARALLEL_FIND_BLOCK_SIZE = 512; + static const size_t PARALLEL_PARTITION_BLOCK_SIZE = 128; + + static const size_t MOVE_STEP_SIZE = 64; + static const size_t CREATE_SPLITS_STEP_SIZE = 128; + + __forceinline HeuristicArrayOpenMergeSAH () + : prims0(nullptr) {} + + /*! remember prim array */ + __forceinline HeuristicArrayOpenMergeSAH (const NodeOpenerFunc& nodeOpenerFunc, PrimRef* prims0, size_t max_open_size) + : prims0(prims0), nodeOpenerFunc(nodeOpenerFunc), max_open_size(max_open_size) + { + assert(max_open_size <= MAX_OPENED_CHILD_NODES); + } + + struct OpenHeuristic + { + __forceinline OpenHeuristic( const PrimInfoExtRange& pinfo ) + { + const Vec3fa diag = pinfo.geomBounds.size(); + dim = maxDim(diag); + assert(diag[dim] > 0.0f); + inv_max_extend = 1.0f / diag[dim]; + } + + __forceinline bool operator () ( PrimRef& prim ) const { + return !prim.node.isLeaf() && prim.bounds().size()[dim] * inv_max_extend > MAX_EXTEND_THRESHOLD; + } + + private: + size_t dim; + float inv_max_extend; + }; + + /*! compute extended ranges */ + __forceinline void setExtentedRanges(const PrimInfoExtRange& set, PrimInfoExtRange& lset, PrimInfoExtRange& rset, const size_t lweight, const size_t rweight) + { + assert(set.ext_range_size() > 0); + const float left_factor = (float)lweight / (lweight + rweight); + const size_t ext_range_size = set.ext_range_size(); + const size_t left_ext_range_size = min((size_t)(floorf(left_factor * ext_range_size)),ext_range_size); + const size_t right_ext_range_size = ext_range_size - left_ext_range_size; + lset.set_ext_range(lset.end() + left_ext_range_size); + rset.set_ext_range(rset.end() + right_ext_range_size); + } + + /*! move ranges */ + __forceinline void moveExtentedRange(const PrimInfoExtRange& set, const PrimInfoExtRange& lset, PrimInfoExtRange& rset) + { + const size_t left_ext_range_size = lset.ext_range_size(); + const size_t right_size = rset.size(); + + /* has the left child an extended range? */ + if (left_ext_range_size > 0) + { + /* left extended range smaller than right range ? */ + if (left_ext_range_size < right_size) + { + /* only move a small part of the beginning of the right range to the end */ + parallel_for( rset.begin(), rset.begin()+left_ext_range_size, MOVE_STEP_SIZE, [&](const range& r) { + for (size_t i=r.begin(); i& r) { + for (size_t i=r.begin(); i getProperties(const PrimInfoExtRange& set) + { + const OpenHeuristic heuristic(set); + const unsigned int geomID = prims0[set.begin()].geomID(); + + auto body = [&] (const range& r) -> std::pair { + bool commonGeomID = true; + size_t opens = 0; + for (size_t i=r.begin(); i(opens,commonGeomID); + }; + auto reduction = [&] (const std::pair& b0, const std::pair& b1) -> std::pair { + return std::pair(b0.first+b1.first,b0.second && b1.second); + }; + return parallel_reduce(set.begin(),set.end(),PARALLEL_FIND_BLOCK_SIZE,PARALLEL_THRESHOLD,std::pair(0,true),body,reduction); + } + + // FIXME: should consider maximum available extended size + __noinline void openNodesBasedOnExtend(PrimInfoExtRange& set) + { + const OpenHeuristic heuristic(set); + const size_t ext_range_start = set.end(); + + if (false && set.size() < PARALLEL_THRESHOLD) + { + size_t extra_elements = 0; + for (size_t i=set.begin(); i ext_elements; + ext_elements.store(0); + PrimInfo info = parallel_reduce( set.begin(), set.end(), CREATE_SPLITS_STEP_SIZE, PrimInfo(empty), [&](const range& r) -> PrimInfo { + PrimInfo info(empty); + for (size_t i=r.begin(); i 0); + + if (unlikely(next_iteration_extra_elements == 0)) break; + } + } + + __noinline const Split find(PrimInfoExtRange& set, const size_t logBlockSize) + { + /* single element */ + if (set.size() <= 1) + return Split(); + + /* disable opening if there is no overlap */ + const size_t D = 4; + if (unlikely(set.has_ext_range() && set.size() <= D)) + { + bool disjoint = true; + for (size_t j=set.begin(); j p(0,false); + + /* disable opening when all primitives are from same geometry */ + if (unlikely(set.has_ext_range())) + { + p = getProperties(set); +#if EQUAL_GEOMID_STOP_CRITERIA == 1 + if (p.second) set.set_ext_range(set.end()); /* disable opening */ +#endif + } + + /* open nodes when we have sufficient space available */ + if (unlikely(set.has_ext_range())) + { +#if USE_LOOP_OPENING == 1 + openNodesBasedOnExtendLoop(set,p.first); +#else + if (p.first <= set.ext_range_size()) + openNodesBasedOnExtend(set); +#endif + + /* disable opening when unsufficient space for opening a node available */ + if (set.ext_range_size() < max_open_size-1) + set.set_ext_range(set.end()); /* disable opening */ + } + + /* find best split */ + return object_find(set,logBlockSize); + } + + + /*! finds the best object split */ + __forceinline const Split object_find(const PrimInfoExtRange& set,const size_t logBlockSize) + { + if (set.size() < PARALLEL_THRESHOLD) return sequential_object_find(set,logBlockSize); + else return parallel_object_find (set,logBlockSize); + } + + /*! finds the best object split */ + __noinline const Split sequential_object_find(const PrimInfoExtRange& set, const size_t logBlockSize) + { + Binner binner(empty); + const BinMapping mapping(set.centBounds); + binner.bin(prims0,set.begin(),set.end(),mapping); + return binner.best(mapping,logBlockSize); + } + + /*! finds the best split */ + __noinline const Split parallel_object_find(const PrimInfoExtRange& set, const size_t logBlockSize) + { + Binner binner(empty); + const BinMapping mapping(set.centBounds); + const BinMapping& _mapping = mapping; // CLANG 3.4 parser bug workaround + auto body = [&] (const range& r) -> Binner { + Binner binner(empty); binner.bin(prims0+r.begin(),r.size(),_mapping); return binner; + }; + auto reduction = [&] (const Binner& b0, const Binner& b1) -> Binner { + Binner r = b0; r.merge(b1,_mapping.size()); return r; + }; + binner = parallel_reduce(set.begin(),set.end(),PARALLEL_FIND_BLOCK_SIZE,binner,body,reduction); + return binner.best(mapping,logBlockSize); + } + + /*! array partitioning */ + __noinline void split(const Split& split, const PrimInfoExtRange& set_i, PrimInfoExtRange& lset, PrimInfoExtRange& rset) + { + PrimInfoExtRange set = set_i; + + /* valid split */ + if (unlikely(!split.valid())) { + deterministic_order(set); + splitFallback(set,lset,rset); + return; + } + + std::pair ext_weights(0,0); + + /* object split */ + if (likely(set.size() < PARALLEL_THRESHOLD)) + ext_weights = sequential_object_split(split,set,lset,rset); + else + ext_weights = parallel_object_split(split,set,lset,rset); + + /* if we have an extended range, set extended child ranges and move right split range */ + if (unlikely(set.has_ext_range())) + { + setExtentedRanges(set,lset,rset,ext_weights.first,ext_weights.second); + moveExtentedRange(set,lset,rset); + } + } + + /*! array partitioning */ + std::pair sequential_object_split(const Split& split, const PrimInfoExtRange& set, PrimInfoExtRange& lset, PrimInfoExtRange& rset) + { + const size_t begin = set.begin(); + const size_t end = set.end(); + PrimInfo local_left(empty); + PrimInfo local_right(empty); + const unsigned int splitPos = split.pos; + const unsigned int splitDim = split.dim; + const unsigned int splitDimMask = (unsigned int)1 << splitDim; + + const vint4 vSplitPos(splitPos); + const vbool4 vSplitMask( (int)splitDimMask ); + + size_t center = serial_partitioning(prims0, + begin,end,local_left,local_right, + [&] (const PrimRef& ref) { return split.mapping.bin_unsafe(ref,vSplitPos,vSplitMask); }, + [] (PrimInfo& pinfo,const PrimRef& ref) { pinfo.add_center2(ref); }); + + new (&lset) PrimInfoExtRange(begin,center,center,local_left); + new (&rset) PrimInfoExtRange(center,end,end,local_right); + assert(area(lset.geomBounds) >= 0.0f); + assert(area(rset.geomBounds) >= 0.0f); + return std::pair(local_left.size(),local_right.size()); + } + + /*! array partitioning */ + __noinline std::pair parallel_object_split(const Split& split, const PrimInfoExtRange& set, PrimInfoExtRange& lset, PrimInfoExtRange& rset) + { + const size_t begin = set.begin(); + const size_t end = set.end(); + PrimInfo left(empty); + PrimInfo right(empty); + const unsigned int splitPos = split.pos; + const unsigned int splitDim = split.dim; + const unsigned int splitDimMask = (unsigned int)1 << splitDim; + + const vint4 vSplitPos(splitPos); + const vbool4 vSplitMask( (int)splitDimMask ); + auto isLeft = [&] (const PrimRef& ref) { return split.mapping.bin_unsafe(ref,vSplitPos,vSplitMask); }; + + const size_t center = parallel_partitioning( + prims0,begin,end,EmptyTy(),left,right,isLeft, + [] (PrimInfo& pinfo,const PrimRef& ref) { pinfo.add_center2(ref); }, + [] (PrimInfo& pinfo0,const PrimInfo& pinfo1) { pinfo0.merge(pinfo1); }, + PARALLEL_PARTITION_BLOCK_SIZE); + + new (&lset) PrimInfoExtRange(begin,center,center,left); + new (&rset) PrimInfoExtRange(center,end,end,right); + assert(area(lset.geomBounds) >= 0.0f); + assert(area(rset.geomBounds) >= 0.0f); + + return std::pair(left.size(),right.size()); + } + + void deterministic_order(const extended_range& set) + { + /* required as parallel partition destroys original primitive order */ + std::sort(&prims0[set.begin()],&prims0[set.end()]); + } + + __forceinline void splitFallback(const PrimInfoExtRange& set, PrimInfoExtRange& lset, PrimInfoExtRange& rset) + { + const size_t begin = set.begin(); + const size_t end = set.end(); + const size_t center = (begin + end)/2; + + PrimInfo left(empty); + for (size_t i=begin; i + struct SpatialBinMapping + { + public: + __forceinline SpatialBinMapping() {} + + /*! calculates the mapping */ + __forceinline SpatialBinMapping(const CentGeomBBox3fa& pinfo) + { + const vfloat4 lower = (vfloat4) pinfo.geomBounds.lower; + const vfloat4 upper = (vfloat4) pinfo.geomBounds.upper; + const vfloat4 eps = 128.0f*vfloat4(ulp)*max(abs(lower),abs(upper)); + const vfloat4 diag = max(eps,(vfloat4) pinfo.geomBounds.size()); + scale = select(upper-lower <= eps,vfloat4(0.0f),vfloat4(BINS)/diag); + ofs = (vfloat4) pinfo.geomBounds.lower; + inv_scale = 1.0f / scale; + } + + /*! slower but safe binning */ + __forceinline vint4 bin(const Vec3fa& p) const + { + const vint4 i = floori((vfloat4(p)-ofs)*scale); + return clamp(i,vint4(0),vint4(BINS-1)); + } + + __forceinline std::pair bin(const BBox3fa& b) const + { +#if defined(__AVX__) + const vfloat8 ofs8(ofs); + const vfloat8 scale8(scale); + const vint8 lu = floori((vfloat8::loadu(&b)-ofs8)*scale8); + const vint8 c_lu = clamp(lu,vint8(zero),vint8(BINS-1)); + return std::pair(extract4<0>(c_lu),extract4<1>(c_lu)); +#else + const vint4 lower = floori((vfloat4(b.lower)-ofs)*scale); + const vint4 upper = floori((vfloat4(b.upper)-ofs)*scale); + const vint4 c_lower = clamp(lower,vint4(0),vint4(BINS-1)); + const vint4 c_upper = clamp(upper,vint4(0),vint4(BINS-1)); + return std::pair(c_lower,c_upper); +#endif + } + + + /*! calculates left spatial position of bin */ + __forceinline float pos(const size_t bin, const size_t dim) const { + return madd(float(bin),inv_scale[dim],ofs[dim]); + } + + /*! calculates left spatial position of bin */ + template + __forceinline vfloat posN(const vfloat bin, const size_t dim) const { + return madd(bin,vfloat(inv_scale[dim]),vfloat(ofs[dim])); + } + + /*! returns true if the mapping is invalid in some dimension */ + __forceinline bool invalid(const size_t dim) const { + return scale[dim] == 0.0f; + } + + public: + vfloat4 ofs,scale,inv_scale; //!< linear function that maps to bin ID + }; + + /*! stores all information required to perform some split */ + template + struct SpatialBinSplit + { + /*! construct an invalid split by default */ + __forceinline SpatialBinSplit() + : sah(inf), dim(-1), pos(0), left(-1), right(-1), factor(1.0f) {} + + /*! constructs specified split */ + __forceinline SpatialBinSplit(float sah, int dim, int pos, const SpatialBinMapping& mapping) + : sah(sah), dim(dim), pos(pos), left(-1), right(-1), factor(1.0f), mapping(mapping) {} + + /*! constructs specified split */ + __forceinline SpatialBinSplit(float sah, int dim, int pos, int left, int right, float factor, const SpatialBinMapping& mapping) + : sah(sah), dim(dim), pos(pos), left(left), right(right), factor(factor), mapping(mapping) {} + + /*! tests if this split is valid */ + __forceinline bool valid() const { return dim != -1; } + + /*! calculates surface area heuristic for performing the split */ + __forceinline float splitSAH() const { return sah; } + + /*! stream output */ + friend embree_ostream operator<<(embree_ostream cout, const SpatialBinSplit& split) { + return cout << "SpatialBinSplit { sah = " << split.sah << ", dim = " << split.dim << ", pos = " << split.pos << ", left = " << split.left << ", right = " << split.right << ", factor = " << split.factor << "}"; + } + + public: + float sah; //!< SAH cost of the split + int dim; //!< split dimension + int pos; //!< split position + int left; //!< number of elements on the left side + int right; //!< number of elements on the right side + float factor; //!< factor splitting the extended range + SpatialBinMapping mapping; //!< mapping into bins + }; + + /*! stores all binning information */ + template + struct __aligned(64) SpatialBinInfo + { + SpatialBinInfo() { + } + + __forceinline SpatialBinInfo(EmptyTy) { + clear(); + } + + /*! clears the bin info */ + __forceinline void clear() + { + for (size_t i=0; i + __forceinline void bin(const SplitPrimitive& splitPrimitive, const PrimRef* prims, size_t N, const SpatialBinMapping& mapping) + { + for (size_t i=0; i> (32-RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS); + + if (unlikely(splits == 1)) + { + const vint4 bin = mapping.bin(center(prim.bounds())); + for (size_t dim=0; dim<3; dim++) + { + assert(bin[dim] >= (int)0 && bin[dim] < (int)BINS); + numBegin[bin[dim]][dim]++; + numEnd [bin[dim]][dim]++; + bounds [bin[dim]][dim].extend(prim.bounds()); + } + } + else + { + const vint4 bin0 = mapping.bin(prim.bounds().lower); + const vint4 bin1 = mapping.bin(prim.bounds().upper); + + for (size_t dim=0; dim<3; dim++) + { + size_t bin; + PrimRef rest = prim; + size_t l = bin0[dim]; + size_t r = bin1[dim]; + + // same bin optimization + if (likely(l == r)) + { + numBegin[l][dim]++; + numEnd [l][dim]++; + bounds [l][dim].extend(prim.bounds()); + continue; + } + + for (bin=(size_t)bin0[dim]; bin<(size_t)bin1[dim]; bin++) + { + const float pos = mapping.pos(bin+1,dim); + + PrimRef left,right; + splitPrimitive(rest,(int)dim,pos,left,right); + if (unlikely(left.bounds().empty())) l++; + bounds[bin][dim].extend(left.bounds()); + rest = right; + } + if (unlikely(rest.bounds().empty())) r--; + numBegin[l][dim]++; + numEnd [r][dim]++; + bounds [bin][dim].extend(rest.bounds()); + } + } + } + } + + /*! bins a range of primitives inside an array */ + template + void bin(const SplitPrimitive& splitPrimitive, const PrimRef* prims, size_t begin, size_t end, const SpatialBinMapping& mapping) { + bin(splitPrimitive,prims+begin,end-begin,mapping); + } + + /*! bins an array of primitives */ + template + __forceinline void bin2(const PrimitiveSplitterFactory& splitterFactory, const PrimRef* source, size_t begin, size_t end, const SpatialBinMapping& mapping) + { + for (size_t i=begin; i& mapping) + { + for (size_t i=begin; i best(const SpatialBinMapping& mapping, const size_t blocks_shift) const + { + /* sweep from right to left and compute parallel prefix of merged bounds */ + vfloat4 rAreas[BINS]; + vuint4 rCounts[BINS]; + vuint4 count = 0; BBox3fa bx = empty; BBox3fa by = empty; BBox3fa bz = empty; + for (size_t i=BINS-1; i>0; i--) + { + count += numEnd[i]; + rCounts[i] = count; + bx.extend(bounds[i][0]); rAreas[i][0] = halfArea(bx); + by.extend(bounds[i][1]); rAreas[i][1] = halfArea(by); + bz.extend(bounds[i][2]); rAreas[i][2] = halfArea(bz); + rAreas[i][3] = 0.0f; + } + + /* sweep from left to right and compute SAH */ + vuint4 blocks_add = (1 << blocks_shift)-1; + vuint4 ii = 1; vfloat4 vbestSAH = pos_inf; vuint4 vbestPos = 0; vuint4 vbestlCount = 0; vuint4 vbestrCount = 0; + count = 0; bx = empty; by = empty; bz = empty; + for (size_t i=1; i> (unsigned int)(blocks_shift); + const vuint4 rCount = (rCounts[i]+blocks_add) >> (unsigned int)(blocks_shift); + const vfloat4 sah = madd(lArea,vfloat4(lCount),rArea*vfloat4(rCount)); + // const vfloat4 sah = madd(lArea,vfloat4(vint4(lCount)),rArea*vfloat4(vint4(rCount))); + const vbool4 mask = sah < vbestSAH; + vbestPos = select(mask,ii ,vbestPos); + vbestSAH = select(mask,sah,vbestSAH); + vbestlCount = select(mask,count,vbestlCount); + vbestrCount = select(mask,rCounts[i],vbestrCount); + } + + /* find best dimension */ + float bestSAH = inf; + int bestDim = -1; + int bestPos = 0; + unsigned int bestlCount = 0; + unsigned int bestrCount = 0; + for (int dim=0; dim<3; dim++) + { + /* ignore zero sized dimensions */ + if (unlikely(mapping.invalid(dim))) + continue; + + /* test if this is a better dimension */ + if (vbestSAH[dim] < bestSAH && vbestPos[dim] != 0) { + bestDim = dim; + bestPos = vbestPos[dim]; + bestSAH = vbestSAH[dim]; + bestlCount = vbestlCount[dim]; + bestrCount = vbestrCount[dim]; + } + } + assert(bestSAH >= 0.0f); + + /* return invalid split if no split found */ + if (bestDim == -1) + return SpatialBinSplit(inf,-1,0,mapping); + + /* return best found split */ + return SpatialBinSplit(bestSAH,bestDim,bestPos,bestlCount,bestrCount,1.0f,mapping); + } + + private: + BBox3fa bounds[BINS][3]; //!< geometry bounds for each bin in each dimension + vuint4 numBegin[BINS]; //!< number of primitives starting in bin + vuint4 numEnd[BINS]; //!< number of primitives ending in bin + }; + } +} + diff --git a/thirdparty/embree/kernels/builders/heuristic_spatial_array.h b/thirdparty/embree/kernels/builders/heuristic_spatial_array.h new file mode 100644 index 000000000000..911dcf950cad --- /dev/null +++ b/thirdparty/embree/kernels/builders/heuristic_spatial_array.h @@ -0,0 +1,552 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "heuristic_binning.h" +#include "heuristic_spatial.h" + +namespace embree +{ + namespace isa + { +#if 0 +#define SPATIAL_ASPLIT_OVERLAP_THRESHOLD 0.2f +#define SPATIAL_ASPLIT_SAH_THRESHOLD 0.95f +#define SPATIAL_ASPLIT_AREA_THRESHOLD 0.0f +#else +#define SPATIAL_ASPLIT_OVERLAP_THRESHOLD 0.1f +#define SPATIAL_ASPLIT_SAH_THRESHOLD 0.99f +#define SPATIAL_ASPLIT_AREA_THRESHOLD 0.000005f +#endif + + struct PrimInfoExtRange : public CentGeomBBox3fa, public extended_range + { + __forceinline PrimInfoExtRange() { + } + + __forceinline PrimInfoExtRange(EmptyTy) + : CentGeomBBox3fa(EmptyTy()), extended_range(0,0,0) {} + + __forceinline PrimInfoExtRange(size_t begin, size_t end, size_t ext_end, const CentGeomBBox3fa& centGeomBounds) + : CentGeomBBox3fa(centGeomBounds), extended_range(begin,end,ext_end) {} + + __forceinline float leafSAH() const { + return expectedApproxHalfArea(geomBounds)*float(size()); + } + + __forceinline float leafSAH(size_t block_shift) const { + return expectedApproxHalfArea(geomBounds)*float((size()+(size_t(1)<> block_shift); + } + }; + + template + struct Split2 + { + __forceinline Split2 () {} + + __forceinline Split2 (const Split2& other) + { + spatial = other.spatial; + sah = other.sah; + if (spatial) spatialSplit() = other.spatialSplit(); + else objectSplit() = other.objectSplit(); + } + + __forceinline Split2& operator= (const Split2& other) + { + spatial = other.spatial; + sah = other.sah; + if (spatial) spatialSplit() = other.spatialSplit(); + else objectSplit() = other.objectSplit(); + return *this; + } + + __forceinline ObjectSplit& objectSplit() { return *( ObjectSplit*)data; } + __forceinline const ObjectSplit& objectSplit() const { return *(const ObjectSplit*)data; } + + __forceinline SpatialSplit& spatialSplit() { return *( SpatialSplit*)data; } + __forceinline const SpatialSplit& spatialSplit() const { return *(const SpatialSplit*)data; } + + __forceinline Split2 (const ObjectSplit& objectSplit, float sah) + : spatial(false), sah(sah) + { + new (data) ObjectSplit(objectSplit); + } + + __forceinline Split2 (const SpatialSplit& spatialSplit, float sah) + : spatial(true), sah(sah) + { + new (data) SpatialSplit(spatialSplit); + } + + __forceinline float splitSAH() const { + return sah; + } + + __forceinline bool valid() const { + return sah < float(inf); + } + + public: + __aligned(64) char data[sizeof(ObjectSplit) > sizeof(SpatialSplit) ? sizeof(ObjectSplit) : sizeof(SpatialSplit)]; + bool spatial; + float sah; + }; + + /*! Performs standard object binning */ + template + struct HeuristicArraySpatialSAH + { + typedef BinSplit ObjectSplit; + typedef BinInfoT ObjectBinner; + + typedef SpatialBinSplit SpatialSplit; + typedef SpatialBinInfo SpatialBinner; + + //typedef extended_range Set; + typedef Split2 Split; + +#if defined(__AVX512ER__) // KNL + static const size_t PARALLEL_THRESHOLD = 3*1024; + static const size_t PARALLEL_FIND_BLOCK_SIZE = 768; + static const size_t PARALLEL_PARTITION_BLOCK_SIZE = 128; +#else + static const size_t PARALLEL_THRESHOLD = 3*1024; + static const size_t PARALLEL_FIND_BLOCK_SIZE = 1024; + static const size_t PARALLEL_PARTITION_BLOCK_SIZE = 128; +#endif + + static const size_t MOVE_STEP_SIZE = 64; + static const size_t CREATE_SPLITS_STEP_SIZE = 64; + + __forceinline HeuristicArraySpatialSAH () + : prims0(nullptr) {} + + /*! remember prim array */ + __forceinline HeuristicArraySpatialSAH (const PrimitiveSplitterFactory& splitterFactory, PrimRef* prims0, const CentGeomBBox3fa& root_info) + : prims0(prims0), splitterFactory(splitterFactory), root_info(root_info) {} + + + /*! compute extended ranges */ + __noinline void setExtentedRanges(const PrimInfoExtRange& set, PrimInfoExtRange& lset, PrimInfoExtRange& rset, const size_t lweight, const size_t rweight) + { + assert(set.ext_range_size() > 0); + const float left_factor = (float)lweight / (lweight + rweight); + const size_t ext_range_size = set.ext_range_size(); + const size_t left_ext_range_size = min((size_t)(floorf(left_factor * ext_range_size)),ext_range_size); + const size_t right_ext_range_size = ext_range_size - left_ext_range_size; + lset.set_ext_range(lset.end() + left_ext_range_size); + rset.set_ext_range(rset.end() + right_ext_range_size); + } + + /*! move ranges */ + __noinline void moveExtentedRange(const PrimInfoExtRange& set, const PrimInfoExtRange& lset, PrimInfoExtRange& rset) + { + const size_t left_ext_range_size = lset.ext_range_size(); + const size_t right_size = rset.size(); + + /* has the left child an extended range? */ + if (left_ext_range_size > 0) + { + /* left extended range smaller than right range ? */ + if (left_ext_range_size < right_size) + { + /* only move a small part of the beginning of the right range to the end */ + parallel_for( rset.begin(), rset.begin()+left_ext_range_size, MOVE_STEP_SIZE, [&](const range& r) { + for (size_t i=r.begin(); i& r) { + for (size_t i=r.begin(); i= SPATIAL_ASPLIT_AREA_THRESHOLD*safeArea(root_info.geomBounds) && + safeArea(overlap) >= SPATIAL_ASPLIT_OVERLAP_THRESHOLD*safeArea(set.geomBounds)) + { + const SpatialSplit spatial_split = spatial_find(set, logBlockSize); + const float spatial_split_sah = spatial_split.splitSAH(); + + /* valid spatial split, better SAH and number of splits do not exceed extended range */ + if (spatial_split_sah < SPATIAL_ASPLIT_SAH_THRESHOLD*object_split_sah && + spatial_split.left + spatial_split.right - set.size() <= set.ext_range_size()) + { + return Split(spatial_split,spatial_split_sah); + } + } + } + + return Split(object_split,object_split_sah); + } + + /*! finds the best object split */ + __forceinline const ObjectSplit object_find(const PrimInfoExtRange& set, const size_t logBlockSize, SplitInfo &info) + { + if (set.size() < PARALLEL_THRESHOLD) return sequential_object_find(set,logBlockSize,info); + else return parallel_object_find (set,logBlockSize,info); + } + + /*! finds the best object split */ + __noinline const ObjectSplit sequential_object_find(const PrimInfoExtRange& set, const size_t logBlockSize, SplitInfo &info) + { + ObjectBinner binner(empty); + const BinMapping mapping(set); + binner.bin(prims0,set.begin(),set.end(),mapping); + ObjectSplit s = binner.best(mapping,logBlockSize); + binner.getSplitInfo(mapping, s, info); + return s; + } + + /*! finds the best split */ + __noinline const ObjectSplit parallel_object_find(const PrimInfoExtRange& set, const size_t logBlockSize, SplitInfo &info) + { + ObjectBinner binner(empty); + const BinMapping mapping(set); + const BinMapping& _mapping = mapping; // CLANG 3.4 parser bug workaround + binner = parallel_reduce(set.begin(),set.end(),PARALLEL_FIND_BLOCK_SIZE,binner, + [&] (const range& r) -> ObjectBinner { ObjectBinner binner(empty); binner.bin(prims0+r.begin(),r.size(),_mapping); return binner; }, + [&] (const ObjectBinner& b0, const ObjectBinner& b1) -> ObjectBinner { ObjectBinner r = b0; r.merge(b1,_mapping.size()); return r; }); + ObjectSplit s = binner.best(mapping,logBlockSize); + binner.getSplitInfo(mapping, s, info); + return s; + } + + /*! finds the best spatial split */ + __forceinline const SpatialSplit spatial_find(const PrimInfoExtRange& set, const size_t logBlockSize) + { + if (set.size() < PARALLEL_THRESHOLD) return sequential_spatial_find(set, logBlockSize); + else return parallel_spatial_find (set, logBlockSize); + } + + /*! finds the best spatial split */ + __noinline const SpatialSplit sequential_spatial_find(const PrimInfoExtRange& set, const size_t logBlockSize) + { + SpatialBinner binner(empty); + const SpatialBinMapping mapping(set); + binner.bin2(splitterFactory,prims0,set.begin(),set.end(),mapping); + /* todo: best spatial split not exeeding the extended range does not provide any benefit ?*/ + return binner.best(mapping,logBlockSize); //,set.ext_size()); + } + + __noinline const SpatialSplit parallel_spatial_find(const PrimInfoExtRange& set, const size_t logBlockSize) + { + SpatialBinner binner(empty); + const SpatialBinMapping mapping(set); + const SpatialBinMapping& _mapping = mapping; // CLANG 3.4 parser bug workaround + binner = parallel_reduce(set.begin(),set.end(),PARALLEL_FIND_BLOCK_SIZE,binner, + [&] (const range& r) -> SpatialBinner { + SpatialBinner binner(empty); + binner.bin2(splitterFactory,prims0,r.begin(),r.end(),_mapping); + return binner; }, + [&] (const SpatialBinner& b0, const SpatialBinner& b1) -> SpatialBinner { return SpatialBinner::reduce(b0,b1); }); + /* todo: best spatial split not exeeding the extended range does not provide any benefit ?*/ + return binner.best(mapping,logBlockSize); //,set.ext_size()); + } + + + /*! subdivides primitives based on a spatial split */ + __noinline void create_spatial_splits(PrimInfoExtRange& set, const SpatialSplit& split, const SpatialBinMapping &mapping) + { + assert(set.has_ext_range()); + const size_t max_ext_range_size = set.ext_range_size(); + const size_t ext_range_start = set.end(); + + /* atomic counter for number of primref splits */ + std::atomic ext_elements; + ext_elements.store(0); + + const float fpos = split.mapping.pos(split.pos,split.dim); + + const unsigned int mask = 0xFFFFFFFF >> RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS; + + parallel_for( set.begin(), set.end(), CREATE_SPLITS_STEP_SIZE, [&](const range& r) { + for (size_t i=r.begin();i> (32-RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS); + + if (likely(splits <= 1)) continue; /* todo: does this ever happen ? */ + + //int bin0 = split.mapping.bin(prims0[i].lower)[split.dim]; + //int bin1 = split.mapping.bin(prims0[i].upper)[split.dim]; + //if (unlikely(bin0 < split.pos && bin1 >= split.pos)) + if (unlikely(prims0[i].lower[split.dim] < fpos && prims0[i].upper[split.dim] > fpos)) + { + assert(splits > 1); + + PrimRef left,right; + const auto splitter = splitterFactory(prims0[i]); + splitter(prims0[i],split.dim,fpos,left,right); + + // no empty splits + if (unlikely(left.bounds().empty() || right.bounds().empty())) continue; + + left.lower.u = (left.lower.u & mask) | ((splits-1) << (32-RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS)); + right.lower.u = (right.lower.u & mask) | ((splits-1) << (32-RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS)); + + const size_t ID = ext_elements.fetch_add(1); + + /* break if the number of subdivided elements are greater than the maximum allowed size */ + if (unlikely(ID >= max_ext_range_size)) + break; + + /* only write within the correct bounds */ + assert(ID < max_ext_range_size); + prims0[i] = left; + prims0[ext_range_start+ID] = right; + } + } + }); + + const size_t numExtElements = min(max_ext_range_size,ext_elements.load()); + assert(set.end()+numExtElements<=set.ext_end()); + set._end += numExtElements; + } + + /*! array partitioning */ + void split(const Split& split, const PrimInfoExtRange& set_i, PrimInfoExtRange& lset, PrimInfoExtRange& rset) + { + PrimInfoExtRange set = set_i; + + /* valid split */ + if (unlikely(!split.valid())) { + deterministic_order(set); + return splitFallback(set,lset,rset); + } + + std::pair ext_weights(0,0); + + if (unlikely(split.spatial)) + { + create_spatial_splits(set,split.spatialSplit(), split.spatialSplit().mapping); + + /* spatial split */ + if (likely(set.size() < PARALLEL_THRESHOLD)) + ext_weights = sequential_spatial_split(split.spatialSplit(),set,lset,rset); + else + ext_weights = parallel_spatial_split(split.spatialSplit(),set,lset,rset); + } + else + { + /* object split */ + if (likely(set.size() < PARALLEL_THRESHOLD)) + ext_weights = sequential_object_split(split.objectSplit(),set,lset,rset); + else + ext_weights = parallel_object_split(split.objectSplit(),set,lset,rset); + } + + /* if we have an extended range, set extended child ranges and move right split range */ + if (unlikely(set.has_ext_range())) + { + setExtentedRanges(set,lset,rset,ext_weights.first,ext_weights.second); + moveExtentedRange(set,lset,rset); + } + } + + /*! array partitioning */ + std::pair sequential_object_split(const ObjectSplit& split, const PrimInfoExtRange& set, PrimInfoExtRange& lset, PrimInfoExtRange& rset) + { + const size_t begin = set.begin(); + const size_t end = set.end(); + PrimInfo local_left(empty); + PrimInfo local_right(empty); + const unsigned int splitPos = split.pos; + const unsigned int splitDim = split.dim; + const unsigned int splitDimMask = (unsigned int)1 << splitDim; + + const typename ObjectBinner::vint vSplitPos(splitPos); + const typename ObjectBinner::vbool vSplitMask(splitDimMask); + size_t center = serial_partitioning(prims0, + begin,end,local_left,local_right, + [&] (const PrimRef& ref) { + return split.mapping.bin_unsafe(ref,vSplitPos,vSplitMask); + }, + [] (PrimInfo& pinfo,const PrimRef& ref) { pinfo.add_center2(ref,ref.lower.u >> (32-RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS)); }); + const size_t left_weight = local_left.end; + const size_t right_weight = local_right.end; + + new (&lset) PrimInfoExtRange(begin,center,center,local_left); + new (&rset) PrimInfoExtRange(center,end,end,local_right); + + assert(area(lset.geomBounds) >= 0.0f); + assert(area(rset.geomBounds) >= 0.0f); + return std::pair(left_weight,right_weight); + } + + + /*! array partitioning */ + __noinline std::pair sequential_spatial_split(const SpatialSplit& split, const PrimInfoExtRange& set, PrimInfoExtRange& lset, PrimInfoExtRange& rset) + { + const size_t begin = set.begin(); + const size_t end = set.end(); + PrimInfo local_left(empty); + PrimInfo local_right(empty); + const unsigned int splitPos = split.pos; + const unsigned int splitDim = split.dim; + const unsigned int splitDimMask = (unsigned int)1 << splitDim; + + /* init spatial mapping */ + const SpatialBinMapping &mapping = split.mapping; + const vint4 vSplitPos(splitPos); + const vbool4 vSplitMask( (int)splitDimMask ); + + size_t center = serial_partitioning(prims0, + begin,end,local_left,local_right, + [&] (const PrimRef& ref) { + const Vec3fa c = ref.bounds().center(); + return any(((vint4)mapping.bin(c) < vSplitPos) & vSplitMask); + }, + [] (PrimInfo& pinfo,const PrimRef& ref) { pinfo.add_center2(ref,ref.lower.u >> (32-RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS)); }); + + const size_t left_weight = local_left.end; + const size_t right_weight = local_right.end; + + new (&lset) PrimInfoExtRange(begin,center,center,local_left); + new (&rset) PrimInfoExtRange(center,end,end,local_right); + assert(area(lset.geomBounds) >= 0.0f); + assert(area(rset.geomBounds) >= 0.0f); + return std::pair(left_weight,right_weight); + } + + + + /*! array partitioning */ + __noinline std::pair parallel_object_split(const ObjectSplit& split, const PrimInfoExtRange& set, PrimInfoExtRange& lset, PrimInfoExtRange& rset) + { + const size_t begin = set.begin(); + const size_t end = set.end(); + PrimInfo left(empty); + PrimInfo right(empty); + const unsigned int splitPos = split.pos; + const unsigned int splitDim = split.dim; + const unsigned int splitDimMask = (unsigned int)1 << splitDim; + + const typename ObjectBinner::vint vSplitPos(splitPos); + const typename ObjectBinner::vbool vSplitMask(splitDimMask); + auto isLeft = [&] (const PrimRef &ref) { return split.mapping.bin_unsafe(ref,vSplitPos,vSplitMask); }; + + const size_t center = parallel_partitioning( + prims0,begin,end,EmptyTy(),left,right,isLeft, + [] (PrimInfo &pinfo,const PrimRef &ref) { pinfo.add_center2(ref,ref.lower.u >> (32-RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS)); }, + [] (PrimInfo &pinfo0,const PrimInfo &pinfo1) { pinfo0.merge(pinfo1); }, + PARALLEL_PARTITION_BLOCK_SIZE); + + const size_t left_weight = left.end; + const size_t right_weight = right.end; + + left.begin = begin; left.end = center; + right.begin = center; right.end = end; + + new (&lset) PrimInfoExtRange(begin,center,center,left); + new (&rset) PrimInfoExtRange(center,end,end,right); + + assert(area(left.geomBounds) >= 0.0f); + assert(area(right.geomBounds) >= 0.0f); + return std::pair(left_weight,right_weight); + } + + /*! array partitioning */ + __noinline std::pair parallel_spatial_split(const SpatialSplit& split, const PrimInfoExtRange& set, PrimInfoExtRange& lset, PrimInfoExtRange& rset) + { + const size_t begin = set.begin(); + const size_t end = set.end(); + PrimInfo left(empty); + PrimInfo right(empty); + const unsigned int splitPos = split.pos; + const unsigned int splitDim = split.dim; + const unsigned int splitDimMask = (unsigned int)1 << splitDim; + + /* init spatial mapping */ + const SpatialBinMapping& mapping = split.mapping; + const vint4 vSplitPos(splitPos); + const vbool4 vSplitMask( (int)splitDimMask ); + + auto isLeft = [&] (const PrimRef &ref) { + const Vec3fa c = ref.bounds().center(); + return any(((vint4)mapping.bin(c) < vSplitPos) & vSplitMask); }; + + const size_t center = parallel_partitioning( + prims0,begin,end,EmptyTy(),left,right,isLeft, + [] (PrimInfo &pinfo,const PrimRef &ref) { pinfo.add_center2(ref,ref.lower.u >> (32-RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS)); }, + [] (PrimInfo &pinfo0,const PrimInfo &pinfo1) { pinfo0.merge(pinfo1); }, + PARALLEL_PARTITION_BLOCK_SIZE); + + const size_t left_weight = left.end; + const size_t right_weight = right.end; + + left.begin = begin; left.end = center; + right.begin = center; right.end = end; + + new (&lset) PrimInfoExtRange(begin,center,center,left); + new (&rset) PrimInfoExtRange(center,end,end,right); + + assert(area(left.geomBounds) >= 0.0f); + assert(area(right.geomBounds) >= 0.0f); + return std::pair(left_weight,right_weight); + } + + void deterministic_order(const PrimInfoExtRange& set) + { + /* required as parallel partition destroys original primitive order */ + std::sort(&prims0[set.begin()],&prims0[set.end()]); + } + + void splitFallback(const PrimInfoExtRange& set, + PrimInfoExtRange& lset, + PrimInfoExtRange& rset) + { + const size_t begin = set.begin(); + const size_t end = set.end(); + const size_t center = (begin + end)/2; + + PrimInfo left(empty); + for (size_t i=begin; i> (32-RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS)); + } + const size_t lweight = left.end; + + PrimInfo right(empty); + for (size_t i=center; i> (32-RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS)); + } + const size_t rweight = right.end; + + new (&lset) PrimInfoExtRange(begin,center,center,left); + new (&rset) PrimInfoExtRange(center,end,end,right); + + /* if we have an extended range */ + if (set.has_ext_range()) { + setExtentedRanges(set,lset,rset,lweight,rweight); + moveExtentedRange(set,lset,rset); + } + } + + private: + PrimRef* const prims0; + const PrimitiveSplitterFactory& splitterFactory; + const CentGeomBBox3fa& root_info; + }; + } +} diff --git a/thirdparty/embree/kernels/builders/heuristic_strand_array.h b/thirdparty/embree/kernels/builders/heuristic_strand_array.h new file mode 100644 index 000000000000..ede0d04c78ab --- /dev/null +++ b/thirdparty/embree/kernels/builders/heuristic_strand_array.h @@ -0,0 +1,188 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "priminfo.h" +#include "../../common/algorithms/parallel_reduce.h" +#include "../../common/algorithms/parallel_partition.h" + +namespace embree +{ + namespace isa + { + /*! Performs standard object binning */ + struct HeuristicStrandSplit + { + typedef range Set; + + static const size_t PARALLEL_THRESHOLD = 10000; + static const size_t PARALLEL_FIND_BLOCK_SIZE = 4096; + static const size_t PARALLEL_PARTITION_BLOCK_SIZE = 64; + + /*! stores all information to perform some split */ + struct Split + { + /*! construct an invalid split by default */ + __forceinline Split() + : sah(inf), axis0(zero), axis1(zero) {} + + /*! constructs specified split */ + __forceinline Split(const float sah, const Vec3fa& axis0, const Vec3fa& axis1) + : sah(sah), axis0(axis0), axis1(axis1) {} + + /*! calculates standard surface area heuristic for the split */ + __forceinline float splitSAH() const { return sah; } + + /*! test if this split is valid */ + __forceinline bool valid() const { return sah != float(inf); } + + public: + float sah; //!< SAH cost of the split + Vec3fa axis0, axis1; //!< axis the two strands are aligned into + }; + + __forceinline HeuristicStrandSplit () // FIXME: required? + : scene(nullptr), prims(nullptr) {} + + /*! remember prim array */ + __forceinline HeuristicStrandSplit (Scene* scene, PrimRef* prims) + : scene(scene), prims(prims) {} + + __forceinline const Vec3fa direction(const PrimRef& prim) { + return scene->get(prim.geomID())->computeDirection(prim.primID()); + } + + __forceinline const BBox3fa bounds(const PrimRef& prim) { + return scene->get(prim.geomID())->vbounds(prim.primID()); + } + + __forceinline const BBox3fa bounds(const LinearSpace3fa& space, const PrimRef& prim) { + return scene->get(prim.geomID())->vbounds(space,prim.primID()); + } + + /*! finds the best split */ + const Split find(const range& set, size_t logBlockSize) + { + Vec3fa axis0(0,0,1); + uint64_t bestGeomPrimID = -1; + + /* curve with minimum ID determines first axis */ + for (size_t i=set.begin(); i= bestGeomPrimID) continue; + const Vec3fa axis = direction(prims[i]); + if (sqr_length(axis) > 1E-18f) { + axis0 = normalize(axis); + bestGeomPrimID = geomprimID; + } + } + + /* find 2nd axis that is most misaligned with first axis and has minimum ID */ + float bestCos = 1.0f; + Vec3fa axis1 = axis0; + bestGeomPrimID = -1; + for (size_t i=set.begin(); i cos1) { lnum++; lbounds.extend(bounds(space0,prim)); } + else { rnum++; rbounds.extend(bounds(space1,prim)); } + } + + /*! return an invalid split if we do not partition */ + if (lnum == 0 || rnum == 0) + return Split(inf,axis0,axis1); + + /*! calculate sah for the split */ + const size_t lblocks = (lnum+(1ull<> logBlockSize; + const size_t rblocks = (rnum+(1ull<> logBlockSize; + const float sah = madd(float(lblocks),halfArea(lbounds),float(rblocks)*halfArea(rbounds)); + return Split(sah,axis0,axis1); + } + + /*! array partitioning */ + void split(const Split& split, const PrimInfoRange& set, PrimInfoRange& lset, PrimInfoRange& rset) + { + if (!split.valid()) { + deterministic_order(set); + return splitFallback(set,lset,rset); + } + + const size_t begin = set.begin(); + const size_t end = set.end(); + CentGeomBBox3fa local_left(empty); + CentGeomBBox3fa local_right(empty); + + auto primOnLeftSide = [&] (const PrimRef& prim) -> bool { + const Vec3fa axisi = normalize(direction(prim)); + const float cos0 = abs(dot(axisi,split.axis0)); + const float cos1 = abs(dot(axisi,split.axis1)); + return cos0 > cos1; + }; + + auto mergePrimBounds = [this] (CentGeomBBox3fa& pinfo,const PrimRef& ref) { + pinfo.extend(bounds(ref)); + }; + + size_t center = serial_partitioning(prims,begin,end,local_left,local_right,primOnLeftSide,mergePrimBounds); + + new (&lset) PrimInfoRange(begin,center,local_left); + new (&rset) PrimInfoRange(center,end,local_right); + assert(area(lset.geomBounds) >= 0.0f); + assert(area(rset.geomBounds) >= 0.0f); + } + + void deterministic_order(const Set& set) + { + /* required as parallel partition destroys original primitive order */ + std::sort(&prims[set.begin()],&prims[set.end()]); + } + + void splitFallback(const Set& set, PrimInfoRange& lset, PrimInfoRange& rset) + { + const size_t begin = set.begin(); + const size_t end = set.end(); + const size_t center = (begin + end)/2; + + CentGeomBBox3fa left(empty); + for (size_t i=begin; i + struct HeuristicMBlurTemporalSplit + { + typedef BinSplit Split; + typedef mvector* PrimRefVector; + typedef typename PrimRefMB::BBox BBox; + + static const size_t PARALLEL_THRESHOLD = 3 * 1024; + static const size_t PARALLEL_FIND_BLOCK_SIZE = 1024; + static const size_t PARALLEL_PARTITION_BLOCK_SIZE = 128; + + HeuristicMBlurTemporalSplit (MemoryMonitorInterface* device, const RecalculatePrimRef& recalculatePrimRef) + : device(device), recalculatePrimRef(recalculatePrimRef) {} + + struct TemporalBinInfo + { + __forceinline TemporalBinInfo () { + } + + __forceinline TemporalBinInfo (EmptyTy) + { + for (size_t i=0; i= time_range.upper) continue; + const BBox1f dt0(time_range.lower,center_time); + const BBox1f dt1(center_time,time_range.upper); + + /* find linear bounds for both time segments */ + for (size_t i=begin; i& r) -> TemporalBinInfo { + TemporalBinInfo binner(empty); binner.bin(prims, r.begin(), r.end(), time_range, set, recalculatePrimRef); return binner; + }; + *this = parallel_reduce(begin,end,blockSize,TemporalBinInfo(empty),bin,merge2); + } + } + + /*! merges in other binning information */ + __forceinline void merge (const TemporalBinInfo& other) + { + for (size_t i=0; i= time_range.upper) continue; + const BBox1f dt0(time_range.lower,center_time); + const BBox1f dt1(center_time,time_range.upper); + + /* calculate sah */ + const size_t lCount = (count0[b]+(size_t(1) << logBlockSize)-1) >> int(logBlockSize); + const size_t rCount = (count1[b]+(size_t(1) << logBlockSize)-1) >> int(logBlockSize); + float sah0 = expectedApproxHalfArea(bounds0[b])*float(lCount)*dt0.size(); + float sah1 = expectedApproxHalfArea(bounds1[b])*float(rCount)*dt1.size(); + if (unlikely(lCount == 0)) sah0 = 0.0f; // happens for initial splits when objects not alive over entire shutter time + if (unlikely(rCount == 0)) sah1 = 0.0f; + const float sah = sah0+sah1; + if (sah < bestSAH) { + bestSAH = sah; + bestPos = center_time; + } + } + return Split(bestSAH*MBLUR_TIME_SPLIT_THRESHOLD,(unsigned)Split::SPLIT_TEMPORAL,0,bestPos); + } + + public: + size_t count0[BINS-1]; + size_t count1[BINS-1]; + BBox bounds0[BINS-1]; + BBox bounds1[BINS-1]; + }; + + /*! finds the best split */ + const Split find(const SetMB& set, const size_t logBlockSize) + { + assert(set.size() > 0); + TemporalBinInfo binner(empty); + binner.bin_parallel(set.prims->data(),set.begin(),set.end(),PARALLEL_FIND_BLOCK_SIZE,PARALLEL_THRESHOLD,set.time_range,set,recalculatePrimRef); + Split tsplit = binner.best((int)logBlockSize,set.time_range,set); + if (!tsplit.valid()) tsplit.data = Split::SPLIT_FALLBACK; // use fallback split + return tsplit; + } + + __forceinline std::unique_ptr> split(const Split& tsplit, const SetMB& set, SetMB& lset, SetMB& rset) + { + assert(tsplit.sah != float(inf)); + assert(tsplit.fpos > set.time_range.lower); + assert(tsplit.fpos < set.time_range.upper); + + float center_time = tsplit.fpos; + const BBox1f time_range0(set.time_range.lower,center_time); + const BBox1f time_range1(center_time,set.time_range.upper); + mvector& prims = *set.prims; + + /* calculate primrefs for first time range */ + std::unique_ptr> new_vector(new mvector(device, set.size())); + PrimRefVector lprims = new_vector.get(); + + auto reduction_func0 = [&] (const range& r) { + PrimInfoMB pinfo = empty; + for (size_t i=r.begin(); idata(), size_t(0), set.size(), size_t(1024), + [&](const PrimRefMB& prim) { return prim.time_range_overlap(time_range0); }); + + lset = SetMB(linfo,lprims,time_range0); + + /* calculate primrefs for second time range */ + auto reduction_func1 = [&] (const range& r) { + PrimInfoMB pinfo = empty; + for (size_t i=r.begin(); i(set.begin(), set.begin() + rinfo.size()); + + /* primrefs for second time range are in prims[set.begin() .. set.end()) */ + /* some primitives may need to be filtered out */ + if (rinfo.size() != set.size()) + rinfo.object_range._end = parallel_filter(prims.data(), set.begin(), set.end(), size_t(1024), + [&](const PrimRefMB& prim) { return prim.time_range_overlap(time_range1); }); + + rset = SetMB(rinfo,&prims,time_range1); + + return new_vector; + } + + private: + MemoryMonitorInterface* device; // device to report memory usage to + const RecalculatePrimRef recalculatePrimRef; + }; + } +} diff --git a/thirdparty/embree/kernels/builders/priminfo.h b/thirdparty/embree/kernels/builders/priminfo.h new file mode 100644 index 000000000000..06c1388742a3 --- /dev/null +++ b/thirdparty/embree/kernels/builders/priminfo.h @@ -0,0 +1,362 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/default.h" +#include "../common/primref.h" +#include "../common/primref_mb.h" + +namespace embree +{ + // FIXME: maybe there's a better place for this util fct + __forceinline float areaProjectedTriangle(const Vec3fa& v0, const Vec3fa& v1, const Vec3fa& v2) + { + const Vec3fa e0 = v1-v0; + const Vec3fa e1 = v2-v0; + const Vec3fa d = cross(e0,e1); + return fabs(d.x) + fabs(d.y) + fabs(d.z); + } + + //namespace isa + //{ + template + class CentGeom + { + public: + __forceinline CentGeom () {} + + __forceinline CentGeom (EmptyTy) + : geomBounds(empty), centBounds(empty) {} + + __forceinline CentGeom (const BBox& geomBounds, const BBox3fa& centBounds) + : geomBounds(geomBounds), centBounds(centBounds) {} + + template + __forceinline void extend_primref(const PrimRef& prim) + { + BBox bounds; Vec3fa center; + prim.binBoundsAndCenter(bounds,center); + geomBounds.extend(bounds); + centBounds.extend(center); + } + + template + __forceinline void extend_center2(const PrimRef& prim) + { + BBox3fa bounds = prim.bounds(); + geomBounds.extend(bounds); + centBounds.extend(bounds.center2()); + } + + __forceinline void extend(const BBox& geomBounds_) { + geomBounds.extend(geomBounds_); + centBounds.extend(center2(geomBounds_)); + } + + __forceinline void merge(const CentGeom& other) + { + geomBounds.extend(other.geomBounds); + centBounds.extend(other.centBounds); + } + + static __forceinline const CentGeom merge2(const CentGeom& a, const CentGeom& b) { + CentGeom r = a; r.merge(b); return r; + } + + public: + BBox geomBounds; //!< geometry bounds of primitives + BBox3fa centBounds; //!< centroid bounds of primitives + }; + + typedef CentGeom CentGeomBBox3fa; + + /*! stores bounding information for a set of primitives */ + template + class PrimInfoT : public CentGeom + { + public: + using CentGeom::geomBounds; + using CentGeom::centBounds; + + __forceinline PrimInfoT () {} + + __forceinline PrimInfoT (EmptyTy) + : CentGeom(empty), begin(0), end(0) {} + + __forceinline PrimInfoT (size_t begin, size_t end, const CentGeomBBox3fa& centGeomBounds) + : CentGeom(centGeomBounds), begin(begin), end(end) {} + + template + __forceinline void add_primref(const PrimRef& prim) + { + CentGeom::extend_primref(prim); + end++; + } + + template + __forceinline void add_center2(const PrimRef& prim) { + CentGeom::extend_center2(prim); + end++; + } + + template + __forceinline void add_center2(const PrimRef& prim, const size_t i) { + CentGeom::extend_center2(prim); + end+=i; + } + + /*__forceinline void add(const BBox& geomBounds_) { + CentGeom::extend(geomBounds_); + end++; + } + + __forceinline void add(const BBox& geomBounds_, const size_t i) { + CentGeom::extend(geomBounds_); + end+=i; + }*/ + + __forceinline void merge(const PrimInfoT& other) + { + CentGeom::merge(other); + begin += other.begin; + end += other.end; + } + + static __forceinline const PrimInfoT merge(const PrimInfoT& a, const PrimInfoT& b) { + PrimInfoT r = a; r.merge(b); return r; + } + + /*! returns the number of primitives */ + __forceinline size_t size() const { + return end-begin; + } + + __forceinline float halfArea() { + return expectedApproxHalfArea(geomBounds); + } + + __forceinline float leafSAH() const { + return expectedApproxHalfArea(geomBounds)*float(size()); + //return halfArea(geomBounds)*blocks(num); + } + + __forceinline float leafSAH(size_t block_shift) const { + return expectedApproxHalfArea(geomBounds)*float((size()+(size_t(1)<> block_shift); + //return halfArea(geomBounds)*float((num+3) >> 2); + //return halfArea(geomBounds)*blocks(num); + } + + /*! stream output */ + friend embree_ostream operator<<(embree_ostream cout, const PrimInfoT& pinfo) { + return cout << "PrimInfo { begin = " << pinfo.begin << ", end = " << pinfo.end << ", geomBounds = " << pinfo.geomBounds << ", centBounds = " << pinfo.centBounds << "}"; + } + + public: + size_t begin,end; //!< number of primitives + }; + + typedef PrimInfoT PrimInfo; + //typedef PrimInfoT PrimInfoMB; + + /*! stores bounding information for a set of primitives */ + template + class PrimInfoMBT : public CentGeom + { + public: + using CentGeom::geomBounds; + using CentGeom::centBounds; + + __forceinline PrimInfoMBT () { + } + + __forceinline PrimInfoMBT (EmptyTy) + : CentGeom(empty), object_range(0,0), num_time_segments(0), max_num_time_segments(0), max_time_range(0.0f,1.0f), time_range(1.0f,0.0f) {} + + __forceinline PrimInfoMBT (size_t begin, size_t end) + : CentGeom(empty), object_range(begin,end), num_time_segments(0), max_num_time_segments(0), max_time_range(0.0f,1.0f), time_range(1.0f,0.0f) {} + + template + __forceinline void add_primref(const PrimRef& prim) + { + CentGeom::extend_primref(prim); + time_range.extend(prim.time_range); + object_range._end++; + num_time_segments += prim.size(); + if (max_num_time_segments < prim.totalTimeSegments()) { + max_num_time_segments = prim.totalTimeSegments(); + max_time_range = prim.time_range; + } + } + + __forceinline void merge(const PrimInfoMBT& other) + { + CentGeom::merge(other); + time_range.extend(other.time_range); + object_range._begin += other.object_range.begin(); + object_range._end += other.object_range.end(); + num_time_segments += other.num_time_segments; + if (max_num_time_segments < other.max_num_time_segments) { + max_num_time_segments = other.max_num_time_segments; + max_time_range = other.max_time_range; + } + } + + static __forceinline const PrimInfoMBT merge2(const PrimInfoMBT& a, const PrimInfoMBT& b) { + PrimInfoMBT r = a; r.merge(b); return r; + } + + __forceinline size_t begin() const { + return object_range.begin(); + } + + __forceinline size_t end() const { + return object_range.end(); + } + + /*! returns the number of primitives */ + __forceinline size_t size() const { + return object_range.size(); + } + + __forceinline float halfArea() const { + return time_range.size()*expectedApproxHalfArea(geomBounds); + } + + __forceinline float leafSAH() const { + return time_range.size()*expectedApproxHalfArea(geomBounds)*float(num_time_segments); + } + + __forceinline float leafSAH(size_t block_shift) const { + return time_range.size()*expectedApproxHalfArea(geomBounds)*float((num_time_segments+(size_t(1)<> block_shift); + } + + __forceinline float align_time(float ct) const + { + //return roundf(ct * float(numTimeSegments)) / float(numTimeSegments); + float t0 = (ct-max_time_range.lower)/max_time_range.size(); + float t1 = roundf(t0 * float(max_num_time_segments)) / float(max_num_time_segments); + return t1*max_time_range.size()+max_time_range.lower; + } + + /*! stream output */ + friend embree_ostream operator<<(embree_ostream cout, const PrimInfoMBT& pinfo) + { + return cout << "PrimInfo { " << + "object_range = " << pinfo.object_range << + ", time_range = " << pinfo.time_range << + ", time_segments = " << pinfo.num_time_segments << + ", geomBounds = " << pinfo.geomBounds << + ", centBounds = " << pinfo.centBounds << + "}"; + } + + public: + range object_range; //!< primitive range + size_t num_time_segments; //!< total number of time segments of all added primrefs + size_t max_num_time_segments; //!< maximum number of time segments of a primitive + BBox1f max_time_range; //!< time range of primitive with max_num_time_segments + BBox1f time_range; //!< merged time range of primitives when merging prims, or additionally clipped with build time range when used in SetMB + }; + + typedef PrimInfoMBT PrimInfoMB; + + struct SetMB : public PrimInfoMB + { + static const size_t PARALLEL_THRESHOLD = 3 * 1024; + static const size_t PARALLEL_FIND_BLOCK_SIZE = 1024; + static const size_t PARALLEL_PARTITION_BLOCK_SIZE = 128; + + typedef mvector* PrimRefVector; + + __forceinline SetMB() {} + + __forceinline SetMB(const PrimInfoMB& pinfo_i, PrimRefVector prims) + : PrimInfoMB(pinfo_i), prims(prims) {} + + __forceinline SetMB(const PrimInfoMB& pinfo_i, PrimRefVector prims, range object_range_in, BBox1f time_range_in) + : PrimInfoMB(pinfo_i), prims(prims) + { + object_range = object_range_in; + time_range = intersect(time_range,time_range_in); + } + + __forceinline SetMB(const PrimInfoMB& pinfo_i, PrimRefVector prims, BBox1f time_range_in) + : PrimInfoMB(pinfo_i), prims(prims) + { + time_range = intersect(time_range,time_range_in); + } + + void deterministic_order() const + { + /* required as parallel partition destroys original primitive order */ + PrimRefMB* prim = prims->data(); + std::sort(&prim[object_range.begin()],&prim[object_range.end()]); + } + + template + __forceinline LBBox3fa linearBounds(const RecalculatePrimRef& recalculatePrimRef) const + { + auto reduce = [&](const range& r) -> LBBox3fa + { + LBBox3fa cbounds(empty); + for (size_t j = r.begin(); j < r.end(); j++) + { + PrimRefMB& ref = (*prims)[j]; + const LBBox3fa bn = recalculatePrimRef.linearBounds(ref, time_range); + cbounds.extend(bn); + }; + return cbounds; + }; + + return parallel_reduce(object_range.begin(), object_range.end(), PARALLEL_FIND_BLOCK_SIZE, PARALLEL_THRESHOLD, LBBox3fa(empty), + reduce, + [&](const LBBox3fa& b0, const LBBox3fa& b1) -> LBBox3fa { return embree::merge(b0, b1); }); + } + + template + __forceinline LBBox3fa linearBounds(const RecalculatePrimRef& recalculatePrimRef, const LinearSpace3fa& space) const + { + auto reduce = [&](const range& r) -> LBBox3fa + { + LBBox3fa cbounds(empty); + for (size_t j = r.begin(); j < r.end(); j++) + { + PrimRefMB& ref = (*prims)[j]; + const LBBox3fa bn = recalculatePrimRef.linearBounds(ref, time_range, space); + cbounds.extend(bn); + }; + return cbounds; + }; + + return parallel_reduce(object_range.begin(), object_range.end(), PARALLEL_FIND_BLOCK_SIZE, PARALLEL_THRESHOLD, LBBox3fa(empty), + reduce, + [&](const LBBox3fa& b0, const LBBox3fa& b1) -> LBBox3fa { return embree::merge(b0, b1); }); + } + + template + const SetMB primInfo(const RecalculatePrimRef& recalculatePrimRef, const LinearSpace3fa& space) const + { + auto computePrimInfo = [&](const range& r) -> PrimInfoMB + { + PrimInfoMB pinfo(empty); + for (size_t j=r.begin(); j& prims, BuildProgressMonitor& progressMonitor) + { + ParallelPrefixSumState pstate; + + /* first try */ + progressMonitor(0); + PrimInfo pinfo = parallel_prefix_sum( pstate, size_t(0), geometry->size(), size_t(1024), PrimInfo(empty), [&](const range& r, const PrimInfo& base) -> PrimInfo { + return geometry->createPrimRefArray(prims,r,r.begin(),geomID); + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + + /* if we need to filter out geometry, run again */ + if (pinfo.size() != prims.size()) + { + progressMonitor(0); + pinfo = parallel_prefix_sum( pstate, size_t(0), geometry->size(), size_t(1024), PrimInfo(empty), [&](const range& r, const PrimInfo& base) -> PrimInfo { + return geometry->createPrimRefArray(prims,r,base.size(),geomID); + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + } + return pinfo; + } + + PrimInfo createPrimRefArray(Scene* scene, Geometry::GTypeMask types, bool mblur, mvector& prims, BuildProgressMonitor& progressMonitor) + { + ParallelForForPrefixSumState pstate; + Scene::Iterator2 iter(scene,types,mblur); + + /* first try */ + progressMonitor(0); + pstate.init(iter,size_t(1024)); + PrimInfo pinfo = parallel_for_for_prefix_sum0( pstate, iter, PrimInfo(empty), [&](Geometry* mesh, const range& r, size_t k, size_t geomID) -> PrimInfo { + return mesh->createPrimRefArray(prims,r,k,(unsigned)geomID); + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + + /* if we need to filter out geometry, run again */ + if (pinfo.size() != prims.size()) + { + progressMonitor(0); + pinfo = parallel_for_for_prefix_sum1( pstate, iter, PrimInfo(empty), [&](Geometry* mesh, const range& r, size_t k, size_t geomID, const PrimInfo& base) -> PrimInfo { + return mesh->createPrimRefArray(prims,r,base.size(),(unsigned)geomID); + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + } + return pinfo; + } + + PrimInfo createPrimRefArrayMBlur(Scene* scene, Geometry::GTypeMask types, mvector& prims, BuildProgressMonitor& progressMonitor, size_t itime) + { + ParallelForForPrefixSumState pstate; + Scene::Iterator2 iter(scene,types,true); + + /* first try */ + progressMonitor(0); + pstate.init(iter,size_t(1024)); + PrimInfo pinfo = parallel_for_for_prefix_sum0( pstate, iter, PrimInfo(empty), [&](Geometry* mesh, const range& r, size_t k, size_t geomID) -> PrimInfo { + return mesh->createPrimRefArrayMB(prims,itime,r,k,(unsigned)geomID); + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + + /* if we need to filter out geometry, run again */ + if (pinfo.size() != prims.size()) + { + progressMonitor(0); + pinfo = parallel_for_for_prefix_sum1( pstate, iter, PrimInfo(empty), [&](Geometry* mesh, const range& r, size_t k, size_t geomID, const PrimInfo& base) -> PrimInfo { + return mesh->createPrimRefArrayMB(prims,itime,r,base.size(),(unsigned)geomID); + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + } + return pinfo; + } + + PrimInfoMB createPrimRefArrayMSMBlur(Scene* scene, Geometry::GTypeMask types, mvector& prims, BuildProgressMonitor& progressMonitor, BBox1f t0t1) + { + ParallelForForPrefixSumState pstate; + Scene::Iterator2 iter(scene,types,true); + + /* first try */ + progressMonitor(0); + pstate.init(iter,size_t(1024)); + PrimInfoMB pinfo = parallel_for_for_prefix_sum0( pstate, iter, PrimInfoMB(empty), [&](Geometry* mesh, const range& r, size_t k, size_t geomID) -> PrimInfoMB { + return mesh->createPrimRefMBArray(prims,t0t1,r,k,(unsigned)geomID); + }, [](const PrimInfoMB& a, const PrimInfoMB& b) -> PrimInfoMB { return PrimInfoMB::merge2(a,b); }); + + /* if we need to filter out geometry, run again */ + if (pinfo.size() != prims.size()) + { + progressMonitor(0); + pinfo = parallel_for_for_prefix_sum1( pstate, iter, PrimInfoMB(empty), [&](Geometry* mesh, const range& r, size_t k, size_t geomID, const PrimInfoMB& base) -> PrimInfoMB { + return mesh->createPrimRefMBArray(prims,t0t1,r,base.size(),(unsigned)geomID); + }, [](const PrimInfoMB& a, const PrimInfoMB& b) -> PrimInfoMB { return PrimInfoMB::merge2(a,b); }); + } + + /* the BVH starts with that time range, even though primitives might have smaller/larger time range */ + pinfo.time_range = t0t1; + return pinfo; + } + + template + size_t createMortonCodeArray(Mesh* mesh, mvector& morton, BuildProgressMonitor& progressMonitor) + { + size_t numPrimitives = morton.size(); + + /* compute scene bounds */ + std::pair cb_empty(0,empty); + auto cb = parallel_reduce + ( size_t(0), numPrimitives, size_t(1024), cb_empty, [&](const range& r) -> std::pair + { + size_t num = 0; + BBox3fa bounds = empty; + + for (size_t j=r.begin(); jbuildBounds(j,&prim_bounds))) continue; + bounds.extend(center2(prim_bounds)); + num++; + } + return std::make_pair(num,bounds); + }, [] (const std::pair& a, const std::pair& b) { + return std::make_pair(a.first + b.first,merge(a.second,b.second)); + }); + + + size_t numPrimitivesGen = cb.first; + const BBox3fa centBounds = cb.second; + + /* compute morton codes */ + if (likely(numPrimitivesGen == numPrimitives)) + { + /* fast path if all primitives were valid */ + BVHBuilderMorton::MortonCodeMapping mapping(centBounds); + parallel_for( size_t(0), numPrimitives, size_t(1024), [&](const range& r) -> void { + BVHBuilderMorton::MortonCodeGenerator generator(mapping,&morton.data()[r.begin()]); + for (size_t j=r.begin(); jbounds(j),unsigned(j)); + }); + } + else + { + /* slow path, fallback in case some primitives were invalid */ + ParallelPrefixSumState pstate; + BVHBuilderMorton::MortonCodeMapping mapping(centBounds); + parallel_prefix_sum( pstate, size_t(0), numPrimitives, size_t(1024), size_t(0), [&](const range& r, const size_t base) -> size_t { + size_t num = 0; + BVHBuilderMorton::MortonCodeGenerator generator(mapping,&morton.data()[r.begin()]); + for (size_t j=r.begin(); jbuildBounds(j,&bounds))) continue; + generator(bounds,unsigned(j)); + num++; + } + return num; + }, std::plus()); + + parallel_prefix_sum( pstate, size_t(0), numPrimitives, size_t(1024), size_t(0), [&](const range& r, const size_t base) -> size_t { + size_t num = 0; + BVHBuilderMorton::MortonCodeGenerator generator(mapping,&morton.data()[base]); + for (size_t j=r.begin(); jbuildBounds(j,&bounds)) continue; + generator(bounds,unsigned(j)); + num++; + } + return num; + }, std::plus()); + } + return numPrimitivesGen; + } + + // ==================================================================================================== + // ==================================================================================================== + // ==================================================================================================== + + // template for grid meshes + +#if 0 + template<> + PrimInfo createPrimRefArray(Scene* scene, mvector& prims, BuildProgressMonitor& progressMonitor) + { + PING; + ParallelForForPrefixSumState pstate; + Scene::Iterator iter(scene); + + /* first try */ + progressMonitor(0); + pstate.init(iter,size_t(1024)); + PrimInfo pinfo = parallel_for_for_prefix_sum0( pstate, iter, PrimInfo(empty), [&](GridMesh* mesh, const range& r, size_t k) -> PrimInfo + { + PrimInfo pinfo(empty); + for (size_t j=r.begin(); jbuildBounds(j,&bounds)) continue; + const PrimRef prim(bounds,mesh->geomID,unsigned(j)); + pinfo.add_center2(prim); + prims[k++] = prim; + } + return pinfo; + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + + /* if we need to filter out geometry, run again */ + if (pinfo.size() != prims.size()) + { + progressMonitor(0); + pinfo = parallel_for_for_prefix_sum1( pstate, iter, PrimInfo(empty), [&](GridMesh* mesh, const range& r, size_t k, const PrimInfo& base) -> PrimInfo + { + k = base.size(); + PrimInfo pinfo(empty); + for (size_t j=r.begin(); jbuildBounds(j,&bounds)) continue; + const PrimRef prim(bounds,mesh->geomID,unsigned(j)); + pinfo.add_center2(prim); + prims[k++] = prim; + } + return pinfo; + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + } + return pinfo; + } +#endif + + // ==================================================================================================== + // ==================================================================================================== + // ==================================================================================================== + + IF_ENABLED_TRIS (template size_t createMortonCodeArray(TriangleMesh* mesh COMMA mvector& morton COMMA BuildProgressMonitor& progressMonitor)); + IF_ENABLED_QUADS(template size_t createMortonCodeArray(QuadMesh* mesh COMMA mvector& morton COMMA BuildProgressMonitor& progressMonitor)); + IF_ENABLED_USER (template size_t createMortonCodeArray(UserGeometry* mesh COMMA mvector& morton COMMA BuildProgressMonitor& progressMonitor)); + IF_ENABLED_INSTANCE (template size_t createMortonCodeArray(Instance* mesh COMMA mvector& morton COMMA BuildProgressMonitor& progressMonitor)); + } +} diff --git a/thirdparty/embree/kernels/builders/primrefgen.h b/thirdparty/embree/kernels/builders/primrefgen.h new file mode 100644 index 000000000000..9919c945c31a --- /dev/null +++ b/thirdparty/embree/kernels/builders/primrefgen.h @@ -0,0 +1,28 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/scene.h" +#include "../common/primref.h" +#include "../common/primref_mb.h" +#include "priminfo.h" +#include "bvh_builder_morton.h" + +namespace embree +{ + namespace isa + { + PrimInfo createPrimRefArray(Geometry* geometry, unsigned int geomID, mvector& prims, BuildProgressMonitor& progressMonitor); + + PrimInfo createPrimRefArray(Scene* scene, Geometry::GTypeMask types, bool mblur, mvector& prims, BuildProgressMonitor& progressMonitor); + + PrimInfo createPrimRefArrayMBlur(Scene* scene, Geometry::GTypeMask types, mvector& prims, BuildProgressMonitor& progressMonitor, size_t itime = 0); + + PrimInfoMB createPrimRefArrayMSMBlur(Scene* scene, Geometry::GTypeMask types, mvector& prims, BuildProgressMonitor& progressMonitor, BBox1f t0t1 = BBox1f(0.0f,1.0f)); + + template + size_t createMortonCodeArray(Mesh* mesh, mvector& morton, BuildProgressMonitor& progressMonitor); + } +} + diff --git a/thirdparty/embree/kernels/builders/primrefgen_presplit.h b/thirdparty/embree/kernels/builders/primrefgen_presplit.h new file mode 100644 index 000000000000..8bdb38b955f7 --- /dev/null +++ b/thirdparty/embree/kernels/builders/primrefgen_presplit.h @@ -0,0 +1,371 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../builders/primrefgen.h" +#include "../builders/heuristic_spatial.h" +#include "../builders/splitter.h" + +#include "../../common/algorithms/parallel_for_for.h" +#include "../../common/algorithms/parallel_for_for_prefix_sum.h" + +#define DBG_PRESPLIT(x) +#define CHECK_PRESPLIT(x) + +#define GRID_SIZE 1024 +#define MAX_PRESPLITS_PER_PRIMITIVE_LOG 5 +#define MAX_PRESPLITS_PER_PRIMITIVE (1<(priority); + } + __forceinline bool operator < (const PresplitItem& item) const + { + return (priority < item.priority); + } + + template + __forceinline static float compute_priority(const PrimRef &ref, Scene *scene, const Vec2i &mc) + { + const unsigned int geomID = ref.geomID(); + const unsigned int primID = ref.primID(); + const float area_aabb = area(ref.bounds()); + const float area_prim = ((Mesh*)scene->get(geomID))->projectedPrimitiveArea(primID); + const unsigned int diff = 31 - lzcnt(mc.x^mc.y); + assert(area_prim <= area_aabb); + //const float priority = powf((area_aabb - area_prim) * powf(PRIORITY_SPLIT_POS_WEIGHT,(float)diff),1.0f/4.0f); + const float priority = sqrtf(sqrtf( (area_aabb - area_prim) * powf(PRIORITY_SPLIT_POS_WEIGHT,(float)diff) )); + assert(priority >= 0.0f && priority < FLT_LARGE); + return priority; + } + + + }; + + inline std::ostream &operator<<(std::ostream &cout, const PresplitItem& item) { + return cout << "index " << item.index << " priority " << item.priority; + }; + + template + void splitPrimitive(SplitterFactory &Splitter, + const PrimRef &prim, + const unsigned int geomID, + const unsigned int primID, + const unsigned int split_level, + const Vec3fa &grid_base, + const float grid_scale, + const float grid_extend, + PrimRef subPrims[MAX_PRESPLITS_PER_PRIMITIVE], + unsigned int& numSubPrims) + { + assert(split_level <= MAX_PRESPLITS_PER_PRIMITIVE_LOG); + if (split_level == 0) + { + assert(numSubPrims < MAX_PRESPLITS_PER_PRIMITIVE); + subPrims[numSubPrims++] = prim; + } + else + { + const Vec3fa lower = prim.lower; + const Vec3fa upper = prim.upper; + const Vec3fa glower = (lower-grid_base)*Vec3fa(grid_scale)+Vec3fa(0.2f); + const Vec3fa gupper = (upper-grid_base)*Vec3fa(grid_scale)-Vec3fa(0.2f); + Vec3ia ilower(floor(glower)); + Vec3ia iupper(floor(gupper)); + + /* this ignores dimensions that are empty */ + iupper = (Vec3ia)(select(vint4(glower) >= vint4(gupper),vint4(ilower),vint4(iupper))); + + /* compute a morton code for the lower and upper grid coordinates. */ + const unsigned int lower_code = bitInterleave(ilower.x,ilower.y,ilower.z); + const unsigned int upper_code = bitInterleave(iupper.x,iupper.y,iupper.z); + + /* if all bits are equal then we cannot split */ + if(unlikely(lower_code == upper_code)) + { + assert(numSubPrims < MAX_PRESPLITS_PER_PRIMITIVE); + subPrims[numSubPrims++] = prim; + return; + } + + /* compute octree level and dimension to perform the split in */ + const unsigned int diff = 31 - lzcnt(lower_code^upper_code); + const unsigned int level = diff / 3; + const unsigned int dim = diff % 3; + + /* now we compute the grid position of the split */ + const unsigned int isplit = iupper[dim] & ~((1<= fsplit); + + /* split primitive */ + const auto splitter = Splitter(prim); + BBox3fa left,right; + splitter(prim.bounds(),dim,fsplit,left,right); + assert(!left.empty()); + assert(!right.empty()); + + + splitPrimitive(Splitter,PrimRef(left ,geomID,primID),geomID,primID,split_level-1,grid_base,grid_scale,grid_extend,subPrims,numSubPrims); + splitPrimitive(Splitter,PrimRef(right,geomID,primID),geomID,primID,split_level-1,grid_base,grid_scale,grid_extend,subPrims,numSubPrims); + } + } + + + template + PrimInfo createPrimRefArray_presplit(Geometry* geometry, unsigned int geomID, size_t numPrimRefs, mvector& prims, BuildProgressMonitor& progressMonitor) + { + ParallelPrefixSumState pstate; + + /* first try */ + progressMonitor(0); + PrimInfo pinfo = parallel_prefix_sum( pstate, size_t(0), geometry->size(), size_t(1024), PrimInfo(empty), [&](const range& r, const PrimInfo& base) -> PrimInfo { + return geometry->createPrimRefArray(prims,r,r.begin(),geomID); + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + + /* if we need to filter out geometry, run again */ + if (pinfo.size() != numPrimRefs) + { + progressMonitor(0); + pinfo = parallel_prefix_sum( pstate, size_t(0), geometry->size(), size_t(1024), PrimInfo(empty), [&](const range& r, const PrimInfo& base) -> PrimInfo { + return geometry->createPrimRefArray(prims,r,base.size(),geomID); + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + } + return pinfo; + } + + __forceinline Vec2i computeMC(const Vec3fa &grid_base, const float grid_scale, const PrimRef &ref) + { + const Vec3fa lower = ref.lower; + const Vec3fa upper = ref.upper; + const Vec3fa glower = (lower-grid_base)*Vec3fa(grid_scale)+Vec3fa(0.2f); + const Vec3fa gupper = (upper-grid_base)*Vec3fa(grid_scale)-Vec3fa(0.2f); + Vec3ia ilower(floor(glower)); + Vec3ia iupper(floor(gupper)); + + /* this ignores dimensions that are empty */ + iupper = (Vec3ia)select(vint4(glower) >= vint4(gupper),vint4(ilower),vint4(iupper)); + + /* compute a morton code for the lower and upper grid coordinates. */ + const unsigned int lower_code = bitInterleave(ilower.x,ilower.y,ilower.z); + const unsigned int upper_code = bitInterleave(iupper.x,iupper.y,iupper.z); + return Vec2i(lower_code,upper_code); + } + + template + PrimInfo createPrimRefArray_presplit(Scene* scene, Geometry::GTypeMask types, bool mblur, size_t numPrimRefs, mvector& prims, BuildProgressMonitor& progressMonitor) + { + static const size_t MIN_STEP_SIZE = 128; + + ParallelForForPrefixSumState pstate; + Scene::Iterator2 iter(scene,types,mblur); + + /* first try */ + progressMonitor(0); + pstate.init(iter,size_t(1024)); + PrimInfo pinfo = parallel_for_for_prefix_sum0( pstate, iter, PrimInfo(empty), [&](Geometry* mesh, const range& r, size_t k, size_t geomID) -> PrimInfo { + return mesh->createPrimRefArray(prims,r,k,(unsigned)geomID); + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + + /* if we need to filter out geometry, run again */ + if (pinfo.size() != numPrimRefs) + { + progressMonitor(0); + pinfo = parallel_for_for_prefix_sum1( pstate, iter, PrimInfo(empty), [&](Geometry* mesh, const range& r, size_t k, size_t geomID, const PrimInfo& base) -> PrimInfo { + return mesh->createPrimRefArray(prims,r,base.size(),(unsigned)geomID); + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + } + + /* use correct number of primitives */ + size_t numPrimitives = pinfo.size(); + const size_t alloc_numPrimitives = prims.size(); + const size_t numSplitPrimitivesBudget = alloc_numPrimitives - numPrimitives; + + /* set up primitive splitter */ + SplitterFactory Splitter(scene); + + + DBG_PRESPLIT( + const size_t org_numPrimitives = pinfo.size(); + PRINT(numPrimitives); + PRINT(alloc_numPrimitives); + PRINT(numSplitPrimitivesBudget); + ); + + /* allocate double buffer presplit items */ + const size_t presplit_allocation_size = sizeof(PresplitItem)*alloc_numPrimitives; + PresplitItem *presplitItem = (PresplitItem*)alignedMalloc(presplit_allocation_size,64); + PresplitItem *tmp_presplitItem = (PresplitItem*)alignedMalloc(presplit_allocation_size,64); + + /* compute grid */ + const Vec3fa grid_base = pinfo.geomBounds.lower; + const Vec3fa grid_diag = pinfo.geomBounds.size(); + const float grid_extend = max(grid_diag.x,max(grid_diag.y,grid_diag.z)); + const float grid_scale = grid_extend == 0.0f ? 0.0f : GRID_SIZE / grid_extend; + + /* init presplit items and get total sum */ + const float psum = parallel_reduce( size_t(0), numPrimitives, size_t(MIN_STEP_SIZE), 0.0f, [&](const range& r) -> float { + float sum = 0.0f; + for (size_t i=r.begin(); i(prims[i],scene,mc) : 0.0f; + /* FIXME: sum undeterministic */ + sum += presplitItem[i].priority; + } + return sum; + },[](const float& a, const float& b) -> float { return a+b; }); + + /* compute number of splits per primitive */ + const float inv_psum = 1.0f / psum; + parallel_for( size_t(0), numPrimitives, size_t(MIN_STEP_SIZE), [&](const range& r) -> void { + for (size_t i=r.begin(); i 0.0f) + { + const float rel_p = (float)numSplitPrimitivesBudget * presplitItem[i].priority * inv_psum; + if (rel_p >= PRIORITY_CUTOFF_THRESHOLD) // need at least a split budget that generates two sub-prims + { + presplitItem[i].priority = max(min(ceilf(logf(rel_p)/logf(2.0f)),(float)MAX_PRESPLITS_PER_PRIMITIVE_LOG),1.0f); + //presplitItem[i].priority = min(floorf(logf(rel_p)/logf(2.0f)),(float)MAX_PRESPLITS_PER_PRIMITIVE_LOG); + assert(presplitItem[i].priority >= 0.0f && presplitItem[i].priority <= (float)MAX_PRESPLITS_PER_PRIMITIVE_LOG); + } + else + presplitItem[i].priority = 0.0f; + } + } + }); + + auto isLeft = [&] (const PresplitItem &ref) { return ref.priority < PRIORITY_CUTOFF_THRESHOLD; }; + size_t center = parallel_partitioning(presplitItem,0,numPrimitives,isLeft,1024); + + /* anything to split ? */ + if (center < numPrimitives) + { + const size_t numPrimitivesToSplit = numPrimitives - center; + assert(presplitItem[center].priority >= 1.0f); + + /* sort presplit items in ascending order */ + radix_sort_u32(presplitItem + center,tmp_presplitItem + center,numPrimitivesToSplit,1024); + + CHECK_PRESPLIT( + parallel_for( size_t(center+1), numPrimitives, size_t(MIN_STEP_SIZE), [&](const range& r) -> void { + for (size_t i=r.begin(); i& t) -> size_t { + size_t sum = 0; + for (size_t i=t.begin(); i= 1.0f); + const unsigned int primrefID = presplitItem[i].index; + const float prio = presplitItem[i].priority; + const unsigned int geomID = prims[primrefID].geomID(); + const unsigned int primID = prims[primrefID].primID(); + const unsigned int split_levels = (unsigned int)prio; + unsigned int numSubPrims = 0; + splitPrimitive(Splitter,prims[primrefID],geomID,primID,split_levels,grid_base,grid_scale,grid_extend,subPrims,numSubPrims); + assert(numSubPrims); + numSubPrims--; // can reuse slot + sum+=numSubPrims; + presplitItem[i].data = (numSubPrims << MAX_PRESPLITS_PER_PRIMITIVE_LOG) | split_levels; + primOffset0[i-center] = numSubPrims; + } + return sum; + },[](const size_t& a, const size_t& b) -> size_t { return a+b; }); + + /* if we are over budget, need to shrink the range */ + if (totalNumSubPrims > numSplitPrimitivesBudget) + { + size_t new_center = numPrimitives-1; + size_t sum = 0; + for (;new_center>=center;new_center--) + { + const unsigned int numSubPrims = presplitItem[new_center].data >> MAX_PRESPLITS_PER_PRIMITIVE_LOG; + if (unlikely(sum + numSubPrims >= numSplitPrimitivesBudget)) break; + sum += numSubPrims; + } + new_center++; + center = new_center; + } + + /* parallel prefix sum to compute offsets for storing sub-primitives */ + const unsigned int offset = parallel_prefix_sum(primOffset0,primOffset1,numPrimitivesToSplit,(unsigned int)0,std::plus()); + + /* iterate over range, and split primitives into sub primitives and append them to prims array */ + parallel_for( size_t(center), numPrimitives, size_t(MIN_STEP_SIZE), [&](const range& rn) -> void { + for (size_t j=rn.begin(); j& r) -> PrimInfo { + PrimInfo p(empty); + for (size_t j=r.begin(); j PrimInfo { return PrimInfo::merge(a,b); }); + + assert(pinfo.size() == numPrimitives); + + /* free double buffer presplit items */ + alignedFree(tmp_presplitItem); + alignedFree(presplitItem); + return pinfo; + } + } +} diff --git a/thirdparty/embree/kernels/builders/splitter.h b/thirdparty/embree/kernels/builders/splitter.h new file mode 100644 index 000000000000..dbd6cf07c7b2 --- /dev/null +++ b/thirdparty/embree/kernels/builders/splitter.h @@ -0,0 +1,169 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/scene.h" +#include "../common/primref.h" + +namespace embree +{ + namespace isa + { + template + __forceinline void splitPolygon(const BBox3fa& bounds, + const size_t dim, + const float pos, + const Vec3fa (&v)[N+1], + const Vec3fa (&inv_length)[N], + BBox3fa& left_o, + BBox3fa& right_o) + { + BBox3fa left = empty, right = empty; + /* clip triangle to left and right box by processing all edges */ + for (size_t i=0; i= pos) right.extend(v0); // this point is on right side + + if ((v0d < pos && pos < v1d) || (v1d < pos && pos < v0d)) // the edge crosses the splitting location + { + assert((v1d-v0d) != 0.0f); + const Vec3fa c = madd(Vec3fa((pos-v0d)*inv_length[i][dim]),v1-v0,v0); + left.extend(c); + right.extend(c); + } + } + + /* clip against current bounds */ + left_o = intersect(left,bounds); + right_o = intersect(right,bounds); + } + + template + __forceinline void splitPolygon(const PrimRef& prim, + const size_t dim, + const float pos, + const Vec3fa (&v)[N+1], + PrimRef& left_o, + PrimRef& right_o) + { + BBox3fa left = empty, right = empty; + for (size_t i=0; i= pos) right.extend(v0); // this point is on right side + + if ((v0d < pos && pos < v1d) || (v1d < pos && pos < v0d)) // the edge crosses the splitting location + { + assert((v1d-v0d) != 0.0f); + const float inv_length = 1.0f/(v1d-v0d); + const Vec3fa c = madd(Vec3fa((pos-v0d)*inv_length),v1-v0,v0); + left.extend(c); + right.extend(c); + } + } + + /* clip against current bounds */ + new (&left_o ) PrimRef(intersect(left ,prim.bounds()),prim.geomID(), prim.primID()); + new (&right_o) PrimRef(intersect(right,prim.bounds()),prim.geomID(), prim.primID()); + } + + struct TriangleSplitter + { + __forceinline TriangleSplitter(const Scene* scene, const PrimRef& prim) + { + const unsigned int mask = 0xFFFFFFFF >> RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS; + const TriangleMesh* mesh = (const TriangleMesh*) scene->get(prim.geomID() & mask ); + TriangleMesh::Triangle tri = mesh->triangle(prim.primID()); + v[0] = mesh->vertex(tri.v[0]); + v[1] = mesh->vertex(tri.v[1]); + v[2] = mesh->vertex(tri.v[2]); + v[3] = mesh->vertex(tri.v[0]); + inv_length[0] = Vec3fa(1.0f) / (v[1]-v[0]); + inv_length[1] = Vec3fa(1.0f) / (v[2]-v[1]); + inv_length[2] = Vec3fa(1.0f) / (v[0]-v[2]); + } + + __forceinline void operator() (const PrimRef& prim, const size_t dim, const float pos, PrimRef& left_o, PrimRef& right_o) const { + splitPolygon<3>(prim,dim,pos,v,left_o,right_o); + } + + __forceinline void operator() (const BBox3fa& prim, const size_t dim, const float pos, BBox3fa& left_o, BBox3fa& right_o) const { + splitPolygon<3>(prim,dim,pos,v,inv_length,left_o,right_o); + } + + private: + Vec3fa v[4]; + Vec3fa inv_length[3]; + }; + + struct TriangleSplitterFactory + { + __forceinline TriangleSplitterFactory(const Scene* scene) + : scene(scene) {} + + __forceinline TriangleSplitter operator() (const PrimRef& prim) const { + return TriangleSplitter(scene,prim); + } + + private: + const Scene* scene; + }; + + struct QuadSplitter + { + __forceinline QuadSplitter(const Scene* scene, const PrimRef& prim) + { + const unsigned int mask = 0xFFFFFFFF >> RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS; + const QuadMesh* mesh = (const QuadMesh*) scene->get(prim.geomID() & mask ); + QuadMesh::Quad quad = mesh->quad(prim.primID()); + v[0] = mesh->vertex(quad.v[0]); + v[1] = mesh->vertex(quad.v[1]); + v[2] = mesh->vertex(quad.v[2]); + v[3] = mesh->vertex(quad.v[3]); + v[4] = mesh->vertex(quad.v[0]); + inv_length[0] = Vec3fa(1.0f) / (v[1]-v[0]); + inv_length[1] = Vec3fa(1.0f) / (v[2]-v[1]); + inv_length[2] = Vec3fa(1.0f) / (v[3]-v[2]); + inv_length[3] = Vec3fa(1.0f) / (v[0]-v[3]); + } + + __forceinline void operator() (const PrimRef& prim, const size_t dim, const float pos, PrimRef& left_o, PrimRef& right_o) const { + splitPolygon<4>(prim,dim,pos,v,left_o,right_o); + } + + __forceinline void operator() (const BBox3fa& prim, const size_t dim, const float pos, BBox3fa& left_o, BBox3fa& right_o) const { + splitPolygon<4>(prim,dim,pos,v,inv_length,left_o,right_o); + } + + private: + Vec3fa v[5]; + Vec3fa inv_length[4]; + }; + + struct QuadSplitterFactory + { + __forceinline QuadSplitterFactory(const Scene* scene) + : scene(scene) {} + + __forceinline QuadSplitter operator() (const PrimRef& prim) const { + return QuadSplitter(scene,prim); + } + + private: + const Scene* scene; + }; + } +} + diff --git a/thirdparty/embree/kernels/bvh/bvh.cpp b/thirdparty/embree/kernels/bvh/bvh.cpp new file mode 100644 index 000000000000..9dbb3bcd732e --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh.cpp @@ -0,0 +1,190 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh.h" +#include "bvh_statistics.h" + +namespace embree +{ + template + BVHN::BVHN (const PrimitiveType& primTy, Scene* scene) + : AccelData((N==4) ? AccelData::TY_BVH4 : (N==8) ? AccelData::TY_BVH8 : AccelData::TY_UNKNOWN), + primTy(&primTy), device(scene->device), scene(scene), + root(emptyNode), alloc(scene->device,scene->isStaticAccel()), numPrimitives(0), numVertices(0) + { + } + + template + BVHN::~BVHN () + { + for (size_t i=0; i + void BVHN::clear() + { + set(BVHN::emptyNode,empty,0); + alloc.clear(); + } + + template + void BVHN::set (NodeRef root, const LBBox3fa& bounds, size_t numPrimitives) + { + this->root = root; + this->bounds = bounds; + this->numPrimitives = numPrimitives; + } + + template + void BVHN::clearBarrier(NodeRef& node) + { + if (node.isBarrier()) + node.clearBarrier(); + else if (!node.isLeaf()) { + BaseNode* n = node.baseNode(); // FIXME: flags should be stored in BVH + for (size_t c=0; cchild(c)); + } + } + + template + void BVHN::layoutLargeNodes(size_t num) + { +#if defined(__X86_64__) // do not use tree rotations on 32 bit platforms, barrier bit in NodeRef will cause issues + struct NodeArea + { + __forceinline NodeArea() {} + + __forceinline NodeArea(NodeRef& node, const BBox3fa& bounds) + : node(&node), A(node.isLeaf() ? float(neg_inf) : area(bounds)) {} + + __forceinline bool operator< (const NodeArea& other) const { + return this->A < other.A; + } + + NodeRef* node; + float A; + }; + std::vector lst; + lst.reserve(num); + lst.push_back(NodeArea(root,empty)); + + while (lst.size() < num) + { + std::pop_heap(lst.begin(), lst.end()); + NodeArea n = lst.back(); lst.pop_back(); + if (!n.node->isAABBNode()) break; + AABBNode* node = n.node->getAABBNode(); + for (size_t i=0; ichild(i) == BVHN::emptyNode) continue; + lst.push_back(NodeArea(node->child(i),node->bounds(i))); + std::push_heap(lst.begin(), lst.end()); + } + } + + for (size_t i=0; isetBarrier(); + + root = layoutLargeNodesRecursion(root,alloc.getCachedAllocator()); +#endif + } + + template + typename BVHN::NodeRef BVHN::layoutLargeNodesRecursion(NodeRef& node, const FastAllocator::CachedAllocator& allocator) + { + if (node.isBarrier()) { + node.clearBarrier(); + return node; + } + else if (node.isAABBNode()) + { + AABBNode* oldnode = node.getAABBNode(); + AABBNode* newnode = (BVHN::AABBNode*) allocator.malloc0(sizeof(BVHN::AABBNode),byteNodeAlignment); + *newnode = *oldnode; + for (size_t c=0; cchild(c) = layoutLargeNodesRecursion(oldnode->child(c),allocator); + return encodeNode(newnode); + } + else return node; + } + + template + double BVHN::preBuild(const std::string& builderName) + { + if (builderName == "") + return inf; + + if (device->verbosity(2)) + { + Lock lock(g_printMutex); + std::cout << "building BVH" << N << (builderName.find("MBlur") != std::string::npos ? "MB" : "") << "<" << primTy->name() << "> using " << builderName << " ..." << std::endl << std::flush; + } + + double t0 = 0.0; + if (device->benchmark || device->verbosity(2)) t0 = getSeconds(); + return t0; + } + + template + void BVHN::postBuild(double t0) + { + if (t0 == double(inf)) + return; + + double dt = 0.0; + if (device->benchmark || device->verbosity(2)) + dt = getSeconds()-t0; + + std::unique_ptr> stat; + + /* print statistics */ + if (device->verbosity(2)) + { + if (!stat) stat.reset(new BVHNStatistics(this)); + const size_t usedBytes = alloc.getUsedBytes(); + Lock lock(g_printMutex); + std::cout << "finished BVH" << N << "<" << primTy->name() << "> : " << 1000.0f*dt << "ms, " << 1E-6*double(numPrimitives)/dt << " Mprim/s, " << 1E-9*double(usedBytes)/dt << " GB/s" << std::endl; + + if (device->verbosity(2)) + std::cout << stat->str(); + + if (device->verbosity(2)) + { + FastAllocator::AllStatistics stat(&alloc); + for (size_t i=0; ialloc); + + stat.print(numPrimitives); + } + + if (device->verbosity(3)) + { + alloc.print_blocks(); + for (size_t i=0; ialloc.print_blocks(); + } + + std::cout << std::flush; + } + + /* benchmark mode */ + if (device->benchmark) + { + if (!stat) stat.reset(new BVHNStatistics(this)); + Lock lock(g_printMutex); + std::cout << "BENCHMARK_BUILD " << dt << " " << double(numPrimitives)/dt << " " << stat->sah() << " " << stat->bytesUsed() << " BVH" << N << "<" << primTy->name() << ">" << std::endl << std::flush; + } + } + +#if defined(__AVX__) + template class BVHN<8>; +#endif + +#if !defined(__AVX__) || !defined(EMBREE_TARGET_SSE2) && !defined(EMBREE_TARGET_SSE42) + template class BVHN<4>; +#endif +} + diff --git a/thirdparty/embree/kernels/bvh/bvh.h b/thirdparty/embree/kernels/bvh/bvh.h new file mode 100644 index 000000000000..7c1a45b63288 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh.h @@ -0,0 +1,235 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +/* include all node types */ +#include "bvh_node_aabb.h" +#include "bvh_node_aabb_mb.h" +#include "bvh_node_aabb_mb4d.h" +#include "bvh_node_obb.h" +#include "bvh_node_obb_mb.h" +#include "bvh_node_qaabb.h" + +namespace embree +{ + /*! flags used to enable specific node types in intersectors */ + enum BVHNodeFlags + { + BVH_FLAG_ALIGNED_NODE = 0x00001, + BVH_FLAG_ALIGNED_NODE_MB = 0x00010, + BVH_FLAG_UNALIGNED_NODE = 0x00100, + BVH_FLAG_UNALIGNED_NODE_MB = 0x01000, + BVH_FLAG_QUANTIZED_NODE = 0x100000, + BVH_FLAG_ALIGNED_NODE_MB4D = 0x1000000, + + /* short versions */ + BVH_AN1 = BVH_FLAG_ALIGNED_NODE, + BVH_AN2 = BVH_FLAG_ALIGNED_NODE_MB, + BVH_AN2_AN4D = BVH_FLAG_ALIGNED_NODE_MB | BVH_FLAG_ALIGNED_NODE_MB4D, + BVH_UN1 = BVH_FLAG_UNALIGNED_NODE, + BVH_UN2 = BVH_FLAG_UNALIGNED_NODE_MB, + BVH_MB = BVH_FLAG_ALIGNED_NODE_MB | BVH_FLAG_UNALIGNED_NODE_MB | BVH_FLAG_ALIGNED_NODE_MB4D, + BVH_AN1_UN1 = BVH_FLAG_ALIGNED_NODE | BVH_FLAG_UNALIGNED_NODE, + BVH_AN2_UN2 = BVH_FLAG_ALIGNED_NODE_MB | BVH_FLAG_UNALIGNED_NODE_MB, + BVH_AN2_AN4D_UN2 = BVH_FLAG_ALIGNED_NODE_MB | BVH_FLAG_ALIGNED_NODE_MB4D | BVH_FLAG_UNALIGNED_NODE_MB, + BVH_QN1 = BVH_FLAG_QUANTIZED_NODE + }; + + /*! Multi BVH with N children. Each node stores the bounding box of + * it's N children as well as N child references. */ + template + class BVHN : public AccelData + { + ALIGNED_CLASS_(16); + public: + + /*! forward declaration of node ref type */ + typedef NodeRefPtr NodeRef; + typedef BaseNode_t BaseNode; + typedef AABBNode_t AABBNode; + typedef AABBNodeMB_t AABBNodeMB; + typedef AABBNodeMB4D_t AABBNodeMB4D; + typedef OBBNode_t OBBNode; + typedef OBBNodeMB_t OBBNodeMB; + typedef QuantizedBaseNode_t QuantizedBaseNode; + typedef QuantizedBaseNodeMB_t QuantizedBaseNodeMB; + typedef QuantizedNode_t QuantizedNode; + + /*! Number of bytes the nodes and primitives are minimally aligned to.*/ + static const size_t byteAlignment = 16; + static const size_t byteNodeAlignment = 4*N; + + /*! Empty node */ + static const size_t emptyNode = NodeRef::emptyNode; + + /*! Invalid node, used as marker in traversal */ + static const size_t invalidNode = NodeRef::invalidNode; + static const size_t popRay = NodeRef::popRay; + + /*! Maximum depth of the BVH. */ + static const size_t maxBuildDepth = 32; + static const size_t maxBuildDepthLeaf = maxBuildDepth+8; + static const size_t maxDepth = 2*maxBuildDepthLeaf; // 2x because of two level builder + + /*! Maximum number of primitive blocks in a leaf. */ + static const size_t maxLeafBlocks = NodeRef::maxLeafBlocks; + + public: + + /*! Builder interface to create allocator */ + struct CreateAlloc : public FastAllocator::Create { + __forceinline CreateAlloc (BVHN* bvh) : FastAllocator::Create(&bvh->alloc) {} + }; + + typedef BVHNodeRecord NodeRecord; + typedef BVHNodeRecordMB NodeRecordMB; + typedef BVHNodeRecordMB4D NodeRecordMB4D; + + public: + + /*! BVHN default constructor. */ + BVHN (const PrimitiveType& primTy, Scene* scene); + + /*! BVHN destruction */ + ~BVHN (); + + /*! clears the acceleration structure */ + void clear(); + + /*! sets BVH members after build */ + void set (NodeRef root, const LBBox3fa& bounds, size_t numPrimitives); + + /*! Clears the barrier bits of a subtree. */ + void clearBarrier(NodeRef& node); + + /*! lays out num large nodes of the BVH */ + void layoutLargeNodes(size_t num); + NodeRef layoutLargeNodesRecursion(NodeRef& node, const FastAllocator::CachedAllocator& allocator); + + /*! called by all builders before build starts */ + double preBuild(const std::string& builderName); + + /*! called by all builders after build ended */ + void postBuild(double t0); + + /*! allocator class */ + struct Allocator { + BVHN* bvh; + Allocator (BVHN* bvh) : bvh(bvh) {} + __forceinline void* operator() (size_t bytes) const { + return bvh->alloc._threadLocal()->malloc(&bvh->alloc,bytes); + } + }; + + /*! post build cleanup */ + void cleanup() { + alloc.cleanup(); + } + + public: + + /*! Encodes a node */ + static __forceinline NodeRef encodeNode(AABBNode* node) { return NodeRef::encodeNode(node); } + static __forceinline NodeRef encodeNode(AABBNodeMB* node) { return NodeRef::encodeNode(node); } + static __forceinline NodeRef encodeNode(AABBNodeMB4D* node) { return NodeRef::encodeNode(node); } + static __forceinline NodeRef encodeNode(OBBNode* node) { return NodeRef::encodeNode(node); } + static __forceinline NodeRef encodeNode(OBBNodeMB* node) { return NodeRef::encodeNode(node); } + static __forceinline NodeRef encodeLeaf(void* tri, size_t num) { return NodeRef::encodeLeaf(tri,num); } + static __forceinline NodeRef encodeTypedLeaf(void* ptr, size_t ty) { return NodeRef::encodeTypedLeaf(ptr,ty); } + + public: + + /*! Prefetches the node this reference points to */ + __forceinline static void prefetch(const NodeRef ref, int types=0) + { +#if defined(__AVX512PF__) // MIC + if (types != BVH_FLAG_QUANTIZED_NODE) { + prefetchL2(((char*)ref.ptr)+0*64); + prefetchL2(((char*)ref.ptr)+1*64); + if ((N >= 8) || (types > BVH_FLAG_ALIGNED_NODE)) { + prefetchL2(((char*)ref.ptr)+2*64); + prefetchL2(((char*)ref.ptr)+3*64); + } + if ((N >= 8) && (types > BVH_FLAG_ALIGNED_NODE)) { + /* KNL still needs L2 prefetches for large nodes */ + prefetchL2(((char*)ref.ptr)+4*64); + prefetchL2(((char*)ref.ptr)+5*64); + prefetchL2(((char*)ref.ptr)+6*64); + prefetchL2(((char*)ref.ptr)+7*64); + } + } + else + { + /* todo: reduce if 32bit offsets are enabled */ + prefetchL2(((char*)ref.ptr)+0*64); + prefetchL2(((char*)ref.ptr)+1*64); + prefetchL2(((char*)ref.ptr)+2*64); + } +#else + if (types != BVH_FLAG_QUANTIZED_NODE) { + prefetchL1(((char*)ref.ptr)+0*64); + prefetchL1(((char*)ref.ptr)+1*64); + if ((N >= 8) || (types > BVH_FLAG_ALIGNED_NODE)) { + prefetchL1(((char*)ref.ptr)+2*64); + prefetchL1(((char*)ref.ptr)+3*64); + } + if ((N >= 8) && (types > BVH_FLAG_ALIGNED_NODE)) { + /* deactivate for large nodes on Xeon, as it introduces regressions */ + //prefetchL1(((char*)ref.ptr)+4*64); + //prefetchL1(((char*)ref.ptr)+5*64); + //prefetchL1(((char*)ref.ptr)+6*64); + //prefetchL1(((char*)ref.ptr)+7*64); + } + } + else + { + /* todo: reduce if 32bit offsets are enabled */ + prefetchL1(((char*)ref.ptr)+0*64); + prefetchL1(((char*)ref.ptr)+1*64); + prefetchL1(((char*)ref.ptr)+2*64); + } +#endif + } + + __forceinline static void prefetchW(const NodeRef ref, int types=0) + { + embree::prefetchEX(((char*)ref.ptr)+0*64); + embree::prefetchEX(((char*)ref.ptr)+1*64); + if ((N >= 8) || (types > BVH_FLAG_ALIGNED_NODE)) { + embree::prefetchEX(((char*)ref.ptr)+2*64); + embree::prefetchEX(((char*)ref.ptr)+3*64); + } + if ((N >= 8) && (types > BVH_FLAG_ALIGNED_NODE)) { + embree::prefetchEX(((char*)ref.ptr)+4*64); + embree::prefetchEX(((char*)ref.ptr)+5*64); + embree::prefetchEX(((char*)ref.ptr)+6*64); + embree::prefetchEX(((char*)ref.ptr)+7*64); + } + } + + /*! bvh type information */ + public: + const PrimitiveType* primTy; //!< primitive type stored in the BVH + + /*! bvh data */ + public: + Device* device; //!< device pointer + Scene* scene; //!< scene pointer + NodeRef root; //!< root node + FastAllocator alloc; //!< allocator used to allocate nodes + + /*! statistics data */ + public: + size_t numPrimitives; //!< number of primitives the BVH is build over + size_t numVertices; //!< number of vertices the BVH references + + /*! data arrays for special builders */ + public: + std::vector objects; + vector_t> subdiv_patches; + }; + + typedef BVHN<4> BVH4; + typedef BVHN<8> BVH8; +} diff --git a/thirdparty/embree/kernels/bvh/bvh4_factory.cpp b/thirdparty/embree/kernels/bvh/bvh4_factory.cpp new file mode 100644 index 000000000000..23f4f63d45f0 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh4_factory.cpp @@ -0,0 +1,1325 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh4_factory.h" +#include "../bvh/bvh.h" + +#include "../geometry/curveNv.h" +#include "../geometry/curveNi.h" +#include "../geometry/curveNi_mb.h" +#include "../geometry/linei.h" +#include "../geometry/triangle.h" +#include "../geometry/trianglev.h" +#include "../geometry/trianglev_mb.h" +#include "../geometry/trianglei.h" +#include "../geometry/quadv.h" +#include "../geometry/quadi.h" +#include "../geometry/subdivpatch1.h" +#include "../geometry/object.h" +#include "../geometry/instance.h" +#include "../geometry/subgrid.h" +#include "../common/accelinstance.h" + +namespace embree +{ + DECLARE_SYMBOL2(Accel::Collider,BVH4ColliderUserGeom); + + DECLARE_ISA_FUNCTION(VirtualCurveIntersector*,VirtualCurveIntersector4i,void); + DECLARE_ISA_FUNCTION(VirtualCurveIntersector*,VirtualCurveIntersector8i,void); + DECLARE_ISA_FUNCTION(VirtualCurveIntersector*,VirtualCurveIntersector4v,void); + DECLARE_ISA_FUNCTION(VirtualCurveIntersector*,VirtualCurveIntersector8v,void); + DECLARE_ISA_FUNCTION(VirtualCurveIntersector*,VirtualCurveIntersector4iMB,void); + DECLARE_ISA_FUNCTION(VirtualCurveIntersector*,VirtualCurveIntersector8iMB,void); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH4OBBVirtualCurveIntersector1); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4OBBVirtualCurveIntersector1MB); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4OBBVirtualCurveIntersectorRobust1); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4OBBVirtualCurveIntersectorRobust1MB); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Triangle4Intersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Triangle4iIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Triangle4vIntersector1Pluecker); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Triangle4iIntersector1Pluecker); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Triangle4vMBIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Triangle4iMBIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Triangle4vMBIntersector1Pluecker); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Triangle4iMBIntersector1Pluecker); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Quad4vIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Quad4iIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Quad4vIntersector1Pluecker); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Quad4iIntersector1Pluecker); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Quad4iMBIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4Quad4iMBIntersector1Pluecker); + + DECLARE_SYMBOL2(Accel::Intersector1,QBVH4Triangle4iIntersector1Pluecker); + DECLARE_SYMBOL2(Accel::Intersector1,QBVH4Quad4iIntersector1Pluecker); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH4SubdivPatch1Intersector1); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4SubdivPatch1MBIntersector1); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH4VirtualIntersector1); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4VirtualMBIntersector1); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH4InstanceIntersector1); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4InstanceMBIntersector1); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH4GridIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4GridMBIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH4GridIntersector1Pluecker); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH4OBBVirtualCurveIntersector4Hybrid); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4OBBVirtualCurveIntersector4HybridMB); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4OBBVirtualCurveIntersectorRobust4Hybrid); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4OBBVirtualCurveIntersectorRobust4HybridMB); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Triangle4Intersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Triangle4Intersector4HybridMoellerNoFilter); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Triangle4iIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Triangle4vIntersector4HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Triangle4iIntersector4HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Triangle4vMBIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Triangle4iMBIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Triangle4vMBIntersector4HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Triangle4iMBIntersector4HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Quad4vIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Quad4vIntersector4HybridMoellerNoFilter); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Quad4iIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Quad4vIntersector4HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Quad4iIntersector4HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Quad4iMBIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4Quad4iMBIntersector4HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH4SubdivPatch1Intersector4); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4SubdivPatch1MBIntersector4); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH4VirtualIntersector4Chunk); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4VirtualMBIntersector4Chunk); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH4InstanceIntersector4Chunk); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4InstanceMBIntersector4Chunk); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH4GridIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4GridMBIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH4GridIntersector4HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH4OBBVirtualCurveIntersector8Hybrid); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4OBBVirtualCurveIntersector8HybridMB); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4OBBVirtualCurveIntersectorRobust8Hybrid); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4OBBVirtualCurveIntersectorRobust8HybridMB); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Triangle4Intersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Triangle4Intersector8HybridMoellerNoFilter); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Triangle4iIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Triangle4vIntersector8HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Triangle4iIntersector8HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Triangle4vMBIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Triangle4iMBIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Triangle4vMBIntersector8HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Triangle4iMBIntersector8HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Quad4vIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Quad4vIntersector8HybridMoellerNoFilter); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Quad4iIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Quad4vIntersector8HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Quad4iIntersector8HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Quad4iMBIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4Quad4iMBIntersector8HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH4SubdivPatch1Intersector8); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4SubdivPatch1MBIntersector8); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH4VirtualIntersector8Chunk); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4VirtualMBIntersector8Chunk); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH4InstanceIntersector8Chunk); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4InstanceMBIntersector8Chunk); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH4GridIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4GridMBIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH4GridIntersector8HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH4OBBVirtualCurveIntersector16Hybrid); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4OBBVirtualCurveIntersector16HybridMB); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4OBBVirtualCurveIntersectorRobust16Hybrid); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4OBBVirtualCurveIntersectorRobust16HybridMB); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Triangle4Intersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Triangle4Intersector16HybridMoellerNoFilter); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Triangle4iIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Triangle4vIntersector16HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Triangle4iIntersector16HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Triangle4vMBIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Triangle4iMBIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Triangle4vMBIntersector16HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Triangle4iMBIntersector16HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Quad4vIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Quad4vIntersector16HybridMoellerNoFilter); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Quad4iIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Quad4vIntersector16HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Quad4iIntersector16HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Quad4iMBIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4Quad4iMBIntersector16HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH4SubdivPatch1Intersector16); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4SubdivPatch1MBIntersector16); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH4VirtualIntersector16Chunk); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4VirtualMBIntersector16Chunk); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH4InstanceIntersector16Chunk); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4InstanceMBIntersector16Chunk); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH4GridIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4GridMBIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH4GridIntersector16HybridPluecker); + + DECLARE_SYMBOL2(Accel::IntersectorN,BVH4IntersectorStreamPacketFallback); + + DECLARE_SYMBOL2(Accel::IntersectorN,BVH4Triangle4IntersectorStreamMoeller); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH4Triangle4IntersectorStreamMoellerNoFilter); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH4Triangle4iIntersectorStreamMoeller); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH4Triangle4vIntersectorStreamPluecker); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH4Triangle4iIntersectorStreamPluecker); + + DECLARE_SYMBOL2(Accel::IntersectorN,BVH4Quad4vIntersectorStreamMoeller); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH4Quad4vIntersectorStreamMoellerNoFilter); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH4Quad4iIntersectorStreamMoeller); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH4Quad4vIntersectorStreamPluecker); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH4Quad4iIntersectorStreamPluecker); + + DECLARE_SYMBOL2(Accel::IntersectorN,BVH4VirtualIntersectorStream); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH4InstanceIntersectorStream); + + DECLARE_ISA_FUNCTION(Builder*,BVH4BuilderTwoLevelTriangle4MeshSAH,void* COMMA Scene* COMMA bool); + DECLARE_ISA_FUNCTION(Builder*,BVH4BuilderTwoLevelTriangle4vMeshSAH,void* COMMA Scene* COMMA bool); + DECLARE_ISA_FUNCTION(Builder*,BVH4BuilderTwoLevelTriangle4iMeshSAH,void* COMMA Scene* COMMA bool); + DECLARE_ISA_FUNCTION(Builder*,BVH4BuilderTwoLevelQuadMeshSAH,void* COMMA Scene* COMMA bool); + DECLARE_ISA_FUNCTION(Builder*,BVH4BuilderTwoLevelVirtualSAH,void* COMMA Scene* COMMA bool); + DECLARE_ISA_FUNCTION(Builder*,BVH4BuilderTwoLevelInstanceSAH,void* COMMA Scene* COMMA Geometry::GTypeMask COMMA bool); + + DECLARE_ISA_FUNCTION(Builder*,BVH4Curve4vBuilder_OBB_New,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Curve4iBuilder_OBB_New,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4OBBCurve4iMBBuilder_OBB,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Curve8iBuilder_OBB_New,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4OBBCurve8iMBBuilder_OBB,void* COMMA Scene* COMMA size_t); + + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4SceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4vSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4iMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4vMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4QuantizedTriangle4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DECLARE_ISA_FUNCTION(Builder*,BVH4Quad4vSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Quad4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Quad4iMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4QuantizedQuad4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4SceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4vSceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4iSceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + + DECLARE_ISA_FUNCTION(Builder*,BVH4Quad4vSceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + + DECLARE_ISA_FUNCTION(Builder*,BVH4VirtualSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4VirtualMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DECLARE_ISA_FUNCTION(Builder*,BVH4InstanceSceneBuilderSAH,void* COMMA Scene* COMMA Geometry::GTypeMask); + DECLARE_ISA_FUNCTION(Builder*,BVH4InstanceMBSceneBuilderSAH,void* COMMA Scene* COMMA Geometry::GTypeMask); + + DECLARE_ISA_FUNCTION(Builder*,BVH4GridSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4GridMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DECLARE_ISA_FUNCTION(Builder*,BVH4SubdivPatch1BuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4SubdivPatch1MBBuilderSAH,void* COMMA Scene* COMMA size_t); + + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4MeshRefitSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4vMeshRefitSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4iMeshRefitSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Quad4vMeshRefitSAH,void* COMMA QuadMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4VirtualMeshRefitSAH,void* COMMA UserGeometry* COMMA unsigned int COMMA size_t); + + BVH4Factory::BVH4Factory(int bfeatures, int ifeatures) + { + SELECT_SYMBOL_DEFAULT_AVX_AVX2(ifeatures,BVH4ColliderUserGeom); + + selectBuilders(bfeatures); + selectIntersectors(ifeatures); + } + + void BVH4Factory::selectBuilders(int features) + { + IF_ENABLED_TRIS (SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4BuilderTwoLevelTriangle4MeshSAH)); + IF_ENABLED_TRIS (SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4BuilderTwoLevelTriangle4iMeshSAH)); + IF_ENABLED_TRIS (SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4BuilderTwoLevelTriangle4vMeshSAH)); + IF_ENABLED_QUADS (SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4BuilderTwoLevelQuadMeshSAH)); + IF_ENABLED_USER (SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4BuilderTwoLevelVirtualSAH)); + IF_ENABLED_INSTANCE (SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4BuilderTwoLevelInstanceSAH)); + + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4Curve4vBuilder_OBB_New)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4Curve4iBuilder_OBB_New)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4OBBCurve4iMBBuilder_OBB)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX(features,BVH4Curve8iBuilder_OBB_New)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX(features,BVH4OBBCurve8iMBBuilder_OBB)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4Triangle4SceneBuilderSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4Triangle4vSceneBuilderSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4Triangle4iSceneBuilderSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4Triangle4iMBSceneBuilderSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4Triangle4vMBSceneBuilderSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4QuantizedTriangle4iSceneBuilderSAH)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4Quad4vSceneBuilderSAH)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4Quad4iSceneBuilderSAH)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4Quad4iMBSceneBuilderSAH)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4QuantizedQuad4iSceneBuilderSAH)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4Triangle4SceneBuilderFastSpatialSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4Triangle4vSceneBuilderFastSpatialSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4Triangle4iSceneBuilderFastSpatialSAH)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4Quad4vSceneBuilderFastSpatialSAH)); + + IF_ENABLED_USER(SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4VirtualSceneBuilderSAH)); + IF_ENABLED_USER(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4VirtualMBSceneBuilderSAH)); + + IF_ENABLED_INSTANCE(SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4InstanceSceneBuilderSAH)); + IF_ENABLED_INSTANCE(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4InstanceMBSceneBuilderSAH)); + + IF_ENABLED_GRIDS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4GridSceneBuilderSAH)); + IF_ENABLED_GRIDS(SELECT_SYMBOL_DEFAULT_AVX(features,BVH4GridMBSceneBuilderSAH)); + + IF_ENABLED_SUBDIV(SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4SubdivPatch1BuilderSAH)); + IF_ENABLED_SUBDIV(SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,BVH4SubdivPatch1MBBuilderSAH)); + } + + void BVH4Factory::selectIntersectors(int features) + { + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(features,VirtualCurveIntersector4i)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,VirtualCurveIntersector8i)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(features,VirtualCurveIntersector4v)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,VirtualCurveIntersector8v)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(features,VirtualCurveIntersector4iMB)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,VirtualCurveIntersector8iMB)); + + /* select intersectors1 */ + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512SKX(features,BVH4OBBVirtualCurveIntersector1)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512SKX(features,BVH4OBBVirtualCurveIntersector1MB)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512SKX(features,BVH4OBBVirtualCurveIntersectorRobust1)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512SKX(features,BVH4OBBVirtualCurveIntersectorRobust1MB)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4Triangle4Intersector1Moeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX512SKX(features,BVH4Triangle4iIntersector1Moeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX512SKX(features,BVH4Triangle4vIntersector1Pluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX512SKX(features,BVH4Triangle4iIntersector1Pluecker)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Triangle4vMBIntersector1Moeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Triangle4iMBIntersector1Moeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Triangle4vMBIntersector1Pluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Triangle4iMBIntersector1Pluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4vIntersector1Moeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4iIntersector1Moeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4vIntersector1Pluecker)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4iIntersector1Pluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4iMBIntersector1Pluecker)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4iMBIntersector1Moeller)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX512SKX(features,QBVH4Triangle4iIntersector1Pluecker)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX512SKX(features,QBVH4Quad4iIntersector1Pluecker)); + + IF_ENABLED_SUBDIV(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4SubdivPatch1Intersector1)); + IF_ENABLED_SUBDIV(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4SubdivPatch1MBIntersector1)); + + IF_ENABLED_USER(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4VirtualIntersector1)); + IF_ENABLED_USER(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4VirtualMBIntersector1)); + + IF_ENABLED_INSTANCE(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4InstanceIntersector1)); + IF_ENABLED_INSTANCE(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4InstanceMBIntersector1)); + + IF_ENABLED_GRIDS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4GridIntersector1Moeller)); + IF_ENABLED_GRIDS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4GridMBIntersector1Moeller)) + IF_ENABLED_GRIDS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4GridIntersector1Pluecker)); + +#if defined (EMBREE_RAY_PACKETS) + + /* select intersectors4 */ + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512SKX(features,BVH4OBBVirtualCurveIntersector4Hybrid)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512SKX(features,BVH4OBBVirtualCurveIntersector4HybridMB)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512SKX(features,BVH4OBBVirtualCurveIntersectorRobust4Hybrid)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512SKX(features,BVH4OBBVirtualCurveIntersectorRobust4HybridMB)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Triangle4Intersector4HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Triangle4Intersector4HybridMoellerNoFilter)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Triangle4iIntersector4HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Triangle4vIntersector4HybridPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Triangle4iIntersector4HybridPluecker)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Triangle4vMBIntersector4HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Triangle4iMBIntersector4HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Triangle4vMBIntersector4HybridPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Triangle4iMBIntersector4HybridPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4vIntersector4HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4vIntersector4HybridMoellerNoFilter)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4iIntersector4HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4vIntersector4HybridPluecker)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4iIntersector4HybridPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4iMBIntersector4HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4iMBIntersector4HybridPluecker)); + + IF_ENABLED_SUBDIV(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4SubdivPatch1Intersector4)); + IF_ENABLED_SUBDIV(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4SubdivPatch1MBIntersector4)); + + IF_ENABLED_USER(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4VirtualIntersector4Chunk)); + IF_ENABLED_USER(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4VirtualMBIntersector4Chunk)); + + IF_ENABLED_INSTANCE(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4InstanceIntersector4Chunk)); + IF_ENABLED_INSTANCE(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4InstanceMBIntersector4Chunk)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4Quad4vIntersector4HybridMoeller)); + + IF_ENABLED_GRIDS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4GridIntersector4HybridMoeller)); + IF_ENABLED_GRIDS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4GridMBIntersector4HybridMoeller)); + IF_ENABLED_GRIDS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,BVH4GridIntersector4HybridPluecker)); + + /* select intersectors8 */ + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4OBBVirtualCurveIntersector8Hybrid)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4OBBVirtualCurveIntersector8HybridMB)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4OBBVirtualCurveIntersectorRobust8Hybrid)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4OBBVirtualCurveIntersectorRobust8HybridMB)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Triangle4Intersector8HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Triangle4Intersector8HybridMoellerNoFilter)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Triangle4iIntersector8HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Triangle4vIntersector8HybridPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Triangle4iIntersector8HybridPluecker)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Triangle4vMBIntersector8HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Triangle4iMBIntersector8HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Triangle4vMBIntersector8HybridPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Triangle4iMBIntersector8HybridPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Quad4vIntersector8HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Quad4vIntersector8HybridMoellerNoFilter)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Quad4iIntersector8HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Quad4vIntersector8HybridPluecker)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Quad4iIntersector8HybridPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Quad4iMBIntersector8HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4Quad4iMBIntersector8HybridPluecker)); + + IF_ENABLED_SUBDIV(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4SubdivPatch1Intersector8)); + IF_ENABLED_SUBDIV(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4SubdivPatch1MBIntersector8)); + + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4VirtualIntersector8Chunk)); + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4VirtualMBIntersector8Chunk)); + + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4InstanceIntersector8Chunk)); + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4InstanceMBIntersector8Chunk)); + + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4GridIntersector8HybridMoeller)); + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4GridMBIntersector8HybridMoeller)); + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH4GridIntersector8HybridPluecker)); + + /* select intersectors16 */ + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4OBBVirtualCurveIntersector16Hybrid)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4OBBVirtualCurveIntersector16HybridMB)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4OBBVirtualCurveIntersectorRobust16Hybrid)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4OBBVirtualCurveIntersectorRobust16HybridMB)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Triangle4Intersector16HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Triangle4Intersector16HybridMoellerNoFilter)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Triangle4iIntersector16HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Triangle4vIntersector16HybridPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Triangle4iIntersector16HybridPluecker)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Triangle4vMBIntersector16HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Triangle4iMBIntersector16HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Triangle4vMBIntersector16HybridPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Triangle4iMBIntersector16HybridPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Quad4vIntersector16HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Quad4vIntersector16HybridMoellerNoFilter)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Quad4iIntersector16HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Quad4vIntersector16HybridPluecker)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Quad4iIntersector16HybridPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Quad4iMBIntersector16HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4Quad4iMBIntersector16HybridPluecker)); + + IF_ENABLED_SUBDIV(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4SubdivPatch1Intersector16)); + IF_ENABLED_SUBDIV(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4SubdivPatch1MBIntersector16)); + + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4VirtualIntersector16Chunk)); + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4VirtualMBIntersector16Chunk)); + + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4InstanceIntersector16Chunk)); + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4InstanceMBIntersector16Chunk)); + + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4GridIntersector16HybridMoeller)); + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4GridMBIntersector16HybridMoeller)); + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH4GridIntersector16HybridPluecker)); + + /* select stream intersectors */ + SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4IntersectorStreamPacketFallback); + + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4Triangle4IntersectorStreamMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4Triangle4IntersectorStreamMoellerNoFilter)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4Triangle4iIntersectorStreamMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4Triangle4vIntersectorStreamPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4Triangle4iIntersectorStreamPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4Quad4vIntersectorStreamMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4Quad4vIntersectorStreamMoellerNoFilter)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4Quad4iIntersectorStreamMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4Quad4vIntersectorStreamPluecker)); + IF_ENABLED_QUADS(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4Quad4iIntersectorStreamPluecker)); + + IF_ENABLED_USER(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4VirtualIntersectorStream)); + + IF_ENABLED_INSTANCE(SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH4InstanceIntersectorStream)); + +#endif + } + + Accel::Intersectors BVH4Factory::BVH4OBBVirtualCurveIntersectors(BVH4* bvh, VirtualCurveIntersector* leafIntersector, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.leafIntersector = leafIntersector; + intersectors.intersector1 = BVH4OBBVirtualCurveIntersector1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4OBBVirtualCurveIntersector4Hybrid(); + intersectors.intersector8 = BVH4OBBVirtualCurveIntersector8Hybrid(); + intersectors.intersector16 = BVH4OBBVirtualCurveIntersector16Hybrid(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.leafIntersector = leafIntersector; + intersectors.intersector1 = BVH4OBBVirtualCurveIntersectorRobust1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4OBBVirtualCurveIntersectorRobust4Hybrid(); + intersectors.intersector8 = BVH4OBBVirtualCurveIntersectorRobust8Hybrid(); + intersectors.intersector16 = BVH4OBBVirtualCurveIntersectorRobust16Hybrid(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + default: assert(false); + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH4Factory::BVH4OBBVirtualCurveIntersectorsMB(BVH4* bvh, VirtualCurveIntersector* leafIntersector, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.leafIntersector = leafIntersector; + intersectors.intersector1 = BVH4OBBVirtualCurveIntersector1MB(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4OBBVirtualCurveIntersector4HybridMB(); + intersectors.intersector8 = BVH4OBBVirtualCurveIntersector8HybridMB(); + intersectors.intersector16 = BVH4OBBVirtualCurveIntersector16HybridMB(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.leafIntersector = leafIntersector; + intersectors.intersector1 = BVH4OBBVirtualCurveIntersectorRobust1MB(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4OBBVirtualCurveIntersectorRobust4HybridMB(); + intersectors.intersector8 = BVH4OBBVirtualCurveIntersectorRobust8HybridMB(); + intersectors.intersector16 = BVH4OBBVirtualCurveIntersectorRobust16HybridMB(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + default: assert(false); + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH4Factory::BVH4Triangle4Intersectors(BVH4* bvh, IntersectVariant ivariant) + { + assert(ivariant == IntersectVariant::FAST); + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Triangle4Intersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4_filter = BVH4Triangle4Intersector4HybridMoeller(); + intersectors.intersector4_nofilter = BVH4Triangle4Intersector4HybridMoellerNoFilter(); + intersectors.intersector8_filter = BVH4Triangle4Intersector8HybridMoeller(); + intersectors.intersector8_nofilter = BVH4Triangle4Intersector8HybridMoellerNoFilter(); + intersectors.intersector16_filter = BVH4Triangle4Intersector16HybridMoeller(); + intersectors.intersector16_nofilter = BVH4Triangle4Intersector16HybridMoellerNoFilter(); + intersectors.intersectorN_filter = BVH4Triangle4IntersectorStreamMoeller(); + intersectors.intersectorN_nofilter = BVH4Triangle4IntersectorStreamMoellerNoFilter(); +#endif + return intersectors; + } + + Accel::Intersectors BVH4Factory::BVH4Triangle4vIntersectors(BVH4* bvh, IntersectVariant ivariant) + { + assert(ivariant == IntersectVariant::ROBUST); + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Triangle4vIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4Triangle4vIntersector4HybridPluecker(); + intersectors.intersector8 = BVH4Triangle4vIntersector8HybridPluecker(); + intersectors.intersector16 = BVH4Triangle4vIntersector16HybridPluecker(); + intersectors.intersectorN = BVH4Triangle4vIntersectorStreamPluecker(); +#endif + return intersectors; + } + + Accel::Intersectors BVH4Factory::BVH4Triangle4iIntersectors(BVH4* bvh, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Triangle4iIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4Triangle4iIntersector4HybridMoeller(); + intersectors.intersector8 = BVH4Triangle4iIntersector8HybridMoeller(); + intersectors.intersector16 = BVH4Triangle4iIntersector16HybridMoeller(); + intersectors.intersectorN = BVH4Triangle4iIntersectorStreamMoeller(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Triangle4iIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4Triangle4iIntersector4HybridPluecker(); + intersectors.intersector8 = BVH4Triangle4iIntersector8HybridPluecker(); + intersectors.intersector16 = BVH4Triangle4iIntersector16HybridPluecker(); + intersectors.intersectorN = BVH4Triangle4iIntersectorStreamPluecker(); +#endif + return intersectors; + } + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH4Factory::BVH4Triangle4vMBIntersectors(BVH4* bvh, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Triangle4vMBIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4Triangle4vMBIntersector4HybridMoeller(); + intersectors.intersector8 = BVH4Triangle4vMBIntersector8HybridMoeller(); + intersectors.intersector16 = BVH4Triangle4vMBIntersector16HybridMoeller(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Triangle4vMBIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4Triangle4vMBIntersector4HybridPluecker(); + intersectors.intersector8 = BVH4Triangle4vMBIntersector8HybridPluecker(); + intersectors.intersector16 = BVH4Triangle4vMBIntersector16HybridPluecker(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH4Factory::BVH4Triangle4iMBIntersectors(BVH4* bvh, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Triangle4iMBIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4Triangle4iMBIntersector4HybridMoeller(); + intersectors.intersector8 = BVH4Triangle4iMBIntersector8HybridMoeller(); + intersectors.intersector16 = BVH4Triangle4iMBIntersector16HybridMoeller(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Triangle4iMBIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4Triangle4iMBIntersector4HybridPluecker(); + intersectors.intersector8 = BVH4Triangle4iMBIntersector8HybridPluecker(); + intersectors.intersector16 = BVH4Triangle4iMBIntersector16HybridPluecker(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH4Factory::BVH4Quad4vIntersectors(BVH4* bvh, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Quad4vIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4_filter = BVH4Quad4vIntersector4HybridMoeller(); + intersectors.intersector4_nofilter = BVH4Quad4vIntersector4HybridMoellerNoFilter(); + intersectors.intersector8_filter = BVH4Quad4vIntersector8HybridMoeller(); + intersectors.intersector8_nofilter = BVH4Quad4vIntersector8HybridMoellerNoFilter(); + intersectors.intersector16_filter = BVH4Quad4vIntersector16HybridMoeller(); + intersectors.intersector16_nofilter = BVH4Quad4vIntersector16HybridMoellerNoFilter(); + intersectors.intersectorN_filter = BVH4Quad4vIntersectorStreamMoeller(); + intersectors.intersectorN_nofilter = BVH4Quad4vIntersectorStreamMoellerNoFilter(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Quad4vIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4Quad4vIntersector4HybridPluecker(); + intersectors.intersector8 = BVH4Quad4vIntersector8HybridPluecker(); + intersectors.intersector16 = BVH4Quad4vIntersector16HybridPluecker(); + intersectors.intersectorN = BVH4Quad4vIntersectorStreamPluecker(); +#endif + return intersectors; + } + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH4Factory::BVH4Quad4iIntersectors(BVH4* bvh, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Quad4iIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4Quad4iIntersector4HybridMoeller(); + intersectors.intersector8 = BVH4Quad4iIntersector8HybridMoeller(); + intersectors.intersector16= BVH4Quad4iIntersector16HybridMoeller(); + intersectors.intersectorN = BVH4Quad4iIntersectorStreamMoeller(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Quad4iIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4Quad4iIntersector4HybridPluecker(); + intersectors.intersector8 = BVH4Quad4iIntersector8HybridPluecker(); + intersectors.intersector16= BVH4Quad4iIntersector16HybridPluecker(); + intersectors.intersectorN = BVH4Quad4iIntersectorStreamPluecker(); +#endif + return intersectors; + } + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH4Factory::BVH4Quad4iMBIntersectors(BVH4* bvh, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Quad4iMBIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4Quad4iMBIntersector4HybridMoeller(); + intersectors.intersector8 = BVH4Quad4iMBIntersector8HybridMoeller(); + intersectors.intersector16= BVH4Quad4iMBIntersector16HybridMoeller(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4Quad4iMBIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4Quad4iMBIntersector4HybridPluecker(); + intersectors.intersector8 = BVH4Quad4iMBIntersector8HybridPluecker(); + intersectors.intersector16= BVH4Quad4iMBIntersector16HybridPluecker(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH4Factory::QBVH4Triangle4iIntersectors(BVH4* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = QBVH4Triangle4iIntersector1Pluecker(); + return intersectors; + } + + Accel::Intersectors BVH4Factory::QBVH4Quad4iIntersectors(BVH4* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = QBVH4Quad4iIntersector1Pluecker(); + return intersectors; + } + + Accel::Intersectors BVH4Factory::BVH4UserGeometryIntersectors(BVH4* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4VirtualIntersector1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4VirtualIntersector4Chunk(); + intersectors.intersector8 = BVH4VirtualIntersector8Chunk(); + intersectors.intersector16 = BVH4VirtualIntersector16Chunk(); + intersectors.intersectorN = BVH4VirtualIntersectorStream(); +#endif + intersectors.collider = BVH4ColliderUserGeom(); + return intersectors; + } + + Accel::Intersectors BVH4Factory::BVH4UserGeometryMBIntersectors(BVH4* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4VirtualMBIntersector1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4VirtualMBIntersector4Chunk(); + intersectors.intersector8 = BVH4VirtualMBIntersector8Chunk(); + intersectors.intersector16 = BVH4VirtualMBIntersector16Chunk(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + + Accel::Intersectors BVH4Factory::BVH4InstanceIntersectors(BVH4* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4InstanceIntersector1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4InstanceIntersector4Chunk(); + intersectors.intersector8 = BVH4InstanceIntersector8Chunk(); + intersectors.intersector16 = BVH4InstanceIntersector16Chunk(); + intersectors.intersectorN = BVH4InstanceIntersectorStream(); +#endif + return intersectors; + } + + Accel::Intersectors BVH4Factory::BVH4InstanceMBIntersectors(BVH4* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4InstanceMBIntersector1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4InstanceMBIntersector4Chunk(); + intersectors.intersector8 = BVH4InstanceMBIntersector8Chunk(); + intersectors.intersector16 = BVH4InstanceMBIntersector16Chunk(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + + Accel::Intersectors BVH4Factory::BVH4SubdivPatch1Intersectors(BVH4* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4SubdivPatch1Intersector1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4SubdivPatch1Intersector4(); + intersectors.intersector8 = BVH4SubdivPatch1Intersector8(); + intersectors.intersector16 = BVH4SubdivPatch1Intersector16(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + + Accel::Intersectors BVH4Factory::BVH4SubdivPatch1MBIntersectors(BVH4* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4SubdivPatch1MBIntersector1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4SubdivPatch1MBIntersector4(); + intersectors.intersector8 = BVH4SubdivPatch1MBIntersector8(); + intersectors.intersector16 = BVH4SubdivPatch1MBIntersector16(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + + Accel* BVH4Factory::BVH4OBBVirtualCurve4i(Scene* scene, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(Curve4i::type,scene); + Accel::Intersectors intersectors = BVH4OBBVirtualCurveIntersectors(accel,VirtualCurveIntersector4i(),ivariant); + + Builder* builder = nullptr; + if (scene->device->hair_builder == "default" ) builder = BVH4Curve4iBuilder_OBB_New(accel,scene,0); + else if (scene->device->hair_builder == "sah" ) builder = BVH4Curve4iBuilder_OBB_New(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->hair_builder+" for BVH4OBB"); + + return new AccelInstance(accel,builder,intersectors); + } + +#if defined(EMBREE_TARGET_SIMD8) + Accel* BVH4Factory::BVH4OBBVirtualCurve8i(Scene* scene, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(Curve8i::type,scene); + Accel::Intersectors intersectors = BVH4OBBVirtualCurveIntersectors(accel,VirtualCurveIntersector8i(),ivariant); + + Builder* builder = nullptr; + if (scene->device->hair_builder == "default" ) builder = BVH4Curve8iBuilder_OBB_New(accel,scene,0); + else if (scene->device->hair_builder == "sah" ) builder = BVH4Curve8iBuilder_OBB_New(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->hair_builder+" for BVH4OBB"); + + return new AccelInstance(accel,builder,intersectors); + } +#endif + + Accel* BVH4Factory::BVH4OBBVirtualCurve4v(Scene* scene, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(Curve4v::type,scene); + Accel::Intersectors intersectors = BVH4OBBVirtualCurveIntersectors(accel,VirtualCurveIntersector4v(),ivariant); + + Builder* builder = nullptr; + if (scene->device->hair_builder == "default" ) builder = BVH4Curve4vBuilder_OBB_New(accel,scene,0); + else if (scene->device->hair_builder == "sah" ) builder = BVH4Curve4vBuilder_OBB_New(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->hair_builder+" for BVH4OBB"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4OBBVirtualCurve4iMB(Scene* scene, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(Curve4iMB::type,scene); + Accel::Intersectors intersectors = BVH4OBBVirtualCurveIntersectorsMB(accel,VirtualCurveIntersector4iMB(),ivariant); + + Builder* builder = nullptr; + if (scene->device->hair_builder == "default" ) builder = BVH4OBBCurve4iMBBuilder_OBB(accel,scene,0); + else if (scene->device->hair_builder == "sah" ) builder = BVH4OBBCurve4iMBBuilder_OBB(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->hair_builder+" for BVH4OBB"); + + return new AccelInstance(accel,builder,intersectors); + } + +#if defined(EMBREE_TARGET_SIMD8) + Accel* BVH4Factory::BVH4OBBVirtualCurve8iMB(Scene* scene, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(Curve8iMB::type,scene); + Accel::Intersectors intersectors = BVH4OBBVirtualCurveIntersectorsMB(accel,VirtualCurveIntersector8iMB(), ivariant); + + Builder* builder = nullptr; + if (scene->device->hair_builder == "default" ) builder = BVH4OBBCurve8iMBBuilder_OBB(accel,scene,0); + else if (scene->device->hair_builder == "sah" ) builder = BVH4OBBCurve8iMBBuilder_OBB(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->hair_builder+" for BVH4OBB"); + + return new AccelInstance(accel,builder,intersectors); + } +#endif + + Accel* BVH4Factory::BVH4Triangle4(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(Triangle4::type,scene); + + Accel::Intersectors intersectors; + if (scene->device->tri_traverser == "default") intersectors = BVH4Triangle4Intersectors(accel,ivariant); + else if (scene->device->tri_traverser == "fast" ) intersectors = BVH4Triangle4Intersectors(accel,IntersectVariant::FAST); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown traverser "+scene->device->tri_traverser+" for BVH4"); + + Builder* builder = nullptr; + if (scene->device->tri_builder == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH4Triangle4SceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : builder = BVH4BuilderTwoLevelTriangle4MeshSAH(accel,scene,false); break; + case BuildVariant::HIGH_QUALITY: builder = BVH4Triangle4SceneBuilderFastSpatialSAH(accel,scene,0); break; + } + } + else if (scene->device->tri_builder == "sah" ) builder = BVH4Triangle4SceneBuilderSAH(accel,scene,0); + else if (scene->device->tri_builder == "sah_fast_spatial" ) builder = BVH4Triangle4SceneBuilderFastSpatialSAH(accel,scene,0); + else if (scene->device->tri_builder == "sah_presplit") builder = BVH4Triangle4SceneBuilderSAH(accel,scene,MODE_HIGH_QUALITY); + else if (scene->device->tri_builder == "dynamic" ) builder = BVH4BuilderTwoLevelTriangle4MeshSAH(accel,scene,false); + else if (scene->device->tri_builder == "morton" ) builder = BVH4BuilderTwoLevelTriangle4MeshSAH(accel,scene,true); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->tri_builder+" for BVH4"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4Triangle4v(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(Triangle4v::type,scene); + + Accel::Intersectors intersectors; + if (scene->device->tri_traverser == "default") intersectors = BVH4Triangle4vIntersectors(accel,ivariant); + else if (scene->device->tri_traverser == "fast" ) intersectors = BVH4Triangle4vIntersectors(accel,IntersectVariant::FAST); + else if (scene->device->tri_traverser == "robust" ) intersectors = BVH4Triangle4vIntersectors(accel,IntersectVariant::ROBUST); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown traverser "+scene->device->tri_traverser+" for BVH4"); + + Builder* builder = nullptr; + if (scene->device->tri_builder == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH4Triangle4vSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : builder = BVH4BuilderTwoLevelTriangle4vMeshSAH(accel,scene,false); break; + case BuildVariant::HIGH_QUALITY: builder = BVH4Triangle4vSceneBuilderFastSpatialSAH(accel,scene,0); break; + } + } + else if (scene->device->tri_builder == "sah" ) builder = BVH4Triangle4vSceneBuilderSAH(accel,scene,0); + else if (scene->device->tri_builder == "sah_fast_spatial" ) builder = BVH4Triangle4vSceneBuilderFastSpatialSAH(accel,scene,0); + else if (scene->device->tri_builder == "sah_presplit") builder = BVH4Triangle4vSceneBuilderSAH(accel,scene,MODE_HIGH_QUALITY); + else if (scene->device->tri_builder == "dynamic" ) builder = BVH4BuilderTwoLevelTriangle4vMeshSAH(accel,scene,false); + else if (scene->device->tri_builder == "morton" ) builder = BVH4BuilderTwoLevelTriangle4vMeshSAH(accel,scene,true); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->tri_builder+" for BVH4"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4Triangle4i(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(Triangle4i::type,scene); + + Accel::Intersectors intersectors; + if (scene->device->tri_traverser == "default") intersectors = BVH4Triangle4iIntersectors(accel,ivariant); + else if (scene->device->tri_traverser == "fast" ) intersectors = BVH4Triangle4iIntersectors(accel,IntersectVariant::FAST); + else if (scene->device->tri_traverser == "robust" ) intersectors = BVH4Triangle4iIntersectors(accel,IntersectVariant::ROBUST); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown traverser "+scene->device->tri_traverser+" for BVH4"); + + Builder* builder = nullptr; + if (scene->device->tri_builder == "default" ) { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH4Triangle4iSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : builder = BVH4BuilderTwoLevelTriangle4iMeshSAH(accel,scene,false); break; + case BuildVariant::HIGH_QUALITY: builder = BVH4Triangle4iSceneBuilderFastSpatialSAH(accel,scene,0); break; + } + } + else if (scene->device->tri_builder == "sah" ) builder = BVH4Triangle4iSceneBuilderSAH(accel,scene,0); + else if (scene->device->tri_builder == "sah_fast_spatial" ) builder = BVH4Triangle4iSceneBuilderFastSpatialSAH(accel,scene,0); + else if (scene->device->tri_builder == "sah_presplit") builder = BVH4Triangle4iSceneBuilderSAH(accel,scene,MODE_HIGH_QUALITY); + else if (scene->device->tri_builder == "dynamic" ) builder = BVH4BuilderTwoLevelTriangle4iMeshSAH(accel,scene,false); + else if (scene->device->tri_builder == "morton" ) builder = BVH4BuilderTwoLevelTriangle4iMeshSAH(accel,scene,true); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->tri_builder+" for BVH4"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4Triangle4iMB(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(Triangle4i::type,scene); + + Accel::Intersectors intersectors; + if (scene->device->tri_traverser_mb == "default") intersectors = BVH4Triangle4iMBIntersectors(accel,ivariant); + else if (scene->device->tri_traverser_mb == "fast" ) intersectors = BVH4Triangle4iMBIntersectors(accel,IntersectVariant::FAST); + else if (scene->device->tri_traverser_mb == "robust" ) intersectors = BVH4Triangle4iMBIntersectors(accel,IntersectVariant::ROBUST); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown traverser "+scene->device->tri_traverser_mb+" for BVH4"); + + Builder* builder = nullptr; + if (scene->device->tri_builder_mb == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH4Triangle4iMBSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : assert(false); break; // FIXME: implement + case BuildVariant::HIGH_QUALITY: assert(false); break; + } + } + else if (scene->device->tri_builder_mb == "internal_time_splits") builder = BVH4Triangle4iMBSceneBuilderSAH(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->tri_builder_mb+" for BVH4"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4Triangle4vMB(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(Triangle4vMB::type,scene); + + Accel::Intersectors intersectors; + if (scene->device->tri_traverser_mb == "default") intersectors = BVH4Triangle4vMBIntersectors(accel,ivariant); + else if (scene->device->tri_traverser_mb == "fast" ) intersectors = BVH4Triangle4vMBIntersectors(accel,IntersectVariant::FAST); + else if (scene->device->tri_traverser_mb == "robust" ) intersectors = BVH4Triangle4vMBIntersectors(accel,IntersectVariant::ROBUST); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown traverser "+scene->device->tri_traverser_mb+" for BVH4"); + + Builder* builder = nullptr; + if (scene->device->tri_builder_mb == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH4Triangle4vMBSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : assert(false); break; // FIXME: implement + case BuildVariant::HIGH_QUALITY: assert(false); break; + } + } + else if (scene->device->tri_builder_mb == "internal_time_splits") builder = BVH4Triangle4vMBSceneBuilderSAH(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->tri_builder_mb+" for BVH4"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4Quad4v(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(Quad4v::type,scene); + Accel::Intersectors intersectors = BVH4Quad4vIntersectors(accel,ivariant); + + Builder* builder = nullptr; + if (scene->device->quad_builder == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH4Quad4vSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : builder = BVH4BuilderTwoLevelQuadMeshSAH(accel,scene,false); break; + case BuildVariant::HIGH_QUALITY: builder = BVH4Quad4vSceneBuilderFastSpatialSAH(accel,scene,0); break; + } + } + else if (scene->device->quad_builder == "sah" ) builder = BVH4Quad4vSceneBuilderSAH(accel,scene,0); + else if (scene->device->quad_builder == "sah_fast_spatial" ) builder = BVH4Quad4vSceneBuilderFastSpatialSAH(accel,scene,0); + else if (scene->device->quad_builder == "dynamic" ) builder = BVH4BuilderTwoLevelQuadMeshSAH(accel,scene,false); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->quad_builder+" for BVH4"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4Quad4i(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(Quad4i::type,scene); + Accel::Intersectors intersectors = BVH4Quad4iIntersectors(accel,ivariant); + + Builder* builder = nullptr; + if (scene->device->quad_builder == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH4Quad4iSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : assert(false); break; // FIXME: implement + case BuildVariant::HIGH_QUALITY: assert(false); break; // FIXME: implement + } + } + else if (scene->device->quad_builder == "sah") builder = BVH4Quad4iSceneBuilderSAH(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->quad_builder+" for BVH4"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4Quad4iMB(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(Quad4i::type,scene); + Accel::Intersectors intersectors = BVH4Quad4iMBIntersectors(accel,ivariant); + + Builder* builder = nullptr; + if (scene->device->quad_builder_mb == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH4Quad4iMBSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : assert(false); break; // FIXME: implement + case BuildVariant::HIGH_QUALITY: assert(false); break; + } + } + else if (scene->device->quad_builder_mb == "sah") builder = BVH4Quad4iMBSceneBuilderSAH(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->quad_builder_mb+" for BVH4"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4QuantizedQuad4i(Scene* scene) + { + BVH4* accel = new BVH4(Quad4i::type,scene); + Builder* builder = BVH4QuantizedQuad4iSceneBuilderSAH(accel,scene,0); + Accel::Intersectors intersectors = QBVH4Quad4iIntersectors(accel); + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4QuantizedTriangle4i(Scene* scene) + { + BVH4* accel = new BVH4(Triangle4i::type,scene); + Builder* builder = BVH4QuantizedTriangle4iSceneBuilderSAH(accel,scene,0); + Accel::Intersectors intersectors = QBVH4Triangle4iIntersectors(accel); + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4SubdivPatch1(Scene* scene) + { + BVH4* accel = new BVH4(SubdivPatch1::type,scene); + Accel::Intersectors intersectors = BVH4SubdivPatch1Intersectors(accel); + Builder* builder = BVH4SubdivPatch1BuilderSAH(accel,scene,0); + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4SubdivPatch1MB(Scene* scene) + { + BVH4* accel = new BVH4(SubdivPatch1::type,scene); + Accel::Intersectors intersectors = BVH4SubdivPatch1MBIntersectors(accel); + Builder* builder = BVH4SubdivPatch1MBBuilderSAH(accel,scene,0); + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4UserGeometry(Scene* scene, BuildVariant bvariant) + { + BVH4* accel = new BVH4(Object::type,scene); + Accel::Intersectors intersectors = BVH4UserGeometryIntersectors(accel); + + Builder* builder = nullptr; + if (scene->device->object_builder == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH4VirtualSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : builder = BVH4BuilderTwoLevelVirtualSAH(accel,scene,false); break; + case BuildVariant::HIGH_QUALITY: assert(false); break; + } + } + else if (scene->device->object_builder == "sah") builder = BVH4VirtualSceneBuilderSAH(accel,scene,0); + else if (scene->device->object_builder == "dynamic") builder = BVH4BuilderTwoLevelVirtualSAH(accel,scene,false); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->object_builder+" for BVH4"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4UserGeometryMB(Scene* scene) + { + BVH4* accel = new BVH4(Object::type,scene); + Accel::Intersectors intersectors = BVH4UserGeometryMBIntersectors(accel); + Builder* builder = BVH4VirtualMBSceneBuilderSAH(accel,scene,0); + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4Instance(Scene* scene, bool isExpensive, BuildVariant bvariant) + { + BVH4* accel = new BVH4(InstancePrimitive::type,scene); + Accel::Intersectors intersectors = BVH4InstanceIntersectors(accel); + auto gtype = isExpensive ? Geometry::MTY_INSTANCE_EXPENSIVE : Geometry::MTY_INSTANCE_CHEAP; + // Builder* builder = BVH4InstanceSceneBuilderSAH(accel,scene,gtype); + + Builder* builder = nullptr; + if (scene->device->object_builder == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH4InstanceSceneBuilderSAH(accel,scene,gtype); break; + case BuildVariant::DYNAMIC : builder = BVH4BuilderTwoLevelInstanceSAH(accel,scene,gtype,false); break; + case BuildVariant::HIGH_QUALITY: assert(false); break; + } + } + else if (scene->device->object_builder == "sah") builder = BVH4InstanceSceneBuilderSAH(accel,scene,gtype); + else if (scene->device->object_builder == "dynamic") builder = BVH4BuilderTwoLevelInstanceSAH(accel,scene,gtype,false); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->object_builder+" for BVH4"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4InstanceMB(Scene* scene, bool isExpensive) + { + BVH4* accel = new BVH4(InstancePrimitive::type,scene); + Accel::Intersectors intersectors = BVH4InstanceMBIntersectors(accel); + auto gtype = isExpensive ? Geometry::MTY_INSTANCE_EXPENSIVE : Geometry::MTY_INSTANCE_CHEAP; + Builder* builder = BVH4InstanceMBSceneBuilderSAH(accel,scene,gtype); + return new AccelInstance(accel,builder,intersectors); + } + + Accel::Intersectors BVH4Factory::BVH4GridIntersectors(BVH4* bvh, IntersectVariant ivariant) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + if (ivariant == IntersectVariant::FAST) + { + intersectors.intersector1 = BVH4GridIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4GridIntersector4HybridMoeller(); + intersectors.intersector8 = BVH4GridIntersector8HybridMoeller(); + intersectors.intersector16 = BVH4GridIntersector16HybridMoeller(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + } + else /* if (ivariant == IntersectVariant::ROBUST) */ + { + intersectors.intersector1 = BVH4GridIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4GridIntersector4HybridPluecker(); + intersectors.intersector8 = BVH4GridIntersector8HybridPluecker(); + intersectors.intersector16 = BVH4GridIntersector16HybridPluecker(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + } + return intersectors; + } + + Accel::Intersectors BVH4Factory::BVH4GridMBIntersectors(BVH4* bvh, IntersectVariant ivariant) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH4GridMBIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH4GridMBIntersector4HybridMoeller(); + intersectors.intersector8 = BVH4GridMBIntersector8HybridMoeller(); + intersectors.intersector16 = BVH4GridMBIntersector16HybridMoeller(); + intersectors.intersectorN = BVH4IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + + Accel* BVH4Factory::BVH4Grid(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(SubGridQBVH4::type,scene); + Accel::Intersectors intersectors = BVH4GridIntersectors(accel,ivariant); + + Builder* builder = nullptr; + if (scene->device->object_builder == "default") { + builder = BVH4GridSceneBuilderSAH(accel,scene,0); + } + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->grid_builder+" for BVH4"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH4Factory::BVH4GridMB(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH4* accel = new BVH4(SubGridQBVH4::type,scene); + Accel::Intersectors intersectors = BVH4GridMBIntersectors(accel,ivariant); + Builder* builder = nullptr; + if (scene->device->object_builder == "default") { + builder = BVH4GridMBSceneBuilderSAH(accel,scene,0); + } + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->grid_builder+" for BVH4MB"); + return new AccelInstance(accel,builder,intersectors); + } + +} diff --git a/thirdparty/embree/kernels/bvh/bvh4_factory.h b/thirdparty/embree/kernels/bvh/bvh4_factory.h new file mode 100644 index 000000000000..a68227b41f80 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh4_factory.h @@ -0,0 +1,316 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh_factory.h" + +namespace embree +{ + /*! BVH4 instantiations */ + class BVH4Factory : public BVHFactory + { + public: + BVH4Factory(int bfeatures, int ifeatures); + + public: + Accel* BVH4OBBVirtualCurve4i(Scene* scene, IntersectVariant ivariant); + Accel* BVH4OBBVirtualCurve4v(Scene* scene, IntersectVariant ivariant); + Accel* BVH4OBBVirtualCurve8i(Scene* scene, IntersectVariant ivariant); + Accel* BVH4OBBVirtualCurve4iMB(Scene* scene, IntersectVariant ivariant); + Accel* BVH4OBBVirtualCurve8iMB(Scene* scene, IntersectVariant ivariant); + DEFINE_SYMBOL2(VirtualCurveIntersector*,VirtualCurveIntersector4i); + DEFINE_SYMBOL2(VirtualCurveIntersector*,VirtualCurveIntersector8i); + DEFINE_SYMBOL2(VirtualCurveIntersector*,VirtualCurveIntersector4v); + DEFINE_SYMBOL2(VirtualCurveIntersector*,VirtualCurveIntersector8v); + DEFINE_SYMBOL2(VirtualCurveIntersector*,VirtualCurveIntersector4iMB); + DEFINE_SYMBOL2(VirtualCurveIntersector*,VirtualCurveIntersector8iMB); + + Accel* BVH4Triangle4 (Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + Accel* BVH4Triangle4v (Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::ROBUST); + Accel* BVH4Triangle4i (Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + Accel* BVH4Triangle4vMB(Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + Accel* BVH4Triangle4iMB(Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + + Accel* BVH4Quad4v (Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + Accel* BVH4Quad4i (Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + Accel* BVH4Quad4iMB(Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + + Accel* BVH4QuantizedTriangle4i(Scene* scene); + Accel* BVH4QuantizedQuad4i(Scene* scene); + + Accel* BVH4SubdivPatch1(Scene* scene); + Accel* BVH4SubdivPatch1MB(Scene* scene); + + Accel* BVH4UserGeometry(Scene* scene, BuildVariant bvariant = BuildVariant::STATIC); + Accel* BVH4UserGeometryMB(Scene* scene); + + Accel* BVH4Instance(Scene* scene, bool isExpensive, BuildVariant bvariant = BuildVariant::STATIC); + Accel* BVH4InstanceMB(Scene* scene, bool isExpensive); + + Accel* BVH4Grid(Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + Accel* BVH4GridMB(Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + + private: + void selectBuilders(int features); + void selectIntersectors(int features); + + private: + Accel::Intersectors BVH4OBBVirtualCurveIntersectors(BVH4* bvh, VirtualCurveIntersector* leafIntersector, IntersectVariant ivariant); + Accel::Intersectors BVH4OBBVirtualCurveIntersectorsMB(BVH4* bvh, VirtualCurveIntersector* leafIntersector, IntersectVariant ivariant); + + Accel::Intersectors BVH4Triangle4Intersectors(BVH4* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH4Triangle4vIntersectors(BVH4* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH4Triangle4iIntersectors(BVH4* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH4Triangle4iMBIntersectors(BVH4* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH4Triangle4vMBIntersectors(BVH4* bvh, IntersectVariant ivariant); + + Accel::Intersectors BVH4Quad4vIntersectors(BVH4* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH4Quad4iIntersectors(BVH4* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH4Quad4iMBIntersectors(BVH4* bvh, IntersectVariant ivariant); + + Accel::Intersectors QBVH4Quad4iIntersectors(BVH4* bvh); + Accel::Intersectors QBVH4Triangle4iIntersectors(BVH4* bvh); + + Accel::Intersectors BVH4UserGeometryIntersectors(BVH4* bvh); + Accel::Intersectors BVH4UserGeometryMBIntersectors(BVH4* bvh); + + Accel::Intersectors BVH4InstanceIntersectors(BVH4* bvh); + Accel::Intersectors BVH4InstanceMBIntersectors(BVH4* bvh); + + Accel::Intersectors BVH4SubdivPatch1Intersectors(BVH4* bvh); + Accel::Intersectors BVH4SubdivPatch1MBIntersectors(BVH4* bvh); + + Accel::Intersectors BVH4GridIntersectors(BVH4* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH4GridMBIntersectors(BVH4* bvh, IntersectVariant ivariant); + + private: + + DEFINE_SYMBOL2(Accel::Collider,BVH4ColliderUserGeom); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH4OBBVirtualCurveIntersector1); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4OBBVirtualCurveIntersector1MB); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4OBBVirtualCurveIntersectorRobust1); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4OBBVirtualCurveIntersectorRobust1MB); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Triangle4Intersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Triangle4iIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Triangle4vIntersector1Pluecker); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Triangle4iIntersector1Pluecker); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Triangle4vMBIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Triangle4iMBIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Triangle4vMBIntersector1Pluecker); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Triangle4iMBIntersector1Pluecker); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Quad4vIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Quad4iIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Quad4vIntersector1Pluecker); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Quad4iIntersector1Pluecker); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Quad4iMBIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4Quad4iMBIntersector1Pluecker); + + DEFINE_SYMBOL2(Accel::Intersector1,QBVH4Triangle4iIntersector1Pluecker); + DEFINE_SYMBOL2(Accel::Intersector1,QBVH4Quad4iIntersector1Pluecker); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH4SubdivPatch1Intersector1); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4SubdivPatch1MBIntersector1); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH4VirtualIntersector1); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4VirtualMBIntersector1); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH4InstanceIntersector1); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4InstanceMBIntersector1); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH4GridIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4GridMBIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH4GridIntersector1Pluecker); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH4OBBVirtualCurveIntersector4Hybrid); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4OBBVirtualCurveIntersector4HybridMB); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4OBBVirtualCurveIntersectorRobust4Hybrid); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4OBBVirtualCurveIntersectorRobust4HybridMB); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Triangle4Intersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Triangle4Intersector4HybridMoellerNoFilter); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Triangle4iIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Triangle4vIntersector4HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Triangle4iIntersector4HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Triangle4vMBIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Triangle4iMBIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Triangle4vMBIntersector4HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Triangle4iMBIntersector4HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Quad4vIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Quad4vIntersector4HybridMoellerNoFilter); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Quad4iIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Quad4vIntersector4HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Quad4iIntersector4HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Quad4iMBIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4Quad4iMBIntersector4HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH4SubdivPatch1Intersector4); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4SubdivPatch1MBIntersector4); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH4VirtualIntersector4Chunk); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4VirtualMBIntersector4Chunk); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH4InstanceIntersector4Chunk); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4InstanceMBIntersector4Chunk); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH4GridIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4GridMBIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH4GridIntersector4HybridPluecker); + + // ============== + + DEFINE_SYMBOL2(Accel::Intersector8,BVH4OBBVirtualCurveIntersector8Hybrid); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4OBBVirtualCurveIntersector8HybridMB); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4OBBVirtualCurveIntersectorRobust8Hybrid); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4OBBVirtualCurveIntersectorRobust8HybridMB); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Triangle4Intersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Triangle4Intersector8HybridMoellerNoFilter); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Triangle4iIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Triangle4vIntersector8HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Triangle4iIntersector8HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Triangle4vMBIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Triangle4iMBIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Triangle4vMBIntersector8HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Triangle4iMBIntersector8HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Quad4vIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Quad4vIntersector8HybridMoellerNoFilter); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Quad4iIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Quad4vIntersector8HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Quad4iIntersector8HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Quad4iMBIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4Quad4iMBIntersector8HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH4SubdivPatch1Intersector8); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4SubdivPatch1MBIntersector8); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH4VirtualIntersector8Chunk); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4VirtualMBIntersector8Chunk); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH4InstanceIntersector8Chunk); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4InstanceMBIntersector8Chunk); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH4GridIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4GridMBIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH4GridIntersector8HybridPluecker); + + // ============== + + DEFINE_SYMBOL2(Accel::Intersector16,BVH4OBBVirtualCurveIntersector16Hybrid); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4OBBVirtualCurveIntersector16HybridMB); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4OBBVirtualCurveIntersectorRobust16Hybrid); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4OBBVirtualCurveIntersectorRobust16HybridMB); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Triangle4Intersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Triangle4Intersector16HybridMoellerNoFilter); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Triangle4iIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Triangle4vIntersector16HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Triangle4iIntersector16HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Triangle4vMBIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Triangle4iMBIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Triangle4vMBIntersector16HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Triangle4iMBIntersector16HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Quad4vIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Quad4vIntersector16HybridMoellerNoFilter); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Quad4iIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Quad4vIntersector16HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Quad4iIntersector16HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Quad4iMBIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4Quad4iMBIntersector16HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH4SubdivPatch1Intersector16); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4SubdivPatch1MBIntersector16); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH4VirtualIntersector16Chunk); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4VirtualMBIntersector16Chunk); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH4InstanceIntersector16Chunk); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4InstanceMBIntersector16Chunk); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH4GridIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4GridMBIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH4GridIntersector16HybridPluecker); + + // ============== + + DEFINE_SYMBOL2(Accel::IntersectorN, BVH4IntersectorStreamPacketFallback); + + DEFINE_SYMBOL2(Accel::IntersectorN, BVH4Triangle4IntersectorStreamMoeller); + DEFINE_SYMBOL2(Accel::IntersectorN, BVH4Triangle4IntersectorStreamMoellerNoFilter); + DEFINE_SYMBOL2(Accel::IntersectorN, BVH4Triangle4iIntersectorStreamMoeller); + DEFINE_SYMBOL2(Accel::IntersectorN, BVH4Triangle4vIntersectorStreamPluecker); + DEFINE_SYMBOL2(Accel::IntersectorN, BVH4Triangle4iIntersectorStreamPluecker); + + DEFINE_SYMBOL2(Accel::IntersectorN, BVH4Quad4vIntersectorStreamMoeller); + DEFINE_SYMBOL2(Accel::IntersectorN, BVH4Quad4vIntersectorStreamMoellerNoFilter); + DEFINE_SYMBOL2(Accel::IntersectorN, BVH4Quad4iIntersectorStreamMoeller); + DEFINE_SYMBOL2(Accel::IntersectorN, BVH4Quad4vIntersectorStreamPluecker); + DEFINE_SYMBOL2(Accel::IntersectorN, BVH4Quad4iIntersectorStreamPluecker); + + DEFINE_SYMBOL2(Accel::IntersectorN,BVH4VirtualIntersectorStream); + + DEFINE_SYMBOL2(Accel::IntersectorN,BVH4InstanceIntersectorStream); + + // SAH scene builders + private: + DEFINE_ISA_FUNCTION(Builder*,BVH4Curve4vBuilder_OBB_New,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4Curve4iBuilder_OBB_New,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4OBBCurve4iMBBuilder_OBB,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4Curve8iBuilder_OBB_New,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4OBBCurve8iMBBuilder_OBB,void* COMMA Scene* COMMA size_t); + + DEFINE_ISA_FUNCTION(Builder*,BVH4Triangle4SceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4Triangle4vSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4Triangle4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4Triangle4iMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4Triangle4vMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4QuantizedTriangle4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DEFINE_ISA_FUNCTION(Builder*,BVH4Quad4vSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4Quad4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4Quad4iMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4QuantizedQuad4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DEFINE_ISA_FUNCTION(Builder*,BVH4SubdivPatch1BuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4SubdivPatch1MBBuilderSAH,void* COMMA Scene* COMMA size_t); + + DEFINE_ISA_FUNCTION(Builder*,BVH4VirtualSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4VirtualMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DEFINE_ISA_FUNCTION(Builder*,BVH4InstanceSceneBuilderSAH,void* COMMA Scene* COMMA Geometry::GTypeMask); + DEFINE_ISA_FUNCTION(Builder*,BVH4InstanceMBSceneBuilderSAH,void* COMMA Scene* COMMA Geometry::GTypeMask); + + DEFINE_ISA_FUNCTION(Builder*,BVH4GridSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4GridMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + // spatial scene builder + private: + DEFINE_ISA_FUNCTION(Builder*,BVH4Triangle4SceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4Triangle4vSceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4Triangle4iSceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH4Quad4vSceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + + // twolevel scene builders + private: + DEFINE_ISA_FUNCTION(Builder*,BVH4BuilderTwoLevelTriangle4MeshSAH,void* COMMA Scene* COMMA bool); + DEFINE_ISA_FUNCTION(Builder*,BVH4BuilderTwoLevelTriangle4vMeshSAH,void* COMMA Scene* COMMA bool); + DEFINE_ISA_FUNCTION(Builder*,BVH4BuilderTwoLevelTriangle4iMeshSAH,void* COMMA Scene* COMMA bool); + DEFINE_ISA_FUNCTION(Builder*,BVH4BuilderTwoLevelQuadMeshSAH,void* COMMA Scene* COMMA bool); + DEFINE_ISA_FUNCTION(Builder*,BVH4BuilderTwoLevelVirtualSAH,void* COMMA Scene* COMMA bool); + DEFINE_ISA_FUNCTION(Builder*,BVH4BuilderTwoLevelInstanceSAH,void* COMMA Scene* COMMA Geometry::GTypeMask COMMA bool); + }; +} diff --git a/thirdparty/embree/kernels/bvh/bvh8_factory.cpp b/thirdparty/embree/kernels/bvh/bvh8_factory.cpp new file mode 100644 index 000000000000..9fe057c39231 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh8_factory.cpp @@ -0,0 +1,1165 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "../common/isa.h" // to define EMBREE_TARGET_SIMD8 + +#if defined (EMBREE_TARGET_SIMD8) + +#include "bvh8_factory.h" +#include "../bvh/bvh.h" + +#include "../geometry/curveNv.h" +#include "../geometry/curveNi.h" +#include "../geometry/curveNi_mb.h" +#include "../geometry/linei.h" +#include "../geometry/triangle.h" +#include "../geometry/trianglev.h" +#include "../geometry/trianglev_mb.h" +#include "../geometry/trianglei.h" +#include "../geometry/quadv.h" +#include "../geometry/quadi.h" +#include "../geometry/subdivpatch1.h" +#include "../geometry/object.h" +#include "../geometry/instance.h" +#include "../geometry/subgrid.h" +#include "../common/accelinstance.h" + +namespace embree +{ + DECLARE_SYMBOL2(Accel::Collider,BVH8ColliderUserGeom); + + DECLARE_ISA_FUNCTION(VirtualCurveIntersector*,VirtualCurveIntersector8v,void); + DECLARE_ISA_FUNCTION(VirtualCurveIntersector*,VirtualCurveIntersector8iMB,void); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH8OBBVirtualCurveIntersector1); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8OBBVirtualCurveIntersector1MB); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8OBBVirtualCurveIntersectorRobust1); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8OBBVirtualCurveIntersectorRobust1MB); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Triangle4Intersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Triangle4iIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Triangle4vIntersector1Pluecker); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Triangle4iIntersector1Pluecker); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Triangle4vIntersector1Woop); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Triangle4vMBIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Triangle4iMBIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Triangle4vMBIntersector1Pluecker); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Triangle4iMBIntersector1Pluecker); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Quad4vIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Quad4iIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Quad4vIntersector1Pluecker); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Quad4iIntersector1Pluecker); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Quad4iMBIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8Quad4iMBIntersector1Pluecker); + + DECLARE_SYMBOL2(Accel::Intersector1,QBVH8Triangle4iIntersector1Pluecker); + DECLARE_SYMBOL2(Accel::Intersector1,QBVH8Triangle4Intersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,QBVH8Quad4iIntersector1Pluecker); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH8VirtualIntersector1); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8VirtualMBIntersector1); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH8InstanceIntersector1); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8InstanceMBIntersector1); + + DECLARE_SYMBOL2(Accel::Intersector1,BVH8GridIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8GridMBIntersector1Moeller); + DECLARE_SYMBOL2(Accel::Intersector1,BVH8GridIntersector1Pluecker); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH8OBBVirtualCurveIntersector4Hybrid); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8OBBVirtualCurveIntersector4HybridMB); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8OBBVirtualCurveIntersectorRobust4Hybrid); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8OBBVirtualCurveIntersectorRobust4HybridMB); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Triangle4Intersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Triangle4Intersector4HybridMoellerNoFilter); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Triangle4iIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Triangle4vIntersector4HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Triangle4iIntersector4HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Triangle4vMBIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Triangle4iMBIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Triangle4vMBIntersector4HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Triangle4iMBIntersector4HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Quad4vIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Quad4vIntersector4HybridMoellerNoFilter); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Quad4iIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Quad4vIntersector4HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Quad4iIntersector4HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Quad4iMBIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8Quad4iMBIntersector4HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH8VirtualIntersector4Chunk); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8VirtualMBIntersector4Chunk); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH8InstanceIntersector4Chunk); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8InstanceMBIntersector4Chunk); + + DECLARE_SYMBOL2(Accel::Intersector4,BVH8GridIntersector4HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector4,BVH8GridIntersector4HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH8OBBVirtualCurveIntersector8Hybrid); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8OBBVirtualCurveIntersector8HybridMB); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8OBBVirtualCurveIntersectorRobust8Hybrid); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8OBBVirtualCurveIntersectorRobust8HybridMB); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Triangle4Intersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Triangle4Intersector8HybridMoellerNoFilter); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Triangle4iIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Triangle4vIntersector8HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Triangle4iIntersector8HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Triangle4vMBIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Triangle4iMBIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Triangle4vMBIntersector8HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Triangle4iMBIntersector8HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Quad4vIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Quad4vIntersector8HybridMoellerNoFilter); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Quad4iIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Quad4vIntersector8HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Quad4iIntersector8HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Quad4iMBIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8Quad4iMBIntersector8HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH8VirtualIntersector8Chunk); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8VirtualMBIntersector8Chunk); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH8InstanceIntersector8Chunk); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8InstanceMBIntersector8Chunk); + + DECLARE_SYMBOL2(Accel::Intersector8,BVH8GridIntersector8HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector8,BVH8GridIntersector8HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH8OBBVirtualCurveIntersector16Hybrid); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8OBBVirtualCurveIntersector16HybridMB); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8OBBVirtualCurveIntersectorRobust16Hybrid); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8OBBVirtualCurveIntersectorRobust16HybridMB); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Triangle4Intersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Triangle4Intersector16HybridMoellerNoFilter); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Triangle4iIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Triangle4vIntersector16HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Triangle4iIntersector16HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Triangle4vMBIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Triangle4iMBIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Triangle4vMBIntersector16HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Triangle4iMBIntersector16HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Quad4vIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Quad4vIntersector16HybridMoellerNoFilter); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Quad4iIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Quad4vIntersector16HybridPluecker); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Quad4iIntersector16HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Quad4iMBIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8Quad4iMBIntersector16HybridPluecker); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH8VirtualIntersector16Chunk); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8VirtualMBIntersector16Chunk); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH8InstanceIntersector16Chunk); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8InstanceMBIntersector16Chunk); + + DECLARE_SYMBOL2(Accel::Intersector16,BVH8GridIntersector16HybridMoeller); + DECLARE_SYMBOL2(Accel::Intersector16,BVH8GridIntersector16HybridPluecker); + + DECLARE_SYMBOL2(Accel::IntersectorN,BVH8IntersectorStreamPacketFallback); + + DECLARE_SYMBOL2(Accel::IntersectorN,BVH8Triangle4IntersectorStreamMoeller); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH8Triangle4IntersectorStreamMoellerNoFilter); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH8Triangle4iIntersectorStreamMoeller); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH8Triangle4vIntersectorStreamPluecker); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH8Triangle4iIntersectorStreamPluecker); + + DECLARE_SYMBOL2(Accel::IntersectorN,BVH8Quad4vIntersectorStreamMoeller); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH8Quad4vIntersectorStreamMoellerNoFilter); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH8Quad4iIntersectorStreamMoeller); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH8Quad4vIntersectorStreamPluecker); + DECLARE_SYMBOL2(Accel::IntersectorN,BVH8Quad4iIntersectorStreamPluecker); + + DECLARE_SYMBOL2(Accel::IntersectorN,BVH8VirtualIntersectorStream); + + DECLARE_SYMBOL2(Accel::IntersectorN,BVH8InstanceIntersectorStream); + + DECLARE_ISA_FUNCTION(Builder*,BVH8Curve8vBuilder_OBB_New,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8OBBCurve8iMBBuilder_OBB,void* COMMA Scene* COMMA size_t); + + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4SceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4vSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4iMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4vMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8QuantizedTriangle4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8QuantizedTriangle4SceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DECLARE_ISA_FUNCTION(Builder*,BVH8Quad4vSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Quad4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Quad4iMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8QuantizedQuad4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DECLARE_ISA_FUNCTION(Builder*,BVH8VirtualSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8VirtualMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DECLARE_ISA_FUNCTION(Builder*,BVH8InstanceSceneBuilderSAH,void* COMMA Scene* COMMA Geometry::GTypeMask); + DECLARE_ISA_FUNCTION(Builder*,BVH8InstanceMBSceneBuilderSAH,void* COMMA Scene* COMMA Geometry::GTypeMask); + + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4SceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4vSceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Quad4vSceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8GridSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8GridMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DECLARE_ISA_FUNCTION(Builder*,BVH8BuilderTwoLevelTriangle4MeshSAH,void* COMMA Scene* COMMA bool); + DECLARE_ISA_FUNCTION(Builder*,BVH8BuilderTwoLevelTriangle4vMeshSAH,void* COMMA Scene* COMMA bool); + DECLARE_ISA_FUNCTION(Builder*,BVH8BuilderTwoLevelTriangle4iMeshSAH,void* COMMA Scene* COMMA bool); + DECLARE_ISA_FUNCTION(Builder*,BVH8BuilderTwoLevelQuadMeshSAH,void* COMMA Scene* COMMA bool); + DECLARE_ISA_FUNCTION(Builder*,BVH8BuilderTwoLevelVirtualSAH,void* COMMA Scene* COMMA bool); + DECLARE_ISA_FUNCTION(Builder*,BVH8BuilderTwoLevelInstanceSAH,void* COMMA Scene* COMMA Geometry::GTypeMask COMMA bool); + + BVH8Factory::BVH8Factory(int bfeatures, int ifeatures) + { + SELECT_SYMBOL_INIT_AVX(ifeatures,BVH8ColliderUserGeom); + + selectBuilders(bfeatures); + selectIntersectors(ifeatures); + } + + void BVH8Factory::selectBuilders(int features) + { + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX(features,BVH8Curve8vBuilder_OBB_New)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX(features,BVH8OBBCurve8iMBBuilder_OBB)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8Triangle4SceneBuilderSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8Triangle4vSceneBuilderSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8Triangle4iSceneBuilderSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8Triangle4iMBSceneBuilderSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8Triangle4vMBSceneBuilderSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX(features,BVH8QuantizedTriangle4iSceneBuilderSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX(features,BVH8QuantizedTriangle4SceneBuilderSAH)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8Quad4vSceneBuilderSAH)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8Quad4iSceneBuilderSAH)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8Quad4iMBSceneBuilderSAH)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX(features,BVH8QuantizedQuad4iSceneBuilderSAH)); + + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX(features,BVH8VirtualSceneBuilderSAH)); + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX(features,BVH8VirtualMBSceneBuilderSAH)); + + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX(features,BVH8InstanceSceneBuilderSAH)); + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX(features,BVH8InstanceMBSceneBuilderSAH)); + + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX(features,BVH8GridSceneBuilderSAH)); + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX(features,BVH8GridMBSceneBuilderSAH)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8Triangle4SceneBuilderFastSpatialSAH)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8Triangle4vSceneBuilderFastSpatialSAH)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8Quad4vSceneBuilderFastSpatialSAH)); + + IF_ENABLED_TRIS (SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8BuilderTwoLevelTriangle4MeshSAH)); + IF_ENABLED_TRIS (SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8BuilderTwoLevelTriangle4vMeshSAH)); + IF_ENABLED_TRIS (SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8BuilderTwoLevelTriangle4iMeshSAH)); + IF_ENABLED_QUADS (SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8BuilderTwoLevelQuadMeshSAH)); + IF_ENABLED_USER (SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8BuilderTwoLevelVirtualSAH)); + IF_ENABLED_INSTANCE (SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,BVH8BuilderTwoLevelInstanceSAH)); + } + + void BVH8Factory::selectIntersectors(int features) + { + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,VirtualCurveIntersector8v)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,VirtualCurveIntersector8iMB)); + + /* select intersectors1 */ + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8OBBVirtualCurveIntersector1)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8OBBVirtualCurveIntersector1MB)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8OBBVirtualCurveIntersectorRobust1)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8OBBVirtualCurveIntersectorRobust1MB)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4Intersector1Moeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4iIntersector1Moeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4vIntersector1Pluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4iIntersector1Pluecker)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4vIntersector1Woop)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4vMBIntersector1Moeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4iMBIntersector1Moeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4vMBIntersector1Pluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4iMBIntersector1Pluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Quad4vIntersector1Moeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Quad4iIntersector1Moeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Quad4vIntersector1Pluecker)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Quad4iIntersector1Pluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Quad4iMBIntersector1Moeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Quad4iMBIntersector1Pluecker)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,QBVH8Triangle4iIntersector1Pluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,QBVH8Triangle4Intersector1Moeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,QBVH8Quad4iIntersector1Pluecker)); + + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8VirtualIntersector1)); + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8VirtualMBIntersector1)); + + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8InstanceIntersector1)); + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8InstanceMBIntersector1)); + + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8GridIntersector1Moeller)); + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8GridMBIntersector1Moeller)) + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8GridIntersector1Pluecker)); + +#if defined (EMBREE_RAY_PACKETS) + + /* select intersectors4 */ + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8OBBVirtualCurveIntersector4Hybrid)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8OBBVirtualCurveIntersector4HybridMB)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8OBBVirtualCurveIntersectorRobust4Hybrid)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8OBBVirtualCurveIntersectorRobust4HybridMB)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4Intersector4HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4Intersector4HybridMoellerNoFilter)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4iIntersector4HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4vIntersector4HybridPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4iIntersector4HybridPluecker)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4vMBIntersector4HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4iMBIntersector4HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4vMBIntersector4HybridPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4iMBIntersector4HybridPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Quad4vIntersector4HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Quad4vIntersector4HybridMoellerNoFilter)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Quad4iIntersector4HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Quad4vIntersector4HybridPluecker)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Quad4iIntersector4HybridPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2(features,BVH8Quad4iMBIntersector4HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2(features,BVH8Quad4iMBIntersector4HybridPluecker)); + + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8VirtualIntersector4Chunk)); + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8VirtualMBIntersector4Chunk)); + + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8InstanceIntersector4Chunk)); + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8InstanceMBIntersector4Chunk)); + + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8GridIntersector4HybridMoeller)); + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8GridIntersector4HybridPluecker)); + + /* select intersectors8 */ + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8OBBVirtualCurveIntersector8Hybrid)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8OBBVirtualCurveIntersector8HybridMB)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8OBBVirtualCurveIntersectorRobust8Hybrid)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8OBBVirtualCurveIntersectorRobust8HybridMB)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4Intersector8HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4Intersector8HybridMoellerNoFilter)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4iIntersector8HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4vIntersector8HybridPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4iIntersector8HybridPluecker)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4vMBIntersector8HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4iMBIntersector8HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4vMBIntersector8HybridPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Triangle4iMBIntersector8HybridPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Quad4vIntersector8HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Quad4vIntersector8HybridMoellerNoFilter)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Quad4iIntersector8HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Quad4vIntersector8HybridPluecker)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8Quad4iIntersector8HybridPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2(features,BVH8Quad4iMBIntersector8HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2(features,BVH8Quad4iMBIntersector8HybridPluecker)); + + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8VirtualIntersector8Chunk)); + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8VirtualMBIntersector8Chunk)); + + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8InstanceIntersector8Chunk)); + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8InstanceMBIntersector8Chunk)); + + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8GridIntersector8HybridMoeller)); + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,BVH8GridIntersector8HybridPluecker)); + + /* select intersectors16 */ + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8OBBVirtualCurveIntersector16Hybrid)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8OBBVirtualCurveIntersector16HybridMB)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8OBBVirtualCurveIntersectorRobust16Hybrid)); + IF_ENABLED_CURVES_OR_POINTS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8OBBVirtualCurveIntersectorRobust16HybridMB)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Triangle4Intersector16HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Triangle4Intersector16HybridMoellerNoFilter)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Triangle4iIntersector16HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Triangle4vIntersector16HybridPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Triangle4iIntersector16HybridPluecker)); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Triangle4vMBIntersector16HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Triangle4iMBIntersector16HybridMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Triangle4vMBIntersector16HybridPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Triangle4iMBIntersector16HybridPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Quad4vIntersector16HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Quad4vIntersector16HybridMoellerNoFilter)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Quad4iIntersector16HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Quad4vIntersector16HybridPluecker)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Quad4iIntersector16HybridPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Quad4iMBIntersector16HybridMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8Quad4iMBIntersector16HybridPluecker)); + + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8VirtualIntersector16Chunk)); + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8VirtualMBIntersector16Chunk)); + + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8InstanceIntersector16Chunk)); + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8InstanceMBIntersector16Chunk)); + + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8GridIntersector16HybridMoeller)); + IF_ENABLED_GRIDS(SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,BVH8GridIntersector16HybridPluecker)); + + /* select stream intersectors */ + + SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8IntersectorStreamPacketFallback); + + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4IntersectorStreamMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4IntersectorStreamMoellerNoFilter)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4iIntersectorStreamMoeller)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4vIntersectorStreamPluecker)); + IF_ENABLED_TRIS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Triangle4iIntersectorStreamPluecker)); + + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Quad4vIntersectorStreamMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Quad4vIntersectorStreamMoellerNoFilter)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Quad4iIntersectorStreamMoeller)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Quad4vIntersectorStreamPluecker)); + IF_ENABLED_QUADS(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8Quad4iIntersectorStreamPluecker)); + + IF_ENABLED_USER(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8VirtualIntersectorStream)); + + IF_ENABLED_INSTANCE(SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,BVH8InstanceIntersectorStream)); + +#endif + } + + Accel::Intersectors BVH8Factory::BVH8OBBVirtualCurveIntersectors(BVH8* bvh, VirtualCurveIntersector* leafIntersector, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.leafIntersector = leafIntersector; + intersectors.intersector1 = BVH8OBBVirtualCurveIntersector1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8OBBVirtualCurveIntersector4Hybrid(); + intersectors.intersector8 = BVH8OBBVirtualCurveIntersector8Hybrid(); + intersectors.intersector16 = BVH8OBBVirtualCurveIntersector16Hybrid(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.leafIntersector = leafIntersector; + intersectors.intersector1 = BVH8OBBVirtualCurveIntersectorRobust1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8OBBVirtualCurveIntersectorRobust4Hybrid(); + intersectors.intersector8 = BVH8OBBVirtualCurveIntersectorRobust8Hybrid(); + intersectors.intersector16 = BVH8OBBVirtualCurveIntersectorRobust16Hybrid(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + default: assert(false); + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH8Factory::BVH8OBBVirtualCurveIntersectorsMB(BVH8* bvh, VirtualCurveIntersector* leafIntersector, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.leafIntersector = leafIntersector; + intersectors.intersector1 = BVH8OBBVirtualCurveIntersector1MB(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8OBBVirtualCurveIntersector4HybridMB(); + intersectors.intersector8 = BVH8OBBVirtualCurveIntersector8HybridMB(); + intersectors.intersector16 = BVH8OBBVirtualCurveIntersector16HybridMB(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.leafIntersector = leafIntersector; + intersectors.intersector1 = BVH8OBBVirtualCurveIntersectorRobust1MB(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8OBBVirtualCurveIntersectorRobust4HybridMB(); + intersectors.intersector8 = BVH8OBBVirtualCurveIntersectorRobust8HybridMB(); + intersectors.intersector16 = BVH8OBBVirtualCurveIntersectorRobust16HybridMB(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + default: assert(false); + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH8Factory::BVH8Triangle4Intersectors(BVH8* bvh, IntersectVariant ivariant) + { + assert(ivariant == IntersectVariant::FAST); + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8Triangle4Intersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4_filter = BVH8Triangle4Intersector4HybridMoeller(); + intersectors.intersector4_nofilter = BVH8Triangle4Intersector4HybridMoellerNoFilter(); + intersectors.intersector8_filter = BVH8Triangle4Intersector8HybridMoeller(); + intersectors.intersector8_nofilter = BVH8Triangle4Intersector8HybridMoellerNoFilter(); + intersectors.intersector16_filter = BVH8Triangle4Intersector16HybridMoeller(); + intersectors.intersector16_nofilter = BVH8Triangle4Intersector16HybridMoellerNoFilter(); + intersectors.intersectorN_filter = BVH8Triangle4IntersectorStreamMoeller(); + intersectors.intersectorN_nofilter = BVH8Triangle4IntersectorStreamMoellerNoFilter(); +#endif + return intersectors; + } + + Accel::Intersectors BVH8Factory::BVH8Triangle4vIntersectors(BVH8* bvh, IntersectVariant ivariant) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; +#define ENABLE_WOOP_TEST 0 +#if ENABLE_WOOP_TEST == 0 + //assert(ivariant == IntersectVariant::ROBUST); + intersectors.intersector1 = BVH8Triangle4vIntersector1Pluecker(); +#else + intersectors.intersector1 = BVH8Triangle4vIntersector1Woop(); +#endif + +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8Triangle4vIntersector4HybridPluecker(); + intersectors.intersector8 = BVH8Triangle4vIntersector8HybridPluecker(); + intersectors.intersector16 = BVH8Triangle4vIntersector16HybridPluecker(); + intersectors.intersectorN = BVH8Triangle4vIntersectorStreamPluecker(); +#endif + return intersectors; + } + + Accel::Intersectors BVH8Factory::BVH8Triangle4iIntersectors(BVH8* bvh, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8Triangle4iIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8Triangle4iIntersector4HybridMoeller(); + intersectors.intersector8 = BVH8Triangle4iIntersector8HybridMoeller(); + intersectors.intersector16 = BVH8Triangle4iIntersector16HybridMoeller(); + intersectors.intersectorN = BVH8Triangle4iIntersectorStreamMoeller(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8Triangle4iIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8Triangle4iIntersector4HybridPluecker(); + intersectors.intersector8 = BVH8Triangle4iIntersector8HybridPluecker(); + intersectors.intersector16 = BVH8Triangle4iIntersector16HybridPluecker(); + intersectors.intersectorN = BVH8Triangle4iIntersectorStreamPluecker(); +#endif + return intersectors; + } + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH8Factory::BVH8Triangle4vMBIntersectors(BVH8* bvh, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8Triangle4vMBIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8Triangle4vMBIntersector4HybridMoeller(); + intersectors.intersector8 = BVH8Triangle4vMBIntersector8HybridMoeller(); + intersectors.intersector16 = BVH8Triangle4vMBIntersector16HybridMoeller(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8Triangle4vMBIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8Triangle4vMBIntersector4HybridPluecker(); + intersectors.intersector8 = BVH8Triangle4vMBIntersector8HybridPluecker(); + intersectors.intersector16 = BVH8Triangle4vMBIntersector16HybridPluecker(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH8Factory::BVH8Triangle4iMBIntersectors(BVH8* bvh, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8Triangle4iMBIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8Triangle4iMBIntersector4HybridMoeller(); + intersectors.intersector8 = BVH8Triangle4iMBIntersector8HybridMoeller(); + intersectors.intersector16 = BVH8Triangle4iMBIntersector16HybridMoeller(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8Triangle4iMBIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8Triangle4iMBIntersector4HybridPluecker(); + intersectors.intersector8 = BVH8Triangle4iMBIntersector8HybridPluecker(); + intersectors.intersector16 = BVH8Triangle4iMBIntersector16HybridPluecker(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH8Factory::BVH8Quad4vIntersectors(BVH8* bvh, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8Quad4vIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4_filter = BVH8Quad4vIntersector4HybridMoeller(); + intersectors.intersector4_nofilter = BVH8Quad4vIntersector4HybridMoellerNoFilter(); + intersectors.intersector8_filter = BVH8Quad4vIntersector8HybridMoeller(); + intersectors.intersector8_nofilter = BVH8Quad4vIntersector8HybridMoellerNoFilter(); + intersectors.intersector16_filter = BVH8Quad4vIntersector16HybridMoeller(); + intersectors.intersector16_nofilter = BVH8Quad4vIntersector16HybridMoellerNoFilter(); + intersectors.intersectorN_filter = BVH8Quad4vIntersectorStreamMoeller(); + intersectors.intersectorN_nofilter = BVH8Quad4vIntersectorStreamMoellerNoFilter(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8Quad4vIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8Quad4vIntersector4HybridPluecker(); + intersectors.intersector8 = BVH8Quad4vIntersector8HybridPluecker(); + intersectors.intersector16 = BVH8Quad4vIntersector16HybridPluecker(); + intersectors.intersectorN = BVH8Quad4vIntersectorStreamPluecker(); +#endif + return intersectors; + } + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH8Factory::BVH8Quad4iIntersectors(BVH8* bvh, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8Quad4iIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8Quad4iIntersector4HybridMoeller(); + intersectors.intersector8 = BVH8Quad4iIntersector8HybridMoeller(); + intersectors.intersector16 = BVH8Quad4iIntersector16HybridMoeller(); + intersectors.intersectorN = BVH8Quad4iIntersectorStreamMoeller(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8Quad4iIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8Quad4iIntersector4HybridPluecker(); + intersectors.intersector8 = BVH8Quad4iIntersector8HybridPluecker(); + intersectors.intersector16 = BVH8Quad4iIntersector16HybridPluecker(); + intersectors.intersectorN = BVH8Quad4iIntersectorStreamPluecker(); +#endif + return intersectors; + } + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH8Factory::BVH8Quad4iMBIntersectors(BVH8* bvh, IntersectVariant ivariant) + { + switch (ivariant) { + case IntersectVariant::FAST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8Quad4iMBIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8Quad4iMBIntersector4HybridMoeller(); + intersectors.intersector8 = BVH8Quad4iMBIntersector8HybridMoeller(); + intersectors.intersector16 = BVH8Quad4iMBIntersector16HybridMoeller(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + case IntersectVariant::ROBUST: + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8Quad4iMBIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8Quad4iMBIntersector4HybridPluecker(); + intersectors.intersector8 = BVH8Quad4iMBIntersector8HybridPluecker(); + intersectors.intersector16 = BVH8Quad4iMBIntersector16HybridPluecker(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + } + return Accel::Intersectors(); + } + + Accel::Intersectors BVH8Factory::QBVH8Triangle4iIntersectors(BVH8* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = QBVH8Triangle4iIntersector1Pluecker(); + return intersectors; + } + + Accel::Intersectors BVH8Factory::QBVH8Triangle4Intersectors(BVH8* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = QBVH8Triangle4Intersector1Moeller(); + return intersectors; + } + + Accel::Intersectors BVH8Factory::QBVH8Quad4iIntersectors(BVH8* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = QBVH8Quad4iIntersector1Pluecker(); + return intersectors; + } + + Accel::Intersectors BVH8Factory::BVH8UserGeometryIntersectors(BVH8* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8VirtualIntersector1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8VirtualIntersector4Chunk(); + intersectors.intersector8 = BVH8VirtualIntersector8Chunk(); + intersectors.intersector16 = BVH8VirtualIntersector16Chunk(); + intersectors.intersectorN = BVH8VirtualIntersectorStream(); +#endif + intersectors.collider = BVH8ColliderUserGeom(); + return intersectors; + } + + Accel::Intersectors BVH8Factory::BVH8UserGeometryMBIntersectors(BVH8* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8VirtualMBIntersector1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8VirtualMBIntersector4Chunk(); + intersectors.intersector8 = BVH8VirtualMBIntersector8Chunk(); + intersectors.intersector16 = BVH8VirtualMBIntersector16Chunk(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + + Accel::Intersectors BVH8Factory::BVH8InstanceIntersectors(BVH8* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8InstanceIntersector1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8InstanceIntersector4Chunk(); + intersectors.intersector8 = BVH8InstanceIntersector8Chunk(); + intersectors.intersector16 = BVH8InstanceIntersector16Chunk(); + intersectors.intersectorN = BVH8InstanceIntersectorStream(); +#endif + return intersectors; + } + + Accel::Intersectors BVH8Factory::BVH8InstanceMBIntersectors(BVH8* bvh) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8InstanceMBIntersector1(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8InstanceMBIntersector4Chunk(); + intersectors.intersector8 = BVH8InstanceMBIntersector8Chunk(); + intersectors.intersector16 = BVH8InstanceMBIntersector16Chunk(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + return intersectors; + } + + Accel* BVH8Factory::BVH8OBBVirtualCurve8v(Scene* scene, IntersectVariant ivariant) + { + BVH8* accel = new BVH8(Curve8v::type,scene); + Accel::Intersectors intersectors = BVH8OBBVirtualCurveIntersectors(accel,VirtualCurveIntersector8v(),ivariant); + Builder* builder = BVH8Curve8vBuilder_OBB_New(accel,scene,0); + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8OBBVirtualCurve8iMB(Scene* scene, IntersectVariant ivariant) + { + BVH8* accel = new BVH8(Curve8iMB::type,scene); + Accel::Intersectors intersectors = BVH8OBBVirtualCurveIntersectorsMB(accel,VirtualCurveIntersector8iMB(),ivariant); + Builder* builder = BVH8OBBCurve8iMBBuilder_OBB(accel,scene,0); + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8Triangle4(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH8* accel = new BVH8(Triangle4::type,scene); + Accel::Intersectors intersectors= BVH8Triangle4Intersectors(accel,ivariant); + Builder* builder = nullptr; + if (scene->device->tri_builder == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH8Triangle4SceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : builder = BVH8BuilderTwoLevelTriangle4MeshSAH(accel,scene,false); break; + case BuildVariant::HIGH_QUALITY: builder = BVH8Triangle4SceneBuilderFastSpatialSAH(accel,scene,0); break; + } + } + else if (scene->device->tri_builder == "sah" ) builder = BVH8Triangle4SceneBuilderSAH(accel,scene,0); + else if (scene->device->tri_builder == "sah_fast_spatial") builder = BVH8Triangle4SceneBuilderFastSpatialSAH(accel,scene,0); + else if (scene->device->tri_builder == "sah_presplit") builder = BVH8Triangle4SceneBuilderSAH(accel,scene,MODE_HIGH_QUALITY); + else if (scene->device->tri_builder == "dynamic" ) builder = BVH8BuilderTwoLevelTriangle4MeshSAH(accel,scene,false); + else if (scene->device->tri_builder == "morton" ) builder = BVH8BuilderTwoLevelTriangle4MeshSAH(accel,scene,true); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->tri_builder+" for BVH8"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8Triangle4v(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH8* accel = new BVH8(Triangle4v::type,scene); + Accel::Intersectors intersectors= BVH8Triangle4vIntersectors(accel,ivariant); + Builder* builder = nullptr; + if (scene->device->tri_builder == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH8Triangle4vSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : builder = BVH8BuilderTwoLevelTriangle4vMeshSAH(accel,scene,false); break; + case BuildVariant::HIGH_QUALITY: builder = BVH8Triangle4vSceneBuilderFastSpatialSAH(accel,scene,0); break; + } + } + else if (scene->device->tri_builder == "sah_fast_spatial") builder = BVH8Triangle4SceneBuilderFastSpatialSAH(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->tri_builder+" for BVH8"); + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8Triangle4i(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH8* accel = new BVH8(Triangle4i::type,scene); + Accel::Intersectors intersectors = BVH8Triangle4iIntersectors(accel,ivariant); + + Builder* builder = nullptr; + if (scene->device->tri_builder == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH8Triangle4iSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : builder = BVH8BuilderTwoLevelTriangle4iMeshSAH(accel,scene,false); break; + case BuildVariant::HIGH_QUALITY: assert(false); break; // FIXME: implement + } + } + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->tri_builder+" for BVH8"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8Triangle4iMB(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH8* accel = new BVH8(Triangle4i::type,scene); + Accel::Intersectors intersectors = BVH8Triangle4iMBIntersectors(accel,ivariant); + + Builder* builder = nullptr; + if (scene->device->tri_builder_mb == "default") { // FIXME: implement + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH8Triangle4iMBSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : assert(false); break; // FIXME: implement + case BuildVariant::HIGH_QUALITY: assert(false); break; + } + } + else if (scene->device->tri_builder_mb == "internal_time_splits") builder = BVH8Triangle4iMBSceneBuilderSAH(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->tri_builder_mb+" for BVH8"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8Triangle4vMB(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH8* accel = new BVH8(Triangle4vMB::type,scene); + Accel::Intersectors intersectors= BVH8Triangle4vMBIntersectors(accel,ivariant); + + Builder* builder = nullptr; + if (scene->device->tri_builder_mb == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH8Triangle4vMBSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : assert(false); break; // FIXME: implement + case BuildVariant::HIGH_QUALITY: assert(false); break; + } + } + else if (scene->device->tri_builder_mb == "internal_time_splits") builder = BVH8Triangle4vMBSceneBuilderSAH(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->tri_builder_mb+" for BVH8"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8QuantizedTriangle4i(Scene* scene) + { + BVH8* accel = new BVH8(Triangle4i::type,scene); + Accel::Intersectors intersectors = QBVH8Triangle4iIntersectors(accel); + Builder* builder = BVH8QuantizedTriangle4iSceneBuilderSAH(accel,scene,0); + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8QuantizedTriangle4(Scene* scene) + { + BVH8* accel = new BVH8(Triangle4::type,scene); + Accel::Intersectors intersectors = QBVH8Triangle4Intersectors(accel); + Builder* builder = BVH8QuantizedTriangle4SceneBuilderSAH(accel,scene,0); + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8Quad4v(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH8* accel = new BVH8(Quad4v::type,scene); + Accel::Intersectors intersectors = BVH8Quad4vIntersectors(accel,ivariant); + + Builder* builder = nullptr; + if (scene->device->quad_builder == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH8Quad4vSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : builder = BVH8BuilderTwoLevelQuadMeshSAH(accel,scene,false); break; + case BuildVariant::HIGH_QUALITY: builder = BVH8Quad4vSceneBuilderFastSpatialSAH(accel,scene,0); break; + } + } + else if (scene->device->quad_builder == "dynamic" ) builder = BVH8BuilderTwoLevelQuadMeshSAH(accel,scene,false); + else if (scene->device->quad_builder == "morton" ) builder = BVH8BuilderTwoLevelQuadMeshSAH(accel,scene,true); + else if (scene->device->quad_builder == "sah_fast_spatial" ) builder = BVH8Quad4vSceneBuilderFastSpatialSAH(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->quad_builder+" for BVH8"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8Quad4i(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH8* accel = new BVH8(Quad4i::type,scene); + Accel::Intersectors intersectors = BVH8Quad4iIntersectors(accel,ivariant); + + Builder* builder = nullptr; + if (scene->device->quad_builder == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH8Quad4iSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : assert(false); break; // FIXME: implement + case BuildVariant::HIGH_QUALITY: assert(false); break; // FIXME: implement + } + } + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->quad_builder+" for BVH8"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8Quad4iMB(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH8* accel = new BVH8(Quad4i::type,scene); + Accel::Intersectors intersectors = BVH8Quad4iMBIntersectors(accel,ivariant); + + Builder* builder = nullptr; + if (scene->device->quad_builder_mb == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH8Quad4iMBSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : assert(false); break; // FIXME: implement + case BuildVariant::HIGH_QUALITY: assert(false); break; + } + } + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->quad_builder_mb+" for BVH8"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8QuantizedQuad4i(Scene* scene) + { + BVH8* accel = new BVH8(Quad4i::type,scene); + Accel::Intersectors intersectors = QBVH8Quad4iIntersectors(accel); + Builder* builder = nullptr; + if (scene->device->quad_builder == "default" ) builder = BVH8QuantizedQuad4iSceneBuilderSAH(accel,scene,0); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->quad_builder+" for QBVH8"); + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8UserGeometry(Scene* scene, BuildVariant bvariant) + { + BVH8* accel = new BVH8(Object::type,scene); + Accel::Intersectors intersectors = BVH8UserGeometryIntersectors(accel); + + Builder* builder = nullptr; + if (scene->device->object_builder == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH8VirtualSceneBuilderSAH(accel,scene,0); break; + case BuildVariant::DYNAMIC : builder = BVH8BuilderTwoLevelVirtualSAH(accel,scene,false); break; + case BuildVariant::HIGH_QUALITY: assert(false); break; + } + } + else if (scene->device->object_builder == "sah") builder = BVH8VirtualSceneBuilderSAH(accel,scene,0); + else if (scene->device->object_builder == "dynamic") builder = BVH8BuilderTwoLevelVirtualSAH(accel,scene,false); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->object_builder+" for BVH8"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8UserGeometryMB(Scene* scene) + { + BVH8* accel = new BVH8(Object::type,scene); + Accel::Intersectors intersectors = BVH8UserGeometryMBIntersectors(accel); + Builder* builder = BVH8VirtualMBSceneBuilderSAH(accel,scene,0); + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8Instance(Scene* scene, bool isExpensive, BuildVariant bvariant) + { + BVH8* accel = new BVH8(InstancePrimitive::type,scene); + Accel::Intersectors intersectors = BVH8InstanceIntersectors(accel); + auto gtype = isExpensive ? Geometry::MTY_INSTANCE_EXPENSIVE : Geometry::MTY_INSTANCE; + // Builder* builder = BVH8InstanceSceneBuilderSAH(accel,scene,gtype); + + Builder* builder = nullptr; + if (scene->device->object_builder == "default") { + switch (bvariant) { + case BuildVariant::STATIC : builder = BVH8InstanceSceneBuilderSAH(accel,scene,gtype);; break; + case BuildVariant::DYNAMIC : builder = BVH8BuilderTwoLevelInstanceSAH(accel,scene,gtype,false); break; + case BuildVariant::HIGH_QUALITY: assert(false); break; + } + } + else if (scene->device->object_builder == "sah") builder = BVH8InstanceSceneBuilderSAH(accel,scene,gtype); + else if (scene->device->object_builder == "dynamic") builder = BVH8BuilderTwoLevelInstanceSAH(accel,scene,gtype,false); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->object_builder+" for BVH8"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8InstanceMB(Scene* scene, bool isExpensive) + { + BVH8* accel = new BVH8(InstancePrimitive::type,scene); + Accel::Intersectors intersectors = BVH8InstanceMBIntersectors(accel); + auto gtype = isExpensive ? Geometry::MTY_INSTANCE_EXPENSIVE : Geometry::MTY_INSTANCE; + Builder* builder = BVH8InstanceMBSceneBuilderSAH(accel,scene,gtype); + return new AccelInstance(accel,builder,intersectors); + } + + Accel::Intersectors BVH8Factory::BVH8GridIntersectors(BVH8* bvh, IntersectVariant ivariant) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + if (ivariant == IntersectVariant::FAST) + { + intersectors.intersector1 = BVH8GridIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8GridIntersector4HybridMoeller(); + intersectors.intersector8 = BVH8GridIntersector8HybridMoeller(); + intersectors.intersector16 = BVH8GridIntersector16HybridMoeller(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + } + else /* if (ivariant == IntersectVariant::ROBUST) */ + { + intersectors.intersector1 = BVH8GridIntersector1Pluecker(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = BVH8GridIntersector4HybridPluecker(); + intersectors.intersector8 = BVH8GridIntersector8HybridPluecker(); + intersectors.intersector16 = BVH8GridIntersector16HybridPluecker(); + intersectors.intersectorN = BVH8IntersectorStreamPacketFallback(); +#endif + } + return intersectors; + } + + Accel::Intersectors BVH8Factory::BVH8GridMBIntersectors(BVH8* bvh, IntersectVariant ivariant) + { + Accel::Intersectors intersectors; + intersectors.ptr = bvh; + intersectors.intersector1 = BVH8GridMBIntersector1Moeller(); +#if defined (EMBREE_RAY_PACKETS) + intersectors.intersector4 = nullptr; + intersectors.intersector8 = nullptr; + intersectors.intersector16 = nullptr; + intersectors.intersectorN = nullptr; +#endif + return intersectors; + } + + Accel* BVH8Factory::BVH8Grid(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH8* accel = new BVH8(SubGridQBVH8::type,scene); + Accel::Intersectors intersectors = BVH8GridIntersectors(accel,ivariant); + Builder* builder = nullptr; + if (scene->device->grid_builder == "default") { + builder = BVH8GridSceneBuilderSAH(accel,scene,0); + } + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->object_builder+" for BVH4"); + + return new AccelInstance(accel,builder,intersectors); + } + + Accel* BVH8Factory::BVH8GridMB(Scene* scene, BuildVariant bvariant, IntersectVariant ivariant) + { + BVH8* accel = new BVH8(SubGridQBVH8::type,scene); + Accel::Intersectors intersectors = BVH8GridMBIntersectors(accel,ivariant); + Builder* builder = nullptr; + if (scene->device->grid_builder_mb == "default") { + builder = BVH8GridMBSceneBuilderSAH(accel,scene,0); + } + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown builder "+scene->device->object_builder+" for BVH8MB"); + return new AccelInstance(accel,builder,intersectors); + } +} + +#endif diff --git a/thirdparty/embree/kernels/bvh/bvh8_factory.h b/thirdparty/embree/kernels/bvh/bvh8_factory.h new file mode 100644 index 000000000000..b92188e7d356 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh8_factory.h @@ -0,0 +1,280 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh_factory.h" + +namespace embree +{ + /*! BVH8 instantiations */ + class BVH8Factory : public BVHFactory + { + public: + BVH8Factory(int bfeatures, int ifeatures); + + public: + Accel* BVH8OBBVirtualCurve8v(Scene* scene, IntersectVariant ivariant); + Accel* BVH8OBBVirtualCurve8iMB(Scene* scene, IntersectVariant ivariant); + DEFINE_SYMBOL2(VirtualCurveIntersector*,VirtualCurveIntersector8v); + DEFINE_SYMBOL2(VirtualCurveIntersector*,VirtualCurveIntersector8iMB); + + Accel* BVH8Triangle4 (Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + Accel* BVH8Triangle4v (Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + Accel* BVH8Triangle4i (Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + Accel* BVH8Triangle4vMB(Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + Accel* BVH8Triangle4iMB(Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + + Accel* BVH8Quad4v (Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + Accel* BVH8Quad4i (Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + Accel* BVH8Quad4iMB(Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + + Accel* BVH8QuantizedTriangle4i(Scene* scene); + Accel* BVH8QuantizedTriangle4(Scene* scene); + Accel* BVH8QuantizedQuad4i(Scene* scene); + + Accel* BVH8UserGeometry(Scene* scene, BuildVariant bvariant = BuildVariant::STATIC); + Accel* BVH8UserGeometryMB(Scene* scene); + + Accel* BVH8Instance(Scene* scene, bool isExpensive, BuildVariant bvariant = BuildVariant::STATIC); + Accel* BVH8InstanceMB(Scene* scene, bool isExpensive); + + Accel* BVH8Grid(Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + Accel* BVH8GridMB(Scene* scene, BuildVariant bvariant = BuildVariant::STATIC, IntersectVariant ivariant = IntersectVariant::FAST); + + private: + void selectBuilders(int features); + void selectIntersectors(int features); + + private: + Accel::Intersectors BVH8OBBVirtualCurveIntersectors(BVH8* bvh, VirtualCurveIntersector* leafIntersector, IntersectVariant ivariant); + Accel::Intersectors BVH8OBBVirtualCurveIntersectorsMB(BVH8* bvh, VirtualCurveIntersector* leafIntersector, IntersectVariant ivariant); + + Accel::Intersectors BVH8Triangle4Intersectors(BVH8* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH8Triangle4vIntersectors(BVH8* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH8Triangle4iIntersectors(BVH8* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH8Triangle4iMBIntersectors(BVH8* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH8Triangle4vMBIntersectors(BVH8* bvh, IntersectVariant ivariant); + + Accel::Intersectors BVH8Quad4vIntersectors(BVH8* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH8Quad4iIntersectors(BVH8* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH8Quad4iMBIntersectors(BVH8* bvh, IntersectVariant ivariant); + + Accel::Intersectors QBVH8Triangle4iIntersectors(BVH8* bvh); + Accel::Intersectors QBVH8Triangle4Intersectors(BVH8* bvh); + Accel::Intersectors QBVH8Quad4iIntersectors(BVH8* bvh); + + Accel::Intersectors BVH8UserGeometryIntersectors(BVH8* bvh); + Accel::Intersectors BVH8UserGeometryMBIntersectors(BVH8* bvh); + + Accel::Intersectors BVH8InstanceIntersectors(BVH8* bvh); + Accel::Intersectors BVH8InstanceMBIntersectors(BVH8* bvh); + + Accel::Intersectors BVH8GridIntersectors(BVH8* bvh, IntersectVariant ivariant); + Accel::Intersectors BVH8GridMBIntersectors(BVH8* bvh, IntersectVariant ivariant); + + private: + DEFINE_SYMBOL2(Accel::Collider,BVH8ColliderUserGeom); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH8OBBVirtualCurveIntersector1); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8OBBVirtualCurveIntersector1MB); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8OBBVirtualCurveIntersectorRobust1); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8OBBVirtualCurveIntersectorRobust1MB); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Triangle4Intersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Triangle4iIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Triangle4vIntersector1Pluecker); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Triangle4iIntersector1Pluecker); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Triangle4vMBIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Triangle4iMBIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Triangle4vMBIntersector1Pluecker); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Triangle4iMBIntersector1Pluecker); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Triangle4vIntersector1Woop); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Quad4vIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Quad4iIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Quad4vIntersector1Pluecker); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Quad4iIntersector1Pluecker); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Quad4iMBIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8Quad4iMBIntersector1Pluecker); + + DEFINE_SYMBOL2(Accel::Intersector1,QBVH8Triangle4iIntersector1Pluecker); + DEFINE_SYMBOL2(Accel::Intersector1,QBVH8Triangle4Intersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,QBVH8Quad4iIntersector1Pluecker); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH8VirtualIntersector1); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8VirtualMBIntersector1); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH8InstanceIntersector1); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8InstanceMBIntersector1); + + DEFINE_SYMBOL2(Accel::Intersector1,BVH8GridIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8GridMBIntersector1Moeller); + DEFINE_SYMBOL2(Accel::Intersector1,BVH8GridIntersector1Pluecker); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH8OBBVirtualCurveIntersector4Hybrid); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8OBBVirtualCurveIntersector4HybridMB); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8OBBVirtualCurveIntersectorRobust4Hybrid); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8OBBVirtualCurveIntersectorRobust4HybridMB); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Triangle4Intersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Triangle4Intersector4HybridMoellerNoFilter); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Triangle4iIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Triangle4vIntersector4HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Triangle4iIntersector4HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Triangle4vMBIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Triangle4iMBIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Triangle4vMBIntersector4HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Triangle4iMBIntersector4HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Quad4vIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Quad4vIntersector4HybridMoellerNoFilter); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Quad4iIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Quad4vIntersector4HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Quad4iIntersector4HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Quad4iMBIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8Quad4iMBIntersector4HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH8VirtualIntersector4Chunk); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8VirtualMBIntersector4Chunk); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH8InstanceIntersector4Chunk); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8InstanceMBIntersector4Chunk); + + DEFINE_SYMBOL2(Accel::Intersector4,BVH8GridIntersector4HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector4,BVH8GridIntersector4HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH8OBBVirtualCurveIntersector8Hybrid); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8OBBVirtualCurveIntersector8HybridMB); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8OBBVirtualCurveIntersectorRobust8Hybrid); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8OBBVirtualCurveIntersectorRobust8HybridMB); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Triangle4Intersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Triangle4Intersector8HybridMoellerNoFilter); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Triangle4iIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Triangle4vIntersector8HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Triangle4iIntersector8HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Triangle4vMBIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Triangle4iMBIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Triangle4vMBIntersector8HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Triangle4iMBIntersector8HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Quad4vIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Quad4vIntersector8HybridMoellerNoFilter); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Quad4iIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Quad4vIntersector8HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Quad4iIntersector8HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Quad4iMBIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8Quad4iMBIntersector8HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH8VirtualIntersector8Chunk); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8VirtualMBIntersector8Chunk); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH8InstanceIntersector8Chunk); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8InstanceMBIntersector8Chunk); + + DEFINE_SYMBOL2(Accel::Intersector8,BVH8GridIntersector8HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector8,BVH8GridIntersector8HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH8OBBVirtualCurveIntersector16Hybrid); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8OBBVirtualCurveIntersector16HybridMB); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8OBBVirtualCurveIntersectorRobust16Hybrid); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8OBBVirtualCurveIntersectorRobust16HybridMB); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Triangle4Intersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Triangle4Intersector16HybridMoellerNoFilter); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Triangle4iIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Triangle4vIntersector16HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Triangle4iIntersector16HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Triangle4vMBIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Triangle4iMBIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Triangle4vMBIntersector16HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Triangle4iMBIntersector16HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Quad4vIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Quad4vIntersector16HybridMoellerNoFilter); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Quad4iIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Quad4vIntersector16HybridPluecker); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Quad4iIntersector16HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Quad4iMBIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8Quad4iMBIntersector16HybridPluecker); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH8VirtualIntersector16Chunk); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8VirtualMBIntersector16Chunk); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH8InstanceIntersector16Chunk); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8InstanceMBIntersector16Chunk); + + DEFINE_SYMBOL2(Accel::Intersector16,BVH8GridIntersector16HybridMoeller); + DEFINE_SYMBOL2(Accel::Intersector16,BVH8GridIntersector16HybridPluecker); + + DEFINE_SYMBOL2(Accel::IntersectorN,BVH8IntersectorStreamPacketFallback); + + DEFINE_SYMBOL2(Accel::IntersectorN,BVH8Triangle4IntersectorStreamMoeller); + DEFINE_SYMBOL2(Accel::IntersectorN,BVH8Triangle4IntersectorStreamMoellerNoFilter); + DEFINE_SYMBOL2(Accel::IntersectorN,BVH8Triangle4iIntersectorStreamMoeller); + DEFINE_SYMBOL2(Accel::IntersectorN,BVH8Triangle4vIntersectorStreamPluecker); + DEFINE_SYMBOL2(Accel::IntersectorN,BVH8Triangle4iIntersectorStreamPluecker); + + DEFINE_SYMBOL2(Accel::IntersectorN,BVH8Quad4vIntersectorStreamMoeller); + DEFINE_SYMBOL2(Accel::IntersectorN,BVH8Quad4vIntersectorStreamMoellerNoFilter); + DEFINE_SYMBOL2(Accel::IntersectorN,BVH8Quad4iIntersectorStreamMoeller); + DEFINE_SYMBOL2(Accel::IntersectorN,BVH8Quad4vIntersectorStreamPluecker); + DEFINE_SYMBOL2(Accel::IntersectorN,BVH8Quad4iIntersectorStreamPluecker); + + DEFINE_SYMBOL2(Accel::IntersectorN,BVH8VirtualIntersectorStream); + + DEFINE_SYMBOL2(Accel::IntersectorN,BVH8InstanceIntersectorStream); + + // SAH scene builders + private: + DEFINE_ISA_FUNCTION(Builder*,BVH8Curve8vBuilder_OBB_New,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8OBBCurve8iMBBuilder_OBB,void* COMMA Scene* COMMA size_t); + + DEFINE_ISA_FUNCTION(Builder*,BVH8Triangle4SceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8Triangle4vSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8Triangle4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8Triangle4iMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8Triangle4vMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8QuantizedTriangle4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8QuantizedTriangle4SceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DEFINE_ISA_FUNCTION(Builder*,BVH8Quad4vSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8Quad4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8Quad4iMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8QuantizedQuad4iSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DEFINE_ISA_FUNCTION(Builder*,BVH8VirtualSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8VirtualMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + DEFINE_ISA_FUNCTION(Builder*,BVH8InstanceSceneBuilderSAH,void* COMMA Scene* COMMA Geometry::GTypeMask); + DEFINE_ISA_FUNCTION(Builder*,BVH8InstanceMBSceneBuilderSAH,void* COMMA Scene* COMMA Geometry::GTypeMask); + + DEFINE_ISA_FUNCTION(Builder*,BVH8GridSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8GridMBSceneBuilderSAH,void* COMMA Scene* COMMA size_t); + + // SAH spatial scene builders + private: + DEFINE_ISA_FUNCTION(Builder*,BVH8Triangle4SceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8Triangle4vSceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + DEFINE_ISA_FUNCTION(Builder*,BVH8Quad4vSceneBuilderFastSpatialSAH,void* COMMA Scene* COMMA size_t); + + // twolevel scene builders + private: + DEFINE_ISA_FUNCTION(Builder*,BVH8BuilderTwoLevelTriangle4MeshSAH,void* COMMA Scene* COMMA bool); + DEFINE_ISA_FUNCTION(Builder*,BVH8BuilderTwoLevelTriangle4vMeshSAH,void* COMMA Scene* COMMA bool); + DEFINE_ISA_FUNCTION(Builder*,BVH8BuilderTwoLevelTriangle4iMeshSAH,void* COMMA Scene* COMMA bool); + DEFINE_ISA_FUNCTION(Builder*,BVH8BuilderTwoLevelQuadMeshSAH,void* COMMA Scene* COMMA bool); + DEFINE_ISA_FUNCTION(Builder*,BVH8BuilderTwoLevelVirtualSAH,void* COMMA Scene* COMMA bool); + DEFINE_ISA_FUNCTION(Builder*,BVH8BuilderTwoLevelInstanceSAH,void* COMMA Scene* COMMA Geometry::GTypeMask COMMA bool); + }; +} diff --git a/thirdparty/embree/kernels/bvh/bvh_builder.cpp b/thirdparty/embree/kernels/bvh/bvh_builder.cpp new file mode 100644 index 000000000000..e832537ec570 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_builder.cpp @@ -0,0 +1,60 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh_builder.h" + +namespace embree +{ + namespace isa + { + template + typename BVHN::NodeRef BVHNBuilderVirtual::BVHNBuilderV::build(FastAllocator* allocator, BuildProgressMonitor& progressFunc, PrimRef* prims, const PrimInfo& pinfo, GeneralBVHBuilder::Settings settings) + { + auto createLeafFunc = [&] (const PrimRef* prims, const range& set, const Allocator& alloc) -> NodeRef { + return createLeaf(prims,set,alloc); + }; + + settings.branchingFactor = N; + settings.maxDepth = BVH::maxBuildDepthLeaf; + return BVHBuilderBinnedSAH::build + (FastAllocator::Create(allocator),typename BVH::AABBNode::Create2(),typename BVH::AABBNode::Set3(allocator,prims),createLeafFunc,progressFunc,prims,pinfo,settings); + } + + + template + typename BVHN::NodeRef BVHNBuilderQuantizedVirtual::BVHNBuilderV::build(FastAllocator* allocator, BuildProgressMonitor& progressFunc, PrimRef* prims, const PrimInfo& pinfo, GeneralBVHBuilder::Settings settings) + { + auto createLeafFunc = [&] (const PrimRef* prims, const range& set, const Allocator& alloc) -> NodeRef { + return createLeaf(prims,set,alloc); + }; + + settings.branchingFactor = N; + settings.maxDepth = BVH::maxBuildDepthLeaf; + return BVHBuilderBinnedSAH::build + (FastAllocator::Create(allocator),typename BVH::QuantizedNode::Create2(),typename BVH::QuantizedNode::Set2(),createLeafFunc,progressFunc,prims,pinfo,settings); + } + + template + typename BVHN::NodeRecordMB BVHNBuilderMblurVirtual::BVHNBuilderV::build(FastAllocator* allocator, BuildProgressMonitor& progressFunc, PrimRef* prims, const PrimInfo& pinfo, GeneralBVHBuilder::Settings settings, const BBox1f& timeRange) + { + auto createLeafFunc = [&] (const PrimRef* prims, const range& set, const Allocator& alloc) -> NodeRecordMB { + return createLeaf(prims,set,alloc); + }; + + settings.branchingFactor = N; + settings.maxDepth = BVH::maxBuildDepthLeaf; + return BVHBuilderBinnedSAH::build + (FastAllocator::Create(allocator),typename BVH::AABBNodeMB::Create(),typename BVH::AABBNodeMB::SetTimeRange(timeRange),createLeafFunc,progressFunc,prims,pinfo,settings); + } + + template struct BVHNBuilderVirtual<4>; + template struct BVHNBuilderQuantizedVirtual<4>; + template struct BVHNBuilderMblurVirtual<4>; + +#if defined(__AVX__) + template struct BVHNBuilderVirtual<8>; + template struct BVHNBuilderQuantizedVirtual<8>; + template struct BVHNBuilderMblurVirtual<8>; +#endif + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_builder.h b/thirdparty/embree/kernels/bvh/bvh_builder.h new file mode 100644 index 000000000000..1b86bb45adac --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_builder.h @@ -0,0 +1,114 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh.h" +#include "../builders/bvh_builder_sah.h" + +namespace embree +{ + namespace isa + { + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + + template + struct BVHNBuilderVirtual + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef FastAllocator::CachedAllocator Allocator; + + struct BVHNBuilderV { + NodeRef build(FastAllocator* allocator, BuildProgressMonitor& progress, PrimRef* prims, const PrimInfo& pinfo, GeneralBVHBuilder::Settings settings); + virtual NodeRef createLeaf (const PrimRef* prims, const range& set, const Allocator& alloc) = 0; + }; + + template + struct BVHNBuilderT : public BVHNBuilderV + { + BVHNBuilderT (CreateLeafFunc createLeafFunc) + : createLeafFunc(createLeafFunc) {} + + NodeRef createLeaf (const PrimRef* prims, const range& set, const Allocator& alloc) { + return createLeafFunc(prims,set,alloc); + } + + private: + CreateLeafFunc createLeafFunc; + }; + + template + static NodeRef build(FastAllocator* allocator, CreateLeafFunc createLeaf, BuildProgressMonitor& progress, PrimRef* prims, const PrimInfo& pinfo, GeneralBVHBuilder::Settings settings) { + return BVHNBuilderT(createLeaf).build(allocator,progress,prims,pinfo,settings); + } + }; + + template + struct BVHNBuilderQuantizedVirtual + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef FastAllocator::CachedAllocator Allocator; + + struct BVHNBuilderV { + NodeRef build(FastAllocator* allocator, BuildProgressMonitor& progress, PrimRef* prims, const PrimInfo& pinfo, GeneralBVHBuilder::Settings settings); + virtual NodeRef createLeaf (const PrimRef* prims, const range& set, const Allocator& alloc) = 0; + }; + + template + struct BVHNBuilderT : public BVHNBuilderV + { + BVHNBuilderT (CreateLeafFunc createLeafFunc) + : createLeafFunc(createLeafFunc) {} + + NodeRef createLeaf (const PrimRef* prims, const range& set, const Allocator& alloc) { + return createLeafFunc(prims,set,alloc); + } + + private: + CreateLeafFunc createLeafFunc; + }; + + template + static NodeRef build(FastAllocator* allocator, CreateLeafFunc createLeaf, BuildProgressMonitor& progress, PrimRef* prims, const PrimInfo& pinfo, GeneralBVHBuilder::Settings settings) { + return BVHNBuilderT(createLeaf).build(allocator,progress,prims,pinfo,settings); + } + }; + + template + struct BVHNBuilderMblurVirtual + { + typedef BVHN BVH; + typedef typename BVH::AABBNodeMB AABBNodeMB; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::NodeRecordMB NodeRecordMB; + typedef FastAllocator::CachedAllocator Allocator; + + struct BVHNBuilderV { + NodeRecordMB build(FastAllocator* allocator, BuildProgressMonitor& progress, PrimRef* prims, const PrimInfo& pinfo, GeneralBVHBuilder::Settings settings, const BBox1f& timeRange); + virtual NodeRecordMB createLeaf (const PrimRef* prims, const range& set, const Allocator& alloc) = 0; + }; + + template + struct BVHNBuilderT : public BVHNBuilderV + { + BVHNBuilderT (CreateLeafFunc createLeafFunc) + : createLeafFunc(createLeafFunc) {} + + NodeRecordMB createLeaf (const PrimRef* prims, const range& set, const Allocator& alloc) { + return createLeafFunc(prims,set,alloc); + } + + private: + CreateLeafFunc createLeafFunc; + }; + + template + static NodeRecordMB build(FastAllocator* allocator, CreateLeafFunc createLeaf, BuildProgressMonitor& progress, PrimRef* prims, const PrimInfo& pinfo, GeneralBVHBuilder::Settings settings, const BBox1f& timeRange) { + return BVHNBuilderT(createLeaf).build(allocator,progress,prims,pinfo,settings,timeRange); + } + }; + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_builder_morton.cpp b/thirdparty/embree/kernels/bvh/bvh_builder_morton.cpp new file mode 100644 index 000000000000..fa1710428378 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_builder_morton.cpp @@ -0,0 +1,531 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh.h" +#include "bvh_statistics.h" +#include "bvh_rotate.h" +#include "../common/profile.h" +#include "../../common/algorithms/parallel_prefix_sum.h" + +#include "../builders/primrefgen.h" +#include "../builders/bvh_builder_morton.h" + +#include "../geometry/triangle.h" +#include "../geometry/trianglev.h" +#include "../geometry/trianglei.h" +#include "../geometry/quadv.h" +#include "../geometry/quadi.h" +#include "../geometry/object.h" +#include "../geometry/instance.h" + +#if defined(__X86_64__) +# define ROTATE_TREE 1 // specifies number of tree rotation rounds to perform +#else +# define ROTATE_TREE 0 // do not use tree rotations on 32 bit platforms, barrier bit in NodeRef will cause issues +#endif + +namespace embree +{ + namespace isa + { + template + struct SetBVHNBounds + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::NodeRecord NodeRecord; + typedef typename BVH::AABBNode AABBNode; + + BVH* bvh; + __forceinline SetBVHNBounds (BVH* bvh) : bvh(bvh) {} + + __forceinline NodeRecord operator() (NodeRef ref, const NodeRecord* children, size_t num) + { + AABBNode* node = ref.getAABBNode(); + + BBox3fa res = empty; + for (size_t i=0; isetRef(i,children[i].ref); + node->setBounds(i,b); + } + + BBox3fx result = (BBox3fx&)res; +#if ROTATE_TREE + if (N == 4) + { + size_t n = 0; + for (size_t i=0; i= 4096) { + for (size_t i=0; i::rotate(node->child(i)); + node->child(i).setBarrier(); + } + } + } + result.lower.a = unsigned(n); + } +#endif + + return NodeRecord(ref,result); + } + }; + + template + struct CreateMortonLeaf; + + template + struct CreateMortonLeaf + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::NodeRecord NodeRecord; + + __forceinline CreateMortonLeaf (TriangleMesh* mesh, unsigned int geomID, BVHBuilderMorton::BuildPrim* morton) + : mesh(mesh), morton(morton), geomID_(geomID) {} + + __noinline NodeRecord operator() (const range& current, const FastAllocator::CachedAllocator& alloc) + { + vfloat4 lower(pos_inf); + vfloat4 upper(neg_inf); + size_t items = current.size(); + size_t start = current.begin(); + assert(items<=4); + + /* allocate leaf node */ + Triangle4* accel = (Triangle4*) alloc.malloc1(sizeof(Triangle4),BVH::byteAlignment); + NodeRef ref = BVH::encodeLeaf((char*)accel,1); + vuint4 vgeomID = -1, vprimID = -1; + Vec3vf4 v0 = zero, v1 = zero, v2 = zero; + const TriangleMesh* __restrict__ const mesh = this->mesh; + + for (size_t i=0; itriangle(primID); + const Vec3fa& p0 = mesh->vertex(tri.v[0]); + const Vec3fa& p1 = mesh->vertex(tri.v[1]); + const Vec3fa& p2 = mesh->vertex(tri.v[2]); + lower = min(lower,(vfloat4)p0,(vfloat4)p1,(vfloat4)p2); + upper = max(upper,(vfloat4)p0,(vfloat4)p1,(vfloat4)p2); + vgeomID [i] = geomID_; + vprimID [i] = primID; + v0.x[i] = p0.x; v0.y[i] = p0.y; v0.z[i] = p0.z; + v1.x[i] = p1.x; v1.y[i] = p1.y; v1.z[i] = p1.z; + v2.x[i] = p2.x; v2.y[i] = p2.y; v2.z[i] = p2.z; + } + + Triangle4::store_nt(accel,Triangle4(v0,v1,v2,vgeomID,vprimID)); + BBox3fx box_o = BBox3fx((Vec3fx)lower,(Vec3fx)upper); +#if ROTATE_TREE + if (N == 4) + box_o.lower.a = unsigned(current.size()); +#endif + return NodeRecord(ref,box_o); + } + + private: + TriangleMesh* mesh; + BVHBuilderMorton::BuildPrim* morton; + unsigned int geomID_ = std::numeric_limits::max(); + }; + + template + struct CreateMortonLeaf + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::NodeRecord NodeRecord; + + __forceinline CreateMortonLeaf (TriangleMesh* mesh, unsigned int geomID, BVHBuilderMorton::BuildPrim* morton) + : mesh(mesh), morton(morton), geomID_(geomID) {} + + __noinline NodeRecord operator() (const range& current, const FastAllocator::CachedAllocator& alloc) + { + vfloat4 lower(pos_inf); + vfloat4 upper(neg_inf); + size_t items = current.size(); + size_t start = current.begin(); + assert(items<=4); + + /* allocate leaf node */ + Triangle4v* accel = (Triangle4v*) alloc.malloc1(sizeof(Triangle4v),BVH::byteAlignment); + NodeRef ref = BVH::encodeLeaf((char*)accel,1); + vuint4 vgeomID = -1, vprimID = -1; + Vec3vf4 v0 = zero, v1 = zero, v2 = zero; + const TriangleMesh* __restrict__ mesh = this->mesh; + + for (size_t i=0; itriangle(primID); + const Vec3fa& p0 = mesh->vertex(tri.v[0]); + const Vec3fa& p1 = mesh->vertex(tri.v[1]); + const Vec3fa& p2 = mesh->vertex(tri.v[2]); + lower = min(lower,(vfloat4)p0,(vfloat4)p1,(vfloat4)p2); + upper = max(upper,(vfloat4)p0,(vfloat4)p1,(vfloat4)p2); + vgeomID [i] = geomID_; + vprimID [i] = primID; + v0.x[i] = p0.x; v0.y[i] = p0.y; v0.z[i] = p0.z; + v1.x[i] = p1.x; v1.y[i] = p1.y; v1.z[i] = p1.z; + v2.x[i] = p2.x; v2.y[i] = p2.y; v2.z[i] = p2.z; + } + Triangle4v::store_nt(accel,Triangle4v(v0,v1,v2,vgeomID,vprimID)); + BBox3fx box_o = BBox3fx((Vec3fx)lower,(Vec3fx)upper); +#if ROTATE_TREE + if (N == 4) + box_o.lower.a = current.size(); +#endif + return NodeRecord(ref,box_o); + } + private: + TriangleMesh* mesh; + BVHBuilderMorton::BuildPrim* morton; + unsigned int geomID_ = std::numeric_limits::max(); + }; + + template + struct CreateMortonLeaf + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::NodeRecord NodeRecord; + + __forceinline CreateMortonLeaf (TriangleMesh* mesh, unsigned int geomID, BVHBuilderMorton::BuildPrim* morton) + : mesh(mesh), morton(morton), geomID_(geomID) {} + + __noinline NodeRecord operator() (const range& current, const FastAllocator::CachedAllocator& alloc) + { + vfloat4 lower(pos_inf); + vfloat4 upper(neg_inf); + size_t items = current.size(); + size_t start = current.begin(); + assert(items<=4); + + /* allocate leaf node */ + Triangle4i* accel = (Triangle4i*) alloc.malloc1(sizeof(Triangle4i),BVH::byteAlignment); + NodeRef ref = BVH::encodeLeaf((char*)accel,1); + + vuint4 v0 = zero, v1 = zero, v2 = zero; + vuint4 vgeomID = -1, vprimID = -1; + const TriangleMesh* __restrict__ const mesh = this->mesh; + + for (size_t i=0; itriangle(primID); + const Vec3fa& p0 = mesh->vertex(tri.v[0]); + const Vec3fa& p1 = mesh->vertex(tri.v[1]); + const Vec3fa& p2 = mesh->vertex(tri.v[2]); + lower = min(lower,(vfloat4)p0,(vfloat4)p1,(vfloat4)p2); + upper = max(upper,(vfloat4)p0,(vfloat4)p1,(vfloat4)p2); + vgeomID[i] = geomID_; + vprimID[i] = primID; + unsigned int int_stride = mesh->vertices0.getStride()/4; + v0[i] = tri.v[0] * int_stride; + v1[i] = tri.v[1] * int_stride; + v2[i] = tri.v[2] * int_stride; + } + + for (size_t i=items; i<4; i++) + { + vgeomID[i] = vgeomID[0]; + vprimID[i] = -1; + v0[i] = 0; + v1[i] = 0; + v2[i] = 0; + } + Triangle4i::store_nt(accel,Triangle4i(v0,v1,v2,vgeomID,vprimID)); + BBox3fx box_o = BBox3fx((Vec3fx)lower,(Vec3fx)upper); +#if ROTATE_TREE + if (N == 4) + box_o.lower.a = current.size(); +#endif + return NodeRecord(ref,box_o); + } + private: + TriangleMesh* mesh; + BVHBuilderMorton::BuildPrim* morton; + unsigned int geomID_ = std::numeric_limits::max(); + }; + + template + struct CreateMortonLeaf + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::NodeRecord NodeRecord; + + __forceinline CreateMortonLeaf (QuadMesh* mesh, unsigned int geomID, BVHBuilderMorton::BuildPrim* morton) + : mesh(mesh), morton(morton), geomID_(geomID) {} + + __noinline NodeRecord operator() (const range& current, const FastAllocator::CachedAllocator& alloc) + { + vfloat4 lower(pos_inf); + vfloat4 upper(neg_inf); + size_t items = current.size(); + size_t start = current.begin(); + assert(items<=4); + + /* allocate leaf node */ + Quad4v* accel = (Quad4v*) alloc.malloc1(sizeof(Quad4v),BVH::byteAlignment); + NodeRef ref = BVH::encodeLeaf((char*)accel,1); + + vuint4 vgeomID = -1, vprimID = -1; + Vec3vf4 v0 = zero, v1 = zero, v2 = zero, v3 = zero; + const QuadMesh* __restrict__ mesh = this->mesh; + + for (size_t i=0; iquad(primID); + const Vec3fa& p0 = mesh->vertex(tri.v[0]); + const Vec3fa& p1 = mesh->vertex(tri.v[1]); + const Vec3fa& p2 = mesh->vertex(tri.v[2]); + const Vec3fa& p3 = mesh->vertex(tri.v[3]); + lower = min(lower,(vfloat4)p0,(vfloat4)p1,(vfloat4)p2,(vfloat4)p3); + upper = max(upper,(vfloat4)p0,(vfloat4)p1,(vfloat4)p2,(vfloat4)p3); + vgeomID [i] = geomID_; + vprimID [i] = primID; + v0.x[i] = p0.x; v0.y[i] = p0.y; v0.z[i] = p0.z; + v1.x[i] = p1.x; v1.y[i] = p1.y; v1.z[i] = p1.z; + v2.x[i] = p2.x; v2.y[i] = p2.y; v2.z[i] = p2.z; + v3.x[i] = p3.x; v3.y[i] = p3.y; v3.z[i] = p3.z; + } + Quad4v::store_nt(accel,Quad4v(v0,v1,v2,v3,vgeomID,vprimID)); + BBox3fx box_o = BBox3fx((Vec3fx)lower,(Vec3fx)upper); +#if ROTATE_TREE + if (N == 4) + box_o.lower.a = current.size(); +#endif + return NodeRecord(ref,box_o); + } + private: + QuadMesh* mesh; + BVHBuilderMorton::BuildPrim* morton; + unsigned int geomID_ = std::numeric_limits::max(); + }; + + template + struct CreateMortonLeaf + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::NodeRecord NodeRecord; + + __forceinline CreateMortonLeaf (UserGeometry* mesh, unsigned int geomID, BVHBuilderMorton::BuildPrim* morton) + : mesh(mesh), morton(morton), geomID_(geomID) {} + + __noinline NodeRecord operator() (const range& current, const FastAllocator::CachedAllocator& alloc) + { + vfloat4 lower(pos_inf); + vfloat4 upper(neg_inf); + size_t items = current.size(); + size_t start = current.begin(); + + /* allocate leaf node */ + Object* accel = (Object*) alloc.malloc1(items*sizeof(Object),BVH::byteAlignment); + NodeRef ref = BVH::encodeLeaf((char*)accel,items); + const UserGeometry* mesh = this->mesh; + + BBox3fa bounds = empty; + for (size_t i=0; ibounds(primID)); + new (&accel[i]) Object(geomID_,primID); + } + + BBox3fx box_o = (BBox3fx&)bounds; +#if ROTATE_TREE + if (N == 4) + box_o.lower.a = current.size(); +#endif + return NodeRecord(ref,box_o); + } + private: + UserGeometry* mesh; + BVHBuilderMorton::BuildPrim* morton; + unsigned int geomID_ = std::numeric_limits::max(); + }; + + template + struct CreateMortonLeaf + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::NodeRecord NodeRecord; + + __forceinline CreateMortonLeaf (Instance* mesh, unsigned int geomID, BVHBuilderMorton::BuildPrim* morton) + : mesh(mesh), morton(morton), geomID_(geomID) {} + + __noinline NodeRecord operator() (const range& current, const FastAllocator::CachedAllocator& alloc) + { + vfloat4 lower(pos_inf); + vfloat4 upper(neg_inf); + size_t items = current.size(); + size_t start = current.begin(); + assert(items <= 1); + + /* allocate leaf node */ + InstancePrimitive* accel = (InstancePrimitive*) alloc.malloc1(items*sizeof(InstancePrimitive),BVH::byteAlignment); + NodeRef ref = BVH::encodeLeaf((char*)accel,items); + const Instance* instance = this->mesh; + + BBox3fa bounds = empty; + for (size_t i=0; ibounds(primID)); + new (&accel[i]) InstancePrimitive(instance, geomID_); + } + + BBox3fx box_o = (BBox3fx&)bounds; +#if ROTATE_TREE + if (N == 4) + box_o.lower.a = current.size(); +#endif + return NodeRecord(ref,box_o); + } + private: + Instance* mesh; + BVHBuilderMorton::BuildPrim* morton; + unsigned int geomID_ = std::numeric_limits::max(); + }; + + template + struct CalculateMeshBounds + { + __forceinline CalculateMeshBounds (Mesh* mesh) + : mesh(mesh) {} + + __forceinline const BBox3fa operator() (const BVHBuilderMorton::BuildPrim& morton) { + return mesh->bounds(morton.index); + } + + private: + Mesh* mesh; + }; + + template + class BVHNMeshBuilderMorton : public Builder + { + typedef BVHN BVH; + typedef typename BVH::AABBNode AABBNode; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::NodeRecord NodeRecord; + + public: + + BVHNMeshBuilderMorton (BVH* bvh, Mesh* mesh, unsigned int geomID, const size_t minLeafSize, const size_t maxLeafSize, const size_t singleThreadThreshold = DEFAULT_SINGLE_THREAD_THRESHOLD) + : bvh(bvh), mesh(mesh), morton(bvh->device,0), settings(N,BVH::maxBuildDepth,minLeafSize,min(maxLeafSize,Primitive::max_size()*BVH::maxLeafBlocks),singleThreadThreshold), geomID_(geomID) {} + + /* build function */ + void build() + { + /* we reset the allocator when the mesh size changed */ + if (mesh->numPrimitives != numPreviousPrimitives) { + bvh->alloc.clear(); + morton.clear(); + } + size_t numPrimitives = mesh->size(); + numPreviousPrimitives = numPrimitives; + + /* skip build for empty scene */ + if (numPrimitives == 0) { + bvh->set(BVH::emptyNode,empty,0); + return; + } + + /* preallocate arrays */ + morton.resize(numPrimitives); + size_t bytesEstimated = numPrimitives*sizeof(AABBNode)/(4*N) + size_t(1.2f*Primitive::blocks(numPrimitives)*sizeof(Primitive)); + size_t bytesMortonCodes = numPrimitives*sizeof(BVHBuilderMorton::BuildPrim); + bytesEstimated = max(bytesEstimated,bytesMortonCodes); // the first allocation block is reused to sort the morton codes + bvh->alloc.init(bytesMortonCodes,bytesMortonCodes,bytesEstimated); + + /* create morton code array */ + BVHBuilderMorton::BuildPrim* dest = (BVHBuilderMorton::BuildPrim*) bvh->alloc.specialAlloc(bytesMortonCodes); + size_t numPrimitivesGen = createMortonCodeArray(mesh,morton,bvh->scene->progressInterface); + + /* create BVH */ + SetBVHNBounds setBounds(bvh); + CreateMortonLeaf createLeaf(mesh,geomID_,morton.data()); + CalculateMeshBounds calculateBounds(mesh); + auto root = BVHBuilderMorton::build( + typename BVH::CreateAlloc(bvh), + typename BVH::AABBNode::Create(), + setBounds,createLeaf,calculateBounds,bvh->scene->progressInterface, + morton.data(),dest,numPrimitivesGen,settings); + + bvh->set(root.ref,LBBox3fa(root.bounds),numPrimitives); + +#if ROTATE_TREE + if (N == 4) + { + for (int i=0; i::rotate(bvh->root); + bvh->clearBarrier(bvh->root); + } +#endif + + /* clear temporary data for static geometry */ + if (bvh->scene->isStaticAccel()) { + morton.clear(); + } + bvh->cleanup(); + } + + void clear() { + morton.clear(); + } + + private: + BVH* bvh; + Mesh* mesh; + mvector morton; + BVHBuilderMorton::Settings settings; + unsigned int geomID_ = std::numeric_limits::max(); + unsigned int numPreviousPrimitives = 0; + }; + +#if defined(EMBREE_GEOMETRY_TRIANGLE) + Builder* BVH4Triangle4MeshBuilderMortonGeneral (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new class BVHNMeshBuilderMorton<4,TriangleMesh,Triangle4> ((BVH4*)bvh,mesh,geomID,4,4); } + Builder* BVH4Triangle4vMeshBuilderMortonGeneral (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new class BVHNMeshBuilderMorton<4,TriangleMesh,Triangle4v>((BVH4*)bvh,mesh,geomID,4,4); } + Builder* BVH4Triangle4iMeshBuilderMortonGeneral (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new class BVHNMeshBuilderMorton<4,TriangleMesh,Triangle4i>((BVH4*)bvh,mesh,geomID,4,4); } +#if defined(__AVX__) + Builder* BVH8Triangle4MeshBuilderMortonGeneral (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new class BVHNMeshBuilderMorton<8,TriangleMesh,Triangle4> ((BVH8*)bvh,mesh,geomID,4,4); } + Builder* BVH8Triangle4vMeshBuilderMortonGeneral (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new class BVHNMeshBuilderMorton<8,TriangleMesh,Triangle4v>((BVH8*)bvh,mesh,geomID,4,4); } + Builder* BVH8Triangle4iMeshBuilderMortonGeneral (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new class BVHNMeshBuilderMorton<8,TriangleMesh,Triangle4i>((BVH8*)bvh,mesh,geomID,4,4); } +#endif +#endif + +#if defined(EMBREE_GEOMETRY_QUAD) + Builder* BVH4Quad4vMeshBuilderMortonGeneral (void* bvh, QuadMesh* mesh, unsigned int geomID, size_t mode) { return new class BVHNMeshBuilderMorton<4,QuadMesh,Quad4v>((BVH4*)bvh,mesh,geomID,4,4); } +#if defined(__AVX__) + Builder* BVH8Quad4vMeshBuilderMortonGeneral (void* bvh, QuadMesh* mesh, unsigned int geomID, size_t mode) { return new class BVHNMeshBuilderMorton<8,QuadMesh,Quad4v>((BVH8*)bvh,mesh,geomID,4,4); } +#endif +#endif + +#if defined(EMBREE_GEOMETRY_USER) + Builder* BVH4VirtualMeshBuilderMortonGeneral (void* bvh, UserGeometry* mesh, unsigned int geomID, size_t mode) { return new class BVHNMeshBuilderMorton<4,UserGeometry,Object>((BVH4*)bvh,mesh,geomID,1,BVH4::maxLeafBlocks); } +#if defined(__AVX__) + Builder* BVH8VirtualMeshBuilderMortonGeneral (void* bvh, UserGeometry* mesh, unsigned int geomID, size_t mode) { return new class BVHNMeshBuilderMorton<8,UserGeometry,Object>((BVH8*)bvh,mesh,geomID,1,BVH4::maxLeafBlocks); } +#endif +#endif + +#if defined(EMBREE_GEOMETRY_INSTANCE) + Builder* BVH4InstanceMeshBuilderMortonGeneral (void* bvh, Instance* mesh, Geometry::GTypeMask gtype, unsigned int geomID, size_t mode) { return new class BVHNMeshBuilderMorton<4,Instance,InstancePrimitive>((BVH4*)bvh,mesh,gtype,geomID,1,BVH4::maxLeafBlocks); } +#if defined(__AVX__) + Builder* BVH8InstanceMeshBuilderMortonGeneral (void* bvh, Instance* mesh, Geometry::GTypeMask gtype, unsigned int geomID, size_t mode) { return new class BVHNMeshBuilderMorton<8,Instance,InstancePrimitive>((BVH8*)bvh,mesh,gtype,geomID,1,BVH4::maxLeafBlocks); } +#endif +#endif + + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_builder_sah.cpp b/thirdparty/embree/kernels/bvh/bvh_builder_sah.cpp new file mode 100644 index 000000000000..cf5b2eb47f83 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_builder_sah.cpp @@ -0,0 +1,640 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh.h" +#include "bvh_builder.h" +#include "../builders/primrefgen.h" +#include "../builders/splitter.h" + +#include "../geometry/linei.h" +#include "../geometry/triangle.h" +#include "../geometry/trianglev.h" +#include "../geometry/trianglev_mb.h" +#include "../geometry/trianglei.h" +#include "../geometry/quadv.h" +#include "../geometry/quadi.h" +#include "../geometry/object.h" +#include "../geometry/instance.h" +#include "../geometry/subgrid.h" + +#include "../common/state.h" +#include "../../common/algorithms/parallel_for_for.h" +#include "../../common/algorithms/parallel_for_for_prefix_sum.h" + +#define PROFILE 0 +#define PROFILE_RUNS 20 + +namespace embree +{ + namespace isa + { + template + struct CreateLeaf + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + + __forceinline CreateLeaf (BVH* bvh) : bvh(bvh) {} + + __forceinline NodeRef operator() (const PrimRef* prims, const range& set, const FastAllocator::CachedAllocator& alloc) const + { + size_t n = set.size(); + size_t items = Primitive::blocks(n); + size_t start = set.begin(); + Primitive* accel = (Primitive*) alloc.malloc1(items*sizeof(Primitive),BVH::byteAlignment); + typename BVH::NodeRef node = BVH::encodeLeaf((char*)accel,items); + for (size_t i=0; iscene); + } + return node; + } + + BVH* bvh; + }; + + + template + struct CreateLeafQuantized + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + + __forceinline CreateLeafQuantized (BVH* bvh) : bvh(bvh) {} + + __forceinline NodeRef operator() (const PrimRef* prims, const range& set, const FastAllocator::CachedAllocator& alloc) const + { + size_t n = set.size(); + size_t items = Primitive::blocks(n); + size_t start = set.begin(); + Primitive* accel = (Primitive*) alloc.malloc1(items*sizeof(Primitive),BVH::byteAlignment); + typename BVH::NodeRef node = BVH::encodeLeaf((char*)accel,items); + for (size_t i=0; iscene); + } + return node; + } + + BVH* bvh; + }; + + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + + template + struct BVHNBuilderSAH : public Builder + { + typedef BVHN BVH; + typedef typename BVHN::NodeRef NodeRef; + + BVH* bvh; + Scene* scene; + Geometry* mesh; + mvector prims; + GeneralBVHBuilder::Settings settings; + Geometry::GTypeMask gtype_; + unsigned int geomID_ = std::numeric_limits::max (); + bool primrefarrayalloc; + unsigned int numPreviousPrimitives = 0; + + BVHNBuilderSAH (BVH* bvh, Scene* scene, const size_t sahBlockSize, const float intCost, const size_t minLeafSize, const size_t maxLeafSize, + const Geometry::GTypeMask gtype, bool primrefarrayalloc = false) + : bvh(bvh), scene(scene), mesh(nullptr), prims(scene->device,0), + settings(sahBlockSize, minLeafSize, min(maxLeafSize,Primitive::max_size()*BVH::maxLeafBlocks), travCost, intCost, DEFAULT_SINGLE_THREAD_THRESHOLD), gtype_(gtype), primrefarrayalloc(primrefarrayalloc) {} + + BVHNBuilderSAH (BVH* bvh, Geometry* mesh, unsigned int geomID, const size_t sahBlockSize, const float intCost, const size_t minLeafSize, const size_t maxLeafSize, const Geometry::GTypeMask gtype) + : bvh(bvh), scene(nullptr), mesh(mesh), prims(bvh->device,0), settings(sahBlockSize, minLeafSize, min(maxLeafSize,Primitive::max_size()*BVH::maxLeafBlocks), travCost, intCost, DEFAULT_SINGLE_THREAD_THRESHOLD), gtype_(gtype), geomID_(geomID), primrefarrayalloc(false) {} + + // FIXME: shrink bvh->alloc in destructor here and in other builders too + + void build() + { + /* we reset the allocator when the mesh size changed */ + if (mesh && mesh->numPrimitives != numPreviousPrimitives) { + bvh->alloc.clear(); + } + + /* if we use the primrefarray for allocations we have to take it back from the BVH */ + if (settings.primrefarrayalloc != size_t(inf)) + bvh->alloc.unshare(prims); + + /* skip build for empty scene */ + const size_t numPrimitives = mesh ? mesh->size() : scene->getNumPrimitives(gtype_,false); + numPreviousPrimitives = numPrimitives; + if (numPrimitives == 0) { + bvh->clear(); + prims.clear(); + return; + } + + double t0 = bvh->preBuild(mesh ? "" : TOSTRING(isa) "::BVH" + toString(N) + "BuilderSAH"); + +#if PROFILE + profile(2,PROFILE_RUNS,numPrimitives,[&] (ProfileTimer& timer) { +#endif + + /* create primref array */ + if (primrefarrayalloc) { + settings.primrefarrayalloc = numPrimitives/1000; + if (settings.primrefarrayalloc < 1000) + settings.primrefarrayalloc = inf; + } + + /* enable os_malloc for two level build */ + if (mesh) + bvh->alloc.setOSallocation(true); + + /* initialize allocator */ + const size_t node_bytes = numPrimitives*sizeof(typename BVH::AABBNodeMB)/(4*N); + const size_t leaf_bytes = size_t(1.2*Primitive::blocks(numPrimitives)*sizeof(Primitive)); + bvh->alloc.init_estimate(node_bytes+leaf_bytes); + settings.singleThreadThreshold = bvh->alloc.fixSingleThreadThreshold(N,DEFAULT_SINGLE_THREAD_THRESHOLD,numPrimitives,node_bytes+leaf_bytes); + prims.resize(numPrimitives); + + PrimInfo pinfo = mesh ? + createPrimRefArray(mesh,geomID_,prims,bvh->scene->progressInterface) : + createPrimRefArray(scene,gtype_,false,prims,bvh->scene->progressInterface); + + /* pinfo might has zero size due to invalid geometry */ + if (unlikely(pinfo.size() == 0)) + { + bvh->clear(); + prims.clear(); + return; + } + + /* call BVH builder */ + NodeRef root = BVHNBuilderVirtual::build(&bvh->alloc,CreateLeaf(bvh),bvh->scene->progressInterface,prims.data(),pinfo,settings); + bvh->set(root,LBBox3fa(pinfo.geomBounds),pinfo.size()); + bvh->layoutLargeNodes(size_t(pinfo.size()*0.005f)); + +#if PROFILE + }); +#endif + + /* if we allocated using the primrefarray we have to keep it alive */ + if (settings.primrefarrayalloc != size_t(inf)) + bvh->alloc.share(prims); + + /* for static geometries we can do some cleanups */ + else if (scene && scene->isStaticAccel()) { + prims.clear(); + } + bvh->cleanup(); + bvh->postBuild(t0); + } + + void clear() { + prims.clear(); + } + }; + + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + + template + struct BVHNBuilderSAHQuantized : public Builder + { + typedef BVHN BVH; + typedef typename BVHN::NodeRef NodeRef; + + BVH* bvh; + Scene* scene; + Geometry* mesh; + mvector prims; + GeneralBVHBuilder::Settings settings; + Geometry::GTypeMask gtype_; + unsigned int geomID_ = std::numeric_limits::max(); + unsigned int numPreviousPrimitives = 0; + + BVHNBuilderSAHQuantized (BVH* bvh, Scene* scene, const size_t sahBlockSize, const float intCost, const size_t minLeafSize, const size_t maxLeafSize, const Geometry::GTypeMask gtype) + : bvh(bvh), scene(scene), mesh(nullptr), prims(scene->device,0), settings(sahBlockSize, minLeafSize, min(maxLeafSize,Primitive::max_size()*BVH::maxLeafBlocks), travCost, intCost, DEFAULT_SINGLE_THREAD_THRESHOLD), gtype_(gtype) {} + + BVHNBuilderSAHQuantized (BVH* bvh, Geometry* mesh, unsigned int geomID, const size_t sahBlockSize, const float intCost, const size_t minLeafSize, const size_t maxLeafSize, const Geometry::GTypeMask gtype) + : bvh(bvh), scene(nullptr), mesh(mesh), prims(bvh->device,0), settings(sahBlockSize, minLeafSize, min(maxLeafSize,Primitive::max_size()*BVH::maxLeafBlocks), travCost, intCost, DEFAULT_SINGLE_THREAD_THRESHOLD), gtype_(gtype), geomID_(geomID) {} + + // FIXME: shrink bvh->alloc in destructor here and in other builders too + + void build() + { + /* we reset the allocator when the mesh size changed */ + if (mesh && mesh->numPrimitives != numPreviousPrimitives) { + bvh->alloc.clear(); + } + + /* skip build for empty scene */ + const size_t numPrimitives = mesh ? mesh->size() : scene->getNumPrimitives(gtype_,false); + numPreviousPrimitives = numPrimitives; + if (numPrimitives == 0) { + prims.clear(); + bvh->clear(); + return; + } + + double t0 = bvh->preBuild(mesh ? "" : TOSTRING(isa) "::QBVH" + toString(N) + "BuilderSAH"); + +#if PROFILE + profile(2,PROFILE_RUNS,numPrimitives,[&] (ProfileTimer& timer) { +#endif + /* create primref array */ + prims.resize(numPrimitives); + PrimInfo pinfo = mesh ? + createPrimRefArray(mesh,geomID_,prims,bvh->scene->progressInterface) : + createPrimRefArray(scene,gtype_,false,prims,bvh->scene->progressInterface); + + /* enable os_malloc for two level build */ + if (mesh) + bvh->alloc.setOSallocation(true); + + /* call BVH builder */ + const size_t node_bytes = numPrimitives*sizeof(typename BVH::QuantizedNode)/(4*N); + const size_t leaf_bytes = size_t(1.2*Primitive::blocks(numPrimitives)*sizeof(Primitive)); + bvh->alloc.init_estimate(node_bytes+leaf_bytes); + settings.singleThreadThreshold = bvh->alloc.fixSingleThreadThreshold(N,DEFAULT_SINGLE_THREAD_THRESHOLD,numPrimitives,node_bytes+leaf_bytes); + NodeRef root = BVHNBuilderQuantizedVirtual::build(&bvh->alloc,CreateLeafQuantized(bvh),bvh->scene->progressInterface,prims.data(),pinfo,settings); + bvh->set(root,LBBox3fa(pinfo.geomBounds),pinfo.size()); + //bvh->layoutLargeNodes(pinfo.size()*0.005f); // FIXME: COPY LAYOUT FOR LARGE NODES !!! +#if PROFILE + }); +#endif + + /* clear temporary data for static geometry */ + if (scene && scene->isStaticAccel()) { + prims.clear(); + } + bvh->cleanup(); + bvh->postBuild(t0); + } + + void clear() { + prims.clear(); + } + }; + + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + + + template + struct CreateLeafGrid + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + + __forceinline CreateLeafGrid (BVH* bvh, const SubGridBuildData * const sgrids) : bvh(bvh),sgrids(sgrids) {} + + __forceinline NodeRef operator() (const PrimRef* prims, const range& set, const FastAllocator::CachedAllocator& alloc) const + { + const size_t items = set.size(); //Primitive::blocks(n); + const size_t start = set.begin(); + + /* collect all subsets with unique geomIDs */ + assert(items <= N); + unsigned int geomIDs[N]; + unsigned int num_geomIDs = 1; + geomIDs[0] = prims[start].geomID(); + + for (size_t i=1;i* accel = (SubGridQBVHN*) alloc.malloc1(num_geomIDs*sizeof(SubGridQBVHN),BVH::byteAlignment); + typename BVH::NodeRef node = BVH::encodeLeaf((char*)accel,num_geomIDs); + + for (size_t g=0;g(x,y,primID,bounds,geomIDs[g],pos); + } + + return node; + } + + BVH* bvh; + const SubGridBuildData * const sgrids; + }; + + + template + struct BVHNBuilderSAHGrid : public Builder + { + typedef BVHN BVH; + typedef typename BVHN::NodeRef NodeRef; + + BVH* bvh; + Scene* scene; + GridMesh* mesh; + mvector prims; + mvector sgrids; + GeneralBVHBuilder::Settings settings; + unsigned int geomID_ = std::numeric_limits::max(); + unsigned int numPreviousPrimitives = 0; + + BVHNBuilderSAHGrid (BVH* bvh, Scene* scene, const size_t sahBlockSize, const float intCost, const size_t minLeafSize, const size_t maxLeafSize, const size_t mode) + : bvh(bvh), scene(scene), mesh(nullptr), prims(scene->device,0), sgrids(scene->device,0), settings(sahBlockSize, minLeafSize, min(maxLeafSize,BVH::maxLeafBlocks), travCost, intCost, DEFAULT_SINGLE_THREAD_THRESHOLD) {} + + BVHNBuilderSAHGrid (BVH* bvh, GridMesh* mesh, unsigned int geomID, const size_t sahBlockSize, const float intCost, const size_t minLeafSize, const size_t maxLeafSize, const size_t mode) + : bvh(bvh), scene(nullptr), mesh(mesh), prims(bvh->device,0), sgrids(scene->device,0), settings(sahBlockSize, minLeafSize, min(maxLeafSize,BVH::maxLeafBlocks), travCost, intCost, DEFAULT_SINGLE_THREAD_THRESHOLD), geomID_(geomID) {} + + void build() + { + /* we reset the allocator when the mesh size changed */ + if (mesh && mesh->numPrimitives != numPreviousPrimitives) { + bvh->alloc.clear(); + } + + /* if we use the primrefarray for allocations we have to take it back from the BVH */ + if (settings.primrefarrayalloc != size_t(inf)) + bvh->alloc.unshare(prims); + + const size_t numGridPrimitives = mesh ? mesh->size() : scene->getNumPrimitives(GridMesh::geom_type,false); + numPreviousPrimitives = numGridPrimitives; + + PrimInfo pinfo(empty); + size_t numPrimitives = 0; + + if (!mesh) + { + /* first run to get #primitives */ + + ParallelForForPrefixSumState pstate; + Scene::Iterator iter(scene); + + pstate.init(iter,size_t(1024)); + + /* iterate over all meshes in the scene */ + pinfo = parallel_for_for_prefix_sum0( pstate, iter, PrimInfo(empty), [&](GridMesh* mesh, const range& r, size_t k, size_t geomID) -> PrimInfo { + PrimInfo pinfo(empty); + for (size_t j=r.begin(); jvalid(j)) continue; + BBox3fa bounds = empty; + const PrimRef prim(bounds,(unsigned)geomID,(unsigned)j); + if (!mesh->valid(j)) continue; + pinfo.add_center2(prim,mesh->getNumSubGrids(j)); + } + return pinfo; + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + numPrimitives = pinfo.size(); + + /* resize arrays */ + sgrids.resize(numPrimitives); + prims.resize(numPrimitives); + + /* second run to fill primrefs and SubGridBuildData arrays */ + pinfo = parallel_for_for_prefix_sum1( pstate, iter, PrimInfo(empty), [&](GridMesh* mesh, const range& r, size_t k, size_t geomID, const PrimInfo& base) -> PrimInfo { + k = base.size(); + size_t p_index = k; + PrimInfo pinfo(empty); + for (size_t j=r.begin(); jvalid(j)) continue; + const GridMesh::Grid &g = mesh->grid(j); + for (unsigned int y=0; ybuildBounds(g,x,y,bounds)) continue; // get bounds of subgrid + const PrimRef prim(bounds,(unsigned)geomID,(unsigned)p_index); + pinfo.add_center2(prim); + sgrids[p_index] = SubGridBuildData(x | g.get3x3FlagsX(x), y | g.get3x3FlagsY(y), unsigned(j)); + prims[p_index++] = prim; + } + } + return pinfo; + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + assert(pinfo.size() == numPrimitives); + } + else + { + ParallelPrefixSumState pstate; + /* iterate over all grids in a single mesh */ + pinfo = parallel_prefix_sum( pstate, size_t(0), mesh->size(), size_t(1024), PrimInfo(empty), [&](const range& r, const PrimInfo& base) -> PrimInfo + { + PrimInfo pinfo(empty); + for (size_t j=r.begin(); jvalid(j)) continue; + BBox3fa bounds = empty; + const PrimRef prim(bounds,geomID_,unsigned(j)); + pinfo.add_center2(prim,mesh->getNumSubGrids(j)); + } + return pinfo; + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + numPrimitives = pinfo.size(); + /* resize arrays */ + sgrids.resize(numPrimitives); + prims.resize(numPrimitives); + + /* second run to fill primrefs and SubGridBuildData arrays */ + pinfo = parallel_prefix_sum( pstate, size_t(0), mesh->size(), size_t(1024), PrimInfo(empty), [&](const range& r, const PrimInfo& base) -> PrimInfo + { + + size_t p_index = base.size(); + PrimInfo pinfo(empty); + for (size_t j=r.begin(); jvalid(j)) continue; + const GridMesh::Grid &g = mesh->grid(j); + for (unsigned int y=0; ybuildBounds(g,x,y,bounds)) continue; // get bounds of subgrid + const PrimRef prim(bounds,geomID_,unsigned(p_index)); + pinfo.add_center2(prim); + sgrids[p_index] = SubGridBuildData(x | g.get3x3FlagsX(x), y | g.get3x3FlagsY(y), unsigned(j)); + prims[p_index++] = prim; + } + } + return pinfo; + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + + } + + /* no primitives */ + if (numPrimitives == 0) { + bvh->clear(); + prims.clear(); + sgrids.clear(); + return; + } + + double t0 = bvh->preBuild(mesh ? "" : TOSTRING(isa) "::BVH" + toString(N) + "BuilderSAH"); + + /* create primref array */ + settings.primrefarrayalloc = numPrimitives/1000; + if (settings.primrefarrayalloc < 1000) + settings.primrefarrayalloc = inf; + + /* enable os_malloc for two level build */ + if (mesh) + bvh->alloc.setOSallocation(true); + + /* initialize allocator */ + const size_t node_bytes = numPrimitives*sizeof(typename BVH::AABBNodeMB)/(4*N); + const size_t leaf_bytes = size_t(1.2*(float)numPrimitives/N * sizeof(SubGridQBVHN)); + + bvh->alloc.init_estimate(node_bytes+leaf_bytes); + settings.singleThreadThreshold = bvh->alloc.fixSingleThreadThreshold(N,DEFAULT_SINGLE_THREAD_THRESHOLD,numPrimitives,node_bytes+leaf_bytes); + + /* pinfo might has zero size due to invalid geometry */ + if (unlikely(pinfo.size() == 0)) + { + bvh->clear(); + sgrids.clear(); + prims.clear(); + return; + } + + /* call BVH builder */ + NodeRef root = BVHNBuilderVirtual::build(&bvh->alloc,CreateLeafGrid>(bvh,sgrids.data()),bvh->scene->progressInterface,prims.data(),pinfo,settings); + bvh->set(root,LBBox3fa(pinfo.geomBounds),pinfo.size()); + bvh->layoutLargeNodes(size_t(pinfo.size()*0.005f)); + + /* clear temporary array */ + sgrids.clear(); + + /* if we allocated using the primrefarray we have to keep it alive */ + if (settings.primrefarrayalloc != size_t(inf)) + bvh->alloc.share(prims); + + /* for static geometries we can do some cleanups */ + else if (scene && scene->isStaticAccel()) { + prims.clear(); + } + bvh->cleanup(); + bvh->postBuild(t0); + } + + void clear() { + prims.clear(); + } + }; + + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + +#if defined(EMBREE_GEOMETRY_TRIANGLE) + Builder* BVH4Triangle4MeshBuilderSAH (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNBuilderSAH<4,Triangle4>((BVH4*)bvh,mesh,geomID,4,1.0f,4,inf,TriangleMesh::geom_type); } + Builder* BVH4Triangle4vMeshBuilderSAH (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNBuilderSAH<4,Triangle4v>((BVH4*)bvh,mesh,geomID,4,1.0f,4,inf,TriangleMesh::geom_type); } + Builder* BVH4Triangle4iMeshBuilderSAH (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNBuilderSAH<4,Triangle4i>((BVH4*)bvh,mesh,geomID,4,1.0f,4,inf,TriangleMesh::geom_type); } + + Builder* BVH4Triangle4SceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAH<4,Triangle4>((BVH4*)bvh,scene,4,1.0f,4,inf,TriangleMesh::geom_type); } + Builder* BVH4Triangle4vSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAH<4,Triangle4v>((BVH4*)bvh,scene,4,1.0f,4,inf,TriangleMesh::geom_type); } + Builder* BVH4Triangle4iSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAH<4,Triangle4i>((BVH4*)bvh,scene,4,1.0f,4,inf,TriangleMesh::geom_type,true); } + + + Builder* BVH4QuantizedTriangle4iSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAHQuantized<4,Triangle4i>((BVH4*)bvh,scene,4,1.0f,4,inf,TriangleMesh::geom_type); } +#if defined(__AVX__) + Builder* BVH8Triangle4MeshBuilderSAH (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNBuilderSAH<8,Triangle4>((BVH8*)bvh,mesh,geomID,4,1.0f,4,inf,TriangleMesh::geom_type); } + Builder* BVH8Triangle4vMeshBuilderSAH (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNBuilderSAH<8,Triangle4v>((BVH8*)bvh,mesh,geomID,4,1.0f,4,inf,TriangleMesh::geom_type); } + Builder* BVH8Triangle4iMeshBuilderSAH (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNBuilderSAH<8,Triangle4i>((BVH8*)bvh,mesh,geomID,4,1.0f,4,inf,TriangleMesh::geom_type); } + + Builder* BVH8Triangle4SceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAH<8,Triangle4>((BVH8*)bvh,scene,4,1.0f,4,inf,TriangleMesh::geom_type); } + Builder* BVH8Triangle4vSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAH<8,Triangle4v>((BVH8*)bvh,scene,4,1.0f,4,inf,TriangleMesh::geom_type); } + Builder* BVH8Triangle4iSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAH<8,Triangle4i>((BVH8*)bvh,scene,4,1.0f,4,inf,TriangleMesh::geom_type,true); } + Builder* BVH8QuantizedTriangle4iSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAHQuantized<8,Triangle4i>((BVH8*)bvh,scene,4,1.0f,4,inf,TriangleMesh::geom_type); } + Builder* BVH8QuantizedTriangle4SceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAHQuantized<8,Triangle4>((BVH8*)bvh,scene,4,1.0f,4,inf,TriangleMesh::geom_type); } + +#endif +#endif + +#if defined(EMBREE_GEOMETRY_QUAD) + Builder* BVH4Quad4vMeshBuilderSAH (void* bvh, QuadMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNBuilderSAH<4,Quad4v>((BVH4*)bvh,mesh,geomID,4,1.0f,4,inf,QuadMesh::geom_type); } + Builder* BVH4Quad4iMeshBuilderSAH (void* bvh, QuadMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNBuilderSAH<4,Quad4i>((BVH4*)bvh,mesh,geomID,4,1.0f,4,inf,QuadMesh::geom_type); } + Builder* BVH4Quad4vSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAH<4,Quad4v>((BVH4*)bvh,scene,4,1.0f,4,inf,QuadMesh::geom_type); } + Builder* BVH4Quad4iSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAH<4,Quad4i>((BVH4*)bvh,scene,4,1.0f,4,inf,QuadMesh::geom_type,true); } + Builder* BVH4QuantizedQuad4vSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAHQuantized<4,Quad4v>((BVH4*)bvh,scene,4,1.0f,4,inf,QuadMesh::geom_type); } + Builder* BVH4QuantizedQuad4iSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAHQuantized<4,Quad4i>((BVH4*)bvh,scene,4,1.0f,4,inf,QuadMesh::geom_type); } + +#if defined(__AVX__) + Builder* BVH8Quad4vSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAH<8,Quad4v>((BVH8*)bvh,scene,4,1.0f,4,inf,QuadMesh::geom_type); } + Builder* BVH8Quad4iSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAH<8,Quad4i>((BVH8*)bvh,scene,4,1.0f,4,inf,QuadMesh::geom_type,true); } + Builder* BVH8QuantizedQuad4vSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAHQuantized<8,Quad4v>((BVH8*)bvh,scene,4,1.0f,4,inf,QuadMesh::geom_type); } + Builder* BVH8QuantizedQuad4iSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAHQuantized<8,Quad4i>((BVH8*)bvh,scene,4,1.0f,4,inf,QuadMesh::geom_type); } + Builder* BVH8Quad4vMeshBuilderSAH (void* bvh, QuadMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNBuilderSAH<8,Quad4v>((BVH8*)bvh,mesh,geomID,4,1.0f,4,inf,QuadMesh::geom_type); } + +#endif +#endif + +#if defined(EMBREE_GEOMETRY_USER) + + Builder* BVH4VirtualSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { + int minLeafSize = scene->device->object_accel_min_leaf_size; + int maxLeafSize = scene->device->object_accel_max_leaf_size; + return new BVHNBuilderSAH<4,Object>((BVH4*)bvh,scene,4,1.0f,minLeafSize,maxLeafSize,UserGeometry::geom_type); + } + + Builder* BVH4VirtualMeshBuilderSAH (void* bvh, UserGeometry* mesh, unsigned int geomID, size_t mode) { + return new BVHNBuilderSAH<4,Object>((BVH4*)bvh,mesh,geomID,4,1.0f,1,inf,UserGeometry::geom_type); + } +#if defined(__AVX__) + + Builder* BVH8VirtualSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { + int minLeafSize = scene->device->object_accel_min_leaf_size; + int maxLeafSize = scene->device->object_accel_max_leaf_size; + return new BVHNBuilderSAH<8,Object>((BVH8*)bvh,scene,8,1.0f,minLeafSize,maxLeafSize,UserGeometry::geom_type); + } + + Builder* BVH8VirtualMeshBuilderSAH (void* bvh, UserGeometry* mesh, unsigned int geomID, size_t mode) { + return new BVHNBuilderSAH<8,Object>((BVH8*)bvh,mesh,geomID,8,1.0f,1,inf,UserGeometry::geom_type); + } +#endif +#endif + +#if defined(EMBREE_GEOMETRY_INSTANCE) + Builder* BVH4InstanceSceneBuilderSAH (void* bvh, Scene* scene, Geometry::GTypeMask gtype) { return new BVHNBuilderSAH<4,InstancePrimitive>((BVH4*)bvh,scene,4,1.0f,1,1,gtype); } + Builder* BVH4InstanceMeshBuilderSAH (void* bvh, Instance* mesh, Geometry::GTypeMask gtype, unsigned int geomID, size_t mode) { + return new BVHNBuilderSAH<4,InstancePrimitive>((BVH4*)bvh,mesh,geomID,4,1.0f,1,inf,gtype); + } +#if defined(__AVX__) + Builder* BVH8InstanceSceneBuilderSAH (void* bvh, Scene* scene, Geometry::GTypeMask gtype) { return new BVHNBuilderSAH<8,InstancePrimitive>((BVH8*)bvh,scene,8,1.0f,1,1,gtype); } + Builder* BVH8InstanceMeshBuilderSAH (void* bvh, Instance* mesh, Geometry::GTypeMask gtype, unsigned int geomID, size_t mode) { + return new BVHNBuilderSAH<8,InstancePrimitive>((BVH8*)bvh,mesh,geomID,8,1.0f,1,inf,gtype); + } +#endif +#endif + +#if defined(EMBREE_GEOMETRY_GRID) + Builder* BVH4GridMeshBuilderSAH (void* bvh, GridMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNBuilderSAHGrid<4>((BVH4*)bvh,mesh,geomID,4,1.0f,4,4,mode); } + Builder* BVH4GridSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAHGrid<4>((BVH4*)bvh,scene,4,1.0f,4,4,mode); } // FIXME: check whether cost factors are correct + +#if defined(__AVX__) + Builder* BVH8GridMeshBuilderSAH (void* bvh, GridMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNBuilderSAHGrid<8>((BVH8*)bvh,mesh,geomID,8,1.0f,8,8,mode); } + Builder* BVH8GridSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderSAHGrid<8>((BVH8*)bvh,scene,8,1.0f,8,8,mode); } // FIXME: check whether cost factors are correct +#endif +#endif + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_builder_sah_mb.cpp b/thirdparty/embree/kernels/bvh/bvh_builder_sah_mb.cpp new file mode 100644 index 000000000000..9c01553ec692 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_builder_sah_mb.cpp @@ -0,0 +1,705 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh.h" +#include "bvh_builder.h" +#include "../builders/bvh_builder_msmblur.h" + +#include "../builders/primrefgen.h" +#include "../builders/splitter.h" + +#include "../geometry/linei.h" +#include "../geometry/triangle.h" +#include "../geometry/trianglev.h" +#include "../geometry/trianglev_mb.h" +#include "../geometry/trianglei.h" +#include "../geometry/quadv.h" +#include "../geometry/quadi.h" +#include "../geometry/object.h" +#include "../geometry/instance.h" +#include "../geometry/subgrid.h" + +#include "../common/state.h" + +// FIXME: remove after removing BVHNBuilderMBlurRootTimeSplitsSAH +#include "../../common/algorithms/parallel_for_for.h" +#include "../../common/algorithms/parallel_for_for_prefix_sum.h" + + +namespace embree +{ + namespace isa + { + +#if 0 + template + struct CreateMBlurLeaf + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::NodeRecordMB NodeRecordMB; + + __forceinline CreateMBlurLeaf (BVH* bvh, PrimRef* prims, size_t time) : bvh(bvh), prims(prims), time(time) {} + + __forceinline NodeRecordMB operator() (const PrimRef* prims, const range& set, const FastAllocator::CachedAllocator& alloc) const + { + size_t items = Primitive::blocks(set.size()); + size_t start = set.begin(); + for (size_t i=start; iencodeLeaf((char*)accel,items); + + LBBox3fa allBounds = empty; + for (size_t i=0; iscene, time)); + + return NodeRecordMB(node,allBounds); + } + + BVH* bvh; + PrimRef* prims; + size_t time; + }; +#endif + + template + struct CreateMSMBlurLeaf + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::NodeRecordMB4D NodeRecordMB4D; + + __forceinline CreateMSMBlurLeaf (BVH* bvh) : bvh(bvh) {} + + __forceinline const NodeRecordMB4D operator() (const BVHBuilderMSMBlur::BuildRecord& current, const FastAllocator::CachedAllocator& alloc) const + { + size_t items = Primitive::blocks(current.prims.size()); + size_t start = current.prims.begin(); + size_t end = current.prims.end(); + for (size_t i=start; iencodeLeaf((char*)accel,items); + LBBox3fa allBounds = empty; + for (size_t i=0; idata(), start, current.prims.end(), bvh->scene, current.prims.time_range)); + return NodeRecordMB4D(node,allBounds,current.prims.time_range); + } + + BVH* bvh; + }; + + /* Motion blur BVH with 4D nodes and internal time splits */ + template + struct BVHNBuilderMBlurSAH : public Builder + { + typedef BVHN BVH; + typedef typename BVHN::NodeRef NodeRef; + typedef typename BVHN::NodeRecordMB NodeRecordMB; + typedef typename BVHN::AABBNodeMB AABBNodeMB; + + BVH* bvh; + Scene* scene; + const size_t sahBlockSize; + const float intCost; + const size_t minLeafSize; + const size_t maxLeafSize; + const Geometry::GTypeMask gtype_; + + BVHNBuilderMBlurSAH (BVH* bvh, Scene* scene, const size_t sahBlockSize, const float intCost, const size_t minLeafSize, const size_t maxLeafSize, const Geometry::GTypeMask gtype) + : bvh(bvh), scene(scene), sahBlockSize(sahBlockSize), intCost(intCost), minLeafSize(minLeafSize), maxLeafSize(min(maxLeafSize,Primitive::max_size()*BVH::maxLeafBlocks)), gtype_(gtype) {} + + void build() + { + /* skip build for empty scene */ + const size_t numPrimitives = scene->getNumPrimitives(gtype_,true); + if (numPrimitives == 0) { bvh->clear(); return; } + + double t0 = bvh->preBuild(TOSTRING(isa) "::BVH" + toString(N) + "BuilderMBlurSAH"); + +#if PROFILE + profile(2,PROFILE_RUNS,numPrimitives,[&] (ProfileTimer& timer) { +#endif + + //const size_t numTimeSteps = scene->getNumTimeSteps(); + //const size_t numTimeSegments = numTimeSteps-1; assert(numTimeSteps > 1); + + /*if (numTimeSegments == 1) + buildSingleSegment(numPrimitives); + else*/ + buildMultiSegment(numPrimitives); + +#if PROFILE + }); +#endif + + /* clear temporary data for static geometry */ + bvh->cleanup(); + bvh->postBuild(t0); + } + +#if 0 // No longer compatible when time_ranges are present for geometries. Would have to create temporal nodes sometimes, and put only a single geometry into leaf. + void buildSingleSegment(size_t numPrimitives) + { + /* create primref array */ + mvector prims(scene->device,numPrimitives); + const PrimInfo pinfo = createPrimRefArrayMBlur(scene,gtype_,prims,bvh->scene->progressInterface,0); + /* early out if no valid primitives */ + if (pinfo.size() == 0) { bvh->clear(); return; } + /* estimate acceleration structure size */ + const size_t node_bytes = pinfo.size()*sizeof(AABBNodeMB)/(4*N); + const size_t leaf_bytes = size_t(1.2*Primitive::blocks(pinfo.size())*sizeof(Primitive)); + bvh->alloc.init_estimate(node_bytes+leaf_bytes); + + /* settings for BVH build */ + GeneralBVHBuilder::Settings settings; + settings.branchingFactor = N; + settings.maxDepth = BVH::maxBuildDepthLeaf; + settings.logBlockSize = bsr(sahBlockSize); + settings.minLeafSize = min(minLeafSize,maxLeafSize); + settings.maxLeafSize = maxLeafSize; + settings.travCost = travCost; + settings.intCost = intCost; + settings.singleThreadThreshold = bvh->alloc.fixSingleThreadThreshold(N,DEFAULT_SINGLE_THREAD_THRESHOLD,pinfo.size(),node_bytes+leaf_bytes); + + /* build hierarchy */ + auto root = BVHBuilderBinnedSAH::build + (typename BVH::CreateAlloc(bvh),typename BVH::AABBNodeMB::Create(),typename BVH::AABBNodeMB::Set(), + CreateMBlurLeaf(bvh,prims.data(),0),bvh->scene->progressInterface, + prims.data(),pinfo,settings); + + bvh->set(root.ref,root.lbounds,pinfo.size()); + } +#endif + + void buildMultiSegment(size_t numPrimitives) + { + /* create primref array */ + mvector prims(scene->device,numPrimitives); + PrimInfoMB pinfo = createPrimRefArrayMSMBlur(scene,gtype_,prims,bvh->scene->progressInterface); + + /* early out if no valid primitives */ + if (pinfo.size() == 0) { bvh->clear(); return; } + + /* estimate acceleration structure size */ + const size_t node_bytes = pinfo.num_time_segments*sizeof(AABBNodeMB)/(4*N); + const size_t leaf_bytes = size_t(1.2*Primitive::blocks(pinfo.num_time_segments)*sizeof(Primitive)); + bvh->alloc.init_estimate(node_bytes+leaf_bytes); + + /* settings for BVH build */ + BVHBuilderMSMBlur::Settings settings; + settings.branchingFactor = N; + settings.maxDepth = BVH::maxDepth; + settings.logBlockSize = bsr(sahBlockSize); + settings.minLeafSize = min(minLeafSize,maxLeafSize); + settings.maxLeafSize = maxLeafSize; + settings.travCost = travCost; + settings.intCost = intCost; + settings.singleLeafTimeSegment = Primitive::singleTimeSegment; + settings.singleThreadThreshold = bvh->alloc.fixSingleThreadThreshold(N,DEFAULT_SINGLE_THREAD_THRESHOLD,pinfo.size(),node_bytes+leaf_bytes); + + /* build hierarchy */ + auto root = + BVHBuilderMSMBlur::build(prims,pinfo,scene->device, + RecalculatePrimRef(scene), + typename BVH::CreateAlloc(bvh), + typename BVH::AABBNodeMB4D::Create(), + typename BVH::AABBNodeMB4D::Set(), + CreateMSMBlurLeaf(bvh), + bvh->scene->progressInterface, + settings); + + bvh->set(root.ref,root.lbounds,pinfo.num_time_segments); + } + + void clear() { + } + }; + + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + + struct GridRecalculatePrimRef + { + Scene* scene; + const SubGridBuildData * const sgrids; + + __forceinline GridRecalculatePrimRef (Scene* scene, const SubGridBuildData * const sgrids) + : scene(scene), sgrids(sgrids) {} + + __forceinline PrimRefMB operator() (const PrimRefMB& prim, const BBox1f time_range) const + { + const unsigned int geomID = prim.geomID(); + const GridMesh* mesh = scene->get(geomID); + const unsigned int buildID = prim.primID(); + const SubGridBuildData &subgrid = sgrids[buildID]; + const unsigned int primID = subgrid.primID; + const size_t x = subgrid.x(); + const size_t y = subgrid.y(); + const LBBox3fa lbounds = mesh->linearBounds(mesh->grid(primID),x,y,time_range); + const unsigned num_time_segments = mesh->numTimeSegments(); + const range tbounds = mesh->timeSegmentRange(time_range); + return PrimRefMB (lbounds, tbounds.size(), mesh->time_range, num_time_segments, geomID, buildID); + } + + __forceinline LBBox3fa linearBounds(const PrimRefMB& prim, const BBox1f time_range) const { + const unsigned int geomID = prim.geomID(); + const GridMesh* mesh = scene->get(geomID); + const unsigned int buildID = prim.primID(); + const SubGridBuildData &subgrid = sgrids[buildID]; + const unsigned int primID = subgrid.primID; + const size_t x = subgrid.x(); + const size_t y = subgrid.y(); + return mesh->linearBounds(mesh->grid(primID),x,y,time_range); + } + + }; + + template + struct CreateMSMBlurLeafGrid + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::NodeRecordMB4D NodeRecordMB4D; + + __forceinline CreateMSMBlurLeafGrid (Scene* scene, BVH* bvh, const SubGridBuildData * const sgrids) : scene(scene), bvh(bvh), sgrids(sgrids) {} + + __forceinline const NodeRecordMB4D operator() (const BVHBuilderMSMBlur::BuildRecord& current, const FastAllocator::CachedAllocator& alloc) const + { + const size_t items = current.prims.size(); + const size_t start = current.prims.begin(); + + const PrimRefMB* prims = current.prims.prims->data(); + /* collect all subsets with unique geomIDs */ + assert(items <= N); + unsigned int geomIDs[N]; + unsigned int num_geomIDs = 1; + geomIDs[0] = prims[start].geomID(); + + for (size_t i=1;i* accel = (SubGridMBQBVHN*) alloc.malloc1(num_geomIDs*sizeof(SubGridMBQBVHN),BVH::byteAlignment); + typename BVH::NodeRef node = bvh->encodeLeaf((char*)accel,num_geomIDs); + + LBBox3fa allBounds = empty; + + for (size_t g=0;gget(geomIDs[g]); + unsigned int x[N]; + unsigned int y[N]; + unsigned int primID[N]; + BBox3fa bounds0[N]; + BBox3fa bounds1[N]; + unsigned int pos = 0; + for (size_t i=0;ilinearBounds(mesh->grid(sgrid_bd.primID),x,y,current.prims.time_range); + allBounds.extend(newBounds); + bounds0[pos] = newBounds.bounds0; + bounds1[pos] = newBounds.bounds1; + pos++; + } + assert(pos <= N); + new (&accel[g]) SubGridMBQBVHN(x,y,primID,bounds0,bounds1,geomIDs[g],current.prims.time_range.lower,1.0f/current.prims.time_range.size(),pos); + } + return NodeRecordMB4D(node,allBounds,current.prims.time_range); + } + + Scene *scene; + BVH* bvh; + const SubGridBuildData * const sgrids; + }; + +#if 0 + template + struct CreateLeafGridMB + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::NodeRecordMB NodeRecordMB; + + __forceinline CreateLeafGridMB (Scene* scene, BVH* bvh, const SubGridBuildData * const sgrids) + : scene(scene), bvh(bvh), sgrids(sgrids) {} + + __forceinline NodeRecordMB operator() (const PrimRef* prims, const range& set, const FastAllocator::CachedAllocator& alloc) const + { + const size_t items = set.size(); + const size_t start = set.begin(); + + /* collect all subsets with unique geomIDs */ + assert(items <= N); + unsigned int geomIDs[N]; + unsigned int num_geomIDs = 1; + geomIDs[0] = prims[start].geomID(); + + for (size_t i=1;i* accel = (SubGridMBQBVHN*) alloc.malloc1(num_geomIDs*sizeof(SubGridMBQBVHN),BVH::byteAlignment); + typename BVH::NodeRef node = bvh->encodeLeaf((char*)accel,num_geomIDs); + + LBBox3fa allBounds = empty; + + for (size_t g=0;gget(geomIDs[g]); + + unsigned int x[N]; + unsigned int y[N]; + unsigned int primID[N]; + BBox3fa bounds0[N]; + BBox3fa bounds1[N]; + unsigned int pos = 0; + for (size_t i=0;ibuildBounds(mesh->grid(sgrid_bd.primID),x,y,0,bounds0[pos]); + bool MAYBE_UNUSED valid1 = mesh->buildBounds(mesh->grid(sgrid_bd.primID),x,y,1,bounds1[pos]); + assert(valid0); + assert(valid1); + allBounds.extend(LBBox3fa(bounds0[pos],bounds1[pos])); + pos++; + } + new (&accel[g]) SubGridMBQBVHN(x,y,primID,bounds0,bounds1,geomIDs[g],0.0f,1.0f,pos); + } + return NodeRecordMB(node,allBounds); + } + + Scene *scene; + BVH* bvh; + const SubGridBuildData * const sgrids; + }; +#endif + + + /* Motion blur BVH with 4D nodes and internal time splits */ + template + struct BVHNBuilderMBlurSAHGrid : public Builder + { + typedef BVHN BVH; + typedef typename BVHN::NodeRef NodeRef; + typedef typename BVHN::NodeRecordMB NodeRecordMB; + typedef typename BVHN::AABBNodeMB AABBNodeMB; + + BVH* bvh; + Scene* scene; + const size_t sahBlockSize; + const float intCost; + const size_t minLeafSize; + const size_t maxLeafSize; + mvector sgrids; + + + BVHNBuilderMBlurSAHGrid (BVH* bvh, Scene* scene, const size_t sahBlockSize, const float intCost, const size_t minLeafSize, const size_t maxLeafSize) + : bvh(bvh), scene(scene), sahBlockSize(sahBlockSize), intCost(intCost), minLeafSize(minLeafSize), maxLeafSize(min(maxLeafSize,BVH::maxLeafBlocks)), sgrids(scene->device,0) {} + + + PrimInfo createPrimRefArrayMBlurGrid(Scene* scene, mvector& prims, BuildProgressMonitor& progressMonitor, size_t itime) + { + /* first run to get #primitives */ + ParallelForForPrefixSumState pstate; + Scene::Iterator iter(scene); + + pstate.init(iter,size_t(1024)); + + /* iterate over all meshes in the scene */ + PrimInfo pinfo = parallel_for_for_prefix_sum0( pstate, iter, PrimInfo(empty), [&](GridMesh* mesh, const range& r, size_t k, size_t geomID) -> PrimInfo { + + PrimInfo pinfo(empty); + for (size_t j=r.begin(); jvalid(j,range(0,1))) continue; + BBox3fa bounds = empty; + const PrimRef prim(bounds,unsigned(geomID),unsigned(j)); + pinfo.add_center2(prim,mesh->getNumSubGrids(j)); + } + return pinfo; + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + + size_t numPrimitives = pinfo.size(); + if (numPrimitives == 0) return pinfo; + + /* resize arrays */ + sgrids.resize(numPrimitives); + prims.resize(numPrimitives); + + /* second run to fill primrefs and SubGridBuildData arrays */ + pinfo = parallel_for_for_prefix_sum1( pstate, iter, PrimInfo(empty), [&](GridMesh* mesh, const range& r, size_t k, size_t geomID, const PrimInfo& base) -> PrimInfo { + + k = base.size(); + size_t p_index = k; + PrimInfo pinfo(empty); + for (size_t j=r.begin(); jgrid(j); + if (!mesh->valid(j,range(0,1))) continue; + + for (unsigned int y=0; ybuildBounds(g,x,y,itime,bounds)) continue; // get bounds of subgrid + const PrimRef prim(bounds,unsigned(geomID),unsigned(p_index)); + pinfo.add_center2(prim); + sgrids[p_index] = SubGridBuildData(x | g.get3x3FlagsX(x), y | g.get3x3FlagsY(y), unsigned(j)); + prims[p_index++] = prim; + } + } + return pinfo; + }, [](const PrimInfo& a, const PrimInfo& b) -> PrimInfo { return PrimInfo::merge(a,b); }); + + assert(pinfo.size() == numPrimitives); + return pinfo; + } + + PrimInfoMB createPrimRefArrayMSMBlurGrid(Scene* scene, mvector& prims, BuildProgressMonitor& progressMonitor, BBox1f t0t1 = BBox1f(0.0f,1.0f)) + { + /* first run to get #primitives */ + ParallelForForPrefixSumState pstate; + Scene::Iterator iter(scene); + + pstate.init(iter,size_t(1024)); + /* iterate over all meshes in the scene */ + PrimInfoMB pinfoMB = parallel_for_for_prefix_sum0( pstate, iter, PrimInfoMB(empty), [&](GridMesh* mesh, const range& r, size_t k, size_t /*geomID*/) -> PrimInfoMB { + + PrimInfoMB pinfoMB(empty); + for (size_t j=r.begin(); jvalid(j, mesh->timeSegmentRange(t0t1))) continue; + LBBox3fa bounds(empty); + PrimInfoMB gridMB(0,mesh->getNumSubGrids(j)); + pinfoMB.merge(gridMB); + } + return pinfoMB; + }, [](const PrimInfoMB& a, const PrimInfoMB& b) -> PrimInfoMB { return PrimInfoMB::merge2(a,b); }); + + size_t numPrimitives = pinfoMB.size(); + if (numPrimitives == 0) return pinfoMB; + + /* resize arrays */ + sgrids.resize(numPrimitives); + prims.resize(numPrimitives); + /* second run to fill primrefs and SubGridBuildData arrays */ + pinfoMB = parallel_for_for_prefix_sum1( pstate, iter, PrimInfoMB(empty), [&](GridMesh* mesh, const range& r, size_t k, size_t geomID, const PrimInfoMB& base) -> PrimInfoMB { + + k = base.size(); + size_t p_index = k; + PrimInfoMB pinfoMB(empty); + for (size_t j=r.begin(); jvalid(j, mesh->timeSegmentRange(t0t1))) continue; + const GridMesh::Grid &g = mesh->grid(j); + + for (unsigned int y=0; ylinearBounds(g,x,y,t0t1),mesh->numTimeSegments(),mesh->time_range,mesh->numTimeSegments(),unsigned(geomID),unsigned(p_index)); + pinfoMB.add_primref(prim); + sgrids[p_index] = SubGridBuildData(x | g.get3x3FlagsX(x), y | g.get3x3FlagsY(y), unsigned(j)); + prims[p_index++] = prim; + } + } + return pinfoMB; + }, [](const PrimInfoMB& a, const PrimInfoMB& b) -> PrimInfoMB { return PrimInfoMB::merge2(a,b); }); + + assert(pinfoMB.size() == numPrimitives); + pinfoMB.time_range = t0t1; + return pinfoMB; + } + + void build() + { + /* skip build for empty scene */ + const size_t numPrimitives = scene->getNumPrimitives(GridMesh::geom_type,true); + if (numPrimitives == 0) { bvh->clear(); return; } + + double t0 = bvh->preBuild(TOSTRING(isa) "::BVH" + toString(N) + "BuilderMBlurSAHGrid"); + + //const size_t numTimeSteps = scene->getNumTimeSteps(); + //const size_t numTimeSegments = numTimeSteps-1; assert(numTimeSteps > 1); + //if (numTimeSegments == 1) + // buildSingleSegment(numPrimitives); + //else + buildMultiSegment(numPrimitives); + + /* clear temporary data for static geometry */ + bvh->cleanup(); + bvh->postBuild(t0); + } + +#if 0 + void buildSingleSegment(size_t numPrimitives) + { + /* create primref array */ + mvector prims(scene->device,numPrimitives); + const PrimInfo pinfo = createPrimRefArrayMBlurGrid(scene,prims,bvh->scene->progressInterface,0); + /* early out if no valid primitives */ + if (pinfo.size() == 0) { bvh->clear(); return; } + + /* estimate acceleration structure size */ + const size_t node_bytes = pinfo.size()*sizeof(AABBNodeMB)/(4*N); + //TODO: check leaf_bytes + const size_t leaf_bytes = size_t(1.2*(float)numPrimitives/N * sizeof(SubGridQBVHN)); + bvh->alloc.init_estimate(node_bytes+leaf_bytes); + + /* settings for BVH build */ + GeneralBVHBuilder::Settings settings; + settings.branchingFactor = N; + settings.maxDepth = BVH::maxBuildDepthLeaf; + settings.logBlockSize = bsr(sahBlockSize); + settings.minLeafSize = min(minLeafSize,maxLeafSize); + settings.maxLeafSize = maxLeafSize; + settings.travCost = travCost; + settings.intCost = intCost; + settings.singleThreadThreshold = bvh->alloc.fixSingleThreadThreshold(N,DEFAULT_SINGLE_THREAD_THRESHOLD,pinfo.size(),node_bytes+leaf_bytes); + + /* build hierarchy */ + auto root = BVHBuilderBinnedSAH::build + (typename BVH::CreateAlloc(bvh), + typename BVH::AABBNodeMB::Create(), + typename BVH::AABBNodeMB::Set(), + CreateLeafGridMB(scene,bvh,sgrids.data()), + bvh->scene->progressInterface, + prims.data(),pinfo,settings); + + bvh->set(root.ref,root.lbounds,pinfo.size()); + } +#endif + + void buildMultiSegment(size_t numPrimitives) + { + /* create primref array */ + mvector prims(scene->device,numPrimitives); + PrimInfoMB pinfo = createPrimRefArrayMSMBlurGrid(scene,prims,bvh->scene->progressInterface); + + /* early out if no valid primitives */ + if (pinfo.size() == 0) { bvh->clear(); return; } + + + + GridRecalculatePrimRef recalculatePrimRef(scene,sgrids.data()); + + /* estimate acceleration structure size */ + const size_t node_bytes = pinfo.num_time_segments*sizeof(AABBNodeMB)/(4*N); + //FIXME: check leaf_bytes + //const size_t leaf_bytes = size_t(1.2*Primitive::blocks(pinfo.num_time_segments)*sizeof(SubGridQBVHN)); + const size_t leaf_bytes = size_t(1.2*(float)numPrimitives/N * sizeof(SubGridQBVHN)); + + bvh->alloc.init_estimate(node_bytes+leaf_bytes); + + /* settings for BVH build */ + BVHBuilderMSMBlur::Settings settings; + settings.branchingFactor = N; + settings.maxDepth = BVH::maxDepth; + settings.logBlockSize = bsr(sahBlockSize); + settings.minLeafSize = min(minLeafSize,maxLeafSize); + settings.maxLeafSize = maxLeafSize; + settings.travCost = travCost; + settings.intCost = intCost; + settings.singleLeafTimeSegment = false; + settings.singleThreadThreshold = bvh->alloc.fixSingleThreadThreshold(N,DEFAULT_SINGLE_THREAD_THRESHOLD,pinfo.size(),node_bytes+leaf_bytes); + + /* build hierarchy */ + auto root = + BVHBuilderMSMBlur::build(prims,pinfo,scene->device, + recalculatePrimRef, + typename BVH::CreateAlloc(bvh), + typename BVH::AABBNodeMB4D::Create(), + typename BVH::AABBNodeMB4D::Set(), + CreateMSMBlurLeafGrid(scene,bvh,sgrids.data()), + bvh->scene->progressInterface, + settings); + bvh->set(root.ref,root.lbounds,pinfo.num_time_segments); + } + + void clear() { + } + }; + + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + +#if defined(EMBREE_GEOMETRY_TRIANGLE) + Builder* BVH4Triangle4iMBSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderMBlurSAH<4,TriangleMesh,Triangle4i>((BVH4*)bvh,scene,4,1.0f,4,inf,Geometry::MTY_TRIANGLE_MESH); } + Builder* BVH4Triangle4vMBSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderMBlurSAH<4,TriangleMesh,Triangle4vMB>((BVH4*)bvh,scene,4,1.0f,4,inf,Geometry::MTY_TRIANGLE_MESH); } +#if defined(__AVX__) + Builder* BVH8Triangle4iMBSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderMBlurSAH<8,TriangleMesh,Triangle4i>((BVH8*)bvh,scene,4,1.0f,4,inf,Geometry::MTY_TRIANGLE_MESH); } + Builder* BVH8Triangle4vMBSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderMBlurSAH<8,TriangleMesh,Triangle4vMB>((BVH8*)bvh,scene,4,1.0f,4,inf,Geometry::MTY_TRIANGLE_MESH); } +#endif +#endif + +#if defined(EMBREE_GEOMETRY_QUAD) + Builder* BVH4Quad4iMBSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderMBlurSAH<4,QuadMesh,Quad4i>((BVH4*)bvh,scene,4,1.0f,4,inf,Geometry::MTY_QUAD_MESH); } +#if defined(__AVX__) + Builder* BVH8Quad4iMBSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderMBlurSAH<8,QuadMesh,Quad4i>((BVH8*)bvh,scene,4,1.0f,4,inf,Geometry::MTY_QUAD_MESH); } +#endif +#endif + +#if defined(EMBREE_GEOMETRY_USER) + Builder* BVH4VirtualMBSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { + int minLeafSize = scene->device->object_accel_mb_min_leaf_size; + int maxLeafSize = scene->device->object_accel_mb_max_leaf_size; + return new BVHNBuilderMBlurSAH<4,UserGeometry,Object>((BVH4*)bvh,scene,4,1.0f,minLeafSize,maxLeafSize,Geometry::MTY_USER_GEOMETRY); + } +#if defined(__AVX__) + Builder* BVH8VirtualMBSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { + int minLeafSize = scene->device->object_accel_mb_min_leaf_size; + int maxLeafSize = scene->device->object_accel_mb_max_leaf_size; + return new BVHNBuilderMBlurSAH<8,UserGeometry,Object>((BVH8*)bvh,scene,8,1.0f,minLeafSize,maxLeafSize,Geometry::MTY_USER_GEOMETRY); + } +#endif +#endif + +#if defined(EMBREE_GEOMETRY_INSTANCE) + Builder* BVH4InstanceMBSceneBuilderSAH (void* bvh, Scene* scene, Geometry::GTypeMask gtype) { return new BVHNBuilderMBlurSAH<4,Instance,InstancePrimitive>((BVH4*)bvh,scene,4,1.0f,1,1,gtype); } +#if defined(__AVX__) + Builder* BVH8InstanceMBSceneBuilderSAH (void* bvh, Scene* scene, Geometry::GTypeMask gtype) { return new BVHNBuilderMBlurSAH<8,Instance,InstancePrimitive>((BVH8*)bvh,scene,8,1.0f,1,1,gtype); } +#endif +#endif + +#if defined(EMBREE_GEOMETRY_GRID) + Builder* BVH4GridMBSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderMBlurSAHGrid<4>((BVH4*)bvh,scene,4,1.0f,4,4); } +#if defined(__AVX__) + Builder* BVH8GridMBSceneBuilderSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderMBlurSAHGrid<8>((BVH8*)bvh,scene,8,1.0f,8,8); } +#endif +#endif + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_builder_sah_spatial.cpp b/thirdparty/embree/kernels/bvh/bvh_builder_sah_spatial.cpp new file mode 100644 index 000000000000..285b38c39d6d --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_builder_sah_spatial.cpp @@ -0,0 +1,201 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh.h" +#include "bvh_builder.h" + +#include "../builders/primrefgen.h" +#include "../builders/primrefgen_presplit.h" +#include "../builders/splitter.h" + +#include "../geometry/linei.h" +#include "../geometry/triangle.h" +#include "../geometry/trianglev.h" +#include "../geometry/trianglev_mb.h" +#include "../geometry/trianglei.h" +#include "../geometry/quadv.h" +#include "../geometry/quadi.h" +#include "../geometry/object.h" +#include "../geometry/instance.h" +#include "../geometry/subgrid.h" + +#include "../common/state.h" + +namespace embree +{ + namespace isa + { + template + struct CreateLeafSpatial + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + + __forceinline CreateLeafSpatial (BVH* bvh) : bvh(bvh) {} + + __forceinline NodeRef operator() (const PrimRef* prims, const range& set, const FastAllocator::CachedAllocator& alloc) const + { + size_t n = set.size(); + size_t items = Primitive::blocks(n); + size_t start = set.begin(); + Primitive* accel = (Primitive*) alloc.malloc1(items*sizeof(Primitive),BVH::byteAlignment); + typename BVH::NodeRef node = BVH::encodeLeaf((char*)accel,items); + for (size_t i=0; iscene); + } + return node; + } + + BVH* bvh; + }; + + template + struct BVHNBuilderFastSpatialSAH : public Builder + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + BVH* bvh; + Scene* scene; + Mesh* mesh; + mvector prims0; + GeneralBVHBuilder::Settings settings; + const float splitFactor; + unsigned int geomID_ = std::numeric_limits::max(); + unsigned int numPreviousPrimitives = 0; + + BVHNBuilderFastSpatialSAH (BVH* bvh, Scene* scene, const size_t sahBlockSize, const float intCost, const size_t minLeafSize, const size_t maxLeafSize, const size_t mode) + : bvh(bvh), scene(scene), mesh(nullptr), prims0(scene->device,0), settings(sahBlockSize, minLeafSize, min(maxLeafSize,Primitive::max_size()*BVH::maxLeafBlocks), travCost, intCost, DEFAULT_SINGLE_THREAD_THRESHOLD), + splitFactor(scene->device->max_spatial_split_replications) {} + + BVHNBuilderFastSpatialSAH (BVH* bvh, Mesh* mesh, const unsigned int geomID, const size_t sahBlockSize, const float intCost, const size_t minLeafSize, const size_t maxLeafSize, const size_t mode) + : bvh(bvh), scene(nullptr), mesh(mesh), prims0(bvh->device,0), settings(sahBlockSize, minLeafSize, min(maxLeafSize,Primitive::max_size()*BVH::maxLeafBlocks), travCost, intCost, DEFAULT_SINGLE_THREAD_THRESHOLD), + splitFactor(scene->device->max_spatial_split_replications), geomID_(geomID) {} + + // FIXME: shrink bvh->alloc in destructor here and in other builders too + + void build() + { + /* we reset the allocator when the mesh size changed */ + if (mesh && mesh->numPrimitives != numPreviousPrimitives) { + bvh->alloc.clear(); + } + + /* skip build for empty scene */ + const size_t numOriginalPrimitives = mesh ? mesh->size() : scene->getNumPrimitives(Mesh::geom_type,false); + numPreviousPrimitives = numOriginalPrimitives; + if (numOriginalPrimitives == 0) { + prims0.clear(); + bvh->clear(); + return; + } + + const unsigned int maxGeomID = mesh ? geomID_ : scene->getMaxGeomID(); + const bool usePreSplits = scene->device->useSpatialPreSplits || (maxGeomID >= ((unsigned int)1 << (32-RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS))); + double t0 = bvh->preBuild(mesh ? "" : TOSTRING(isa) "::BVH" + toString(N) + (usePreSplits ? "BuilderFastSpatialPresplitSAH" : "BuilderFastSpatialSAH")); + + /* create primref array */ + const size_t numSplitPrimitives = max(numOriginalPrimitives,size_t(splitFactor*numOriginalPrimitives)); + prims0.resize(numSplitPrimitives); + + /* enable os_malloc for two level build */ + if (mesh) + bvh->alloc.setOSallocation(true); + + NodeRef root(0); + PrimInfo pinfo; + + + if (likely(usePreSplits)) + { + /* spatial presplit SAH BVH builder */ + pinfo = mesh ? + createPrimRefArray_presplit(mesh,maxGeomID,numOriginalPrimitives,prims0,bvh->scene->progressInterface) : + createPrimRefArray_presplit(scene,Mesh::geom_type,false,numOriginalPrimitives,prims0,bvh->scene->progressInterface); + + const size_t node_bytes = pinfo.size()*sizeof(typename BVH::AABBNode)/(4*N); + const size_t leaf_bytes = size_t(1.2*Primitive::blocks(pinfo.size())*sizeof(Primitive)); + bvh->alloc.init_estimate(node_bytes+leaf_bytes); + settings.singleThreadThreshold = bvh->alloc.fixSingleThreadThreshold(N,DEFAULT_SINGLE_THREAD_THRESHOLD,pinfo.size(),node_bytes+leaf_bytes); + + settings.branchingFactor = N; + settings.maxDepth = BVH::maxBuildDepthLeaf; + + /* call BVH builder */ + root = BVHNBuilderVirtual::build(&bvh->alloc,CreateLeafSpatial(bvh),bvh->scene->progressInterface,prims0.data(),pinfo,settings); + } + else + { + /* standard spatial split SAH BVH builder */ + pinfo = mesh ? + createPrimRefArray(mesh,geomID_,/*numSplitPrimitives,*/prims0,bvh->scene->progressInterface) : + createPrimRefArray(scene,Mesh::geom_type,false,/*numSplitPrimitives,*/prims0,bvh->scene->progressInterface); + + Splitter splitter(scene); + + const size_t node_bytes = pinfo.size()*sizeof(typename BVH::AABBNode)/(4*N); + const size_t leaf_bytes = size_t(1.2*Primitive::blocks(pinfo.size())*sizeof(Primitive)); + bvh->alloc.init_estimate(node_bytes+leaf_bytes); + settings.singleThreadThreshold = bvh->alloc.fixSingleThreadThreshold(N,DEFAULT_SINGLE_THREAD_THRESHOLD,pinfo.size(),node_bytes+leaf_bytes); + + settings.branchingFactor = N; + settings.maxDepth = BVH::maxBuildDepthLeaf; + + /* call BVH builder */ + root = BVHBuilderBinnedFastSpatialSAH::build( + typename BVH::CreateAlloc(bvh), + typename BVH::AABBNode::Create2(), + typename BVH::AABBNode::Set2(), + CreateLeafSpatial(bvh), + splitter, + bvh->scene->progressInterface, + prims0.data(), + numSplitPrimitives, + pinfo,settings); + + /* ==================== */ + } + + bvh->set(root,LBBox3fa(pinfo.geomBounds),pinfo.size()); + bvh->layoutLargeNodes(size_t(pinfo.size()*0.005f)); + + /* clear temporary data for static geometry */ + if (scene && scene->isStaticAccel()) { + prims0.clear(); + } + bvh->cleanup(); + bvh->postBuild(t0); + } + + void clear() { + prims0.clear(); + } + }; + + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + + +#if defined(EMBREE_GEOMETRY_TRIANGLE) + + Builder* BVH4Triangle4SceneBuilderFastSpatialSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderFastSpatialSAH<4,TriangleMesh,Triangle4,TriangleSplitterFactory>((BVH4*)bvh,scene,4,1.0f,4,inf,mode); } + Builder* BVH4Triangle4vSceneBuilderFastSpatialSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderFastSpatialSAH<4,TriangleMesh,Triangle4v,TriangleSplitterFactory>((BVH4*)bvh,scene,4,1.0f,4,inf,mode); } + Builder* BVH4Triangle4iSceneBuilderFastSpatialSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderFastSpatialSAH<4,TriangleMesh,Triangle4i,TriangleSplitterFactory>((BVH4*)bvh,scene,4,1.0f,4,inf,mode); } + +#if defined(__AVX__) + Builder* BVH8Triangle4SceneBuilderFastSpatialSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderFastSpatialSAH<8,TriangleMesh,Triangle4,TriangleSplitterFactory>((BVH8*)bvh,scene,4,1.0f,4,inf,mode); } + Builder* BVH8Triangle4vSceneBuilderFastSpatialSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderFastSpatialSAH<8,TriangleMesh,Triangle4v,TriangleSplitterFactory>((BVH8*)bvh,scene,4,1.0f,4,inf,mode); } +#endif +#endif + +#if defined(EMBREE_GEOMETRY_QUAD) + Builder* BVH4Quad4vSceneBuilderFastSpatialSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderFastSpatialSAH<4,QuadMesh,Quad4v,QuadSplitterFactory>((BVH4*)bvh,scene,4,1.0f,4,inf,mode); } + +#if defined(__AVX__) + Builder* BVH8Quad4vSceneBuilderFastSpatialSAH (void* bvh, Scene* scene, size_t mode) { return new BVHNBuilderFastSpatialSAH<8,QuadMesh,Quad4v,QuadSplitterFactory>((BVH8*)bvh,scene,4,1.0f,4,inf,mode); } +#endif + +#endif + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_builder_twolevel.cpp b/thirdparty/embree/kernels/bvh/bvh_builder_twolevel.cpp new file mode 100644 index 000000000000..1a78f347ac3d --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_builder_twolevel.cpp @@ -0,0 +1,377 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh_builder_twolevel.h" +#include "bvh_statistics.h" +#include "../builders/bvh_builder_sah.h" +#include "../common/scene_line_segments.h" +#include "../common/scene_triangle_mesh.h" +#include "../common/scene_quad_mesh.h" + +#define PROFILE 0 + +namespace embree +{ + namespace isa + { + template + BVHNBuilderTwoLevel::BVHNBuilderTwoLevel (BVH* bvh, Scene* scene, Geometry::GTypeMask gtype, bool useMortonBuilder, const size_t singleThreadThreshold) + : bvh(bvh), scene(scene), refs(scene->device,0), prims(scene->device,0), singleThreadThreshold(singleThreadThreshold), gtype(gtype), useMortonBuilder_(useMortonBuilder) {} + + template + BVHNBuilderTwoLevel::~BVHNBuilderTwoLevel () { + } + + // =========================================================================== + // =========================================================================== + // =========================================================================== + + template + void BVHNBuilderTwoLevel::build() + { + /* delete some objects */ + size_t num = scene->size(); + if (num < bvh->objects.size()) { + parallel_for(num, bvh->objects.size(), [&] (const range& r) { + for (size_t i=r.begin(); iobjects[i]; bvh->objects[i] = nullptr; + } + }); + } + +#if PROFILE + while(1) +#endif + { + /* reset memory allocator */ + bvh->alloc.reset(); + + /* skip build for empty scene */ + const size_t numPrimitives = scene->getNumPrimitives(gtype,false); + + if (numPrimitives == 0) { + prims.resize(0); + bvh->set(BVH::emptyNode,empty,0); + return; + } + + /* calculate the size of the entire BVH */ + const size_t numLeafBlocks = Primitive::blocks(numPrimitives); + const size_t node_bytes = 2*numLeafBlocks*sizeof(typename BVH::AABBNode)/N; + const size_t leaf_bytes = size_t(1.2*numLeafBlocks*sizeof(Primitive)); + bvh->alloc.init_estimate(node_bytes+leaf_bytes); + + double t0 = bvh->preBuild(TOSTRING(isa) "::BVH" + toString(N) + "BuilderTwoLevel"); + + /* resize object array if scene got larger */ + if (bvh->objects.size() < num) bvh->objects.resize(num); + if (builders.size() < num) builders.resize(num); + resizeRefsList (); + nextRef.store(0); + + /* create acceleration structures */ + parallel_for(size_t(0), num, [&] (const range& r) + { + for (size_t objectID=r.begin(); objectIDgetSafe(objectID); + + /* ignore meshes we do not support */ + if (mesh == nullptr || mesh->numTimeSteps != 1) + continue; + + if (isSmallGeometry(mesh)) { + setupSmallBuildRefBuilder (objectID, mesh); + } else { + setupLargeBuildRefBuilder (objectID, mesh); + } + } + }); + + /* parallel build of acceleration structures */ + parallel_for(size_t(0), num, [&] (const range& r) + { + for (size_t objectID=r.begin(); objectIDgetSafe(objectID); + if (mesh == nullptr || !mesh->isEnabled() || mesh->numTimeSteps != 1) + continue; + + builders[objectID]->attachBuildRefs (this); + } + }); + + +#if PROFILE + double d0 = getSeconds(); +#endif + /* fast path for single geometry scenes */ + if (nextRef == 1) { + bvh->set(refs[0].node,LBBox3fa(refs[0].bounds()),numPrimitives); + } + + else + { + /* open all large nodes */ + refs.resize(nextRef); + + /* this probably needs some more tuning */ + const size_t extSize = max(max((size_t)SPLIT_MIN_EXT_SPACE,refs.size()*SPLIT_MEMORY_RESERVE_SCALE),size_t((float)numPrimitives / SPLIT_MEMORY_RESERVE_FACTOR)); + +#if !ENABLE_DIRECT_SAH_MERGE_BUILDER + +#if ENABLE_OPEN_SEQUENTIAL + open_sequential(extSize); +#endif + /* compute PrimRefs */ + prims.resize(refs.size()); +#endif + +#if defined(TASKING_TBB) && defined(__AVX512ER__) && USE_TASK_ARENA // KNL + tbb::task_arena limited(min(32,(int)TaskScheduler::threadCount())); + limited.execute([&] +#endif + { +#if ENABLE_DIRECT_SAH_MERGE_BUILDER + + const PrimInfo pinfo = parallel_reduce(size_t(0), refs.size(), PrimInfo(empty), [&] (const range& r) -> PrimInfo { + + PrimInfo pinfo(empty); + for (size_t i=r.begin(); i& r) -> PrimInfo { + + PrimInfo pinfo(empty); + for (size_t i=r.begin(); iset(BVH::emptyNode,empty,0); + + /* otherwise build toplevel hierarchy */ + else + { + /* settings for BVH build */ + GeneralBVHBuilder::Settings settings; + settings.branchingFactor = N; + settings.maxDepth = BVH::maxBuildDepthLeaf; + settings.logBlockSize = bsr(N); + settings.minLeafSize = 1; + settings.maxLeafSize = 1; + settings.travCost = 1.0f; + settings.intCost = 1.0f; + settings.singleThreadThreshold = singleThreadThreshold; + +#if ENABLE_DIRECT_SAH_MERGE_BUILDER + + refs.resize(extSize); + + NodeRef root = BVHBuilderBinnedOpenMergeSAH::build( + typename BVH::CreateAlloc(bvh), + typename BVH::AABBNode::Create2(), + typename BVH::AABBNode::Set2(), + + [&] (const BuildRef* refs, const range& range, const FastAllocator::CachedAllocator& alloc) -> NodeRef { + assert(range.size() == 1); + return (NodeRef) refs[range.begin()].node; + }, + [&] (BuildRef &bref, BuildRef *refs) -> size_t { + return openBuildRef(bref,refs); + }, + [&] (size_t dn) { bvh->scene->progressMonitor(0); }, + refs.data(),extSize,pinfo,settings); +#else + NodeRef root = BVHBuilderBinnedSAH::build( + typename BVH::CreateAlloc(bvh), + typename BVH::AABBNode::Create2(), + typename BVH::AABBNode::Set2(), + + [&] (const PrimRef* prims, const range& range, const FastAllocator::CachedAllocator& alloc) -> NodeRef { + assert(range.size() == 1); + return (NodeRef) prims[range.begin()].ID(); + }, + [&] (size_t dn) { bvh->scene->progressMonitor(0); }, + prims.data(),pinfo,settings); +#endif + + + bvh->set(root,LBBox3fa(pinfo.geomBounds),numPrimitives); + } + } +#if defined(TASKING_TBB) && defined(__AVX512ER__) && USE_TASK_ARENA // KNL + ); +#endif + + } + + bvh->alloc.cleanup(); + bvh->postBuild(t0); +#if PROFILE + double d1 = getSeconds(); + std::cout << "TOP_LEVEL OPENING/REBUILD TIME " << 1000.0*(d1-d0) << " ms" << std::endl; +#endif + } + + } + + template + void BVHNBuilderTwoLevel::deleteGeometry(size_t geomID) + { + if (geomID >= bvh->objects.size()) return; + if (builders[geomID]) builders[geomID].reset(); + delete bvh->objects [geomID]; bvh->objects [geomID] = nullptr; + } + + template + void BVHNBuilderTwoLevel::clear() + { + for (size_t i=0; iobjects.size(); i++) + if (bvh->objects[i]) bvh->objects[i]->clear(); + + for (size_t i=0; i + void BVHNBuilderTwoLevel::open_sequential(const size_t extSize) + { + if (refs.size() == 0) + return; + + refs.reserve(extSize); + +#if 1 + for (size_t i=0;ichild(i) == BVH::emptyNode) continue; + refs.push_back(BuildRef(node->bounds(i),node->child(i))); + +#if 1 + NodeRef ref_pre = node->child(i); + if (ref_pre.isAABBNode()) + ref_pre.prefetch(); +#endif + std::push_heap (refs.begin(),refs.end()); + } + } + } + + template + void BVHNBuilderTwoLevel::setupSmallBuildRefBuilder (size_t objectID, Mesh const * const /*mesh*/) + { + if (builders[objectID] == nullptr || // new mesh + dynamic_cast(builders[objectID].get()) == nullptr) // size change resulted in large->small change + { + builders[objectID].reset (new RefBuilderSmall(objectID)); + } + } + + template + void BVHNBuilderTwoLevel::setupLargeBuildRefBuilder (size_t objectID, Mesh const * const mesh) + { + if (bvh->objects[objectID] == nullptr || // new mesh + builders[objectID]->meshQualityChanged (mesh->quality) || // changed build quality + dynamic_cast(builders[objectID].get()) == nullptr) // size change resulted in small->large change + { + Builder* builder = nullptr; + delete bvh->objects[objectID]; + createMeshAccel(objectID, builder); + builders[objectID].reset (new RefBuilderLarge(objectID, builder, mesh->quality)); + } + } + +#if defined(EMBREE_GEOMETRY_TRIANGLE) + Builder* BVH4BuilderTwoLevelTriangle4MeshSAH (void* bvh, Scene* scene, bool useMortonBuilder) { + return new BVHNBuilderTwoLevel<4,TriangleMesh,Triangle4>((BVH4*)bvh,scene,TriangleMesh::geom_type,useMortonBuilder); + } + Builder* BVH4BuilderTwoLevelTriangle4vMeshSAH (void* bvh, Scene* scene, bool useMortonBuilder) { + return new BVHNBuilderTwoLevel<4,TriangleMesh,Triangle4v>((BVH4*)bvh,scene,TriangleMesh::geom_type,useMortonBuilder); + } + Builder* BVH4BuilderTwoLevelTriangle4iMeshSAH (void* bvh, Scene* scene, bool useMortonBuilder) { + return new BVHNBuilderTwoLevel<4,TriangleMesh,Triangle4i>((BVH4*)bvh,scene,TriangleMesh::geom_type,useMortonBuilder); + } +#endif + +#if defined(EMBREE_GEOMETRY_QUAD) + Builder* BVH4BuilderTwoLevelQuadMeshSAH (void* bvh, Scene* scene, bool useMortonBuilder) { + return new BVHNBuilderTwoLevel<4,QuadMesh,Quad4v>((BVH4*)bvh,scene,QuadMesh::geom_type,useMortonBuilder); + } +#endif + +#if defined(EMBREE_GEOMETRY_USER) + Builder* BVH4BuilderTwoLevelVirtualSAH (void* bvh, Scene* scene, bool useMortonBuilder) { + return new BVHNBuilderTwoLevel<4,UserGeometry,Object>((BVH4*)bvh,scene,UserGeometry::geom_type,useMortonBuilder); + } +#endif + +#if defined(EMBREE_GEOMETRY_INSTANCE) + Builder* BVH4BuilderTwoLevelInstanceSAH (void* bvh, Scene* scene, Geometry::GTypeMask gtype, bool useMortonBuilder) { + return new BVHNBuilderTwoLevel<4,Instance,InstancePrimitive>((BVH4*)bvh,scene,gtype,useMortonBuilder); + } +#endif + +#if defined(__AVX__) +#if defined(EMBREE_GEOMETRY_TRIANGLE) + Builder* BVH8BuilderTwoLevelTriangle4MeshSAH (void* bvh, Scene* scene, bool useMortonBuilder) { + return new BVHNBuilderTwoLevel<8,TriangleMesh,Triangle4>((BVH8*)bvh,scene,TriangleMesh::geom_type,useMortonBuilder); + } + Builder* BVH8BuilderTwoLevelTriangle4vMeshSAH (void* bvh, Scene* scene, bool useMortonBuilder) { + return new BVHNBuilderTwoLevel<8,TriangleMesh,Triangle4v>((BVH8*)bvh,scene,TriangleMesh::geom_type,useMortonBuilder); + } + Builder* BVH8BuilderTwoLevelTriangle4iMeshSAH (void* bvh, Scene* scene, bool useMortonBuilder) { + return new BVHNBuilderTwoLevel<8,TriangleMesh,Triangle4i>((BVH8*)bvh,scene,TriangleMesh::geom_type,useMortonBuilder); + } +#endif + +#if defined(EMBREE_GEOMETRY_QUAD) + Builder* BVH8BuilderTwoLevelQuadMeshSAH (void* bvh, Scene* scene, bool useMortonBuilder) { + return new BVHNBuilderTwoLevel<8,QuadMesh,Quad4v>((BVH8*)bvh,scene,QuadMesh::geom_type,useMortonBuilder); + } +#endif + +#if defined(EMBREE_GEOMETRY_USER) + Builder* BVH8BuilderTwoLevelVirtualSAH (void* bvh, Scene* scene, bool useMortonBuilder) { + return new BVHNBuilderTwoLevel<8,UserGeometry,Object>((BVH8*)bvh,scene,UserGeometry::geom_type,useMortonBuilder); + } +#endif + +#if defined(EMBREE_GEOMETRY_INSTANCE) + Builder* BVH8BuilderTwoLevelInstanceSAH (void* bvh, Scene* scene, Geometry::GTypeMask gtype, bool useMortonBuilder) { + return new BVHNBuilderTwoLevel<8,Instance,InstancePrimitive>((BVH8*)bvh,scene,gtype,useMortonBuilder); + } +#endif + +#endif + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_builder_twolevel.h b/thirdparty/embree/kernels/bvh/bvh_builder_twolevel.h new file mode 100644 index 000000000000..8f57c3b40691 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_builder_twolevel.h @@ -0,0 +1,263 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "bvh_builder_twolevel_internal.h" +#include "bvh.h" +#include "../common/primref.h" +#include "../builders/priminfo.h" +#include "../builders/primrefgen.h" + +/* new open/merge builder */ +#define ENABLE_DIRECT_SAH_MERGE_BUILDER 1 +#define ENABLE_OPEN_SEQUENTIAL 0 +#define SPLIT_MEMORY_RESERVE_FACTOR 1000 +#define SPLIT_MEMORY_RESERVE_SCALE 2 +#define SPLIT_MIN_EXT_SPACE 1000 + +namespace embree +{ + namespace isa + { + template + class BVHNBuilderTwoLevel : public Builder + { + typedef BVHN BVH; + typedef typename BVH::AABBNode AABBNode; + typedef typename BVH::NodeRef NodeRef; + + __forceinline static bool isSmallGeometry(Mesh* mesh) { + return mesh->size() <= 4; + } + + public: + + typedef void (*createMeshAccelTy)(Scene* scene, unsigned int geomID, AccelData*& accel, Builder*& builder); + + struct BuildRef : public PrimRef + { + public: + __forceinline BuildRef () {} + + __forceinline BuildRef (const BBox3fa& bounds, NodeRef node) + : PrimRef(bounds,(size_t)node), node(node) + { + if (node.isLeaf()) + bounds_area = 0.0f; + else + bounds_area = area(this->bounds()); + } + + /* used by the open/merge bvh builder */ + __forceinline BuildRef (const BBox3fa& bounds, NodeRef node, const unsigned int geomID, const unsigned int numPrimitives) + : PrimRef(bounds,geomID,numPrimitives), node(node) + { + /* important for relative buildref ordering */ + if (node.isLeaf()) + bounds_area = 0.0f; + else + bounds_area = area(this->bounds()); + } + + __forceinline size_t size() const { + return primID(); + } + + friend bool operator< (const BuildRef& a, const BuildRef& b) { + return a.bounds_area < b.bounds_area; + } + + friend __forceinline embree_ostream operator<<(embree_ostream cout, const BuildRef& ref) { + return cout << "{ lower = " << ref.lower << ", upper = " << ref.upper << ", center2 = " << ref.center2() << ", geomID = " << ref.geomID() << ", numPrimitives = " << ref.numPrimitives() << ", bounds_area = " << ref.bounds_area << " }"; + } + + __forceinline unsigned int numPrimitives() const { return primID(); } + + public: + NodeRef node; + float bounds_area; + }; + + + __forceinline size_t openBuildRef(BuildRef &bref, BuildRef *const refs) { + if (bref.node.isLeaf()) + { + refs[0] = bref; + return 1; + } + NodeRef ref = bref.node; + unsigned int geomID = bref.geomID(); + unsigned int numPrims = max((unsigned int)bref.numPrimitives() / N,(unsigned int)1); + AABBNode* node = ref.getAABBNode(); + size_t n = 0; + for (size_t i=0; ichild(i) == BVH::emptyNode) continue; + refs[i] = BuildRef(node->bounds(i),node->child(i),geomID,numPrims); + n++; + } + assert(n > 1); + return n; + } + + /*! Constructor. */ + BVHNBuilderTwoLevel (BVH* bvh, Scene* scene, Geometry::GTypeMask gtype = Mesh::geom_type, bool useMortonBuilder = false, const size_t singleThreadThreshold = DEFAULT_SINGLE_THREAD_THRESHOLD); + + /*! Destructor */ + ~BVHNBuilderTwoLevel (); + + /*! builder entry point */ + void build(); + void deleteGeometry(size_t geomID); + void clear(); + + void open_sequential(const size_t extSize); + + private: + + class RefBuilderBase { + public: + virtual ~RefBuilderBase () {} + virtual void attachBuildRefs (BVHNBuilderTwoLevel* builder) = 0; + virtual bool meshQualityChanged (RTCBuildQuality currQuality) = 0; + }; + + class RefBuilderSmall : public RefBuilderBase { + public: + + RefBuilderSmall (size_t objectID) + : objectID_ (objectID) {} + + void attachBuildRefs (BVHNBuilderTwoLevel* topBuilder) { + + Mesh* mesh = topBuilder->scene->template getSafe(objectID_); + size_t meshSize = mesh->size(); + assert(isSmallGeometry(mesh)); + + mvector prefs(topBuilder->scene->device, meshSize); + auto pinfo = createPrimRefArray(mesh,objectID_,prefs,topBuilder->bvh->scene->progressInterface); + + size_t begin=0; + while (begin < pinfo.size()) + { + Primitive* accel = (Primitive*) topBuilder->bvh->alloc.getCachedAllocator().malloc1(sizeof(Primitive),BVH::byteAlignment); + typename BVH::NodeRef node = BVH::encodeLeaf((char*)accel,1); + accel->fill(prefs.data(),begin,pinfo.size(),topBuilder->bvh->scene); + + /* create build primitive */ +#if ENABLE_DIRECT_SAH_MERGE_BUILDER + topBuilder->refs[topBuilder->nextRef++] = BVHNBuilderTwoLevel::BuildRef(pinfo.geomBounds,node,(unsigned int)objectID_,1); +#else + topBuilder->refs[topBuilder->nextRef++] = BVHNBuilderTwoLevel::BuildRef(pinfo.geomBounds,node); +#endif + } + assert(begin == pinfo.size()); + } + + bool meshQualityChanged (RTCBuildQuality /*currQuality*/) { + return false; + } + + size_t objectID_; + }; + + class RefBuilderLarge : public RefBuilderBase { + public: + + RefBuilderLarge (size_t objectID, const Ref& builder, RTCBuildQuality quality) + : objectID_ (objectID), builder_ (builder), quality_ (quality) {} + + void attachBuildRefs (BVHNBuilderTwoLevel* topBuilder) + { + BVH* object = topBuilder->getBVH(objectID_); assert(object); + + /* build object if it got modified */ + if (topBuilder->isGeometryModified(objectID_)) + builder_->build(); + + /* create build primitive */ + if (!object->getBounds().empty()) + { +#if ENABLE_DIRECT_SAH_MERGE_BUILDER + Mesh* mesh = topBuilder->getMesh(objectID_); + topBuilder->refs[topBuilder->nextRef++] = BVHNBuilderTwoLevel::BuildRef(object->getBounds(),object->root,(unsigned int)objectID_,(unsigned int)mesh->size()); +#else + topBuilder->refs[topBuilder->nextRef++] = BVHNBuilderTwoLevel::BuildRef(object->getBounds(),object->root); +#endif + } + } + + bool meshQualityChanged (RTCBuildQuality currQuality) { + return currQuality != quality_; + } + + private: + size_t objectID_; + Ref builder_; + RTCBuildQuality quality_; + }; + + void setupLargeBuildRefBuilder (size_t objectID, Mesh const * const mesh); + void setupSmallBuildRefBuilder (size_t objectID, Mesh const * const mesh); + + BVH* getBVH (size_t objectID) { + return this->bvh->objects[objectID]; + } + Mesh* getMesh (size_t objectID) { + return this->scene->template getSafe(objectID); + } + bool isGeometryModified (size_t objectID) { + return this->scene->isGeometryModified(objectID); + } + + void resizeRefsList () + { + size_t num = parallel_reduce (size_t(0), scene->size(), size_t(0), + [this](const range& r)->size_t { + size_t c = 0; + for (auto i=r.begin(); igetSafe(i); + if (mesh == nullptr || mesh->numTimeSteps != 1) + continue; + size_t meshSize = mesh->size(); + c += isSmallGeometry(mesh) ? Primitive::blocks(meshSize) : 1; + } + return c; + }, + std::plus() + ); + + if (refs.size() < num) { + refs.resize(num); + } + } + + void createMeshAccel (size_t geomID, Builder*& builder) + { + bvh->objects[geomID] = new BVH(Primitive::type,scene); + BVH* accel = bvh->objects[geomID]; + auto mesh = scene->getSafe(geomID); + if (nullptr == mesh) { + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"geomID does not return correct type"); + return; + } + + __internal_two_level_builder__::MeshBuilder()(accel, mesh, geomID, this->gtype, this->useMortonBuilder_, builder); + } + + using BuilderList = std::vector>; + + BuilderList builders; + BVH* bvh; + Scene* scene; + mvector refs; + mvector prims; + std::atomic nextRef; + const size_t singleThreadThreshold; + Geometry::GTypeMask gtype; + bool useMortonBuilder_ = false; + }; + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_builder_twolevel_internal.h b/thirdparty/embree/kernels/bvh/bvh_builder_twolevel_internal.h new file mode 100644 index 000000000000..1c1ae8d6a729 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_builder_twolevel_internal.h @@ -0,0 +1,267 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh.h" +#include "../geometry/triangle.h" +#include "../geometry/trianglev.h" +#include "../geometry/trianglei.h" +#include "../geometry/quadv.h" +#include "../geometry/quadi.h" +#include "../geometry/object.h" +#include "../geometry/instance.h" + +namespace embree +{ + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4MeshBuilderMortonGeneral,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4MeshBuilderSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4MeshRefitSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4vMeshBuilderMortonGeneral,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4vMeshBuilderSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4vMeshRefitSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4iMeshBuilderMortonGeneral,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4iMeshBuilderSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Triangle4iMeshRefitSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Quad4vMeshBuilderMortonGeneral,void* COMMA QuadMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Quad4vMeshBuilderSAH,void* COMMA QuadMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4Quad4vMeshRefitSAH,void* COMMA QuadMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4VirtualMeshBuilderMortonGeneral,void* COMMA UserGeometry* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4VirtualMeshBuilderSAH,void* COMMA UserGeometry* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4VirtualMeshRefitSAH,void* COMMA UserGeometry* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4InstanceMeshBuilderMortonGeneral,void* COMMA Instance* COMMA Geometry::GTypeMask COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4InstanceMeshBuilderSAH,void* COMMA Instance* COMMA Geometry::GTypeMask COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH4InstanceMeshRefitSAH,void* COMMA Instance* COMMA Geometry::GTypeMask COMMA unsigned int COMMA size_t) + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4MeshBuilderMortonGeneral,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4MeshBuilderSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4MeshRefitSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4vMeshBuilderMortonGeneral,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4vMeshBuilderSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4vMeshRefitSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4iMeshBuilderMortonGeneral,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4iMeshBuilderSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Triangle4iMeshRefitSAH,void* COMMA TriangleMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Quad4vMeshBuilderMortonGeneral,void* COMMA QuadMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Quad4vMeshBuilderSAH,void* COMMA QuadMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8Quad4vMeshRefitSAH,void* COMMA QuadMesh* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8VirtualMeshBuilderMortonGeneral,void* COMMA UserGeometry* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8VirtualMeshBuilderSAH,void* COMMA UserGeometry* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8VirtualMeshRefitSAH,void* COMMA UserGeometry* COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8InstanceMeshBuilderMortonGeneral,void* COMMA Instance* COMMA Geometry::GTypeMask COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8InstanceMeshBuilderSAH,void* COMMA Instance* COMMA Geometry::GTypeMask COMMA unsigned int COMMA size_t); + DECLARE_ISA_FUNCTION(Builder*,BVH8InstanceMeshRefitSAH,void* COMMA Instance* COMMA Geometry::GTypeMask COMMA unsigned int COMMA size_t) + + namespace isa + { + + namespace __internal_two_level_builder__ { + + template + struct MortonBuilder {}; + template<> + struct MortonBuilder<4,TriangleMesh,Triangle4> { + MortonBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4Triangle4MeshBuilderMortonGeneral(bvh,mesh,geomID,0);} + }; + template<> + struct MortonBuilder<4,TriangleMesh,Triangle4v> { + MortonBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4Triangle4vMeshBuilderMortonGeneral(bvh,mesh,geomID,0);} + }; + template<> + struct MortonBuilder<4,TriangleMesh,Triangle4i> { + MortonBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4Triangle4iMeshBuilderMortonGeneral(bvh,mesh,geomID,0);} + }; + template<> + struct MortonBuilder<4,QuadMesh,Quad4v> { + MortonBuilder () {} + Builder* operator () (void* bvh, QuadMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4Quad4vMeshBuilderMortonGeneral(bvh,mesh,geomID,0);} + }; + template<> + struct MortonBuilder<4,UserGeometry,Object> { + MortonBuilder () {} + Builder* operator () (void* bvh, UserGeometry* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4VirtualMeshBuilderMortonGeneral(bvh,mesh,geomID,0);} + }; + template<> + struct MortonBuilder<4,Instance,InstancePrimitive> { + MortonBuilder () {} + Builder* operator () (void* bvh, Instance* mesh, size_t geomID, Geometry::GTypeMask gtype) { return BVH4InstanceMeshBuilderMortonGeneral(bvh,mesh,gtype,geomID,0);} + }; + template<> + struct MortonBuilder<8,TriangleMesh,Triangle4> { + MortonBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8Triangle4MeshBuilderMortonGeneral(bvh,mesh,geomID,0);} + }; + template<> + struct MortonBuilder<8,TriangleMesh,Triangle4v> { + MortonBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8Triangle4vMeshBuilderMortonGeneral(bvh,mesh,geomID,0);} + }; + template<> + struct MortonBuilder<8,TriangleMesh,Triangle4i> { + MortonBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8Triangle4iMeshBuilderMortonGeneral(bvh,mesh,geomID,0);} + }; + template<> + struct MortonBuilder<8,QuadMesh,Quad4v> { + MortonBuilder () {} + Builder* operator () (void* bvh, QuadMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8Quad4vMeshBuilderMortonGeneral(bvh,mesh,geomID,0);} + }; + template<> + struct MortonBuilder<8,UserGeometry,Object> { + MortonBuilder () {} + Builder* operator () (void* bvh, UserGeometry* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8VirtualMeshBuilderMortonGeneral(bvh,mesh,geomID,0);} + }; + template<> + struct MortonBuilder<8,Instance,InstancePrimitive> { + MortonBuilder () {} + Builder* operator () (void* bvh, Instance* mesh, size_t geomID, Geometry::GTypeMask gtype) { return BVH8InstanceMeshBuilderMortonGeneral(bvh,mesh,gtype,geomID,0);} + }; + + template + struct SAHBuilder {}; + template<> + struct SAHBuilder<4,TriangleMesh,Triangle4> { + SAHBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4Triangle4MeshBuilderSAH(bvh,mesh,geomID,0);} + }; + template<> + struct SAHBuilder<4,TriangleMesh,Triangle4v> { + SAHBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4Triangle4vMeshBuilderSAH(bvh,mesh,geomID,0);} + }; + template<> + struct SAHBuilder<4,TriangleMesh,Triangle4i> { + SAHBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4Triangle4iMeshBuilderSAH(bvh,mesh,geomID,0);} + }; + template<> + struct SAHBuilder<4,QuadMesh,Quad4v> { + SAHBuilder () {} + Builder* operator () (void* bvh, QuadMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4Quad4vMeshBuilderSAH(bvh,mesh,geomID,0);} + }; + template<> + struct SAHBuilder<4,UserGeometry,Object> { + SAHBuilder () {} + Builder* operator () (void* bvh, UserGeometry* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4VirtualMeshBuilderSAH(bvh,mesh,geomID,0);} + }; + template<> + struct SAHBuilder<4,Instance,InstancePrimitive> { + SAHBuilder () {} + Builder* operator () (void* bvh, Instance* mesh, size_t geomID, Geometry::GTypeMask gtype) { return BVH4InstanceMeshBuilderSAH(bvh,mesh,gtype,geomID,0);} + }; + template<> + struct SAHBuilder<8,TriangleMesh,Triangle4> { + SAHBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8Triangle4MeshBuilderSAH(bvh,mesh,geomID,0);} + }; + template<> + struct SAHBuilder<8,TriangleMesh,Triangle4v> { + SAHBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8Triangle4vMeshBuilderSAH(bvh,mesh,geomID,0);} + }; + template<> + struct SAHBuilder<8,TriangleMesh,Triangle4i> { + SAHBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8Triangle4iMeshBuilderSAH(bvh,mesh,geomID,0);} + }; + template<> + struct SAHBuilder<8,QuadMesh,Quad4v> { + SAHBuilder () {} + Builder* operator () (void* bvh, QuadMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8Quad4vMeshBuilderSAH(bvh,mesh,geomID,0);} + }; + template<> + struct SAHBuilder<8,UserGeometry,Object> { + SAHBuilder () {} + Builder* operator () (void* bvh, UserGeometry* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8VirtualMeshBuilderSAH(bvh,mesh,geomID,0);} + }; + template<> + struct SAHBuilder<8,Instance,InstancePrimitive> { + SAHBuilder () {} + Builder* operator () (void* bvh, Instance* mesh, size_t geomID, Geometry::GTypeMask gtype) { return BVH8InstanceMeshBuilderSAH(bvh,mesh,gtype,geomID,0);} + }; + + template + struct RefitBuilder {}; + template<> + struct RefitBuilder<4,TriangleMesh,Triangle4> { + RefitBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4Triangle4MeshRefitSAH(bvh,mesh,geomID,0);} + }; + template<> + struct RefitBuilder<4,TriangleMesh,Triangle4v> { + RefitBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4Triangle4vMeshRefitSAH(bvh,mesh,geomID,0);} + }; + template<> + struct RefitBuilder<4,TriangleMesh,Triangle4i> { + RefitBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4Triangle4iMeshRefitSAH(bvh,mesh,geomID,0);} + }; + template<> + struct RefitBuilder<4,QuadMesh,Quad4v> { + RefitBuilder () {} + Builder* operator () (void* bvh, QuadMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4Quad4vMeshRefitSAH(bvh,mesh,geomID,0);} + }; + template<> + struct RefitBuilder<4,UserGeometry,Object> { + RefitBuilder () {} + Builder* operator () (void* bvh, UserGeometry* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH4VirtualMeshRefitSAH(bvh,mesh,geomID,0);} + }; + template<> + struct RefitBuilder<4,Instance,InstancePrimitive> { + RefitBuilder () {} + Builder* operator () (void* bvh, Instance* mesh, size_t geomID, Geometry::GTypeMask gtype) { return BVH4InstanceMeshRefitSAH(bvh,mesh,gtype,geomID,0);} + }; + template<> + struct RefitBuilder<8,TriangleMesh,Triangle4> { + RefitBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8Triangle4MeshRefitSAH(bvh,mesh,geomID,0);} + }; + template<> + struct RefitBuilder<8,TriangleMesh,Triangle4v> { + RefitBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8Triangle4vMeshRefitSAH(bvh,mesh,geomID,0);} + }; + template<> + struct RefitBuilder<8,TriangleMesh,Triangle4i> { + RefitBuilder () {} + Builder* operator () (void* bvh, TriangleMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8Triangle4iMeshRefitSAH(bvh,mesh,geomID,0);} + }; + template<> + struct RefitBuilder<8,QuadMesh,Quad4v> { + RefitBuilder () {} + Builder* operator () (void* bvh, QuadMesh* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8Quad4vMeshRefitSAH(bvh,mesh,geomID,0);} + }; + template<> + struct RefitBuilder<8,UserGeometry,Object> { + RefitBuilder () {} + Builder* operator () (void* bvh, UserGeometry* mesh, size_t geomID, Geometry::GTypeMask /*gtype*/) { return BVH8VirtualMeshRefitSAH(bvh,mesh,geomID,0);} + }; + template<> + struct RefitBuilder<8,Instance,InstancePrimitive> { + RefitBuilder () {} + Builder* operator () (void* bvh, Instance* mesh, size_t geomID, Geometry::GTypeMask gtype) { return BVH8InstanceMeshRefitSAH(bvh,mesh,gtype,geomID,0);} + }; + + template + struct MeshBuilder { + MeshBuilder () {} + void operator () (void* bvh, Mesh* mesh, size_t geomID, Geometry::GTypeMask gtype, bool useMortonBuilder, Builder*& builder) { + if(useMortonBuilder) { + builder = MortonBuilder()(bvh,mesh,geomID,gtype); + return; + } + switch (mesh->quality) { + case RTC_BUILD_QUALITY_LOW: builder = MortonBuilder()(bvh,mesh,geomID,gtype); break; + case RTC_BUILD_QUALITY_MEDIUM: + case RTC_BUILD_QUALITY_HIGH: builder = SAHBuilder()(bvh,mesh,geomID,gtype); break; + case RTC_BUILD_QUALITY_REFIT: builder = RefitBuilder()(bvh,mesh,geomID,gtype); break; + default: throw_RTCError(RTC_ERROR_UNKNOWN,"invalid build quality"); + } + } + }; + } + } +} \ No newline at end of file diff --git a/thirdparty/embree/kernels/bvh/bvh_collider.cpp b/thirdparty/embree/kernels/bvh/bvh_collider.cpp new file mode 100644 index 000000000000..a27be8bae839 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_collider.cpp @@ -0,0 +1,375 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh_collider.h" +#include "../geometry/triangle_triangle_intersector.h" + +namespace embree +{ + namespace isa + { +#define CSTAT(x) + + size_t parallel_depth_threshold = 3; + CSTAT(std::atomic bvh_collide_traversal_steps(0)); + CSTAT(std::atomic bvh_collide_leaf_pairs(0)); + CSTAT(std::atomic bvh_collide_leaf_iterations(0)); + CSTAT(std::atomic bvh_collide_prim_intersections1(0)); + CSTAT(std::atomic bvh_collide_prim_intersections2(0)); + CSTAT(std::atomic bvh_collide_prim_intersections3(0)); + CSTAT(std::atomic bvh_collide_prim_intersections4(0)); + CSTAT(std::atomic bvh_collide_prim_intersections5(0)); + CSTAT(std::atomic bvh_collide_prim_intersections(0)); + + struct Collision + { + __forceinline Collision() {} + + __forceinline Collision (unsigned geomID0, unsigned primID0, unsigned geomID1, unsigned primID1) + : geomID0(geomID0), primID0(primID0), geomID1(geomID1), primID1(primID1) {} + + unsigned geomID0; + unsigned primID0; + unsigned geomID1; + unsigned primID1; + }; + + template + __forceinline size_t overlap(const BBox3fa& box0, const typename BVHN::AABBNode& node1) + { + const vfloat lower_x = max(vfloat(box0.lower.x),node1.lower_x); + const vfloat lower_y = max(vfloat(box0.lower.y),node1.lower_y); + const vfloat lower_z = max(vfloat(box0.lower.z),node1.lower_z); + const vfloat upper_x = min(vfloat(box0.upper.x),node1.upper_x); + const vfloat upper_y = min(vfloat(box0.upper.y),node1.upper_y); + const vfloat upper_z = min(vfloat(box0.upper.z),node1.upper_z); + return movemask((lower_x <= upper_x) & (lower_y <= upper_y) & (lower_z <= upper_z)); + } + + template + __forceinline size_t overlap(const BBox3fa& box0, const BBox>>& box1) + { + const vfloat lower_x = max(vfloat(box0.lower.x),box1.lower.x); + const vfloat lower_y = max(vfloat(box0.lower.y),box1.lower.y); + const vfloat lower_z = max(vfloat(box0.lower.z),box1.lower.z); + const vfloat upper_x = min(vfloat(box0.upper.x),box1.upper.x); + const vfloat upper_y = min(vfloat(box0.upper.y),box1.upper.y); + const vfloat upper_z = min(vfloat(box0.upper.z),box1.upper.z); + return movemask((lower_x <= upper_x) & (lower_y <= upper_y) & (lower_z <= upper_z)); + } + + template + __forceinline size_t overlap(const BBox>>& box0, size_t i, const BBox>>& box1) + { + const vfloat lower_x = max(vfloat(box0.lower.x[i]),box1.lower.x); + const vfloat lower_y = max(vfloat(box0.lower.y[i]),box1.lower.y); + const vfloat lower_z = max(vfloat(box0.lower.z[i]),box1.lower.z); + const vfloat upper_x = min(vfloat(box0.upper.x[i]),box1.upper.x); + const vfloat upper_y = min(vfloat(box0.upper.y[i]),box1.upper.y); + const vfloat upper_z = min(vfloat(box0.upper.z[i]),box1.upper.z); + return movemask((lower_x <= upper_x) & (lower_y <= upper_y) & (lower_z <= upper_z)); + } + + bool intersect_triangle_triangle (Scene* scene0, unsigned geomID0, unsigned primID0, Scene* scene1, unsigned geomID1, unsigned primID1) + { + CSTAT(bvh_collide_prim_intersections1++); + const TriangleMesh* mesh0 = scene0->get(geomID0); + const TriangleMesh* mesh1 = scene1->get(geomID1); + const TriangleMesh::Triangle& tri0 = mesh0->triangle(primID0); + const TriangleMesh::Triangle& tri1 = mesh1->triangle(primID1); + + /* special culling for scene intersection with itself */ + if (scene0 == scene1 && geomID0 == geomID1) + { + /* ignore self intersections */ + if (primID0 == primID1) + return false; + } + CSTAT(bvh_collide_prim_intersections2++); + + if (scene0 == scene1 && geomID0 == geomID1) + { + /* ignore intersection with topological neighbors */ + const vint4 t0(tri0.v[0],tri0.v[1],tri0.v[2],tri0.v[2]); + if (any(vint4(tri1.v[0]) == t0)) return false; + if (any(vint4(tri1.v[1]) == t0)) return false; + if (any(vint4(tri1.v[2]) == t0)) return false; + } + CSTAT(bvh_collide_prim_intersections3++); + + const Vec3fa a0 = mesh0->vertex(tri0.v[0]); + const Vec3fa a1 = mesh0->vertex(tri0.v[1]); + const Vec3fa a2 = mesh0->vertex(tri0.v[2]); + const Vec3fa b0 = mesh1->vertex(tri1.v[0]); + const Vec3fa b1 = mesh1->vertex(tri1.v[1]); + const Vec3fa b2 = mesh1->vertex(tri1.v[2]); + + return TriangleTriangleIntersector::intersect_triangle_triangle(a0,a1,a2,b0,b1,b2); + } + + template + __forceinline void BVHNColliderUserGeom::processLeaf(NodeRef node0, NodeRef node1) + { + Collision collisions[16]; + size_t num_collisions = 0; + + size_t N0; Object* leaf0 = (Object*) node0.leaf(N0); + size_t N1; Object* leaf1 = (Object*) node1.leaf(N1); + for (size_t i=0; iscene0 == this->scene1 && geomID0 == geomID1 && primID0 == primID1) continue; + collisions[num_collisions++] = Collision(geomID0,primID0,geomID1,primID1); + if (num_collisions == 16) { + this->callback(this->userPtr,(RTCCollision*)&collisions,num_collisions); + num_collisions = 0; + } + } + } + if (num_collisions) + this->callback(this->userPtr,(RTCCollision*)&collisions,num_collisions); + } + + template + void BVHNCollider::collide_recurse(NodeRef ref0, const BBox3fa& bounds0, NodeRef ref1, const BBox3fa& bounds1, size_t depth0, size_t depth1) + { + CSTAT(bvh_collide_traversal_steps++); + if (unlikely(ref0.isLeaf())) { + if (unlikely(ref1.isLeaf())) { + CSTAT(bvh_collide_leaf_pairs++); + processLeaf(ref0,ref1); + return; + } else goto recurse_node1; + + } else { + if (unlikely(ref1.isLeaf())) { + goto recurse_node0; + } else { + if (area(bounds0) > area(bounds1)) { + goto recurse_node0; + } + else { + goto recurse_node1; + } + } + } + + { + recurse_node0: + AABBNode* node0 = ref0.getAABBNode(); + size_t mask = overlap(bounds1,*node0); + //for (size_t m=mask, i=bsf(m); m!=0; m=btc(m,i), i=bsf(m)) { + //for (size_t i=0; i::prefetch(node0->child(i),BVH_FLAG_ALIGNED_NODE); + collide_recurse(node0->child(i),node0->bounds(i),ref1,bounds1,depth0+1,depth1); + } + }); + } + else +#endif + { + for (size_t m=mask, i=bsf(m); m!=0; m=btc(m,i), i=bsf(m)) { + BVHN::prefetch(node0->child(i),BVH_FLAG_ALIGNED_NODE); + collide_recurse(node0->child(i),node0->bounds(i),ref1,bounds1,depth0+1,depth1); + } + } + return; + } + + { + recurse_node1: + AABBNode* node1 = ref1.getAABBNode(); + size_t mask = overlap(bounds0,*node1); + //for (size_t m=mask, i=bsf(m); m!=0; m=btc(m,i), i=bsf(m)) { + //for (size_t i=0; i::prefetch(node1->child(i),BVH_FLAG_ALIGNED_NODE); + collide_recurse(ref0,bounds0,node1->child(i),node1->bounds(i),depth0,depth1+1); + } + }); + } + else +#endif + { + for (size_t m=mask, i=bsf(m); m!=0; m=btc(m,i), i=bsf(m)) { + BVHN::prefetch(node1->child(i),BVH_FLAG_ALIGNED_NODE); + collide_recurse(ref0,bounds0,node1->child(i),node1->bounds(i),depth0,depth1+1); + } + } + return; + } + } + + template + void BVHNCollider::split(const CollideJob& job, jobvector& jobs) + { + if (unlikely(job.ref0.isLeaf())) { + if (unlikely(job.ref1.isLeaf())) { + jobs.push_back(job); + return; + } else goto recurse_node1; + } else { + if (unlikely(job.ref1.isLeaf())) { + goto recurse_node0; + } else { + if (area(job.bounds0) > area(job.bounds1)) { + goto recurse_node0; + } + else { + goto recurse_node1; + } + } + } + + { + recurse_node0: + const AABBNode* node0 = job.ref0.getAABBNode(); + size_t mask = overlap(job.bounds1,*node0); + for (size_t m=mask, i=bsf(m); m!=0; m=btc(m,i), i=bsf(m)) { + jobs.push_back(CollideJob(node0->child(i),node0->bounds(i),job.depth0+1,job.ref1,job.bounds1,job.depth1)); + } + return; + } + + { + recurse_node1: + const AABBNode* node1 = job.ref1.getAABBNode(); + size_t mask = overlap(job.bounds0,*node1); + for (size_t m=mask, i=bsf(m); m!=0; m=btc(m,i), i=bsf(m)) { + jobs.push_back(CollideJob(job.ref0,job.bounds0,job.depth0,node1->child(i),node1->bounds(i),job.depth1+1)); + } + return; + } + } + + template + void BVHNCollider::collide_recurse_entry(NodeRef ref0, const BBox3fa& bounds0, NodeRef ref1, const BBox3fa& bounds1) + { + CSTAT(bvh_collide_traversal_steps = 0); + CSTAT(bvh_collide_leaf_pairs = 0); + CSTAT(bvh_collide_leaf_iterations = 0); + CSTAT(bvh_collide_prim_intersections1 = 0); + CSTAT(bvh_collide_prim_intersections2 = 0); + CSTAT(bvh_collide_prim_intersections3 = 0); + CSTAT(bvh_collide_prim_intersections4 = 0); + CSTAT(bvh_collide_prim_intersections5 = 0); + CSTAT(bvh_collide_prim_intersections = 0); +#if 0 + collide_recurse(ref0,bounds0,ref1,bounds1,0,0); +#else + const int M = 2048; + jobvector jobs[2]; + jobs[0].reserve(M); + jobs[1].reserve(M); + jobs[0].push_back(CollideJob(ref0,bounds0,0,ref1,bounds1,0)); + int source = 0; + int target = 1; + + /* try to split job until job list is full */ + while (jobs[source].size()+8 <= M) + { + for (size_t i=0; i M) { + jobs[target].push_back(job); + } else { + split(job,jobs[target]); + } + } + + /* stop splitting jobs if we reached only leaves and cannot make progress anymore */ + if (jobs[target].size() == jobs[source].size()) + break; + + jobs[source].resize(0); + std::swap(source,target); + } + + /* parallel processing of all jobs */ + parallel_for(size_t(jobs[source].size()), [&] ( size_t i ) { + CollideJob& j = jobs[source][i]; + collide_recurse(j.ref0,j.bounds0,j.ref1,j.bounds1,j.depth0,j.depth1); + }); + + +#endif + CSTAT(PRINT(bvh_collide_traversal_steps)); + CSTAT(PRINT(bvh_collide_leaf_pairs)); + CSTAT(PRINT(bvh_collide_leaf_iterations)); + CSTAT(PRINT(bvh_collide_prim_intersections1)); + CSTAT(PRINT(bvh_collide_prim_intersections2)); + CSTAT(PRINT(bvh_collide_prim_intersections3)); + CSTAT(PRINT(bvh_collide_prim_intersections4)); + CSTAT(PRINT(bvh_collide_prim_intersections5)); + CSTAT(PRINT(bvh_collide_prim_intersections)); + } + + template + void BVHNColliderUserGeom::collide(BVH* __restrict__ bvh0, BVH* __restrict__ bvh1, RTCCollideFunc callback, void* userPtr) + { + BVHNColliderUserGeom(bvh0->scene,bvh1->scene,callback,userPtr). + collide_recurse_entry(bvh0->root,bvh0->bounds.bounds(),bvh1->root,bvh1->bounds.bounds()); + } + +#if defined (EMBREE_LOWEST_ISA) + struct collision_regression_test : public RegressionTest + { + collision_regression_test(const char* name) : RegressionTest(name) { + registerRegressionTest(this); + } + + bool run () + { + bool passed = true; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(-0.008815f, 0.041848f, -2.49875e-06f), Vec3fa(-0.008276f, 0.053318f, -2.49875e-06f), Vec3fa(0.003023f, 0.048969f, -2.49875e-06f), + Vec3fa(0.00245f, 0.037612f, -2.49875e-06f), Vec3fa(0.01434f, 0.042634f, -2.49875e-06f), Vec3fa(0.013499f, 0.031309f, -2.49875e-06f)) == false; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0)) == true; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), Vec3fa(0,0,1),Vec3fa(1,0,1),Vec3fa(0,1,1)) == false; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), Vec3fa(0,0,1),Vec3fa(1,0,0),Vec3fa(0,1,0)) == true; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), Vec3fa(0,0,0),Vec3fa(1,0,1),Vec3fa(0,1,1)) == true; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), Vec3fa(0.1f,0.1f,0),Vec3fa(1,0,1),Vec3fa(0,1,1)) == true; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), Vec3fa(0.1f,0.1f,-0.1f),Vec3fa(1,0,1),Vec3fa(0,1,1)) == true; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0)) == true; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), Vec3fa(0,0,0),Vec3fa(0.5f,0,0),Vec3fa(0,0.5f,0)) == true; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), Vec3fa(0.1f,0.1f,0),Vec3fa(0.5f,0,0),Vec3fa(0,0.5f,0)) == true; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), Vec3fa(0.1f,0.1f,0),Vec3fa(0.5f,0.1f,0),Vec3fa(0.1f,0.5f,0)) == true; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), Vec3fa(0.1f,-0.1f,0),Vec3fa(0.5f,0.1f,0),Vec3fa(0.1f,0.5f,0)) == true; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), Vec3fa(-0.1f,0.1f,0),Vec3fa(0.5f,0.1f,0),Vec3fa(0.1f,0.5f,0)) == true; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), + Vec3fa(-1,1,0) + Vec3fa(0,0,0),Vec3fa(-1,1,0) + Vec3fa(0.1f,0,0),Vec3fa(-1,1,0) + Vec3fa(0,0.1f,0)) == false; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), + Vec3fa( 2,0.5f,0) + Vec3fa(0,0,0),Vec3fa( 2,0.5f,0) + Vec3fa(0.1f,0,0),Vec3fa( 2,0.5f,0) + Vec3fa(0,0.1f,0)) == false; + passed &= TriangleTriangleIntersector::intersect_triangle_triangle (Vec3fa(0,0,0),Vec3fa(1,0,0),Vec3fa(0,1,0), + Vec3fa(0.5f,-2.0f,0) + Vec3fa(0,0,0),Vec3fa(0.5f,-2.0f,0) + Vec3fa(0.1f,0,0),Vec3fa(0.5f,-2.0f,0) + Vec3fa(0,0.1f,0)) == false; + return passed; + } + }; + + collision_regression_test collision_regression("collision_regression_test"); +#endif + + //////////////////////////////////////////////////////////////////////////////// + /// Collider Definitions + //////////////////////////////////////////////////////////////////////////////// + + DEFINE_COLLIDER(BVH4ColliderUserGeom,BVHNColliderUserGeom<4>); + +#if defined(__AVX__) + DEFINE_COLLIDER(BVH8ColliderUserGeom,BVHNColliderUserGeom<8>); +#endif + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_collider.h b/thirdparty/embree/kernels/bvh/bvh_collider.h new file mode 100644 index 000000000000..ac4f99c96a7a --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_collider.h @@ -0,0 +1,72 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh.h" +#include "../geometry/trianglev.h" +#include "../geometry/object.h" + +namespace embree +{ + namespace isa + { + template + class BVHNCollider + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::AABBNode AABBNode; + + struct CollideJob + { + CollideJob () {} + + CollideJob (NodeRef ref0, const BBox3fa& bounds0, size_t depth0, + NodeRef ref1, const BBox3fa& bounds1, size_t depth1) + : ref0(ref0), bounds0(bounds0), depth0(depth0), ref1(ref1), bounds1(bounds1), depth1(depth1) {} + + NodeRef ref0; + BBox3fa bounds0; + size_t depth0; + NodeRef ref1; + BBox3fa bounds1; + size_t depth1; + }; + + typedef vector_t> jobvector; + + void split(const CollideJob& job, jobvector& jobs); + + public: + __forceinline BVHNCollider (Scene* scene0, Scene* scene1, RTCCollideFunc callback, void* userPtr) + : scene0(scene0), scene1(scene1), callback(callback), userPtr(userPtr) {} + + public: + virtual void processLeaf(NodeRef leaf0, NodeRef leaf1) = 0; + void collide_recurse(NodeRef node0, const BBox3fa& bounds0, NodeRef node1, const BBox3fa& bounds1, size_t depth0, size_t depth1); + void collide_recurse_entry(NodeRef node0, const BBox3fa& bounds0, NodeRef node1, const BBox3fa& bounds1); + + protected: + Scene* scene0; + Scene* scene1; + RTCCollideFunc callback; + void* userPtr; + }; + + template + class BVHNColliderUserGeom : public BVHNCollider + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::AABBNode AABBNode; + + __forceinline BVHNColliderUserGeom (Scene* scene0, Scene* scene1, RTCCollideFunc callback, void* userPtr) + : BVHNCollider(scene0,scene1,callback,userPtr) {} + + virtual void processLeaf(NodeRef leaf0, NodeRef leaf1); + public: + static void collide(BVH* __restrict__ bvh0, BVH* __restrict__ bvh1, RTCCollideFunc callback, void* userPtr); + }; + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_factory.h b/thirdparty/embree/kernels/bvh/bvh_factory.h new file mode 100644 index 000000000000..54021ca6eb9f --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_factory.h @@ -0,0 +1,21 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../bvh/bvh.h" +#include "../common/isa.h" +#include "../common/accel.h" +#include "../common/scene.h" +#include "../geometry/curve_intersector_virtual.h" + +namespace embree +{ + /*! BVH instantiations */ + class BVHFactory + { + public: + enum class BuildVariant { STATIC, DYNAMIC, HIGH_QUALITY }; + enum class IntersectVariant { FAST, ROBUST }; + }; +} diff --git a/thirdparty/embree/kernels/bvh/bvh_intersector1.cpp b/thirdparty/embree/kernels/bvh/bvh_intersector1.cpp new file mode 100644 index 000000000000..ea6adc271746 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_intersector1.cpp @@ -0,0 +1,330 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh_intersector1.h" +#include "node_intersector1.h" +#include "bvh_traverser1.h" + +#include "../geometry/intersector_iterators.h" +#include "../geometry/triangle_intersector.h" +#include "../geometry/trianglev_intersector.h" +#include "../geometry/trianglev_mb_intersector.h" +#include "../geometry/trianglei_intersector.h" +#include "../geometry/quadv_intersector.h" +#include "../geometry/quadi_intersector.h" +#include "../geometry/curveNv_intersector.h" +#include "../geometry/curveNi_intersector.h" +#include "../geometry/curveNi_mb_intersector.h" +#include "../geometry/linei_intersector.h" +#include "../geometry/subdivpatch1_intersector.h" +#include "../geometry/object_intersector.h" +#include "../geometry/instance_intersector.h" +#include "../geometry/subgrid_intersector.h" +#include "../geometry/subgrid_mb_intersector.h" +#include "../geometry/curve_intersector_virtual.h" + +namespace embree +{ + namespace isa + { + template + void BVHNIntersector1::intersect(const Accel::Intersectors* __restrict__ This, + RayHit& __restrict__ ray, + IntersectContext* __restrict__ context) + { + const BVH* __restrict__ bvh = (const BVH*)This->ptr; + + /* we may traverse an empty BVH in case all geometry was invalid */ + if (bvh->root == BVH::emptyNode) + return; + + /* perform per ray precalculations required by the primitive intersector */ + Precalculations pre(ray, bvh); + + /* stack state */ + StackItemT stack[stackSize]; // stack of nodes + StackItemT* stackPtr = stack+1; // current stack pointer + StackItemT* stackEnd = stack+stackSize; + stack[0].ptr = bvh->root; + stack[0].dist = neg_inf; + + if (bvh->root == BVH::emptyNode) + return; + + /* filter out invalid rays */ +#if defined(EMBREE_IGNORE_INVALID_RAYS) + if (!ray.valid()) return; +#endif + /* verify correct input */ + assert(ray.valid()); + assert(ray.tnear() >= 0.0f); + assert(!(types & BVH_MB) || (ray.time() >= 0.0f && ray.time() <= 1.0f)); + + /* load the ray into SIMD registers */ + TravRay tray(ray.org, ray.dir, max(ray.tnear(), 0.0f), max(ray.tfar, 0.0f)); + + /* initialize the node traverser */ + BVHNNodeTraverser1Hit nodeTraverser; + + /* pop loop */ + while (true) pop: + { + /* pop next node */ + if (unlikely(stackPtr == stack)) break; + stackPtr--; + NodeRef cur = NodeRef(stackPtr->ptr); + + /* if popped node is too far, pop next one */ +#if defined(__AVX512ER__) + /* much faster on KNL */ + if (unlikely(any(vfloat(*(float*)&stackPtr->dist) > tray.tfar))) + continue; +#else + if (unlikely(*(float*)&stackPtr->dist > ray.tfar)) + continue; +#endif + + /* downtraversal loop */ + while (true) + { + /* intersect node */ + size_t mask; vfloat tNear; + STAT3(normal.trav_nodes,1,1,1); + bool nodeIntersected = BVHNNodeIntersector1::intersect(cur, tray, ray.time(), tNear, mask); + if (unlikely(!nodeIntersected)) { STAT3(normal.trav_nodes,-1,-1,-1); break; } + + /* if no child is hit, pop next node */ + if (unlikely(mask == 0)) + goto pop; + + /* select next child and push other children */ + nodeTraverser.traverseClosestHit(cur, mask, tNear, stackPtr, stackEnd); + } + + /* this is a leaf node */ + assert(cur != BVH::emptyNode); + STAT3(normal.trav_leaves,1,1,1); + size_t num; Primitive* prim = (Primitive*)cur.leaf(num); + size_t lazy_node = 0; + PrimitiveIntersector1::intersect(This, pre, ray, context, prim, num, tray, lazy_node); + tray.tfar = ray.tfar; + + /* push lazy node onto stack */ + if (unlikely(lazy_node)) { + stackPtr->ptr = lazy_node; + stackPtr->dist = neg_inf; + stackPtr++; + } + } + } + + template + void BVHNIntersector1::occluded(const Accel::Intersectors* __restrict__ This, + Ray& __restrict__ ray, + IntersectContext* __restrict__ context) + { + const BVH* __restrict__ bvh = (const BVH*)This->ptr; + + /* we may traverse an empty BVH in case all geometry was invalid */ + if (bvh->root == BVH::emptyNode) + return; + + /* early out for already occluded rays */ + if (unlikely(ray.tfar < 0.0f)) + return; + + /* perform per ray precalculations required by the primitive intersector */ + Precalculations pre(ray, bvh); + + /* stack state */ + NodeRef stack[stackSize]; // stack of nodes that still need to get traversed + NodeRef* stackPtr = stack+1; // current stack pointer + NodeRef* stackEnd = stack+stackSize; + stack[0] = bvh->root; + + /* filter out invalid rays */ +#if defined(EMBREE_IGNORE_INVALID_RAYS) + if (!ray.valid()) return; +#endif + + /* verify correct input */ + assert(ray.valid()); + assert(ray.tnear() >= 0.0f); + assert(!(types & BVH_MB) || (ray.time() >= 0.0f && ray.time() <= 1.0f)); + + /* load the ray into SIMD registers */ + TravRay tray(ray.org, ray.dir, max(ray.tnear(), 0.0f), max(ray.tfar, 0.0f)); + + /* initialize the node traverser */ + BVHNNodeTraverser1Hit nodeTraverser; + + /* pop loop */ + while (true) pop: + { + /* pop next node */ + if (unlikely(stackPtr == stack)) break; + stackPtr--; + NodeRef cur = (NodeRef)*stackPtr; + + /* downtraversal loop */ + while (true) + { + /* intersect node */ + size_t mask; vfloat tNear; + STAT3(shadow.trav_nodes,1,1,1); + bool nodeIntersected = BVHNNodeIntersector1::intersect(cur, tray, ray.time(), tNear, mask); + if (unlikely(!nodeIntersected)) { STAT3(shadow.trav_nodes,-1,-1,-1); break; } + + /* if no child is hit, pop next node */ + if (unlikely(mask == 0)) + goto pop; + + /* select next child and push other children */ + nodeTraverser.traverseAnyHit(cur, mask, tNear, stackPtr, stackEnd); + } + + /* this is a leaf node */ + assert(cur != BVH::emptyNode); + STAT3(shadow.trav_leaves,1,1,1); + size_t num; Primitive* prim = (Primitive*)cur.leaf(num); + size_t lazy_node = 0; + if (PrimitiveIntersector1::occluded(This, pre, ray, context, prim, num, tray, lazy_node)) { + ray.tfar = neg_inf; + break; + } + + /* push lazy node onto stack */ + if (unlikely(lazy_node)) { + *stackPtr = (NodeRef)lazy_node; + stackPtr++; + } + } + } + + template + struct PointQueryDispatch + { + typedef typename PrimitiveIntersector1::Precalculations Precalculations; + typedef typename PrimitiveIntersector1::Primitive Primitive; + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::AABBNode AABBNode; + typedef typename BVH::AABBNodeMB4D AABBNodeMB4D; + + static const size_t stackSize = 1+(N-1)*BVH::maxDepth+3; // +3 due to 16-wide store + + /* right now AVX512KNL SIMD extension only for standard node types */ + static const size_t Nx = (types == BVH_AN1 || types == BVH_QN1) ? vextend::size : N; + + static __forceinline bool pointQuery(const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context) + { + const BVH* __restrict__ bvh = (const BVH*)This->ptr; + + /* we may traverse an empty BVH in case all geometry was invalid */ + if (bvh->root == BVH::emptyNode) + return false; + + /* stack state */ + StackItemT stack[stackSize]; // stack of nodes + StackItemT* stackPtr = stack+1; // current stack pointer + StackItemT* stackEnd = stack+stackSize; + stack[0].ptr = bvh->root; + stack[0].dist = neg_inf; + + /* verify correct input */ + assert(!(types & BVH_MB) || (query->time >= 0.0f && query->time <= 1.0f)); + + /* load the point query into SIMD registers */ + TravPointQuery tquery(query->p, context->query_radius); + + /* initialize the node traverser */ + BVHNNodeTraverser1Hit nodeTraverser; + + bool changed = false; + float cull_radius = context->query_type == POINT_QUERY_TYPE_SPHERE + ? query->radius * query->radius + : dot(context->query_radius, context->query_radius); + + /* pop loop */ + while (true) pop: + { + /* pop next node */ + if (unlikely(stackPtr == stack)) break; + stackPtr--; + NodeRef cur = NodeRef(stackPtr->ptr); + + /* if popped node is too far, pop next one */ + if (unlikely(*(float*)&stackPtr->dist > cull_radius)) + continue; + + /* downtraversal loop */ + while (true) + { + /* intersect node */ + size_t mask; vfloat tNear; + STAT3(point_query.trav_nodes,1,1,1); + bool nodeIntersected; + if (likely(context->query_type == POINT_QUERY_TYPE_SPHERE)) { + nodeIntersected = BVHNNodePointQuerySphere1::pointQuery(cur, tquery, query->time, tNear, mask); + } else { + nodeIntersected = BVHNNodePointQueryAABB1 ::pointQuery(cur, tquery, query->time, tNear, mask); + } + if (unlikely(!nodeIntersected)) { STAT3(point_query.trav_nodes,-1,-1,-1); break; } + + /* if no child is hit, pop next node */ + if (unlikely(mask == 0)) + goto pop; + + /* select next child and push other children */ + nodeTraverser.traverseClosestHit(cur, mask, tNear, stackPtr, stackEnd); + } + + /* this is a leaf node */ + assert(cur != BVH::emptyNode); + STAT3(point_query.trav_leaves,1,1,1); + size_t num; Primitive* prim = (Primitive*)cur.leaf(num); + size_t lazy_node = 0; + if (PrimitiveIntersector1::pointQuery(This, query, context, prim, num, tquery, lazy_node)) + { + changed = true; + tquery.rad = context->query_radius; + cull_radius = context->query_type == POINT_QUERY_TYPE_SPHERE + ? query->radius * query->radius + : dot(context->query_radius, context->query_radius); + } + + /* push lazy node onto stack */ + if (unlikely(lazy_node)) { + stackPtr->ptr = lazy_node; + stackPtr->dist = neg_inf; + stackPtr++; + } + } + return changed; + } + }; + + /* disable point queries for not yet supported geometry types */ + template + struct PointQueryDispatch { + static __forceinline bool pointQuery(const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context) { return false; } + }; + + template + struct PointQueryDispatch { + static __forceinline bool pointQuery(const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context) { return false; } + }; + + template + struct PointQueryDispatch { + static __forceinline bool pointQuery(const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context) { return false; } + }; + + template + bool BVHNIntersector1::pointQuery( + const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context) + { + return PointQueryDispatch::pointQuery(This, query, context); + } + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_intersector1.h b/thirdparty/embree/kernels/bvh/bvh_intersector1.h new file mode 100644 index 000000000000..1a269c319a7b --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_intersector1.h @@ -0,0 +1,37 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh.h" +#include "../common/ray.h" +#include "../common/point_query.h" + +namespace embree +{ + namespace isa + { + /*! BVH single ray intersector. */ + template + class BVHNIntersector1 + { + /* shortcuts for frequently used types */ + typedef typename PrimitiveIntersector1::Precalculations Precalculations; + typedef typename PrimitiveIntersector1::Primitive Primitive; + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::AABBNode AABBNode; + typedef typename BVH::AABBNodeMB4D AABBNodeMB4D; + + static const size_t stackSize = 1+(N-1)*BVH::maxDepth+3; // +3 due to 16-wide store + + /* right now AVX512KNL SIMD extension only for standard node types */ + static const size_t Nx = (types == BVH_AN1 || types == BVH_QN1) ? vextend::size : N; + + public: + static void intersect (const Accel::Intersectors* This, RayHit& ray, IntersectContext* context); + static void occluded (const Accel::Intersectors* This, Ray& ray, IntersectContext* context); + static bool pointQuery(const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context); + }; + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_intersector1_bvh4.cpp b/thirdparty/embree/kernels/bvh/bvh_intersector1_bvh4.cpp new file mode 100644 index 000000000000..989f7354fda2 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_intersector1_bvh4.cpp @@ -0,0 +1,61 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh_intersector1.cpp" + +namespace embree +{ + namespace isa + { + int getISA() { + return VerifyMultiTargetLinking::getISA(); + } + + //////////////////////////////////////////////////////////////////////////////// + /// BVH4Intersector1 Definitions + //////////////////////////////////////////////////////////////////////////////// + + IF_ENABLED_CURVES_OR_POINTS(DEFINE_INTERSECTOR1(BVH4OBBVirtualCurveIntersector1,BVHNIntersector1<4 COMMA BVH_AN1_UN1 COMMA false COMMA VirtualCurveIntersector1 >)); + IF_ENABLED_CURVES_OR_POINTS(DEFINE_INTERSECTOR1(BVH4OBBVirtualCurveIntersector1MB,BVHNIntersector1<4 COMMA BVH_AN2_AN4D_UN2 COMMA false COMMA VirtualCurveIntersector1 >)); + + IF_ENABLED_CURVES_OR_POINTS(DEFINE_INTERSECTOR1(BVH4OBBVirtualCurveIntersectorRobust1,BVHNIntersector1<4 COMMA BVH_AN1_UN1 COMMA true COMMA VirtualCurveIntersector1 >)); + IF_ENABLED_CURVES_OR_POINTS(DEFINE_INTERSECTOR1(BVH4OBBVirtualCurveIntersectorRobust1MB,BVHNIntersector1<4 COMMA BVH_AN2_AN4D_UN2 COMMA true COMMA VirtualCurveIntersector1 >)); + + IF_ENABLED_TRIS(DEFINE_INTERSECTOR1(BVH4Triangle4Intersector1Moeller, BVHNIntersector1<4 COMMA BVH_AN1 COMMA false COMMA ArrayIntersector1 > >)); + IF_ENABLED_TRIS(DEFINE_INTERSECTOR1(BVH4Triangle4iIntersector1Moeller, BVHNIntersector1<4 COMMA BVH_AN1 COMMA false COMMA ArrayIntersector1 > >)); + IF_ENABLED_TRIS(DEFINE_INTERSECTOR1(BVH4Triangle4vIntersector1Pluecker,BVHNIntersector1<4 COMMA BVH_AN1 COMMA true COMMA ArrayIntersector1 > >)); + IF_ENABLED_TRIS(DEFINE_INTERSECTOR1(BVH4Triangle4iIntersector1Pluecker,BVHNIntersector1<4 COMMA BVH_AN1 COMMA true COMMA ArrayIntersector1 > >)); + + IF_ENABLED_TRIS(DEFINE_INTERSECTOR1(BVH4Triangle4vMBIntersector1Moeller, BVHNIntersector1<4 COMMA BVH_AN2_AN4D COMMA false COMMA ArrayIntersector1 > >)); + IF_ENABLED_TRIS(DEFINE_INTERSECTOR1(BVH4Triangle4iMBIntersector1Moeller, BVHNIntersector1<4 COMMA BVH_AN2_AN4D COMMA false COMMA ArrayIntersector1 > >)); + IF_ENABLED_TRIS(DEFINE_INTERSECTOR1(BVH4Triangle4vMBIntersector1Pluecker,BVHNIntersector1<4 COMMA BVH_AN2_AN4D COMMA true COMMA ArrayIntersector1 > >)); + IF_ENABLED_TRIS(DEFINE_INTERSECTOR1(BVH4Triangle4iMBIntersector1Pluecker,BVHNIntersector1<4 COMMA BVH_AN2_AN4D COMMA true COMMA ArrayIntersector1 > >)); + + IF_ENABLED_QUADS(DEFINE_INTERSECTOR1(BVH4Quad4vIntersector1Moeller, BVHNIntersector1<4 COMMA BVH_AN1 COMMA false COMMA ArrayIntersector1 > >)); + IF_ENABLED_QUADS(DEFINE_INTERSECTOR1(BVH4Quad4iIntersector1Moeller, BVHNIntersector1<4 COMMA BVH_AN1 COMMA false COMMA ArrayIntersector1 > >)); + IF_ENABLED_QUADS(DEFINE_INTERSECTOR1(BVH4Quad4vIntersector1Pluecker,BVHNIntersector1<4 COMMA BVH_AN1 COMMA true COMMA ArrayIntersector1 > >)); + IF_ENABLED_QUADS(DEFINE_INTERSECTOR1(BVH4Quad4iIntersector1Pluecker,BVHNIntersector1<4 COMMA BVH_AN1 COMMA true COMMA ArrayIntersector1 > >)); + + IF_ENABLED_QUADS(DEFINE_INTERSECTOR1(BVH4Quad4iMBIntersector1Moeller, BVHNIntersector1<4 COMMA BVH_AN2_AN4D COMMA false COMMA ArrayIntersector1 > >)); + IF_ENABLED_QUADS(DEFINE_INTERSECTOR1(BVH4Quad4iMBIntersector1Pluecker,BVHNIntersector1<4 COMMA BVH_AN2_AN4D COMMA true COMMA ArrayIntersector1 > >)); + + IF_ENABLED_SUBDIV(DEFINE_INTERSECTOR1(BVH4SubdivPatch1Intersector1,BVHNIntersector1<4 COMMA BVH_AN1 COMMA true COMMA SubdivPatch1Intersector1>)); + IF_ENABLED_SUBDIV(DEFINE_INTERSECTOR1(BVH4SubdivPatch1MBIntersector1,BVHNIntersector1<4 COMMA BVH_AN2_AN4D COMMA true COMMA SubdivPatch1MBIntersector1>)); + + IF_ENABLED_USER(DEFINE_INTERSECTOR1(BVH4VirtualIntersector1,BVHNIntersector1<4 COMMA BVH_AN1 COMMA false COMMA ArrayIntersector1> >)); + IF_ENABLED_USER(DEFINE_INTERSECTOR1(BVH4VirtualMBIntersector1,BVHNIntersector1<4 COMMA BVH_AN2_AN4D COMMA false COMMA ArrayIntersector1> >)); + + IF_ENABLED_INSTANCE(DEFINE_INTERSECTOR1(BVH4InstanceIntersector1,BVHNIntersector1<4 COMMA BVH_AN1 COMMA false COMMA ArrayIntersector1 >)); + IF_ENABLED_INSTANCE(DEFINE_INTERSECTOR1(BVH4InstanceMBIntersector1,BVHNIntersector1<4 COMMA BVH_AN2_AN4D COMMA false COMMA ArrayIntersector1 >)); + + IF_ENABLED_TRIS(DEFINE_INTERSECTOR1(QBVH4Triangle4iIntersector1Pluecker,BVHNIntersector1<4 COMMA BVH_QN1 COMMA false COMMA ArrayIntersector1 > >)); + IF_ENABLED_QUADS(DEFINE_INTERSECTOR1(QBVH4Quad4iIntersector1Pluecker,BVHNIntersector1<4 COMMA BVH_QN1 COMMA false COMMA ArrayIntersector1 > >)); + + IF_ENABLED_GRIDS(DEFINE_INTERSECTOR1(BVH4GridIntersector1Moeller,BVHNIntersector1<4 COMMA BVH_AN1 COMMA false COMMA SubGridIntersector1Moeller<4 COMMA true> >)); + IF_ENABLED_GRIDS(DEFINE_INTERSECTOR1(BVH4GridMBIntersector1Moeller,BVHNIntersector1<4 COMMA BVH_AN2_AN4D COMMA true COMMA SubGridMBIntersector1Pluecker<4 COMMA true> >)); + + IF_ENABLED_GRIDS(DEFINE_INTERSECTOR1(BVH4GridIntersector1Pluecker,BVHNIntersector1<4 COMMA BVH_AN1 COMMA true COMMA SubGridIntersector1Pluecker<4 COMMA true> >)); + //IF_ENABLED_GRIDS(DEFINE_INTERSECTOR1(BVH4GridMBIntersector1Pluecker,BVHNIntersector1<4 COMMA BVH_AN2_AN4D COMMA false COMMA SubGridMBIntersector1Pluecker<4 COMMA true> >)); + + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_intersector_hybrid.h b/thirdparty/embree/kernels/bvh/bvh_intersector_hybrid.h new file mode 100644 index 000000000000..d764cc928d7e --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_intersector_hybrid.h @@ -0,0 +1,61 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh.h" +#include "../common/ray.h" +#include "../common/stack_item.h" +#include "node_intersector_frustum.h" + +namespace embree +{ + namespace isa + { + template + struct TravRayK; + + /*! BVH hybrid packet intersector. Switches between packet and single ray traversal (optional). */ + template + class BVHNIntersectorKHybrid + { + /* right now AVX512KNL SIMD extension only for standard node types */ + static const size_t Nx = types == BVH_AN1 ? vextend::size : N; + + /* shortcuts for frequently used types */ + typedef typename PrimitiveIntersectorK::Precalculations Precalculations; + typedef typename PrimitiveIntersectorK::Primitive Primitive; + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::BaseNode BaseNode; + typedef typename BVH::AABBNode AABBNode; + + static const size_t stackSizeSingle = 1+(N-1)*BVH::maxDepth+3; // +3 due to 16-wide store + static const size_t stackSizeChunk = 1+(N-1)*BVH::maxDepth; + + static const size_t switchThresholdIncoherent = \ + (K==4) ? 3 : + (K==8) ? ((N==4) ? 5 : 7) : + (K==16) ? 14 : // 14 seems to work best for KNL due to better ordered chunk traversal + 0; + + private: + static void intersect1(Accel::Intersectors* This, const BVH* bvh, NodeRef root, size_t k, Precalculations& pre, + RayHitK& ray, const TravRayK& tray, IntersectContext* context); + static bool occluded1(Accel::Intersectors* This, const BVH* bvh, NodeRef root, size_t k, Precalculations& pre, + RayK& ray, const TravRayK& tray, IntersectContext* context); + + public: + static void intersect(vint* valid, Accel::Intersectors* This, RayHitK& ray, IntersectContext* context); + static void occluded (vint* valid, Accel::Intersectors* This, RayK& ray, IntersectContext* context); + + static void intersectCoherent(vint* valid, Accel::Intersectors* This, RayHitK& ray, IntersectContext* context); + static void occludedCoherent (vint* valid, Accel::Intersectors* This, RayK& ray, IntersectContext* context); + + }; + + /*! BVH packet intersector. */ + template + class BVHNIntersectorKChunk : public BVHNIntersectorKHybrid {}; + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_intersector_stream.h b/thirdparty/embree/kernels/bvh/bvh_intersector_stream.h new file mode 100644 index 000000000000..f5beb6ca9110 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_intersector_stream.h @@ -0,0 +1,284 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "node_intersector_packet_stream.h" +#include "node_intersector_frustum.h" +#include "bvh_traverser_stream.h" + +namespace embree +{ + namespace isa + { + /*! BVH ray stream intersector. */ + template + class BVHNIntersectorStream + { + static const int Nxd = (Nx == N) ? N : Nx/2; + + /* shortcuts for frequently used types */ + template using PrimitiveIntersectorK = typename PrimitiveIntersector::template Type; + template using PrimitiveK = typename PrimitiveIntersectorK::PrimitiveK; + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::BaseNode BaseNode; + typedef typename BVH::AABBNode AABBNode; + typedef typename BVH::AABBNodeMB AABBNodeMB; + + template + __forceinline static size_t initPacketsAndFrustum(RayK** inputPackets, size_t numOctantRays, + TravRayKStream* packets, Frustum& frustum, bool& commonOctant) + { + const size_t numPackets = (numOctantRays+K-1)/K; + + Vec3vf tmp_min_rdir(pos_inf); + Vec3vf tmp_max_rdir(neg_inf); + Vec3vf tmp_min_org(pos_inf); + Vec3vf tmp_max_org(neg_inf); + vfloat tmp_min_dist(pos_inf); + vfloat tmp_max_dist(neg_inf); + + size_t m_active = 0; + for (size_t i = 0; i < numPackets; i++) + { + const vfloat tnear = inputPackets[i]->tnear(); + const vfloat tfar = inputPackets[i]->tfar; + vbool m_valid = (tnear <= tfar) & (tnear >= 0.0f); + +#if defined(EMBREE_IGNORE_INVALID_RAYS) + m_valid &= inputPackets[i]->valid(); +#endif + + m_active |= (size_t)movemask(m_valid) << (i*K); + + vfloat packet_min_dist = max(tnear, 0.0f); + vfloat packet_max_dist = select(m_valid, tfar, neg_inf); + tmp_min_dist = min(tmp_min_dist, packet_min_dist); + tmp_max_dist = max(tmp_max_dist, packet_max_dist); + + const Vec3vf& org = inputPackets[i]->org; + const Vec3vf& dir = inputPackets[i]->dir; + + new (&packets[i]) TravRayKStream(org, dir, packet_min_dist, packet_max_dist); + + tmp_min_rdir = min(tmp_min_rdir, select(m_valid, packets[i].rdir, Vec3vf(pos_inf))); + tmp_max_rdir = max(tmp_max_rdir, select(m_valid, packets[i].rdir, Vec3vf(neg_inf))); + tmp_min_org = min(tmp_min_org , select(m_valid,org , Vec3vf(pos_inf))); + tmp_max_org = max(tmp_max_org , select(m_valid,org , Vec3vf(neg_inf))); + } + + m_active &= (numOctantRays == (8 * sizeof(size_t))) ? (size_t)-1 : (((size_t)1 << numOctantRays)-1); + + + const Vec3fa reduced_min_rdir(reduce_min(tmp_min_rdir.x), + reduce_min(tmp_min_rdir.y), + reduce_min(tmp_min_rdir.z)); + + const Vec3fa reduced_max_rdir(reduce_max(tmp_max_rdir.x), + reduce_max(tmp_max_rdir.y), + reduce_max(tmp_max_rdir.z)); + + const Vec3fa reduced_min_origin(reduce_min(tmp_min_org.x), + reduce_min(tmp_min_org.y), + reduce_min(tmp_min_org.z)); + + const Vec3fa reduced_max_origin(reduce_max(tmp_max_org.x), + reduce_max(tmp_max_org.y), + reduce_max(tmp_max_org.z)); + + commonOctant = + (reduced_max_rdir.x < 0.0f || reduced_min_rdir.x >= 0.0f) && + (reduced_max_rdir.y < 0.0f || reduced_min_rdir.y >= 0.0f) && + (reduced_max_rdir.z < 0.0f || reduced_min_rdir.z >= 0.0f); + + const float frustum_min_dist = reduce_min(tmp_min_dist); + const float frustum_max_dist = reduce_max(tmp_max_dist); + + frustum.init(reduced_min_origin, reduced_max_origin, + reduced_min_rdir, reduced_max_rdir, + frustum_min_dist, frustum_max_dist, + N); + + return m_active; + } + + template + __forceinline static size_t intersectAABBNodePacket(size_t m_active, + const TravRayKStream* packets, + const AABBNode* __restrict__ node, + size_t boxID, + const NearFarPrecalculations& nf) + { + assert(m_active); + const size_t startPacketID = bsf(m_active) / K; + const size_t endPacketID = bsr(m_active) / K; + size_t m_trav_active = 0; + for (size_t i = startPacketID; i <= endPacketID; i++) + { + const size_t m_hit = intersectNodeK(node, boxID, packets[i], nf); + m_trav_active |= m_hit << (i*K); + } + return m_trav_active; + } + + template + __forceinline static size_t traverseCoherentStream(size_t m_active, + TravRayKStream* packets, + const AABBNode* __restrict__ node, + const Frustum& frustum, + size_t* maskK, + vfloat& dist) + { + size_t m_node_hit = intersectNodeFrustum(node, frustum, dist); + const size_t first_index = bsf(m_active); + const size_t first_packetID = first_index / K; + const size_t first_rayID = first_index % K; + size_t m_first_hit = intersectNode1(node, packets[first_packetID], first_rayID, frustum.nf); + + /* this make traversal independent of the ordering of rays */ + size_t m_node = m_node_hit ^ m_first_hit; + while (unlikely(m_node)) + { + const size_t boxID = bscf(m_node); + const size_t m_current = m_active & intersectAABBNodePacket(m_active, packets, node, boxID, frustum.nf); + m_node_hit ^= m_current ? (size_t)0 : ((size_t)1 << boxID); + maskK[boxID] = m_current; + } + return m_node_hit; + } + + // TODO: explicit 16-wide path for KNL + template + __forceinline static vint traverseIncoherentStream(size_t m_active, + TravRayKStreamFast* __restrict__ packets, + const AABBNode* __restrict__ node, + const NearFarPrecalculations& nf, + const int shiftTable[32]) + { + const vfloat bminX = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.nearX)); + const vfloat bminY = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.nearY)); + const vfloat bminZ = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.nearZ)); + const vfloat bmaxX = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.farX)); + const vfloat bmaxY = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.farY)); + const vfloat bmaxZ = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.farZ)); + assert(m_active); + vint vmask(zero); + do + { + STAT3(shadow.trav_nodes,1,1,1); + const size_t rayID = bscf(m_active); + assert(rayID < MAX_INTERNAL_STREAM_SIZE); + TravRayKStream &p = packets[rayID / K]; + const size_t i = rayID % K; + const vint bitmask(shiftTable[rayID]); + const vfloat tNearX = msub(bminX, p.rdir.x[i], p.org_rdir.x[i]); + const vfloat tNearY = msub(bminY, p.rdir.y[i], p.org_rdir.y[i]); + const vfloat tNearZ = msub(bminZ, p.rdir.z[i], p.org_rdir.z[i]); + const vfloat tFarX = msub(bmaxX, p.rdir.x[i], p.org_rdir.x[i]); + const vfloat tFarY = msub(bmaxY, p.rdir.y[i], p.org_rdir.y[i]); + const vfloat tFarZ = msub(bmaxZ, p.rdir.z[i], p.org_rdir.z[i]); + const vfloat tNear = maxi(tNearX, tNearY, tNearZ, vfloat(p.tnear[i])); + const vfloat tFar = mini(tFarX , tFarY , tFarZ, vfloat(p.tfar[i])); + +#if defined(__AVX512ER__) + const vboolx m_node((1 << N)-1); + const vbool hit_mask = le(m_node, tNear, tFar); + vmask = mask_or(hit_mask, vmask, vmask, bitmask); +#else + const vbool hit_mask = tNear <= tFar; +#if defined(__AVX2__) + vmask = vmask | (bitmask & vint(hit_mask)); +#else + vmask = select(hit_mask, vmask | bitmask, vmask); +#endif +#endif + } while(m_active); + return vmask; + } + + template + __forceinline static vint traverseIncoherentStream(size_t m_active, + TravRayKStreamRobust* __restrict__ packets, + const AABBNode* __restrict__ node, + const NearFarPrecalculations& nf, + const int shiftTable[32]) + { + const vfloat bminX = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.nearX)); + const vfloat bminY = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.nearY)); + const vfloat bminZ = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.nearZ)); + const vfloat bmaxX = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.farX)); + const vfloat bmaxY = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.farY)); + const vfloat bmaxZ = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.farZ)); + assert(m_active); + vint vmask(zero); + do + { + STAT3(shadow.trav_nodes,1,1,1); + const size_t rayID = bscf(m_active); + assert(rayID < MAX_INTERNAL_STREAM_SIZE); + TravRayKStream &p = packets[rayID / K]; + const size_t i = rayID % K; + const vint bitmask(shiftTable[rayID]); + const vfloat tNearX = (bminX - p.org.x[i]) * p.rdir.x[i]; + const vfloat tNearY = (bminY - p.org.y[i]) * p.rdir.y[i]; + const vfloat tNearZ = (bminZ - p.org.z[i]) * p.rdir.z[i]; + const vfloat tFarX = (bmaxX - p.org.x[i]) * p.rdir.x[i]; + const vfloat tFarY = (bmaxY - p.org.y[i]) * p.rdir.y[i]; + const vfloat tFarZ = (bmaxZ - p.org.z[i]) * p.rdir.z[i]; + const vfloat tNear = maxi(tNearX, tNearY, tNearZ, vfloat(p.tnear[i])); + const vfloat tFar = mini(tFarX , tFarY , tFarZ, vfloat(p.tfar[i])); + const float round_down = 1.0f-2.0f*float(ulp); + const float round_up = 1.0f+2.0f*float(ulp); +#if defined(__AVX512ER__) + const vboolx m_node((1 << N)-1); + const vbool hit_mask = le(m_node, round_down*tNear, round_up*tFar); + vmask = mask_or(hit_mask, vmask, vmask, bitmask); +#else + const vbool hit_mask = round_down*tNear <= round_up*tFar; +#if defined(__AVX2__) + vmask = vmask | (bitmask & vint(hit_mask)); +#else + vmask = select(hit_mask, vmask | bitmask, vmask); +#endif +#endif + } while(m_active); + return vmask; + } + + + static const size_t stackSizeSingle = 1+(N-1)*BVH::maxDepth; + + public: + static void intersect(Accel::Intersectors* This, RayHitN** inputRays, size_t numRays, IntersectContext* context); + static void occluded (Accel::Intersectors* This, RayN** inputRays, size_t numRays, IntersectContext* context); + + private: + template + static void intersectCoherent(Accel::Intersectors* This, RayHitK** inputRays, size_t numRays, IntersectContext* context); + + template + static void occludedCoherent(Accel::Intersectors* This, RayK** inputRays, size_t numRays, IntersectContext* context); + + template + static void occludedIncoherent(Accel::Intersectors* This, RayK** inputRays, size_t numRays, IntersectContext* context); + }; + + + /*! BVH ray stream intersector with direct fallback to packets. */ + template + class BVHNIntersectorStreamPacketFallback + { + public: + static void intersect(Accel::Intersectors* This, RayHitN** inputRays, size_t numRays, IntersectContext* context); + static void occluded (Accel::Intersectors* This, RayN** inputRays, size_t numRays, IntersectContext* context); + + private: + template + static void intersectK(Accel::Intersectors* This, RayHitK** inputRays, size_t numRays, IntersectContext* context); + + template + static void occludedK(Accel::Intersectors* This, RayK** inputRays, size_t numRays, IntersectContext* context); + }; + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_intersector_stream_filters.h b/thirdparty/embree/kernels/bvh/bvh_intersector_stream_filters.h new file mode 100644 index 000000000000..cdeb9236370e --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_intersector_stream_filters.h @@ -0,0 +1,41 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/default.h" +#include "../common/ray.h" +#include "../common/scene.h" + +namespace embree +{ + namespace isa + { + class RayStreamFilter + { + public: + static void intersectAOS(Scene* scene, RTCRayHit* rays, size_t N, size_t stride, IntersectContext* context); + static void intersectAOP(Scene* scene, RTCRayHit** rays, size_t N, IntersectContext* context); + static void intersectSOA(Scene* scene, char* rays, size_t N, size_t numPackets, size_t stride, IntersectContext* context); + static void intersectSOP(Scene* scene, const RTCRayHitNp* rays, size_t N, IntersectContext* context); + + static void occludedAOS(Scene* scene, RTCRay* rays, size_t N, size_t stride, IntersectContext* context); + static void occludedAOP(Scene* scene, RTCRay** rays, size_t N, IntersectContext* context); + static void occludedSOA(Scene* scene, char* rays, size_t N, size_t numPackets, size_t stride, IntersectContext* context); + static void occludedSOP(Scene* scene, const RTCRayNp* rays, size_t N, IntersectContext* context); + + private: + template + static void filterAOS(Scene* scene, void* rays, size_t N, size_t stride, IntersectContext* context); + + template + static void filterAOP(Scene* scene, void** rays, size_t N, IntersectContext* context); + + template + static void filterSOA(Scene* scene, char* rays, size_t N, size_t numPackets, size_t stride, IntersectContext* context); + + template + static void filterSOP(Scene* scene, const void* rays, size_t N, IntersectContext* context); + }; + } +}; diff --git a/thirdparty/embree/kernels/bvh/bvh_node_aabb.h b/thirdparty/embree/kernels/bvh/bvh_node_aabb.h new file mode 100644 index 000000000000..baa4a8d80539 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_node_aabb.h @@ -0,0 +1,213 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh_node_base.h" + +namespace embree +{ + /*! BVHN AABBNode */ + template + struct AABBNode_t : public BaseNode_t + { + using BaseNode_t::children; + + struct Create + { + __forceinline NodeRef operator() (const FastAllocator::CachedAllocator& alloc, size_t numChildren = 0) const + { + AABBNode_t* node = (AABBNode_t*) alloc.malloc0(sizeof(AABBNode_t),NodeRef::byteNodeAlignment); node->clear(); + return NodeRef::encodeNode(node); + } + }; + + struct Set + { + __forceinline void operator() (NodeRef node, size_t i, NodeRef child, const BBox3fa& bounds) const { + node.getAABBNode()->setRef(i,child); + node.getAABBNode()->setBounds(i,bounds); + } + }; + + struct Create2 + { + template + __forceinline NodeRef operator() (BuildRecord* children, const size_t num, const FastAllocator::CachedAllocator& alloc) const + { + AABBNode_t* node = (AABBNode_t*) alloc.malloc0(sizeof(AABBNode_t), NodeRef::byteNodeAlignment); node->clear(); + for (size_t i=0; isetBounds(i,children[i].bounds()); + return NodeRef::encodeNode(node); + } + }; + + struct Set2 + { + template + __forceinline NodeRef operator() (const BuildRecord& precord, const BuildRecord* crecords, NodeRef ref, NodeRef* children, const size_t num) const + { + AABBNode_t* node = ref.getAABBNode(); + for (size_t i=0; isetRef(i,children[i]); + return ref; + } + }; + + struct Set3 + { + Set3 (FastAllocator* allocator, PrimRef* prims) + : allocator(allocator), prims(prims) {} + + template + __forceinline NodeRef operator() (const BuildRecord& precord, const BuildRecord* crecords, NodeRef ref, NodeRef* children, const size_t num) const + { + AABBNode_t* node = ref.getAABBNode(); + for (size_t i=0; isetRef(i,children[i]); + + if (unlikely(precord.alloc_barrier)) + { + PrimRef* begin = &prims[precord.prims.begin()]; + PrimRef* end = &prims[precord.prims.end()]; // FIXME: extended end for spatial split builder!!!!! + size_t bytes = (size_t)end - (size_t)begin; + allocator->addBlock(begin,bytes); + } + + return ref; + } + + FastAllocator* const allocator; + PrimRef* const prims; + }; + + /*! Clears the node. */ + __forceinline void clear() { + lower_x = lower_y = lower_z = pos_inf; + upper_x = upper_y = upper_z = neg_inf; + BaseNode_t::clear(); + } + + /*! Sets bounding box and ID of child. */ + __forceinline void setRef(size_t i, const NodeRef& ref) { + assert(i < N); + children[i] = ref; + } + + /*! Sets bounding box of child. */ + __forceinline void setBounds(size_t i, const BBox3fa& bounds) + { + assert(i < N); + lower_x[i] = bounds.lower.x; lower_y[i] = bounds.lower.y; lower_z[i] = bounds.lower.z; + upper_x[i] = bounds.upper.x; upper_y[i] = bounds.upper.y; upper_z[i] = bounds.upper.z; + } + + /*! Sets bounding box and ID of child. */ + __forceinline void set(size_t i, const NodeRef& ref, const BBox3fa& bounds) { + setBounds(i,bounds); + children[i] = ref; + } + + /*! Returns bounds of node. */ + __forceinline BBox3fa bounds() const { + const Vec3fa lower(reduce_min(lower_x),reduce_min(lower_y),reduce_min(lower_z)); + const Vec3fa upper(reduce_max(upper_x),reduce_max(upper_y),reduce_max(upper_z)); + return BBox3fa(lower,upper); + } + + /*! Returns bounds of specified child. */ + __forceinline BBox3fa bounds(size_t i) const + { + assert(i < N); + const Vec3fa lower(lower_x[i],lower_y[i],lower_z[i]); + const Vec3fa upper(upper_x[i],upper_y[i],upper_z[i]); + return BBox3fa(lower,upper); + } + + /*! Returns extent of bounds of specified child. */ + __forceinline Vec3fa extend(size_t i) const { + return bounds(i).size(); + } + + /*! Returns bounds of all children (implemented later as specializations) */ + __forceinline void bounds(BBox& bounds0, BBox& bounds1, BBox& bounds2, BBox& bounds3) const; + + /*! swap two children of the node */ + __forceinline void swap(size_t i, size_t j) + { + assert(ichildren[i],b->children[j]); + std::swap(a->lower_x[i],b->lower_x[j]); + std::swap(a->lower_y[i],b->lower_y[j]); + std::swap(a->lower_z[i],b->lower_z[j]); + std::swap(a->upper_x[i],b->upper_x[j]); + std::swap(a->upper_y[i],b->upper_y[j]); + std::swap(a->upper_z[i],b->upper_z[j]); + } + + /*! compacts a node (moves empty children to the end) */ + __forceinline static void compact(AABBNode_t* a) + { + /* find right most filled node */ + ssize_t j=N; + for (j=j-1; j>=0; j--) + if (a->child(j) != NodeRef::emptyNode) + break; + + /* replace empty nodes with filled nodes */ + for (ssize_t i=0; ichild(i) == NodeRef::emptyNode) { + a->swap(i,j); + for (j=j-1; j>i; j--) + if (a->child(j) != NodeRef::emptyNode) + break; + } + } + } + + /*! Returns reference to specified child */ + __forceinline NodeRef& child(size_t i) { assert(i lower_x; //!< X dimension of lower bounds of all N children. + vfloat upper_x; //!< X dimension of upper bounds of all N children. + vfloat lower_y; //!< Y dimension of lower bounds of all N children. + vfloat upper_y; //!< Y dimension of upper bounds of all N children. + vfloat lower_z; //!< Z dimension of lower bounds of all N children. + vfloat upper_z; //!< Z dimension of upper bounds of all N children. + }; + + template<> + __forceinline void AABBNode_t,4>::bounds(BBox& bounds0, BBox& bounds1, BBox& bounds2, BBox& bounds3) const { + transpose(lower_x,lower_y,lower_z,vfloat4(zero),bounds0.lower,bounds1.lower,bounds2.lower,bounds3.lower); + transpose(upper_x,upper_y,upper_z,vfloat4(zero),bounds0.upper,bounds1.upper,bounds2.upper,bounds3.upper); + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_node_aabb_mb.h b/thirdparty/embree/kernels/bvh/bvh_node_aabb_mb.h new file mode 100644 index 000000000000..501f4bce5bde --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_node_aabb_mb.h @@ -0,0 +1,247 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh_node_base.h" + +namespace embree +{ + /*! Motion Blur AABBNode */ + template + struct AABBNodeMB_t : public BaseNode_t + { + using BaseNode_t::children; + typedef BVHNodeRecord NodeRecord; + typedef BVHNodeRecordMB NodeRecordMB; + typedef BVHNodeRecordMB4D NodeRecordMB4D; + + struct Create + { + template + __forceinline NodeRef operator() (BuildRecord* children, const size_t num, const FastAllocator::CachedAllocator& alloc) const + { + AABBNodeMB_t* node = (AABBNodeMB_t*) alloc.malloc0(sizeof(AABBNodeMB_t),NodeRef::byteNodeAlignment); node->clear(); + return NodeRef::encodeNode(node); + } + }; + + struct Set + { + template + __forceinline NodeRecordMB operator() (const BuildRecord& precord, const BuildRecord* crecords, NodeRef ref, NodeRecordMB* children, const size_t num) const + { + AABBNodeMB_t* node = ref.getAABBNodeMB(); + + LBBox3fa bounds = empty; + for (size_t i=0; isetRef(i,children[i].ref); + node->setBounds(i,children[i].lbounds); + bounds.extend(children[i].lbounds); + } + return NodeRecordMB(ref,bounds); + } + }; + + struct SetTimeRange + { + __forceinline SetTimeRange(BBox1f tbounds) : tbounds(tbounds) {} + + template + __forceinline NodeRecordMB operator() (const BuildRecord& precord, const BuildRecord* crecords, NodeRef ref, NodeRecordMB* children, const size_t num) const + { + AABBNodeMB_t* node = ref.getAABBNodeMB(); + + LBBox3fa bounds = empty; + for (size_t i=0; isetRef(i, children[i].ref); + node->setBounds(i, children[i].lbounds, tbounds); + bounds.extend(children[i].lbounds); + } + return NodeRecordMB(ref,bounds); + } + + BBox1f tbounds; + }; + + /*! Clears the node. */ + __forceinline void clear() { + lower_x = lower_y = lower_z = vfloat(pos_inf); + upper_x = upper_y = upper_z = vfloat(neg_inf); + lower_dx = lower_dy = lower_dz = vfloat(0.0f); + upper_dx = upper_dy = upper_dz = vfloat(0.0f); + BaseNode_t::clear(); + } + + /*! Sets ID of child. */ + __forceinline void setRef(size_t i, NodeRef ref) { + children[i] = ref; + } + + /*! Sets bounding box of child. */ + __forceinline void setBounds(size_t i, const BBox3fa& bounds0_i, const BBox3fa& bounds1_i) + { + /*! for empty bounds we have to avoid inf-inf=nan */ + BBox3fa bounds0(min(bounds0_i.lower,Vec3fa(+FLT_MAX)),max(bounds0_i.upper,Vec3fa(-FLT_MAX))); + BBox3fa bounds1(min(bounds1_i.lower,Vec3fa(+FLT_MAX)),max(bounds1_i.upper,Vec3fa(-FLT_MAX))); + bounds0 = bounds0.enlarge_by(4.0f*float(ulp)); + bounds1 = bounds1.enlarge_by(4.0f*float(ulp)); + Vec3fa dlower = bounds1.lower-bounds0.lower; + Vec3fa dupper = bounds1.upper-bounds0.upper; + + lower_x[i] = bounds0.lower.x; lower_y[i] = bounds0.lower.y; lower_z[i] = bounds0.lower.z; + upper_x[i] = bounds0.upper.x; upper_y[i] = bounds0.upper.y; upper_z[i] = bounds0.upper.z; + + lower_dx[i] = dlower.x; lower_dy[i] = dlower.y; lower_dz[i] = dlower.z; + upper_dx[i] = dupper.x; upper_dy[i] = dupper.y; upper_dz[i] = dupper.z; + } + + /*! Sets bounding box of child. */ + __forceinline void setBounds(size_t i, const LBBox3fa& bounds) { + setBounds(i, bounds.bounds0, bounds.bounds1); + } + + /*! Sets bounding box of child. */ + __forceinline void setBounds(size_t i, const LBBox3fa& bounds, const BBox1f& tbounds) { + setBounds(i, bounds.global(tbounds)); + } + + /*! Sets bounding box and ID of child. */ + __forceinline void set(size_t i, NodeRef ref, const BBox3fa& bounds) { + lower_x[i] = bounds.lower.x; lower_y[i] = bounds.lower.y; lower_z[i] = bounds.lower.z; + upper_x[i] = bounds.upper.x; upper_y[i] = bounds.upper.y; upper_z[i] = bounds.upper.z; + children[i] = ref; + } + + /*! Sets bounding box and ID of child. */ + __forceinline void set(size_t i, const NodeRecordMB4D& child) + { + setRef(i, child.ref); + setBounds(i, child.lbounds, child.dt); + } + + /*! Return bounding box for time 0 */ + __forceinline BBox3fa bounds0(size_t i) const { + return BBox3fa(Vec3fa(lower_x[i],lower_y[i],lower_z[i]), + Vec3fa(upper_x[i],upper_y[i],upper_z[i])); + } + + /*! Return bounding box for time 1 */ + __forceinline BBox3fa bounds1(size_t i) const { + return BBox3fa(Vec3fa(lower_x[i]+lower_dx[i],lower_y[i]+lower_dy[i],lower_z[i]+lower_dz[i]), + Vec3fa(upper_x[i]+upper_dx[i],upper_y[i]+upper_dy[i],upper_z[i]+upper_dz[i])); + } + + /*! Returns bounds of node. */ + __forceinline BBox3fa bounds() const { + return BBox3fa(Vec3fa(reduce_min(min(lower_x,lower_x+lower_dx)), + reduce_min(min(lower_y,lower_y+lower_dy)), + reduce_min(min(lower_z,lower_z+lower_dz))), + Vec3fa(reduce_max(max(upper_x,upper_x+upper_dx)), + reduce_max(max(upper_y,upper_y+upper_dy)), + reduce_max(max(upper_z,upper_z+upper_dz)))); + } + + /*! Return bounding box of child i */ + __forceinline BBox3fa bounds(size_t i) const { + return merge(bounds0(i),bounds1(i)); + } + + /*! Return linear bounding box of child i */ + __forceinline LBBox3fa lbounds(size_t i) const { + return LBBox3fa(bounds0(i),bounds1(i)); + } + + /*! Return bounding box of child i at specified time */ + __forceinline BBox3fa bounds(size_t i, float time) const { + return lerp(bounds0(i),bounds1(i),time); + } + + /*! Returns the expected surface area when randomly sampling the time. */ + __forceinline float expectedHalfArea(size_t i) const { + return lbounds(i).expectedHalfArea(); + } + + /*! Returns the expected surface area when randomly sampling the time. */ + __forceinline float expectedHalfArea(size_t i, const BBox1f& t0t1) const { + return lbounds(i).expectedHalfArea(t0t1); + } + + /*! swap two children of the node */ + __forceinline void swap(size_t i, size_t j) + { + assert(i=0; j--) + if (a->child(j) != NodeRef::emptyNode) + break; + + /* replace empty nodes with filled nodes */ + for (ssize_t i=0; ichild(i) == NodeRef::emptyNode) { + a->swap(i,j); + for (j=j-1; j>i; j--) + if (a->child(j) != NodeRef::emptyNode) + break; + } + } + } + + /*! Returns reference to specified child */ + __forceinline NodeRef& child(size_t i) { assert(i lower_x; //!< X dimension of lower bounds of all N children. + vfloat upper_x; //!< X dimension of upper bounds of all N children. + vfloat lower_y; //!< Y dimension of lower bounds of all N children. + vfloat upper_y; //!< Y dimension of upper bounds of all N children. + vfloat lower_z; //!< Z dimension of lower bounds of all N children. + vfloat upper_z; //!< Z dimension of upper bounds of all N children. + + vfloat lower_dx; //!< X dimension of lower bounds of all N children. + vfloat upper_dx; //!< X dimension of upper bounds of all N children. + vfloat lower_dy; //!< Y dimension of lower bounds of all N children. + vfloat upper_dy; //!< Y dimension of upper bounds of all N children. + vfloat lower_dz; //!< Z dimension of lower bounds of all N children. + vfloat upper_dz; //!< Z dimension of upper bounds of all N children. + }; +} diff --git a/thirdparty/embree/kernels/bvh/bvh_node_aabb_mb4d.h b/thirdparty/embree/kernels/bvh/bvh_node_aabb_mb4d.h new file mode 100644 index 000000000000..e968bbbc399f --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_node_aabb_mb4d.h @@ -0,0 +1,107 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh_node_aabb_mb.h" + +namespace embree +{ + /*! Aligned 4D Motion Blur Node */ + template + struct AABBNodeMB4D_t : public AABBNodeMB_t + { + using BaseNode_t::children; + using AABBNodeMB_t::set; + + typedef BVHNodeRecord NodeRecord; + typedef BVHNodeRecordMB NodeRecordMB; + typedef BVHNodeRecordMB4D NodeRecordMB4D; + + struct Create + { + template + __forceinline NodeRef operator() (BuildRecord*, const size_t, const FastAllocator::CachedAllocator& alloc, bool hasTimeSplits = true) const + { + if (hasTimeSplits) + { + AABBNodeMB4D_t* node = (AABBNodeMB4D_t*) alloc.malloc0(sizeof(AABBNodeMB4D_t),NodeRef::byteNodeAlignment); node->clear(); + return NodeRef::encodeNode(node); + } + else + { + AABBNodeMB_t* node = (AABBNodeMB_t*) alloc.malloc0(sizeof(AABBNodeMB_t),NodeRef::byteNodeAlignment); node->clear(); + return NodeRef::encodeNode(node); + } + } + }; + + struct Set + { + template + __forceinline void operator() (const BuildRecord&, const BuildRecord*, NodeRef ref, NodeRecordMB4D* children, const size_t num) const + { + if (likely(ref.isAABBNodeMB())) { + for (size_t i=0; iset(i, children[i]); + } else { + for (size_t i=0; iset(i, children[i]); + } + } + }; + + /*! Clears the node. */ + __forceinline void clear() { + lower_t = vfloat(pos_inf); + upper_t = vfloat(neg_inf); + AABBNodeMB_t::clear(); + } + + /*! Sets bounding box of child. */ + __forceinline void setBounds(size_t i, const LBBox3fa& bounds, const BBox1f& tbounds) + { + AABBNodeMB_t::setBounds(i, bounds.global(tbounds)); + lower_t[i] = tbounds.lower; + upper_t[i] = tbounds.upper == 1.0f ? 1.0f+float(ulp) : tbounds.upper; + } + + /*! Sets bounding box and ID of child. */ + __forceinline void set(size_t i, const NodeRecordMB4D& child) { + AABBNodeMB_t::setRef(i,child.ref); + setBounds(i, child.lbounds, child.dt); + } + + /*! Returns the expected surface area when randomly sampling the time. */ + __forceinline float expectedHalfArea(size_t i) const { + return AABBNodeMB_t::lbounds(i).expectedHalfArea(timeRange(i)); + } + + /*! returns time range for specified child */ + __forceinline BBox1f timeRange(size_t i) const { + return BBox1f(lower_t[i],upper_t[i]); + } + + /*! stream output operator */ + friend embree_ostream operator<<(embree_ostream cout, const AABBNodeMB4D_t& n) + { + cout << "AABBNodeMB4D {" << embree_endl; + for (size_t i=0; i lower_t; //!< time dimension of lower bounds of all N children + vfloat upper_t; //!< time dimension of upper bounds of all N children + }; +} diff --git a/thirdparty/embree/kernels/bvh/bvh_node_base.h b/thirdparty/embree/kernels/bvh/bvh_node_base.h new file mode 100644 index 000000000000..8268f3b93289 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_node_base.h @@ -0,0 +1,43 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh_node_ref.h" + +namespace embree +{ + + /*! BVHN Base Node */ + template + struct BaseNode_t + { + /*! Clears the node. */ + __forceinline void clear() + { + for (size_t i=0; i + struct OBBNode_t : public BaseNode_t + { + using BaseNode_t::children; + + struct Create + { + __forceinline NodeRef operator() (const FastAllocator::CachedAllocator& alloc) const + { + OBBNode_t* node = (OBBNode_t*) alloc.malloc0(sizeof(OBBNode_t),NodeRef::byteNodeAlignment); node->clear(); + return NodeRef::encodeNode(node); + } + }; + + struct Set + { + __forceinline void operator() (NodeRef node, size_t i, NodeRef child, const OBBox3fa& bounds) const { + node.ungetAABBNode()->setRef(i,child); + node.ungetAABBNode()->setBounds(i,bounds); + } + }; + + /*! Clears the node. */ + __forceinline void clear() + { + naabb.l.vx = Vec3fa(nan); + naabb.l.vy = Vec3fa(nan); + naabb.l.vz = Vec3fa(nan); + naabb.p = Vec3fa(nan); + BaseNode_t::clear(); + } + + /*! Sets bounding box. */ + __forceinline void setBounds(size_t i, const OBBox3fa& b) + { + assert(i < N); + + AffineSpace3fa space = b.space; + space.p -= b.bounds.lower; + space = AffineSpace3fa::scale(1.0f/max(Vec3fa(1E-19f),b.bounds.upper-b.bounds.lower))*space; + + naabb.l.vx.x[i] = space.l.vx.x; + naabb.l.vx.y[i] = space.l.vx.y; + naabb.l.vx.z[i] = space.l.vx.z; + + naabb.l.vy.x[i] = space.l.vy.x; + naabb.l.vy.y[i] = space.l.vy.y; + naabb.l.vy.z[i] = space.l.vy.z; + + naabb.l.vz.x[i] = space.l.vz.x; + naabb.l.vz.y[i] = space.l.vz.y; + naabb.l.vz.z[i] = space.l.vz.z; + + naabb.p.x[i] = space.p.x; + naabb.p.y[i] = space.p.y; + naabb.p.z[i] = space.p.z; + } + + /*! Sets ID of child. */ + __forceinline void setRef(size_t i, const NodeRef& ref) { + assert(i < N); + children[i] = ref; + } + + /*! Returns the extent of the bounds of the ith child */ + __forceinline Vec3fa extent(size_t i) const { + assert(i naabb; //!< non-axis aligned bounding boxes (bounds are [0,1] in specified space) + }; +} diff --git a/thirdparty/embree/kernels/bvh/bvh_node_obb_mb.h b/thirdparty/embree/kernels/bvh/bvh_node_obb_mb.h new file mode 100644 index 000000000000..834cf5ec28b6 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_node_obb_mb.h @@ -0,0 +1,90 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh_node_base.h" + +namespace embree +{ + template + struct OBBNodeMB_t : public BaseNode_t + { + using BaseNode_t::children; + + struct Create + { + __forceinline NodeRef operator() (const FastAllocator::CachedAllocator& alloc) const + { + OBBNodeMB_t* node = (OBBNodeMB_t*) alloc.malloc0(sizeof(OBBNodeMB_t),NodeRef::byteNodeAlignment); node->clear(); + return NodeRef::encodeNode(node); + } + }; + + struct Set + { + __forceinline void operator() (NodeRef node, size_t i, NodeRef child, const LinearSpace3fa& space, const LBBox3fa& lbounds, const BBox1f dt) const { + node.ungetAABBNodeMB()->setRef(i,child); + node.ungetAABBNodeMB()->setBounds(i,space,lbounds.global(dt)); + } + }; + + /*! Clears the node. */ + __forceinline void clear() + { + space0 = one; + //b0.lower = b0.upper = Vec3fa(nan); + b1.lower = b1.upper = Vec3fa(nan); + BaseNode_t::clear(); + } + + /*! Sets space and bounding boxes. */ + __forceinline void setBounds(size_t i, const AffineSpace3fa& space, const LBBox3fa& lbounds) { + setBounds(i,space,lbounds.bounds0,lbounds.bounds1); + } + + /*! Sets space and bounding boxes. */ + __forceinline void setBounds(size_t i, const AffineSpace3fa& s0, const BBox3fa& a, const BBox3fa& c) + { + assert(i < N); + + AffineSpace3fa space = s0; + space.p -= a.lower; + Vec3fa scale = 1.0f/max(Vec3fa(1E-19f),a.upper-a.lower); + space = AffineSpace3fa::scale(scale)*space; + BBox3fa a1((a.lower-a.lower)*scale,(a.upper-a.lower)*scale); + BBox3fa c1((c.lower-a.lower)*scale,(c.upper-a.lower)*scale); + + space0.l.vx.x[i] = space.l.vx.x; space0.l.vx.y[i] = space.l.vx.y; space0.l.vx.z[i] = space.l.vx.z; + space0.l.vy.x[i] = space.l.vy.x; space0.l.vy.y[i] = space.l.vy.y; space0.l.vy.z[i] = space.l.vy.z; + space0.l.vz.x[i] = space.l.vz.x; space0.l.vz.y[i] = space.l.vz.y; space0.l.vz.z[i] = space.l.vz.z; + space0.p .x[i] = space.p .x; space0.p .y[i] = space.p .y; space0.p .z[i] = space.p .z; + + /*b0.lower.x[i] = a1.lower.x; b0.lower.y[i] = a1.lower.y; b0.lower.z[i] = a1.lower.z; + b0.upper.x[i] = a1.upper.x; b0.upper.y[i] = a1.upper.y; b0.upper.z[i] = a1.upper.z;*/ + + b1.lower.x[i] = c1.lower.x; b1.lower.y[i] = c1.lower.y; b1.lower.z[i] = c1.lower.z; + b1.upper.x[i] = c1.upper.x; b1.upper.y[i] = c1.upper.y; b1.upper.z[i] = c1.upper.z; + } + + /*! Sets ID of child. */ + __forceinline void setRef(size_t i, const NodeRef& ref) { + assert(i < N); + children[i] = ref; + } + + /*! Returns the extent of the bounds of the ith child */ + __forceinline Vec3fa extent0(size_t i) const { + assert(i < N); + const Vec3fa vx(space0.l.vx.x[i],space0.l.vx.y[i],space0.l.vx.z[i]); + const Vec3fa vy(space0.l.vy.x[i],space0.l.vy.y[i],space0.l.vy.z[i]); + const Vec3fa vz(space0.l.vz.x[i],space0.l.vz.y[i],space0.l.vz.z[i]); + return rsqrt(vx*vx + vy*vy + vz*vz); + } + + public: + AffineSpace3vf space0; + //BBox3vf b0; // these are the unit bounds + BBox3vf b1; + }; +} diff --git a/thirdparty/embree/kernels/bvh/bvh_node_qaabb.h b/thirdparty/embree/kernels/bvh/bvh_node_qaabb.h new file mode 100644 index 000000000000..5212821f3fe1 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_node_qaabb.h @@ -0,0 +1,265 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh_node_base.h" + +namespace embree +{ + /*! BVHN Quantized Node */ + template + struct __aligned(8) QuantizedBaseNode_t + { + typedef unsigned char T; + static const T MIN_QUAN = 0; + static const T MAX_QUAN = 255; + + /*! Clears the node. */ + __forceinline void clear() { + for (size_t i=0; i &lower, + const vfloat &upper, + T lower_quant[N], + T upper_quant[N], + float &start, + float &scale) + { + /* quantize bounds */ + const vbool m_valid = lower != vfloat(pos_inf); + const float minF = reduce_min(lower); + const float maxF = reduce_max(upper); + float diff = (1.0f+2.0f*float(ulp))*(maxF - minF); + float decode_scale = diff / float(MAX_QUAN); + if (decode_scale == 0.0f) decode_scale = 2.0f*FLT_MIN; // result may have been flushed to zero + assert(madd(decode_scale,float(MAX_QUAN),minF) >= maxF); + const float encode_scale = diff > 0 ? (float(MAX_QUAN) / diff) : 0.0f; + vint ilower = max(vint(floor((lower - vfloat(minF))*vfloat(encode_scale))),MIN_QUAN); + vint iupper = min(vint(ceil ((upper - vfloat(minF))*vfloat(encode_scale))),MAX_QUAN); + + /* lower/upper correction */ + vbool m_lower_correction = (madd(vfloat(ilower),decode_scale,minF)) > lower; + vbool m_upper_correction = (madd(vfloat(iupper),decode_scale,minF)) < upper; + ilower = max(select(m_lower_correction,ilower-1,ilower),MIN_QUAN); + iupper = min(select(m_upper_correction,iupper+1,iupper),MAX_QUAN); + + /* disable invalid lanes */ + ilower = select(m_valid,ilower,MAX_QUAN); + iupper = select(m_valid,iupper,MIN_QUAN); + + /* store as uchar to memory */ + vint::store(lower_quant,ilower); + vint::store(upper_quant,iupper); + start = minF; + scale = decode_scale; + +#if defined(DEBUG) + vfloat extract_lower( vint::loadu(lower_quant) ); + vfloat extract_upper( vint::loadu(upper_quant) ); + vfloat final_extract_lower = madd(extract_lower,decode_scale,minF); + vfloat final_extract_upper = madd(extract_upper,decode_scale,minF); + assert( (movemask(final_extract_lower <= lower ) & movemask(m_valid)) == movemask(m_valid)); + assert( (movemask(final_extract_upper >= upper ) & movemask(m_valid)) == movemask(m_valid)); +#endif + } + + __forceinline void init_dim(AABBNode_t,N>& node) + { + init_dim(node.lower_x,node.upper_x,lower_x,upper_x,start.x,scale.x); + init_dim(node.lower_y,node.upper_y,lower_y,upper_y,start.y,scale.y); + init_dim(node.lower_z,node.upper_z,lower_z,upper_z,start.z,scale.z); + } + + __forceinline vbool validMask() const { return vint::loadu(lower_x) <= vint::loadu(upper_x); } + +#if defined(__AVX512F__) // KNL + __forceinline vbool16 validMask16() const { return le(0xff,vint<16>::loadu(lower_x),vint<16>::loadu(upper_x)); } +#endif + __forceinline vfloat dequantizeLowerX() const { return madd(vfloat(vint::loadu(lower_x)),scale.x,vfloat(start.x)); } + + __forceinline vfloat dequantizeUpperX() const { return madd(vfloat(vint::loadu(upper_x)),scale.x,vfloat(start.x)); } + + __forceinline vfloat dequantizeLowerY() const { return madd(vfloat(vint::loadu(lower_y)),scale.y,vfloat(start.y)); } + + __forceinline vfloat dequantizeUpperY() const { return madd(vfloat(vint::loadu(upper_y)),scale.y,vfloat(start.y)); } + + __forceinline vfloat dequantizeLowerZ() const { return madd(vfloat(vint::loadu(lower_z)),scale.z,vfloat(start.z)); } + + __forceinline vfloat dequantizeUpperZ() const { return madd(vfloat(vint::loadu(upper_z)),scale.z,vfloat(start.z)); } + + template + __forceinline vfloat dequantize(const size_t offset) const { return vfloat(vint::loadu(all_planes+offset)); } + +#if defined(__AVX512F__) + __forceinline vfloat16 dequantizeLowerUpperX(const vint16 &p) const { return madd(vfloat16(permute(vint<16>::loadu(lower_x),p)),scale.x,vfloat16(start.x)); } + __forceinline vfloat16 dequantizeLowerUpperY(const vint16 &p) const { return madd(vfloat16(permute(vint<16>::loadu(lower_y),p)),scale.y,vfloat16(start.y)); } + __forceinline vfloat16 dequantizeLowerUpperZ(const vint16 &p) const { return madd(vfloat16(permute(vint<16>::loadu(lower_z),p)),scale.z,vfloat16(start.z)); } +#endif + + union { + struct { + T lower_x[N]; //!< 8bit discretized X dimension of lower bounds of all N children + T upper_x[N]; //!< 8bit discretized X dimension of upper bounds of all N children + T lower_y[N]; //!< 8bit discretized Y dimension of lower bounds of all N children + T upper_y[N]; //!< 8bit discretized Y dimension of upper bounds of all N children + T lower_z[N]; //!< 8bit discretized Z dimension of lower bounds of all N children + T upper_z[N]; //!< 8bit discretized Z dimension of upper bounds of all N children + }; + T all_planes[6*N]; + }; + + Vec3f start; + Vec3f scale; + + friend embree_ostream operator<<(embree_ostream o, const QuantizedBaseNode_t& n) + { + o << "QuantizedBaseNode { " << embree_endl; + o << " start " << n.start << embree_endl; + o << " scale " << n.scale << embree_endl; + o << " lower_x " << vuint::loadu(n.lower_x) << embree_endl; + o << " upper_x " << vuint::loadu(n.upper_x) << embree_endl; + o << " lower_y " << vuint::loadu(n.lower_y) << embree_endl; + o << " upper_y " << vuint::loadu(n.upper_y) << embree_endl; + o << " lower_z " << vuint::loadu(n.lower_z) << embree_endl; + o << " upper_z " << vuint::loadu(n.upper_z) << embree_endl; + o << "}" << embree_endl; + return o; + } + + }; + + template + struct __aligned(8) QuantizedNode_t : public BaseNode_t, QuantizedBaseNode_t + { + using BaseNode_t::children; + using QuantizedBaseNode_t::lower_x; + using QuantizedBaseNode_t::upper_x; + using QuantizedBaseNode_t::lower_y; + using QuantizedBaseNode_t::upper_y; + using QuantizedBaseNode_t::lower_z; + using QuantizedBaseNode_t::upper_z; + using QuantizedBaseNode_t::start; + using QuantizedBaseNode_t::scale; + using QuantizedBaseNode_t::init_dim; + + __forceinline void setRef(size_t i, const NodeRef& ref) { + assert(i < N); + children[i] = ref; + } + + struct Create2 + { + template + __forceinline NodeRef operator() (BuildRecord* children, const size_t n, const FastAllocator::CachedAllocator& alloc) const + { + __aligned(64) AABBNode_t node; + node.clear(); + for (size_t i=0; iinit(node); + + return (size_t)qnode | NodeRef::tyQuantizedNode; + } + }; + + struct Set2 + { + template + __forceinline NodeRef operator() (const BuildRecord& precord, const BuildRecord* crecords, NodeRef ref, NodeRef* children, const size_t num) const + { + QuantizedNode_t* node = ref.quantizedNode(); + for (size_t i=0; isetRef(i,children[i]); + return ref; + } + }; + + __forceinline void init(AABBNode_t& node) + { + for (size_t i=0;i + struct __aligned(8) QuantizedBaseNodeMB_t + { + QuantizedBaseNode_t node0; + QuantizedBaseNode_t node1; + + /*! Clears the node. */ + __forceinline void clear() { + node0.clear(); + node1.clear(); + } + + /*! Returns bounds of specified child. */ + __forceinline BBox3fa bounds(size_t i) const + { + assert(i < N); + BBox3fa bounds0 = node0.bounds(i); + BBox3fa bounds1 = node1.bounds(i); + bounds0.extend(bounds1); + return bounds0; + } + + /*! Returns extent of bounds of specified child. */ + __forceinline Vec3fa extent(size_t i) const { + return bounds(i).size(); + } + + __forceinline vbool validMask() const { return node0.validMask(); } + + template + __forceinline vfloat dequantizeLowerX(const T t) const { return lerp(node0.dequantizeLowerX(),node1.dequantizeLowerX(),t); } + template + __forceinline vfloat dequantizeUpperX(const T t) const { return lerp(node0.dequantizeUpperX(),node1.dequantizeUpperX(),t); } + template + __forceinline vfloat dequantizeLowerY(const T t) const { return lerp(node0.dequantizeLowerY(),node1.dequantizeLowerY(),t); } + template + __forceinline vfloat dequantizeUpperY(const T t) const { return lerp(node0.dequantizeUpperY(),node1.dequantizeUpperY(),t); } + template + __forceinline vfloat dequantizeLowerZ(const T t) const { return lerp(node0.dequantizeLowerZ(),node1.dequantizeLowerZ(),t); } + template + __forceinline vfloat dequantizeUpperZ(const T t) const { return lerp(node0.dequantizeUpperZ(),node1.dequantizeUpperZ(),t); } + + + template + __forceinline vfloat dequantizeLowerX(const size_t i, const vfloat &t) const { return lerp(vfloat(node0.dequantizeLowerX()[i]),vfloat(node1.dequantizeLowerX()[i]),t); } + template + __forceinline vfloat dequantizeUpperX(const size_t i, const vfloat &t) const { return lerp(vfloat(node0.dequantizeUpperX()[i]),vfloat(node1.dequantizeUpperX()[i]),t); } + template + __forceinline vfloat dequantizeLowerY(const size_t i, const vfloat &t) const { return lerp(vfloat(node0.dequantizeLowerY()[i]),vfloat(node1.dequantizeLowerY()[i]),t); } + template + __forceinline vfloat dequantizeUpperY(const size_t i, const vfloat &t) const { return lerp(vfloat(node0.dequantizeUpperY()[i]),vfloat(node1.dequantizeUpperY()[i]),t); } + template + __forceinline vfloat dequantizeLowerZ(const size_t i, const vfloat &t) const { return lerp(vfloat(node0.dequantizeLowerZ()[i]),vfloat(node1.dequantizeLowerZ()[i]),t); } + template + __forceinline vfloat dequantizeUpperZ(const size_t i, const vfloat &t) const { return lerp(vfloat(node0.dequantizeUpperZ()[i]),vfloat(node1.dequantizeUpperZ()[i]),t); } + + }; +} diff --git a/thirdparty/embree/kernels/bvh/bvh_node_ref.h b/thirdparty/embree/kernels/bvh/bvh_node_ref.h new file mode 100644 index 000000000000..5efc9c72c7c0 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_node_ref.h @@ -0,0 +1,242 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/default.h" +#include "../common/alloc.h" +#include "../common/accel.h" +#include "../common/device.h" +#include "../common/scene.h" +#include "../geometry/primitive.h" +#include "../common/ray.h" + +namespace embree +{ + /* BVH node reference with bounds */ + template + struct BVHNodeRecord + { + __forceinline BVHNodeRecord() {} + __forceinline BVHNodeRecord(NodeRef ref, const BBox3fa& bounds) : ref(ref), bounds((BBox3fx)bounds) {} + __forceinline BVHNodeRecord(NodeRef ref, const BBox3fx& bounds) : ref(ref), bounds(bounds) {} + + NodeRef ref; + BBox3fx bounds; + }; + + template + struct BVHNodeRecordMB + { + __forceinline BVHNodeRecordMB() {} + __forceinline BVHNodeRecordMB(NodeRef ref, const LBBox3fa& lbounds) : ref(ref), lbounds(lbounds) {} + + NodeRef ref; + LBBox3fa lbounds; + }; + + template + struct BVHNodeRecordMB4D + { + __forceinline BVHNodeRecordMB4D() {} + __forceinline BVHNodeRecordMB4D(NodeRef ref, const LBBox3fa& lbounds, const BBox1f& dt) : ref(ref), lbounds(lbounds), dt(dt) {} + + NodeRef ref; + LBBox3fa lbounds; + BBox1f dt; + }; + + template struct BaseNode_t; + template struct AABBNode_t; + template struct AABBNodeMB_t; + template struct AABBNodeMB4D_t; + template struct OBBNode_t; + template struct OBBNodeMB_t; + template struct QuantizedNode_t; + template struct QuantizedNodeMB_t; + + /*! Pointer that points to a node or a list of primitives */ + template + struct NodeRefPtr + { + //template friend class BVHN; + + /*! Number of bytes the nodes and primitives are minimally aligned to.*/ + static const size_t byteAlignment = 16; + static const size_t byteNodeAlignment = 4*N; + + /*! highest address bit is used as barrier for some algorithms */ + static const size_t barrier_mask = (1LL << (8*sizeof(size_t)-1)); + + /*! Masks the bits that store the number of items per leaf. */ + static const size_t align_mask = byteAlignment-1; + static const size_t items_mask = byteAlignment-1; + + /*! different supported node types */ + static const size_t tyAABBNode = 0; + static const size_t tyAABBNodeMB = 1; + static const size_t tyAABBNodeMB4D = 6; + static const size_t tyOBBNode = 2; + static const size_t tyOBBNodeMB = 3; + static const size_t tyQuantizedNode = 5; + static const size_t tyLeaf = 8; + + /*! Empty node */ + static const size_t emptyNode = tyLeaf; + + /*! Invalid node, used as marker in traversal */ + static const size_t invalidNode = (((size_t)-1) & (~items_mask)) | (tyLeaf+0); + static const size_t popRay = (((size_t)-1) & (~items_mask)) | (tyLeaf+1); + + /*! Maximum number of primitive blocks in a leaf. */ + static const size_t maxLeafBlocks = items_mask-tyLeaf; + + /*! Default constructor */ + __forceinline NodeRefPtr () {} + + /*! Construction from integer */ + __forceinline NodeRefPtr (size_t ptr) : ptr(ptr) {} + + /*! Cast to size_t */ + __forceinline operator size_t() const { return ptr; } + + /*! Sets the barrier bit. */ + __forceinline void setBarrier() { +#if defined(__X86_64__) + assert(!isBarrier()); + ptr |= barrier_mask; +#else + assert(false); +#endif + } + + /*! Clears the barrier bit. */ + __forceinline void clearBarrier() { +#if defined(__X86_64__) + ptr &= ~barrier_mask; +#else + assert(false); +#endif + } + + /*! Checks if this is an barrier. A barrier tells the top level tree rotations how deep to enter the tree. */ + __forceinline bool isBarrier() const { return (ptr & barrier_mask) != 0; } + + /*! checks if this is a leaf */ + __forceinline size_t isLeaf() const { return ptr & tyLeaf; } + + /*! returns node type */ + __forceinline int type() const { return ptr & (size_t)align_mask; } + + /*! checks if this is a node */ + __forceinline int isAABBNode() const { return (ptr & (size_t)align_mask) == tyAABBNode; } + + /*! checks if this is a motion blur node */ + __forceinline int isAABBNodeMB() const { return (ptr & (size_t)align_mask) == tyAABBNodeMB; } + + /*! checks if this is a 4D motion blur node */ + __forceinline int isAABBNodeMB4D() const { return (ptr & (size_t)align_mask) == tyAABBNodeMB4D; } + + /*! checks if this is a node with unaligned bounding boxes */ + __forceinline int isOBBNode() const { return (ptr & (size_t)align_mask) == tyOBBNode; } + + /*! checks if this is a motion blur node with unaligned bounding boxes */ + __forceinline int isOBBNodeMB() const { return (ptr & (size_t)align_mask) == tyOBBNodeMB; } + + /*! checks if this is a quantized node */ + __forceinline int isQuantizedNode() const { return (ptr & (size_t)align_mask) == tyQuantizedNode; } + + /*! Encodes a node */ + static __forceinline NodeRefPtr encodeNode(AABBNode_t* node) { + assert(!((size_t)node & align_mask)); + return NodeRefPtr((size_t) node); + } + + static __forceinline NodeRefPtr encodeNode(AABBNodeMB_t* node) { + assert(!((size_t)node & align_mask)); + return NodeRefPtr((size_t) node | tyAABBNodeMB); + } + + static __forceinline NodeRefPtr encodeNode(AABBNodeMB4D_t* node) { + assert(!((size_t)node & align_mask)); + return NodeRefPtr((size_t) node | tyAABBNodeMB4D); + } + + /*! Encodes an unaligned node */ + static __forceinline NodeRefPtr encodeNode(OBBNode_t* node) { + return NodeRefPtr((size_t) node | tyOBBNode); + } + + /*! Encodes an unaligned motion blur node */ + static __forceinline NodeRefPtr encodeNode(OBBNodeMB_t* node) { + return NodeRefPtr((size_t) node | tyOBBNodeMB); + } + + /*! Encodes a leaf */ + static __forceinline NodeRefPtr encodeLeaf(void* tri, size_t num) { + assert(!((size_t)tri & align_mask)); + assert(num <= maxLeafBlocks); + return NodeRefPtr((size_t)tri | (tyLeaf+min(num,(size_t)maxLeafBlocks))); + } + + /*! Encodes a leaf */ + static __forceinline NodeRefPtr encodeTypedLeaf(void* ptr, size_t ty) { + assert(!((size_t)ptr & align_mask)); + return NodeRefPtr((size_t)ptr | (tyLeaf+ty)); + } + + /*! returns base node pointer */ + __forceinline BaseNode_t* baseNode() + { + assert(!isLeaf()); + return (BaseNode_t*)(ptr & ~(size_t)align_mask); + } + __forceinline const BaseNode_t* baseNode() const + { + assert(!isLeaf()); + return (const BaseNode_t*)(ptr & ~(size_t)align_mask); + } + + /*! returns node pointer */ + __forceinline AABBNode_t* getAABBNode() { assert(isAABBNode()); return ( AABBNode_t*)ptr; } + __forceinline const AABBNode_t* getAABBNode() const { assert(isAABBNode()); return (const AABBNode_t*)ptr; } + + /*! returns motion blur node pointer */ + __forceinline AABBNodeMB_t* getAABBNodeMB() { assert(isAABBNodeMB() || isAABBNodeMB4D()); return ( AABBNodeMB_t*)(ptr & ~(size_t)align_mask); } + __forceinline const AABBNodeMB_t* getAABBNodeMB() const { assert(isAABBNodeMB() || isAABBNodeMB4D()); return (const AABBNodeMB_t*)(ptr & ~(size_t)align_mask); } + + /*! returns 4D motion blur node pointer */ + __forceinline AABBNodeMB4D_t* getAABBNodeMB4D() { assert(isAABBNodeMB4D()); return ( AABBNodeMB4D_t*)(ptr & ~(size_t)align_mask); } + __forceinline const AABBNodeMB4D_t* getAABBNodeMB4D() const { assert(isAABBNodeMB4D()); return (const AABBNodeMB4D_t*)(ptr & ~(size_t)align_mask); } + + /*! returns unaligned node pointer */ + __forceinline OBBNode_t* ungetAABBNode() { assert(isOBBNode()); return ( OBBNode_t*)(ptr & ~(size_t)align_mask); } + __forceinline const OBBNode_t* ungetAABBNode() const { assert(isOBBNode()); return (const OBBNode_t*)(ptr & ~(size_t)align_mask); } + + /*! returns unaligned motion blur node pointer */ + __forceinline OBBNodeMB_t* ungetAABBNodeMB() { assert(isOBBNodeMB()); return ( OBBNodeMB_t*)(ptr & ~(size_t)align_mask); } + __forceinline const OBBNodeMB_t* ungetAABBNodeMB() const { assert(isOBBNodeMB()); return (const OBBNodeMB_t*)(ptr & ~(size_t)align_mask); } + + /*! returns quantized node pointer */ + __forceinline QuantizedNode_t* quantizedNode() { assert(isQuantizedNode()); return ( QuantizedNode_t*)(ptr & ~(size_t)align_mask ); } + __forceinline const QuantizedNode_t* quantizedNode() const { assert(isQuantizedNode()); return (const QuantizedNode_t*)(ptr & ~(size_t)align_mask ); } + + /*! returns leaf pointer */ + __forceinline char* leaf(size_t& num) const { + assert(isLeaf()); + num = (ptr & (size_t)items_mask)-tyLeaf; + return (char*)(ptr & ~(size_t)align_mask); + } + + /*! clear all bit flags */ + __forceinline void clearFlags() { + ptr &= ~(size_t)align_mask; + } + + /*! returns the wideness */ + __forceinline size_t getN() const { return N; } + + public: + size_t ptr; + }; +} diff --git a/thirdparty/embree/kernels/bvh/bvh_refit.cpp b/thirdparty/embree/kernels/bvh/bvh_refit.cpp new file mode 100644 index 000000000000..a273c21e8ba2 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_refit.cpp @@ -0,0 +1,247 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh_refit.h" +#include "bvh_statistics.h" + +#include "../geometry/linei.h" +#include "../geometry/triangle.h" +#include "../geometry/trianglev.h" +#include "../geometry/trianglei.h" +#include "../geometry/quadv.h" +#include "../geometry/object.h" +#include "../geometry/instance.h" + +namespace embree +{ + namespace isa + { + static const size_t SINGLE_THREAD_THRESHOLD = 4*1024; + + template + __forceinline bool compare(const typename BVHN::NodeRef* a, const typename BVHN::NodeRef* b) + { + size_t sa = *(size_t*)&a->node()->lower_x; + size_t sb = *(size_t*)&b->node()->lower_x; + return sa < sb; + } + + template + BVHNRefitter::BVHNRefitter (BVH* bvh, const LeafBoundsInterface& leafBounds) + : bvh(bvh), leafBounds(leafBounds), numSubTrees(0) + { + } + + template + void BVHNRefitter::refit() + { + if (bvh->numPrimitives <= SINGLE_THREAD_THRESHOLD) { + bvh->bounds = LBBox3fa(recurse_bottom(bvh->root)); + } + else + { + BBox3fa subTreeBounds[MAX_NUM_SUB_TREES]; + numSubTrees = 0; + gather_subtree_refs(bvh->root,numSubTrees,0); + if (numSubTrees) + parallel_for(size_t(0), numSubTrees, size_t(1), [&](const range& r) { + for (size_t i=r.begin(); ibounds = LBBox3fa(refit_toplevel(bvh->root,numSubTrees,subTreeBounds,0)); + } + } + + template + void BVHNRefitter::gather_subtree_refs(NodeRef& ref, + size_t &subtrees, + const size_t depth) + { + if (depth >= MAX_SUB_TREE_EXTRACTION_DEPTH) + { + assert(subtrees < MAX_NUM_SUB_TREES); + subTrees[subtrees++] = ref; + return; + } + + if (ref.isAABBNode()) + { + AABBNode* node = ref.getAABBNode(); + for (size_t i=0; ichild(i); + if (unlikely(child == BVH::emptyNode)) continue; + gather_subtree_refs(child,subtrees,depth+1); + } + } + } + + template + BBox3fa BVHNRefitter::refit_toplevel(NodeRef& ref, + size_t &subtrees, + const BBox3fa *const subTreeBounds, + const size_t depth) + { + if (depth >= MAX_SUB_TREE_EXTRACTION_DEPTH) + { + assert(subtrees < MAX_NUM_SUB_TREES); + assert(subTrees[subtrees] == ref); + return subTreeBounds[subtrees++]; + } + + if (ref.isAABBNode()) + { + AABBNode* node = ref.getAABBNode(); + BBox3fa bounds[N]; + + for (size_t i=0; ichild(i); + + if (unlikely(child == BVH::emptyNode)) + bounds[i] = BBox3fa(empty); + else + bounds[i] = refit_toplevel(child,subtrees,subTreeBounds,depth+1); + } + + BBox3vf boundsT = transpose(bounds); + + /* set new bounds */ + node->lower_x = boundsT.lower.x; + node->lower_y = boundsT.lower.y; + node->lower_z = boundsT.lower.z; + node->upper_x = boundsT.upper.x; + node->upper_y = boundsT.upper.y; + node->upper_z = boundsT.upper.z; + + return merge(bounds); + } + else + return leafBounds.leafBounds(ref); + } + + // ========================================================= + // ========================================================= + // ========================================================= + + + template + BBox3fa BVHNRefitter::recurse_bottom(NodeRef& ref) + { + /* this is a leaf node */ + if (unlikely(ref.isLeaf())) + return leafBounds.leafBounds(ref); + + /* recurse if this is an internal node */ + AABBNode* node = ref.getAABBNode(); + + /* enable exclusive prefetch for >= AVX platforms */ +#if defined(__AVX__) + BVH::prefetchW(ref); +#endif + BBox3fa bounds[N]; + + for (size_t i=0; ichild(i) == BVH::emptyNode)) + { + bounds[i] = BBox3fa(empty); + } + else + bounds[i] = recurse_bottom(node->child(i)); + + /* AOS to SOA transform */ + BBox3vf boundsT = transpose(bounds); + + /* set new bounds */ + node->lower_x = boundsT.lower.x; + node->lower_y = boundsT.lower.y; + node->lower_z = boundsT.lower.z; + node->upper_x = boundsT.upper.x; + node->upper_y = boundsT.upper.y; + node->upper_z = boundsT.upper.z; + + return merge(bounds); + } + + template + BVHNRefitT::BVHNRefitT (BVH* bvh, Builder* builder, Mesh* mesh, size_t mode) + : bvh(bvh), builder(builder), refitter(new BVHNRefitter(bvh,*(typename BVHNRefitter::LeafBoundsInterface*)this)), mesh(mesh), topologyVersion(0) {} + + template + void BVHNRefitT::clear() + { + if (builder) + builder->clear(); + } + + template + void BVHNRefitT::build() + { + if (mesh->topologyChanged(topologyVersion)) { + topologyVersion = mesh->getTopologyVersion(); + builder->build(); + } + else + refitter->refit(); + } + + template class BVHNRefitter<4>; +#if defined(__AVX__) + template class BVHNRefitter<8>; +#endif + +#if defined(EMBREE_GEOMETRY_TRIANGLE) + Builder* BVH4Triangle4MeshBuilderSAH (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode); + Builder* BVH4Triangle4vMeshBuilderSAH (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode); + Builder* BVH4Triangle4iMeshBuilderSAH (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode); + + Builder* BVH4Triangle4MeshRefitSAH (void* accel, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNRefitT<4,TriangleMesh,Triangle4> ((BVH4*)accel,BVH4Triangle4MeshBuilderSAH (accel,mesh,geomID,mode),mesh,mode); } + Builder* BVH4Triangle4vMeshRefitSAH (void* accel, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNRefitT<4,TriangleMesh,Triangle4v>((BVH4*)accel,BVH4Triangle4vMeshBuilderSAH(accel,mesh,geomID,mode),mesh,mode); } + Builder* BVH4Triangle4iMeshRefitSAH (void* accel, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNRefitT<4,TriangleMesh,Triangle4i>((BVH4*)accel,BVH4Triangle4iMeshBuilderSAH(accel,mesh,geomID,mode),mesh,mode); } +#if defined(__AVX__) + Builder* BVH8Triangle4MeshBuilderSAH (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode); + Builder* BVH8Triangle4vMeshBuilderSAH (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode); + Builder* BVH8Triangle4iMeshBuilderSAH (void* bvh, TriangleMesh* mesh, unsigned int geomID, size_t mode); + + Builder* BVH8Triangle4MeshRefitSAH (void* accel, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNRefitT<8,TriangleMesh,Triangle4> ((BVH8*)accel,BVH8Triangle4MeshBuilderSAH (accel,mesh,geomID,mode),mesh,mode); } + Builder* BVH8Triangle4vMeshRefitSAH (void* accel, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNRefitT<8,TriangleMesh,Triangle4v>((BVH8*)accel,BVH8Triangle4vMeshBuilderSAH(accel,mesh,geomID,mode),mesh,mode); } + Builder* BVH8Triangle4iMeshRefitSAH (void* accel, TriangleMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNRefitT<8,TriangleMesh,Triangle4i>((BVH8*)accel,BVH8Triangle4iMeshBuilderSAH(accel,mesh,geomID,mode),mesh,mode); } +#endif +#endif + +#if defined(EMBREE_GEOMETRY_QUAD) + Builder* BVH4Quad4vMeshBuilderSAH (void* bvh, QuadMesh* mesh, unsigned int geomID, size_t mode); + Builder* BVH4Quad4vMeshRefitSAH (void* accel, QuadMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNRefitT<4,QuadMesh,Quad4v>((BVH4*)accel,BVH4Quad4vMeshBuilderSAH(accel,mesh,geomID,mode),mesh,mode); } + +#if defined(__AVX__) + Builder* BVH8Quad4vMeshBuilderSAH (void* bvh, QuadMesh* mesh, unsigned int geomID, size_t mode); + Builder* BVH8Quad4vMeshRefitSAH (void* accel, QuadMesh* mesh, unsigned int geomID, size_t mode) { return new BVHNRefitT<8,QuadMesh,Quad4v>((BVH8*)accel,BVH8Quad4vMeshBuilderSAH(accel,mesh,geomID,mode),mesh,mode); } +#endif + +#endif + +#if defined(EMBREE_GEOMETRY_USER) + Builder* BVH4VirtualMeshBuilderSAH (void* bvh, UserGeometry* mesh, unsigned int geomID, size_t mode); + Builder* BVH4VirtualMeshRefitSAH (void* accel, UserGeometry* mesh, unsigned int geomID, size_t mode) { return new BVHNRefitT<4,UserGeometry,Object>((BVH4*)accel,BVH4VirtualMeshBuilderSAH(accel,mesh,geomID,mode),mesh,mode); } + +#if defined(__AVX__) + Builder* BVH8VirtualMeshBuilderSAH (void* bvh, UserGeometry* mesh, unsigned int geomID, size_t mode); + Builder* BVH8VirtualMeshRefitSAH (void* accel, UserGeometry* mesh, unsigned int geomID, size_t mode) { return new BVHNRefitT<8,UserGeometry,Object>((BVH8*)accel,BVH8VirtualMeshBuilderSAH(accel,mesh,geomID,mode),mesh,mode); } +#endif +#endif + +#if defined(EMBREE_GEOMETRY_INSTANCE) + Builder* BVH4InstanceMeshBuilderSAH (void* bvh, Instance* mesh, Geometry::GTypeMask gtype, unsigned int geomID, size_t mode); + Builder* BVH4InstanceMeshRefitSAH (void* accel, Instance* mesh, Geometry::GTypeMask gtype, unsigned int geomID, size_t mode) { return new BVHNRefitT<4,Instance,InstancePrimitive>((BVH4*)accel,BVH4InstanceMeshBuilderSAH(accel,mesh,gtype,geomID,mode),mesh,mode); } + +#if defined(__AVX__) + Builder* BVH8InstanceMeshBuilderSAH (void* bvh, Instance* mesh, Geometry::GTypeMask gtype, unsigned int geomID, size_t mode); + Builder* BVH8InstanceMeshRefitSAH (void* accel, Instance* mesh, Geometry::GTypeMask gtype, unsigned int geomID, size_t mode) { return new BVHNRefitT<8,Instance,InstancePrimitive>((BVH8*)accel,BVH8InstanceMeshBuilderSAH(accel,mesh,gtype,geomID,mode),mesh,mode); } +#endif +#endif + + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_refit.h b/thirdparty/embree/kernels/bvh/bvh_refit.h new file mode 100644 index 000000000000..4aa9bdd7cc89 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_refit.h @@ -0,0 +1,95 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../bvh/bvh.h" + +namespace embree +{ + namespace isa + { + template + class BVHNRefitter + { + public: + + /*! Type shortcuts */ + typedef BVHN BVH; + typedef typename BVH::AABBNode AABBNode; + typedef typename BVH::NodeRef NodeRef; + + struct LeafBoundsInterface { + virtual const BBox3fa leafBounds(NodeRef& ref) const = 0; + }; + + public: + + /*! Constructor. */ + BVHNRefitter (BVH* bvh, const LeafBoundsInterface& leafBounds); + + /*! refits the BVH */ + void refit(); + + private: + /* single-threaded subtree extraction based on BVH depth */ + void gather_subtree_refs(NodeRef& ref, + size_t &subtrees, + const size_t depth = 0); + + /* single-threaded top-level refit */ + BBox3fa refit_toplevel(NodeRef& ref, + size_t &subtrees, + const BBox3fa *const subTreeBounds, + const size_t depth = 0); + + /* single-threaded subtree refit */ + BBox3fa recurse_bottom(NodeRef& ref); + + public: + BVH* bvh; //!< BVH to refit + const LeafBoundsInterface& leafBounds; //!< calculates bounds of leaves + + static const size_t MAX_SUB_TREE_EXTRACTION_DEPTH = (N==4) ? 4 : (N==8) ? 3 : 3; + static const size_t MAX_NUM_SUB_TREES = (N==4) ? 256 : (N==8) ? 512 : N*N*N; // N ^ MAX_SUB_TREE_EXTRACTION_DEPTH + size_t numSubTrees; + NodeRef subTrees[MAX_NUM_SUB_TREES]; + }; + + template + class BVHNRefitT : public Builder, public BVHNRefitter::LeafBoundsInterface + { + public: + + /*! Type shortcuts */ + typedef BVHN BVH; + typedef typename BVH::AABBNode AABBNode; + typedef typename BVH::NodeRef NodeRef; + + public: + BVHNRefitT (BVH* bvh, Builder* builder, Mesh* mesh, size_t mode); + + virtual void build(); + + virtual void clear(); + + virtual const BBox3fa leafBounds (NodeRef& ref) const + { + size_t num; char* prim = ref.leaf(num); + if (unlikely(ref == BVH::emptyNode)) return empty; + + BBox3fa bounds = empty; + for (size_t i=0; i builder; + std::unique_ptr> refitter; + Mesh* mesh; + unsigned int topologyVersion; + }; + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_rotate.cpp b/thirdparty/embree/kernels/bvh/bvh_rotate.cpp new file mode 100644 index 000000000000..2bb431bf0e4f --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_rotate.cpp @@ -0,0 +1,127 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh_rotate.h" + +namespace embree +{ + namespace isa + { + /*! Computes half surface area of box. */ + __forceinline float halfArea3f(const BBox& box) { + const vfloat4 d = box.size(); + const vfloat4 a = d*shuffle<1,2,0,3>(d); + return a[0]+a[1]+a[2]; + } + + size_t BVHNRotate<4>::rotate(NodeRef parentRef, size_t depth) + { + /*! nothing to rotate if we reached a leaf node. */ + if (parentRef.isBarrier()) return 0; + if (parentRef.isLeaf()) return 0; + AABBNode* parent = parentRef.getAABBNode(); + + /*! rotate all children first */ + vint4 cdepth; + for (size_t c=0; c<4; c++) + cdepth[c] = (int)rotate(parent->child(c),depth+1); + + /* compute current areas of all children */ + vfloat4 sizeX = parent->upper_x-parent->lower_x; + vfloat4 sizeY = parent->upper_y-parent->lower_y; + vfloat4 sizeZ = parent->upper_z-parent->lower_z; + vfloat4 childArea = madd(sizeX,(sizeY + sizeZ),sizeY*sizeZ); + + /*! get node bounds */ + BBox child1_0,child1_1,child1_2,child1_3; + parent->bounds(child1_0,child1_1,child1_2,child1_3); + + /*! Find best rotation. We pick a first child (child1) and a sub-child + (child2child) of a different second child (child2), and swap child1 + and child2child. We perform the best such swap. */ + float bestArea = 0; + size_t bestChild1 = -1, bestChild2 = -1, bestChild2Child = -1; + for (size_t c2=0; c2<4; c2++) + { + /*! ignore leaf nodes as we cannot descent into them */ + if (parent->child(c2).isBarrier()) continue; + if (parent->child(c2).isLeaf()) continue; + AABBNode* child2 = parent->child(c2).getAABBNode(); + + /*! transpose child bounds */ + BBox child2c0,child2c1,child2c2,child2c3; + child2->bounds(child2c0,child2c1,child2c2,child2c3); + + /*! put child1_0 at each child2 position */ + float cost00 = halfArea3f(merge(child1_0,child2c1,child2c2,child2c3)); + float cost01 = halfArea3f(merge(child2c0,child1_0,child2c2,child2c3)); + float cost02 = halfArea3f(merge(child2c0,child2c1,child1_0,child2c3)); + float cost03 = halfArea3f(merge(child2c0,child2c1,child2c2,child1_0)); + vfloat4 cost0 = vfloat4(cost00,cost01,cost02,cost03); + vfloat4 min0 = vreduce_min(cost0); + int pos0 = (int)bsf(movemask(min0 == cost0)); + + /*! put child1_1 at each child2 position */ + float cost10 = halfArea3f(merge(child1_1,child2c1,child2c2,child2c3)); + float cost11 = halfArea3f(merge(child2c0,child1_1,child2c2,child2c3)); + float cost12 = halfArea3f(merge(child2c0,child2c1,child1_1,child2c3)); + float cost13 = halfArea3f(merge(child2c0,child2c1,child2c2,child1_1)); + vfloat4 cost1 = vfloat4(cost10,cost11,cost12,cost13); + vfloat4 min1 = vreduce_min(cost1); + int pos1 = (int)bsf(movemask(min1 == cost1)); + + /*! put child1_2 at each child2 position */ + float cost20 = halfArea3f(merge(child1_2,child2c1,child2c2,child2c3)); + float cost21 = halfArea3f(merge(child2c0,child1_2,child2c2,child2c3)); + float cost22 = halfArea3f(merge(child2c0,child2c1,child1_2,child2c3)); + float cost23 = halfArea3f(merge(child2c0,child2c1,child2c2,child1_2)); + vfloat4 cost2 = vfloat4(cost20,cost21,cost22,cost23); + vfloat4 min2 = vreduce_min(cost2); + int pos2 = (int)bsf(movemask(min2 == cost2)); + + /*! put child1_3 at each child2 position */ + float cost30 = halfArea3f(merge(child1_3,child2c1,child2c2,child2c3)); + float cost31 = halfArea3f(merge(child2c0,child1_3,child2c2,child2c3)); + float cost32 = halfArea3f(merge(child2c0,child2c1,child1_3,child2c3)); + float cost33 = halfArea3f(merge(child2c0,child2c1,child2c2,child1_3)); + vfloat4 cost3 = vfloat4(cost30,cost31,cost32,cost33); + vfloat4 min3 = vreduce_min(cost3); + int pos3 = (int)bsf(movemask(min3 == cost3)); + + /*! find best other child */ + vfloat4 area0123 = vfloat4(extract<0>(min0),extract<0>(min1),extract<0>(min2),extract<0>(min3)) - vfloat4(childArea[c2]); + int pos[4] = { pos0,pos1,pos2,pos3 }; + const size_t mbd = BVH4::maxBuildDepth; + vbool4 valid = vint4(int(depth+1))+cdepth <= vint4(mbd); // only select swaps that fulfill depth constraints + valid &= vint4(int(c2)) != vint4(step); + if (none(valid)) continue; + size_t c1 = select_min(valid,area0123); + float area = area0123[c1]; + if (c1 == c2) continue; // can happen if bounds are NANs + + /*! accept a swap when it reduces cost and is not swapping a node with itself */ + if (area < bestArea) { + bestArea = area; + bestChild1 = c1; + bestChild2 = c2; + bestChild2Child = pos[c1]; + } + } + + /*! if we did not find a swap that improves the SAH then do nothing */ + if (bestChild1 == size_t(-1)) return 1+reduce_max(cdepth); + + /*! perform the best found tree rotation */ + AABBNode* child2 = parent->child(bestChild2).getAABBNode(); + AABBNode::swap(parent,bestChild1,child2,bestChild2Child); + parent->setBounds(bestChild2,child2->bounds()); + AABBNode::compact(parent); + AABBNode::compact(child2); + + /*! This returned depth is conservative as the child that was + * pulled up in the tree could have been on the critical path. */ + cdepth[bestChild1]++; // bestChild1 was pushed down one level + return 1+reduce_max(cdepth); + } + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_rotate.h b/thirdparty/embree/kernels/bvh/bvh_rotate.h new file mode 100644 index 000000000000..009bef339ee1 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_rotate.h @@ -0,0 +1,37 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh.h" + +namespace embree +{ + namespace isa + { + template + class BVHNRotate + { + typedef typename BVHN::NodeRef NodeRef; + + public: + static const bool enabled = false; + + static __forceinline size_t rotate(NodeRef parentRef, size_t depth = 1) { return 0; } + static __forceinline void restructure(NodeRef ref, size_t depth = 1) {} + }; + + /* BVH4 tree rotations */ + template<> + class BVHNRotate<4> + { + typedef BVH4::AABBNode AABBNode; + typedef BVH4::NodeRef NodeRef; + + public: + static const bool enabled = true; + + static size_t rotate(NodeRef parentRef, size_t depth = 1); + }; + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_statistics.cpp b/thirdparty/embree/kernels/bvh/bvh_statistics.cpp new file mode 100644 index 000000000000..05460843af22 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_statistics.cpp @@ -0,0 +1,165 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "bvh_statistics.h" +#include "../../common/algorithms/parallel_reduce.h" + +namespace embree +{ + template + BVHNStatistics::BVHNStatistics (BVH* bvh) : bvh(bvh) + { + double A = max(0.0f,bvh->getLinearBounds().expectedHalfArea()); + stat = statistics(bvh->root,A,BBox1f(0.0f,1.0f)); + } + + template + std::string BVHNStatistics::str() + { + std::ostringstream stream; + stream.setf(std::ios::fixed, std::ios::floatfield); + stream << " primitives = " << bvh->numPrimitives << ", vertices = " << bvh->numVertices << ", depth = " << stat.depth << std::endl; + size_t totalBytes = stat.bytes(bvh); + double totalSAH = stat.sah(bvh); + stream << " total : sah = " << std::setw(7) << std::setprecision(3) << totalSAH << " (100.00%), "; + stream << "#bytes = " << std::setw(7) << std::setprecision(2) << totalBytes/1E6 << " MB (100.00%), "; + stream << "#nodes = " << std::setw(7) << stat.size() << " (" << std::setw(6) << std::setprecision(2) << 100.0*stat.fillRate(bvh) << "% filled), "; + stream << "#bytes/prim = " << std::setw(6) << std::setprecision(2) << double(totalBytes)/double(bvh->numPrimitives) << std::endl; + if (stat.statAABBNodes.numNodes ) stream << " getAABBNodes : " << stat.statAABBNodes.toString(bvh,totalSAH,totalBytes) << std::endl; + if (stat.statOBBNodes.numNodes ) stream << " ungetAABBNodes : " << stat.statOBBNodes.toString(bvh,totalSAH,totalBytes) << std::endl; + if (stat.statAABBNodesMB.numNodes ) stream << " getAABBNodesMB : " << stat.statAABBNodesMB.toString(bvh,totalSAH,totalBytes) << std::endl; + if (stat.statAABBNodesMB4D.numNodes) stream << " getAABBNodesMB4D : " << stat.statAABBNodesMB4D.toString(bvh,totalSAH,totalBytes) << std::endl; + if (stat.statOBBNodesMB.numNodes) stream << " ungetAABBNodesMB : " << stat.statOBBNodesMB.toString(bvh,totalSAH,totalBytes) << std::endl; + if (stat.statQuantizedNodes.numNodes ) stream << " quantizedNodes : " << stat.statQuantizedNodes.toString(bvh,totalSAH,totalBytes) << std::endl; + if (true) stream << " leaves : " << stat.statLeaf.toString(bvh,totalSAH,totalBytes) << std::endl; + if (true) stream << " histogram : " << stat.statLeaf.histToString() << std::endl; + return stream.str(); + } + + template + typename BVHNStatistics::Statistics BVHNStatistics::statistics(NodeRef node, const double A, const BBox1f t0t1) + { + Statistics s; + assert(t0t1.size() > 0.0f); + double dt = max(0.0f,t0t1.size()); + if (node.isAABBNode()) + { + AABBNode* n = node.getAABBNode(); + s = s + parallel_reduce(0,N,Statistics(),[&] ( const int i ) { + if (n->child(i) == BVH::emptyNode) return Statistics(); + const double Ai = max(0.0f,halfArea(n->extend(i))); + Statistics s = statistics(n->child(i),Ai,t0t1); + s.statAABBNodes.numChildren++; + return s; + }, Statistics::add); + s.statAABBNodes.numNodes++; + s.statAABBNodes.nodeSAH += dt*A; + s.depth++; + } + else if (node.isOBBNode()) + { + OBBNode* n = node.ungetAABBNode(); + s = s + parallel_reduce(0,N,Statistics(),[&] ( const int i ) { + if (n->child(i) == BVH::emptyNode) return Statistics(); + const double Ai = max(0.0f,halfArea(n->extent(i))); + Statistics s = statistics(n->child(i),Ai,t0t1); + s.statOBBNodes.numChildren++; + return s; + }, Statistics::add); + s.statOBBNodes.numNodes++; + s.statOBBNodes.nodeSAH += dt*A; + s.depth++; + } + else if (node.isAABBNodeMB()) + { + AABBNodeMB* n = node.getAABBNodeMB(); + s = s + parallel_reduce(0,N,Statistics(),[&] ( const int i ) { + if (n->child(i) == BVH::emptyNode) return Statistics(); + const double Ai = max(0.0f,n->expectedHalfArea(i,t0t1)); + Statistics s = statistics(n->child(i),Ai,t0t1); + s.statAABBNodesMB.numChildren++; + return s; + }, Statistics::add); + s.statAABBNodesMB.numNodes++; + s.statAABBNodesMB.nodeSAH += dt*A; + s.depth++; + } + else if (node.isAABBNodeMB4D()) + { + AABBNodeMB4D* n = node.getAABBNodeMB4D(); + s = s + parallel_reduce(0,N,Statistics(),[&] ( const int i ) { + if (n->child(i) == BVH::emptyNode) return Statistics(); + const BBox1f t0t1i = intersect(t0t1,n->timeRange(i)); + assert(!t0t1i.empty()); + const double Ai = n->AABBNodeMB::expectedHalfArea(i,t0t1i); + Statistics s = statistics(n->child(i),Ai,t0t1i); + s.statAABBNodesMB4D.numChildren++; + return s; + }, Statistics::add); + s.statAABBNodesMB4D.numNodes++; + s.statAABBNodesMB4D.nodeSAH += dt*A; + s.depth++; + } + else if (node.isOBBNodeMB()) + { + OBBNodeMB* n = node.ungetAABBNodeMB(); + s = s + parallel_reduce(0,N,Statistics(),[&] ( const int i ) { + if (n->child(i) == BVH::emptyNode) return Statistics(); + const double Ai = max(0.0f,halfArea(n->extent0(i))); + Statistics s = statistics(n->child(i),Ai,t0t1); + s.statOBBNodesMB.numChildren++; + return s; + }, Statistics::add); + s.statOBBNodesMB.numNodes++; + s.statOBBNodesMB.nodeSAH += dt*A; + s.depth++; + } + else if (node.isQuantizedNode()) + { + QuantizedNode* n = node.quantizedNode(); + s = s + parallel_reduce(0,N,Statistics(),[&] ( const int i ) { + if (n->child(i) == BVH::emptyNode) return Statistics(); + const double Ai = max(0.0f,halfArea(n->extent(i))); + Statistics s = statistics(n->child(i),Ai,t0t1); + s.statQuantizedNodes.numChildren++; + return s; + }, Statistics::add); + s.statQuantizedNodes.numNodes++; + s.statQuantizedNodes.nodeSAH += dt*A; + s.depth++; + } + else if (node.isLeaf()) + { + size_t num; const char* tri = node.leaf(num); + if (num) + { + for (size_t i=0; iprimTy->getBytes(tri); + s.statLeaf.numPrimsActive += bvh->primTy->sizeActive(tri); + s.statLeaf.numPrimsTotal += bvh->primTy->sizeTotal(tri); + s.statLeaf.numBytes += bytes; + tri+=bytes; + } + s.statLeaf.numLeaves++; + s.statLeaf.numPrimBlocks += num; + s.statLeaf.leafSAH += dt*A*num; + if (num-1 < Statistics::LeafStat::NHIST) { + s.statLeaf.numPrimBlocksHistogram[num-1]++; + } + } + } + else { + throw std::runtime_error("not supported node type in bvh_statistics"); + } + return s; + } + +#if defined(__AVX__) + template class BVHNStatistics<8>; +#endif + +#if !defined(__AVX__) || !defined(EMBREE_TARGET_SSE2) && !defined(EMBREE_TARGET_SSE42) + template class BVHNStatistics<4>; +#endif +} diff --git a/thirdparty/embree/kernels/bvh/bvh_statistics.h b/thirdparty/embree/kernels/bvh/bvh_statistics.h new file mode 100644 index 000000000000..73dfc6fbcc0d --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_statistics.h @@ -0,0 +1,285 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh.h" +#include + +namespace embree +{ + template + class BVHNStatistics + { + typedef BVHN BVH; + typedef typename BVH::AABBNode AABBNode; + typedef typename BVH::OBBNode OBBNode; + typedef typename BVH::AABBNodeMB AABBNodeMB; + typedef typename BVH::AABBNodeMB4D AABBNodeMB4D; + typedef typename BVH::OBBNodeMB OBBNodeMB; + typedef typename BVH::QuantizedNode QuantizedNode; + + typedef typename BVH::NodeRef NodeRef; + + struct Statistics + { + template + struct NodeStat + { + NodeStat ( double nodeSAH = 0, + size_t numNodes = 0, + size_t numChildren = 0) + : nodeSAH(nodeSAH), + numNodes(numNodes), + numChildren(numChildren) {} + + double sah(BVH* bvh) const { + return nodeSAH/bvh->getLinearBounds().expectedHalfArea(); + } + + size_t bytes() const { + return numNodes*sizeof(Node); + } + + size_t size() const { + return numNodes; + } + + double fillRateNom () const { return double(numChildren); } + double fillRateDen () const { return double(numNodes*N); } + double fillRate () const { return fillRateNom()/fillRateDen(); } + + __forceinline friend NodeStat operator+ ( const NodeStat& a, const NodeStat& b) + { + return NodeStat(a.nodeSAH + b.nodeSAH, + a.numNodes+b.numNodes, + a.numChildren+b.numChildren); + } + + std::string toString(BVH* bvh, double sahTotal, size_t bytesTotal) const + { + std::ostringstream stream; + stream.setf(std::ios::fixed, std::ios::floatfield); + stream << "sah = " << std::setw(7) << std::setprecision(3) << sah(bvh); + stream << " (" << std::setw(6) << std::setprecision(2) << 100.0*sah(bvh)/sahTotal << "%), "; + stream << "#bytes = " << std::setw(7) << std::setprecision(2) << bytes()/1E6 << " MB "; + stream << "(" << std::setw(6) << std::setprecision(2) << 100.0*double(bytes())/double(bytesTotal) << "%), "; + stream << "#nodes = " << std::setw(7) << numNodes << " (" << std::setw(6) << std::setprecision(2) << 100.0*fillRate() << "% filled), "; + stream << "#bytes/prim = " << std::setw(6) << std::setprecision(2) << double(bytes())/double(bvh->numPrimitives); + return stream.str(); + } + + public: + double nodeSAH; + size_t numNodes; + size_t numChildren; + }; + + struct LeafStat + { + static const int NHIST = 8; + + LeafStat ( double leafSAH = 0.0f, + size_t numLeaves = 0, + size_t numPrimsActive = 0, + size_t numPrimsTotal = 0, + size_t numPrimBlocks = 0, + size_t numBytes = 0) + : leafSAH(leafSAH), + numLeaves(numLeaves), + numPrimsActive(numPrimsActive), + numPrimsTotal(numPrimsTotal), + numPrimBlocks(numPrimBlocks), + numBytes(numBytes) + { + for (size_t i=0; igetLinearBounds().expectedHalfArea(); + } + + size_t bytes(BVH* bvh) const { + return numBytes; + } + + size_t size() const { + return numLeaves; + } + + double fillRateNom (BVH* bvh) const { return double(numPrimsActive); } + double fillRateDen (BVH* bvh) const { return double(numPrimsTotal); } + double fillRate (BVH* bvh) const { return fillRateNom(bvh)/fillRateDen(bvh); } + + __forceinline friend LeafStat operator+ ( const LeafStat& a, const LeafStat& b) + { + LeafStat stat(a.leafSAH + b.leafSAH, + a.numLeaves+b.numLeaves, + a.numPrimsActive+b.numPrimsActive, + a.numPrimsTotal+b.numPrimsTotal, + a.numPrimBlocks+b.numPrimBlocks, + a.numBytes+b.numBytes); + for (size_t i=0; i statAABBNodes = NodeStat(), + NodeStat statOBBNodes = NodeStat(), + NodeStat statAABBNodesMB = NodeStat(), + NodeStat statAABBNodesMB4D = NodeStat(), + NodeStat statOBBNodesMB = NodeStat(), + NodeStat statQuantizedNodes = NodeStat()) + + : depth(depth), + statLeaf(statLeaf), + statAABBNodes(statAABBNodes), + statOBBNodes(statOBBNodes), + statAABBNodesMB(statAABBNodesMB), + statAABBNodesMB4D(statAABBNodesMB4D), + statOBBNodesMB(statOBBNodesMB), + statQuantizedNodes(statQuantizedNodes) {} + + double sah(BVH* bvh) const + { + return statLeaf.sah(bvh) + + statAABBNodes.sah(bvh) + + statOBBNodes.sah(bvh) + + statAABBNodesMB.sah(bvh) + + statAABBNodesMB4D.sah(bvh) + + statOBBNodesMB.sah(bvh) + + statQuantizedNodes.sah(bvh); + } + + size_t bytes(BVH* bvh) const { + return statLeaf.bytes(bvh) + + statAABBNodes.bytes() + + statOBBNodes.bytes() + + statAABBNodesMB.bytes() + + statAABBNodesMB4D.bytes() + + statOBBNodesMB.bytes() + + statQuantizedNodes.bytes(); + } + + size_t size() const + { + return statLeaf.size() + + statAABBNodes.size() + + statOBBNodes.size() + + statAABBNodesMB.size() + + statAABBNodesMB4D.size() + + statOBBNodesMB.size() + + statQuantizedNodes.size(); + } + + double fillRate (BVH* bvh) const + { + double nom = statLeaf.fillRateNom(bvh) + + statAABBNodes.fillRateNom() + + statOBBNodes.fillRateNom() + + statAABBNodesMB.fillRateNom() + + statAABBNodesMB4D.fillRateNom() + + statOBBNodesMB.fillRateNom() + + statQuantizedNodes.fillRateNom(); + double den = statLeaf.fillRateDen(bvh) + + statAABBNodes.fillRateDen() + + statOBBNodes.fillRateDen() + + statAABBNodesMB.fillRateDen() + + statAABBNodesMB4D.fillRateDen() + + statOBBNodesMB.fillRateDen() + + statQuantizedNodes.fillRateDen(); + return nom/den; + } + + friend Statistics operator+ ( const Statistics& a, const Statistics& b ) + { + return Statistics(max(a.depth,b.depth), + a.statLeaf + b.statLeaf, + a.statAABBNodes + b.statAABBNodes, + a.statOBBNodes + b.statOBBNodes, + a.statAABBNodesMB + b.statAABBNodesMB, + a.statAABBNodesMB4D + b.statAABBNodesMB4D, + a.statOBBNodesMB + b.statOBBNodesMB, + a.statQuantizedNodes + b.statQuantizedNodes); + } + + static Statistics add ( const Statistics& a, const Statistics& b ) { + return a+b; + } + + public: + size_t depth; + LeafStat statLeaf; + NodeStat statAABBNodes; + NodeStat statOBBNodes; + NodeStat statAABBNodesMB; + NodeStat statAABBNodesMB4D; + NodeStat statOBBNodesMB; + NodeStat statQuantizedNodes; + }; + + public: + + /* Constructor gathers statistics. */ + BVHNStatistics (BVH* bvh); + + /*! Convert statistics into a string */ + std::string str(); + + double sah() const { + return stat.sah(bvh); + } + + size_t bytesUsed() const { + return stat.bytes(bvh); + } + + private: + Statistics statistics(NodeRef node, const double A, const BBox1f dt); + + private: + BVH* bvh; + Statistics stat; + }; + + typedef BVHNStatistics<4> BVH4Statistics; + typedef BVHNStatistics<8> BVH8Statistics; +} diff --git a/thirdparty/embree/kernels/bvh/bvh_traverser1.h b/thirdparty/embree/kernels/bvh/bvh_traverser1.h new file mode 100644 index 000000000000..7f17084b81d6 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_traverser1.h @@ -0,0 +1,676 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh.h" +#include "node_intersector1.h" +#include "../common/stack_item.h" + +#define NEW_SORTING_CODE 1 + +namespace embree +{ + namespace isa + { + /*! BVH regular node traversal for single rays. */ + template + class BVHNNodeTraverser1Hit; + + /*! Helper functions for fast sorting using AVX512 instructions. */ +#if defined(__AVX512ER__) + + /* KNL code path */ + __forceinline void isort_update(vfloat16 &dist, vllong8 &ptr, const vfloat16 &d, const vllong8 &p) + { + const vfloat16 dist_shift = align_shift_right<15>(dist,dist); + const vllong8 ptr_shift = align_shift_right<7>(ptr,ptr); + const vbool16 m_geq = d >= dist; + const vbool16 m_geq_shift = m_geq << 1; + dist = select(m_geq,d,dist); + ptr = select(vboold8(m_geq),p,ptr); + dist = select(m_geq_shift,dist_shift,dist); + ptr = select(vboold8(m_geq_shift),ptr_shift,ptr); + } + + __forceinline void isort_quick_update(vfloat16 &dist, vllong8 &ptr, const vfloat16 &d, const vllong8 &p) + { + //dist = align_shift_right<15>(dist,d); + //ptr = align_shift_right<7>(ptr,p); + dist = align_shift_right<15>(dist,permute(d,vint16(zero))); + ptr = align_shift_right<7>(ptr,permute(p,vllong8(zero))); + } + + template + __forceinline void traverseClosestHitAVX512(NodeRef& cur, + size_t mask, + const vfloat& tNear, + StackItemT*& stackPtr, + StackItemT* stackEnd) + { + assert(mask != 0); + const BaseNode* node = cur.baseNode(); + + vllong8 children( vllong::loadu((void*)node->children) ); + children = vllong8::compact((int)mask,children); + vfloat16 distance = tNear; + distance = vfloat16::compact((int)mask,distance,tNear); + + cur = toScalar(children); + BVHN::prefetch(cur,types); + + mask &= mask-1; + if (likely(mask == 0)) return; + + /* 2 hits: order A0 B0 */ + const vllong8 c0(children); + const vfloat16 d0(distance); + children = align_shift_right<1>(children,children); + distance = align_shift_right<1>(distance,distance); + const vllong8 c1(children); + const vfloat16 d1(distance); + + cur = toScalar(children); + BVHN::prefetch(cur,types); + + /* a '<' keeps the order for equal distances, scenes like powerplant largely benefit from it */ + const vboolf16 m_dist = d0 < d1; + const vfloat16 dist_A0 = select(m_dist, d0, d1); + const vfloat16 dist_B0 = select(m_dist, d1, d0); + const vllong8 ptr_A0 = select(vboold8(m_dist), c0, c1); + const vllong8 ptr_B0 = select(vboold8(m_dist), c1, c0); + + mask &= mask-1; + if (likely(mask == 0)) { + cur = toScalar(ptr_A0); + stackPtr[0].ptr = toScalar(ptr_B0); + *(float*)&stackPtr[0].dist = toScalar(dist_B0); + stackPtr++; + return; + } + + /* 3 hits: order A1 B1 C1 */ + + children = align_shift_right<1>(children,children); + distance = align_shift_right<1>(distance,distance); + + const vllong8 c2(children); + const vfloat16 d2(distance); + + cur = toScalar(children); + BVHN::prefetch(cur,types); + + const vboolf16 m_dist1 = dist_A0 <= d2; + const vfloat16 dist_tmp_B1 = select(m_dist1, d2, dist_A0); + const vllong8 ptr_A1 = select(vboold8(m_dist1), ptr_A0, c2); + const vllong8 ptr_tmp_B1 = select(vboold8(m_dist1), c2, ptr_A0); + + const vboolf16 m_dist2 = dist_B0 <= dist_tmp_B1; + const vfloat16 dist_B1 = select(m_dist2, dist_B0 , dist_tmp_B1); + const vfloat16 dist_C1 = select(m_dist2, dist_tmp_B1, dist_B0); + const vllong8 ptr_B1 = select(vboold8(m_dist2), ptr_B0, ptr_tmp_B1); + const vllong8 ptr_C1 = select(vboold8(m_dist2), ptr_tmp_B1, ptr_B0); + + mask &= mask-1; + if (likely(mask == 0)) { + cur = toScalar(ptr_A1); + stackPtr[0].ptr = toScalar(ptr_C1); + *(float*)&stackPtr[0].dist = toScalar(dist_C1); + stackPtr[1].ptr = toScalar(ptr_B1); + *(float*)&stackPtr[1].dist = toScalar(dist_B1); + stackPtr+=2; + return; + } + + /* 4 hits: order A2 B2 C2 D2 */ + + const vfloat16 dist_A1 = select(m_dist1, dist_A0, d2); + + children = align_shift_right<1>(children,children); + distance = align_shift_right<1>(distance,distance); + + const vllong8 c3(children); + const vfloat16 d3(distance); + + cur = toScalar(children); + BVHN::prefetch(cur,types); + + const vboolf16 m_dist3 = dist_A1 <= d3; + const vfloat16 dist_tmp_B2 = select(m_dist3, d3, dist_A1); + const vllong8 ptr_A2 = select(vboold8(m_dist3), ptr_A1, c3); + const vllong8 ptr_tmp_B2 = select(vboold8(m_dist3), c3, ptr_A1); + + const vboolf16 m_dist4 = dist_B1 <= dist_tmp_B2; + const vfloat16 dist_B2 = select(m_dist4, dist_B1 , dist_tmp_B2); + const vfloat16 dist_tmp_C2 = select(m_dist4, dist_tmp_B2, dist_B1); + const vllong8 ptr_B2 = select(vboold8(m_dist4), ptr_B1, ptr_tmp_B2); + const vllong8 ptr_tmp_C2 = select(vboold8(m_dist4), ptr_tmp_B2, ptr_B1); + + const vboolf16 m_dist5 = dist_C1 <= dist_tmp_C2; + const vfloat16 dist_C2 = select(m_dist5, dist_C1 , dist_tmp_C2); + const vfloat16 dist_D2 = select(m_dist5, dist_tmp_C2, dist_C1); + const vllong8 ptr_C2 = select(vboold8(m_dist5), ptr_C1, ptr_tmp_C2); + const vllong8 ptr_D2 = select(vboold8(m_dist5), ptr_tmp_C2, ptr_C1); + + mask &= mask-1; + if (likely(mask == 0)) { + cur = toScalar(ptr_A2); + stackPtr[0].ptr = toScalar(ptr_D2); + *(float*)&stackPtr[0].dist = toScalar(dist_D2); + stackPtr[1].ptr = toScalar(ptr_C2); + *(float*)&stackPtr[1].dist = toScalar(dist_C2); + stackPtr[2].ptr = toScalar(ptr_B2); + *(float*)&stackPtr[2].dist = toScalar(dist_B2); + stackPtr+=3; + return; + } + + /* >=5 hits: reverse to descending order for writing to stack */ + + const size_t hits = 4 + popcnt(mask); + const vfloat16 dist_A2 = select(m_dist3, dist_A1, d3); + vfloat16 dist(neg_inf); + vllong8 ptr(zero); + + + isort_quick_update(dist,ptr,dist_A2,ptr_A2); + isort_quick_update(dist,ptr,dist_B2,ptr_B2); + isort_quick_update(dist,ptr,dist_C2,ptr_C2); + isort_quick_update(dist,ptr,dist_D2,ptr_D2); + + do { + + children = align_shift_right<1>(children,children); + distance = align_shift_right<1>(distance,distance); + + cur = toScalar(children); + BVHN::prefetch(cur,types); + + const vfloat16 new_dist(permute(distance,vint16(zero))); + const vllong8 new_ptr(permute(children,vllong8(zero))); + + mask &= mask-1; + isort_update(dist,ptr,new_dist,new_ptr); + + } while(mask); + + const vboold8 m_stack_ptr(0x55); // 10101010 (lsb -> msb) + const vboolf16 m_stack_dist(0x4444); // 0010001000100010 (lsb -> msb) + + /* extract current noderef */ + cur = toScalar(permute(ptr,vllong8(hits-1))); + /* rearrange pointers to beginning of 16 bytes block */ + vllong8 stackElementA0; + stackElementA0 = vllong8::expand(m_stack_ptr,ptr,stackElementA0); + /* put distances in between */ + vuint16 stackElementA1((__m512i)stackElementA0); + stackElementA1 = vuint16::expand(m_stack_dist,asUInt(dist),stackElementA1); + /* write out first 4 x 16 bytes block to stack */ + vuint16::storeu(stackPtr,stackElementA1); + /* get upper half of dist and ptr */ + dist = align_shift_right<4>(dist,dist); + ptr = align_shift_right<4>(ptr,ptr); + /* assemble and write out second block */ + vllong8 stackElementB0; + stackElementB0 = vllong8::expand(m_stack_ptr,ptr,stackElementB0); + vuint16 stackElementB1((__m512i)stackElementB0); + stackElementB1 = vuint16::expand(m_stack_dist,asUInt(dist),stackElementB1); + vuint16::storeu(stackPtr + 4,stackElementB1); + /* increase stack pointer */ + stackPtr += hits-1; + } +#endif + +#if defined(__AVX512VL__) // SKX + + template + __forceinline void isort_update(vint &dist, const vint &d) + { + const vint dist_shift = align_shift_right(dist,dist); + const vboolf m_geq = d >= dist; + const vboolf m_geq_shift = m_geq << 1; + dist = select(m_geq,d,dist); + dist = select(m_geq_shift,dist_shift,dist); + } + + template + __forceinline void isort_quick_update(vint &dist, const vint &d) { + dist = align_shift_right(dist,permute(d,vint(zero))); + } + + __forceinline size_t permuteExtract(const vint8& index, const vllong4& n0, const vllong4& n1) { + return toScalar(permutex2var((__m256i)index,n0,n1)); + } + + __forceinline float permuteExtract(const vint8& index, const vfloat8& n) { + return toScalar(permute(n,index)); + } + +#endif + + /* Specialization for BVH4. */ + template + class BVHNNodeTraverser1Hit<4, Nx, types> + { + typedef BVH4 BVH; + typedef BVH4::NodeRef NodeRef; + typedef BVH4::BaseNode BaseNode; + + + public: + /* Traverses a node with at least one hit child. Optimized for finding the closest hit (intersection). */ + static __forceinline void traverseClosestHit(NodeRef& cur, + size_t mask, + const vfloat& tNear, + StackItemT*& stackPtr, + StackItemT* stackEnd) + { + assert(mask != 0); +#if defined(__AVX512ER__) + traverseClosestHitAVX512<4,Nx,types,NodeRef,BaseNode>(cur,mask,tNear,stackPtr,stackEnd); +#else + const BaseNode* node = cur.baseNode(); + + /*! one child is hit, continue with that child */ + size_t r = bscf(mask); + cur = node->child(r); + BVH::prefetch(cur,types); + if (likely(mask == 0)) { + assert(cur != BVH::emptyNode); + return; + } + + /*! two children are hit, push far child, and continue with closer child */ + NodeRef c0 = cur; + const unsigned int d0 = ((unsigned int*)&tNear)[r]; + r = bscf(mask); + NodeRef c1 = node->child(r); + BVH::prefetch(c1,types); + const unsigned int d1 = ((unsigned int*)&tNear)[r]; + assert(c0 != BVH::emptyNode); + assert(c1 != BVH::emptyNode); + if (likely(mask == 0)) { + assert(stackPtr < stackEnd); + if (d0 < d1) { stackPtr->ptr = c1; stackPtr->dist = d1; stackPtr++; cur = c0; return; } + else { stackPtr->ptr = c0; stackPtr->dist = d0; stackPtr++; cur = c1; return; } + } + +#if NEW_SORTING_CODE == 1 + vint4 s0((size_t)c0,(size_t)d0); + vint4 s1((size_t)c1,(size_t)d1); + r = bscf(mask); + NodeRef c2 = node->child(r); BVH::prefetch(c2,types); unsigned int d2 = ((unsigned int*)&tNear)[r]; + vint4 s2((size_t)c2,(size_t)d2); + /* 3 hits */ + if (likely(mask == 0)) { + StackItemT::sort3(s0,s1,s2); + *(vint4*)&stackPtr[0] = s0; *(vint4*)&stackPtr[1] = s1; + cur = toSizeT(s2); + stackPtr+=2; + return; + } + r = bscf(mask); + NodeRef c3 = node->child(r); BVH::prefetch(c3,types); unsigned int d3 = ((unsigned int*)&tNear)[r]; + vint4 s3((size_t)c3,(size_t)d3); + /* 4 hits */ + StackItemT::sort4(s0,s1,s2,s3); + *(vint4*)&stackPtr[0] = s0; *(vint4*)&stackPtr[1] = s1; *(vint4*)&stackPtr[2] = s2; + cur = toSizeT(s3); + stackPtr+=3; +#else + /*! Here starts the slow path for 3 or 4 hit children. We push + * all nodes onto the stack to sort them there. */ + assert(stackPtr < stackEnd); + stackPtr->ptr = c0; stackPtr->dist = d0; stackPtr++; + assert(stackPtr < stackEnd); + stackPtr->ptr = c1; stackPtr->dist = d1; stackPtr++; + + /*! three children are hit, push all onto stack and sort 3 stack items, continue with closest child */ + assert(stackPtr < stackEnd); + r = bscf(mask); + NodeRef c = node->child(r); BVH::prefetch(c,types); unsigned int d = ((unsigned int*)&tNear)[r]; stackPtr->ptr = c; stackPtr->dist = d; stackPtr++; + assert(c != BVH::emptyNode); + if (likely(mask == 0)) { + sort(stackPtr[-1],stackPtr[-2],stackPtr[-3]); + cur = (NodeRef) stackPtr[-1].ptr; stackPtr--; + return; + } + + /*! four children are hit, push all onto stack and sort 4 stack items, continue with closest child */ + assert(stackPtr < stackEnd); + r = bscf(mask); + c = node->child(r); BVH::prefetch(c,types); d = *(unsigned int*)&tNear[r]; stackPtr->ptr = c; stackPtr->dist = d; stackPtr++; + assert(c != BVH::emptyNode); + sort(stackPtr[-1],stackPtr[-2],stackPtr[-3],stackPtr[-4]); + cur = (NodeRef) stackPtr[-1].ptr; stackPtr--; +#endif +#endif + } + + /* Traverses a node with at least one hit child. Optimized for finding any hit (occlusion). */ + static __forceinline void traverseAnyHit(NodeRef& cur, + size_t mask, + const vfloat& tNear, + NodeRef*& stackPtr, + NodeRef* stackEnd) + { + const BaseNode* node = cur.baseNode(); + + /*! one child is hit, continue with that child */ + size_t r = bscf(mask); + cur = node->child(r); + BVH::prefetch(cur,types); + + /* simpler in sequence traversal order */ + assert(cur != BVH::emptyNode); + if (likely(mask == 0)) return; + assert(stackPtr < stackEnd); + *stackPtr = cur; stackPtr++; + + for (; ;) + { + r = bscf(mask); + cur = node->child(r); BVH::prefetch(cur,types); + assert(cur != BVH::emptyNode); + if (likely(mask == 0)) return; + assert(stackPtr < stackEnd); + *stackPtr = cur; stackPtr++; + } + } + }; + + /* Specialization for BVH8. */ + template + class BVHNNodeTraverser1Hit<8, Nx, types> + { + typedef BVH8 BVH; + typedef BVH8::NodeRef NodeRef; + typedef BVH8::BaseNode BaseNode; + +#if defined(__AVX512VL__) + template + static __forceinline void traverseClosestHitAVX512VL8(NodeRef& cur, + size_t mask, + const vfloat8& tNear, + StackItemT*& stackPtr, + StackItemT* stackEnd) + { + assert(mask != 0); + const BaseNode* node = cur.baseNode(); + const vllong4 n0 = vllong4::loadu((vllong4*)&node->children[0]); + const vllong4 n1 = vllong4::loadu((vllong4*)&node->children[4]); + vint8 distance_i = (asInt(tNear) & 0xfffffff8) | vint8(step); + distance_i = vint8::compact((int)mask,distance_i,distance_i); + cur = permuteExtract(distance_i,n0,n1); + BVH::prefetch(cur,types); + + mask &= mask-1; + if (likely(mask == 0)) return; + + /* 2 hits: order A0 B0 */ + const vint8 d0(distance_i); + const vint8 d1(shuffle<1>(distance_i)); + cur = permuteExtract(d1,n0,n1); + BVH::prefetch(cur,types); + + const vint8 dist_A0 = min(d0, d1); + const vint8 dist_B0 = max(d0, d1); + assert(dist_A0[0] < dist_B0[0]); + + mask &= mask-1; + if (likely(mask == 0)) { + cur = permuteExtract(dist_A0,n0,n1); + stackPtr[0].ptr = permuteExtract(dist_B0,n0,n1); + *(float*)&stackPtr[0].dist = permuteExtract(dist_B0,tNear); + stackPtr++; + return; + } + + /* 3 hits: order A1 B1 C1 */ + + const vint8 d2(shuffle<2>(distance_i)); + cur = permuteExtract(d2,n0,n1); + BVH::prefetch(cur,types); + + const vint8 dist_A1 = min(dist_A0,d2); + const vint8 dist_tmp_B1 = max(dist_A0,d2); + const vint8 dist_B1 = min(dist_B0,dist_tmp_B1); + const vint8 dist_C1 = max(dist_B0,dist_tmp_B1); + assert(dist_A1[0] < dist_B1[0]); + assert(dist_B1[0] < dist_C1[0]); + + mask &= mask-1; + if (likely(mask == 0)) { + cur = permuteExtract(dist_A1,n0,n1); + stackPtr[0].ptr = permuteExtract(dist_C1,n0,n1); + *(float*)&stackPtr[0].dist = permuteExtract(dist_C1,tNear); + stackPtr[1].ptr = permuteExtract(dist_B1,n0,n1); + *(float*)&stackPtr[1].dist = permuteExtract(dist_B1,tNear); + stackPtr+=2; + return; + } + + /* 4 hits: order A2 B2 C2 D2 */ + + const vint8 d3(shuffle<3>(distance_i)); + cur = permuteExtract(d3,n0,n1); + BVH::prefetch(cur,types); + + const vint8 dist_A2 = min(dist_A1,d3); + const vint8 dist_tmp_B2 = max(dist_A1,d3); + const vint8 dist_B2 = min(dist_B1,dist_tmp_B2); + const vint8 dist_tmp_C2 = max(dist_B1,dist_tmp_B2); + const vint8 dist_C2 = min(dist_C1,dist_tmp_C2); + const vint8 dist_D2 = max(dist_C1,dist_tmp_C2); + assert(dist_A2[0] < dist_B2[0]); + assert(dist_B2[0] < dist_C2[0]); + assert(dist_C2[0] < dist_D2[0]); + + mask &= mask-1; + if (likely(mask == 0)) { + cur = permuteExtract(dist_A2,n0,n1); + stackPtr[0].ptr = permuteExtract(dist_D2,n0,n1); + *(float*)&stackPtr[0].dist = permuteExtract(dist_D2,tNear); + stackPtr[1].ptr = permuteExtract(dist_C2,n0,n1); + *(float*)&stackPtr[1].dist = permuteExtract(dist_C2,tNear); + stackPtr[2].ptr = permuteExtract(dist_B2,n0,n1); + *(float*)&stackPtr[2].dist = permuteExtract(dist_B2,tNear); + stackPtr+=3; + return; + } + + /* >=5 hits: reverse to descending order for writing to stack */ + + distance_i = align_shift_right<3>(distance_i,distance_i); + const size_t hits = 4 + popcnt(mask); + vint8 dist(INT_MIN); // this will work with -0.0f (0x80000000) as distance, isort_update uses >= to insert + + isort_quick_update(dist,dist_A2); + isort_quick_update(dist,dist_B2); + isort_quick_update(dist,dist_C2); + isort_quick_update(dist,dist_D2); + + do { + + distance_i = align_shift_right<1>(distance_i,distance_i); + cur = permuteExtract(distance_i,n0,n1); + BVH::prefetch(cur,types); + const vint8 new_dist(permute(distance_i,vint8(zero))); + mask &= mask-1; + isort_update(dist,new_dist); + + } while(mask); + + for (size_t i=0; i<7; i++) + assert(dist[i+0]>=dist[i+1]); + + for (size_t i=0;iptr = permuteExtract(dist,n0,n1); + *(float*)&stackPtr->dist = permuteExtract(dist,tNear); + dist = align_shift_right<1>(dist,dist); + stackPtr++; + } + cur = permuteExtract(dist,n0,n1); + } +#endif + + public: + static __forceinline void traverseClosestHit(NodeRef& cur, + size_t mask, + const vfloat& tNear, + StackItemT*& stackPtr, + StackItemT* stackEnd) + { + assert(mask != 0); +#if defined(__AVX512ER__) + traverseClosestHitAVX512<8,Nx,types,NodeRef,BaseNode>(cur,mask,tNear,stackPtr,stackEnd); +#elif defined(__AVX512VL__) + traverseClosestHitAVX512VL8(cur,mask,tNear,stackPtr,stackEnd); +#else + + const BaseNode* node = cur.baseNode(); + + /*! one child is hit, continue with that child */ + size_t r = bscf(mask); + cur = node->child(r); + BVH::prefetch(cur,types); + if (likely(mask == 0)) { + assert(cur != BVH::emptyNode); + return; + } + + /*! two children are hit, push far child, and continue with closer child */ + NodeRef c0 = cur; + const unsigned int d0 = ((unsigned int*)&tNear)[r]; + r = bscf(mask); + NodeRef c1 = node->child(r); + BVH::prefetch(c1,types); + const unsigned int d1 = ((unsigned int*)&tNear)[r]; + + assert(c0 != BVH::emptyNode); + assert(c1 != BVH::emptyNode); + if (likely(mask == 0)) { + assert(stackPtr < stackEnd); + if (d0 < d1) { stackPtr->ptr = c1; stackPtr->dist = d1; stackPtr++; cur = c0; return; } + else { stackPtr->ptr = c0; stackPtr->dist = d0; stackPtr++; cur = c1; return; } + } +#if NEW_SORTING_CODE == 1 + vint4 s0((size_t)c0,(size_t)d0); + vint4 s1((size_t)c1,(size_t)d1); + + r = bscf(mask); + NodeRef c2 = node->child(r); BVH::prefetch(c2,types); unsigned int d2 = ((unsigned int*)&tNear)[r]; + vint4 s2((size_t)c2,(size_t)d2); + /* 3 hits */ + if (likely(mask == 0)) { + StackItemT::sort3(s0,s1,s2); + *(vint4*)&stackPtr[0] = s0; *(vint4*)&stackPtr[1] = s1; + cur = toSizeT(s2); + stackPtr+=2; + return; + } + r = bscf(mask); + NodeRef c3 = node->child(r); BVH::prefetch(c3,types); unsigned int d3 = ((unsigned int*)&tNear)[r]; + vint4 s3((size_t)c3,(size_t)d3); + /* 4 hits */ + if (likely(mask == 0)) { + StackItemT::sort4(s0,s1,s2,s3); + *(vint4*)&stackPtr[0] = s0; *(vint4*)&stackPtr[1] = s1; *(vint4*)&stackPtr[2] = s2; + cur = toSizeT(s3); + stackPtr+=3; + return; + } + *(vint4*)&stackPtr[0] = s0; *(vint4*)&stackPtr[1] = s1; *(vint4*)&stackPtr[2] = s2; *(vint4*)&stackPtr[3] = s3; + /*! fallback case if more than 4 children are hit */ + StackItemT* stackFirst = stackPtr; + stackPtr+=4; + while (1) + { + assert(stackPtr < stackEnd); + r = bscf(mask); + NodeRef c = node->child(r); BVH::prefetch(c,types); unsigned int d = *(unsigned int*)&tNear[r]; + const vint4 s((size_t)c,(size_t)d); + *(vint4*)stackPtr++ = s; + assert(c != BVH::emptyNode); + if (unlikely(mask == 0)) break; + } + sort(stackFirst,stackPtr); + cur = (NodeRef) stackPtr[-1].ptr; stackPtr--; +#else + /*! Here starts the slow path for 3 or 4 hit children. We push + * all nodes onto the stack to sort them there. */ + assert(stackPtr < stackEnd); + stackPtr->ptr = c0; stackPtr->dist = d0; stackPtr++; + assert(stackPtr < stackEnd); + stackPtr->ptr = c1; stackPtr->dist = d1; stackPtr++; + + /*! three children are hit, push all onto stack and sort 3 stack items, continue with closest child */ + assert(stackPtr < stackEnd); + r = bscf(mask); + NodeRef c = node->child(r); BVH::prefetch(c,types); unsigned int d = ((unsigned int*)&tNear)[r]; stackPtr->ptr = c; stackPtr->dist = d; stackPtr++; + assert(c != BVH::emptyNode); + if (likely(mask == 0)) { + sort(stackPtr[-1],stackPtr[-2],stackPtr[-3]); + cur = (NodeRef) stackPtr[-1].ptr; stackPtr--; + return; + } + + /*! four children are hit, push all onto stack and sort 4 stack items, continue with closest child */ + assert(stackPtr < stackEnd); + r = bscf(mask); + c = node->child(r); BVH::prefetch(c,types); d = *(unsigned int*)&tNear[r]; stackPtr->ptr = c; stackPtr->dist = d; stackPtr++; + assert(c != BVH::emptyNode); + if (likely(mask == 0)) { + sort(stackPtr[-1],stackPtr[-2],stackPtr[-3],stackPtr[-4]); + cur = (NodeRef) stackPtr[-1].ptr; stackPtr--; + return; + } + /*! fallback case if more than 4 children are hit */ + StackItemT* stackFirst = stackPtr-4; + while (1) + { + assert(stackPtr < stackEnd); + r = bscf(mask); + c = node->child(r); BVH::prefetch(c,types); d = *(unsigned int*)&tNear[r]; stackPtr->ptr = c; stackPtr->dist = d; stackPtr++; + assert(c != BVH::emptyNode); + if (unlikely(mask == 0)) break; + } + sort(stackFirst,stackPtr); + cur = (NodeRef) stackPtr[-1].ptr; stackPtr--; +#endif +#endif + } + + static __forceinline void traverseAnyHit(NodeRef& cur, + size_t mask, + const vfloat& tNear, + NodeRef*& stackPtr, + NodeRef* stackEnd) + { + const BaseNode* node = cur.baseNode(); + + /*! one child is hit, continue with that child */ + size_t r = bscf(mask); + cur = node->child(r); + BVH::prefetch(cur,types); + + /* simpler in sequence traversal order */ + assert(cur != BVH::emptyNode); + if (likely(mask == 0)) return; + assert(stackPtr < stackEnd); + *stackPtr = cur; stackPtr++; + + for (; ;) + { + r = bscf(mask); + cur = node->child(r); BVH::prefetch(cur,types); + assert(cur != BVH::emptyNode); + if (likely(mask == 0)) return; + assert(stackPtr < stackEnd); + *stackPtr = cur; stackPtr++; + } + } + }; + } +} diff --git a/thirdparty/embree/kernels/bvh/bvh_traverser_stream.h b/thirdparty/embree/kernels/bvh/bvh_traverser_stream.h new file mode 100644 index 000000000000..9c603babf0a5 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/bvh_traverser_stream.h @@ -0,0 +1,154 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh.h" +#include "../common/ray.h" +#include "../common/stack_item.h" + +namespace embree +{ + namespace isa + { + template + class BVHNNodeTraverserStreamHitCoherent + { + typedef BVHN BVH; + typedef typename BVH::NodeRef NodeRef; + typedef typename BVH::BaseNode BaseNode; + + public: + template + static __forceinline void traverseClosestHit(NodeRef& cur, + size_t& m_trav_active, + const vbool& vmask, + const vfloat& tNear, + const T* const tMask, + StackItemMaskCoherent*& stackPtr) + { + const NodeRef parent = cur; + size_t mask = movemask(vmask); + assert(mask != 0); + const BaseNode* node = cur.baseNode(); + + /*! one child is hit, continue with that child */ + const size_t r0 = bscf(mask); + assert(r0 < 8); + cur = node->child(r0); + BVHN::prefetch(cur,types); + m_trav_active = tMask[r0]; + assert(cur != BVH::emptyNode); + if (unlikely(mask == 0)) return; + + const unsigned int* const tNear_i = (unsigned int*)&tNear; + + /*! two children are hit, push far child, and continue with closer child */ + NodeRef c0 = cur; + unsigned int d0 = tNear_i[r0]; + const size_t r1 = bscf(mask); + assert(r1 < 8); + NodeRef c1 = node->child(r1); + BVHN::prefetch(c1,types); + unsigned int d1 = tNear_i[r1]; + + assert(c0 != BVH::emptyNode); + assert(c1 != BVH::emptyNode); + if (likely(mask == 0)) { + if (d0 < d1) { + assert(tNear[r1] >= 0.0f); + stackPtr->mask = tMask[r1]; + stackPtr->parent = parent; + stackPtr->child = c1; + stackPtr++; + cur = c0; + m_trav_active = tMask[r0]; + return; + } + else { + assert(tNear[r0] >= 0.0f); + stackPtr->mask = tMask[r0]; + stackPtr->parent = parent; + stackPtr->child = c0; + stackPtr++; + cur = c1; + m_trav_active = tMask[r1]; + return; + } + } + + /*! slow path for more than two hits */ + size_t hits = movemask(vmask); + const vint dist_i = select(vmask, (asInt(tNear) & 0xfffffff8) | vint(step), 0); + #if defined(__AVX512F__) && !defined(__AVX512VL__) // KNL + const vint tmp = extractN(dist_i); + const vint dist_i_sorted = usort_descending(tmp); + #else + const vint dist_i_sorted = usort_descending(dist_i); + #endif + const vint sorted_index = dist_i_sorted & 7; + + size_t i = 0; + for (;;) + { + const unsigned int index = sorted_index[i]; + assert(index < 8); + cur = node->child(index); + m_trav_active = tMask[index]; + assert(m_trav_active); + BVHN::prefetch(cur,types); + bscf(hits); + if (unlikely(hits==0)) break; + i++; + assert(cur != BVH::emptyNode); + assert(tNear[index] >= 0.0f); + stackPtr->mask = m_trav_active; + stackPtr->parent = parent; + stackPtr->child = cur; + stackPtr++; + } + } + + template + static __forceinline void traverseAnyHit(NodeRef& cur, + size_t& m_trav_active, + const vbool& vmask, + const T* const tMask, + StackItemMaskCoherent*& stackPtr) + { + const NodeRef parent = cur; + size_t mask = movemask(vmask); + assert(mask != 0); + const BaseNode* node = cur.baseNode(); + + /*! one child is hit, continue with that child */ + size_t r = bscf(mask); + cur = node->child(r); + BVHN::prefetch(cur,types); + m_trav_active = tMask[r]; + + /* simple in order sequence */ + assert(cur != BVH::emptyNode); + if (likely(mask == 0)) return; + stackPtr->mask = m_trav_active; + stackPtr->parent = parent; + stackPtr->child = cur; + stackPtr++; + + for (; ;) + { + r = bscf(mask); + cur = node->child(r); + BVHN::prefetch(cur,types); + m_trav_active = tMask[r]; + assert(cur != BVH::emptyNode); + if (likely(mask == 0)) return; + stackPtr->mask = m_trav_active; + stackPtr->parent = parent; + stackPtr->child = cur; + stackPtr++; + } + } + }; + } +} diff --git a/thirdparty/embree/kernels/bvh/node_intersector.h b/thirdparty/embree/kernels/bvh/node_intersector.h new file mode 100644 index 000000000000..a978c0c45907 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/node_intersector.h @@ -0,0 +1,31 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bvh.h" + +namespace embree +{ + namespace isa + { + struct NearFarPrecalculations + { + size_t nearX, nearY, nearZ; + size_t farX, farY, farZ; + + __forceinline NearFarPrecalculations() {} + + __forceinline NearFarPrecalculations(const Vec3fa& dir, size_t N) + { + const size_t size = sizeof(float)*N; + nearX = (dir.x < 0.0f) ? 1*size : 0*size; + nearY = (dir.y < 0.0f) ? 3*size : 2*size; + nearZ = (dir.z < 0.0f) ? 5*size : 4*size; + farX = nearX ^ size; + farY = nearY ^ size; + farZ = nearZ ^ size; + } + }; + } +} diff --git a/thirdparty/embree/kernels/bvh/node_intersector1.h b/thirdparty/embree/kernels/bvh/node_intersector1.h new file mode 100644 index 000000000000..b1e63ce345ce --- /dev/null +++ b/thirdparty/embree/kernels/bvh/node_intersector1.h @@ -0,0 +1,1695 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "node_intersector.h" + +namespace embree +{ + namespace isa + { + ////////////////////////////////////////////////////////////////////////////////////// + // Ray structure used in single-ray traversal + ////////////////////////////////////////////////////////////////////////////////////// + + template + struct TravRayBase; + + /* Base (without tnear and tfar) */ + template + struct TravRayBase + { + __forceinline TravRayBase() {} + + __forceinline TravRayBase(const Vec3fa& ray_org, const Vec3fa& ray_dir) + : org_xyz(ray_org), dir_xyz(ray_dir) + { + const Vec3fa ray_rdir = rcp_safe(ray_dir); + org = Vec3vf(ray_org.x,ray_org.y,ray_org.z); + dir = Vec3vf(ray_dir.x,ray_dir.y,ray_dir.z); + rdir = Vec3vf(ray_rdir.x,ray_rdir.y,ray_rdir.z); +#if defined(__AVX2__) + const Vec3fa ray_org_rdir = ray_org*ray_rdir; + org_rdir = Vec3vf(ray_org_rdir.x,ray_org_rdir.y,ray_org_rdir.z); +#endif + nearX = ray_rdir.x >= 0.0f ? 0*sizeof(vfloat) : 1*sizeof(vfloat); + nearY = ray_rdir.y >= 0.0f ? 2*sizeof(vfloat) : 3*sizeof(vfloat); + nearZ = ray_rdir.z >= 0.0f ? 4*sizeof(vfloat) : 5*sizeof(vfloat); + farX = nearX ^ sizeof(vfloat); + farY = nearY ^ sizeof(vfloat); + farZ = nearZ ^ sizeof(vfloat); + +#if defined(__AVX512ER__) // KNL+ + /* optimization works only for 8-wide BVHs with 16-wide SIMD */ + const vint<16> id(step); + const vint<16> id2 = align_shift_right<16/2>(id, id); + permX = select(vfloat<16>(dir.x) >= 0.0f, id, id2); + permY = select(vfloat<16>(dir.y) >= 0.0f, id, id2); + permZ = select(vfloat<16>(dir.z) >= 0.0f, id, id2); +#endif + + } + + template + __forceinline TravRayBase(size_t k, const Vec3vf& ray_org, const Vec3vf& ray_dir, + const Vec3vf& ray_rdir, const Vec3vi& nearXYZ, + size_t flip = sizeof(vfloat)) + { + org = Vec3vf(ray_org.x[k], ray_org.y[k], ray_org.z[k]); + dir = Vec3vf(ray_dir.x[k], ray_dir.y[k], ray_dir.z[k]); + rdir = Vec3vf(ray_rdir.x[k], ray_rdir.y[k], ray_rdir.z[k]); +#if defined(__AVX2__) + org_rdir = org*rdir; +#endif + nearX = nearXYZ.x[k]; + nearY = nearXYZ.y[k]; + nearZ = nearXYZ.z[k]; + farX = nearX ^ flip; + farY = nearY ^ flip; + farZ = nearZ ^ flip; + +#if defined(__AVX512ER__) // KNL+ + /* optimization works only for 8-wide BVHs with 16-wide SIMD */ + const vint<16> id(step); + const vint<16> id2 = align_shift_right<16/2>(id, id); + permX = select(vfloat<16>(dir.x) >= 0.0f, id, id2); + permY = select(vfloat<16>(dir.y) >= 0.0f, id, id2); + permZ = select(vfloat<16>(dir.z) >= 0.0f, id, id2); +#endif + } + + Vec3fa org_xyz, dir_xyz; + Vec3vf org, dir, rdir; +#if defined(__AVX2__) + Vec3vf org_rdir; +#endif +#if defined(__AVX512ER__) // KNL+ + vint16 permX, permY, permZ; +#endif + + size_t nearX, nearY, nearZ; + size_t farX, farY, farZ; + }; + + /* Base (without tnear and tfar) */ + template + struct TravRayBase + { + __forceinline TravRayBase() {} + + __forceinline TravRayBase(const Vec3fa& ray_org, const Vec3fa& ray_dir) + : org_xyz(ray_org), dir_xyz(ray_dir) + { + const float round_down = 1.0f-3.0f*float(ulp); + const float round_up = 1.0f+3.0f*float(ulp); + const Vec3fa ray_rdir = 1.0f/zero_fix(ray_dir); + const Vec3fa ray_rdir_near = round_down*ray_rdir; + const Vec3fa ray_rdir_far = round_up *ray_rdir; + org = Vec3vf(ray_org.x,ray_org.y,ray_org.z); + dir = Vec3vf(ray_dir.x,ray_dir.y,ray_dir.z); + rdir_near = Vec3vf(ray_rdir_near.x,ray_rdir_near.y,ray_rdir_near.z); + rdir_far = Vec3vf(ray_rdir_far .x,ray_rdir_far .y,ray_rdir_far .z); + + nearX = ray_rdir_near.x >= 0.0f ? 0*sizeof(vfloat) : 1*sizeof(vfloat); + nearY = ray_rdir_near.y >= 0.0f ? 2*sizeof(vfloat) : 3*sizeof(vfloat); + nearZ = ray_rdir_near.z >= 0.0f ? 4*sizeof(vfloat) : 5*sizeof(vfloat); + farX = nearX ^ sizeof(vfloat); + farY = nearY ^ sizeof(vfloat); + farZ = nearZ ^ sizeof(vfloat); + +#if defined(__AVX512ER__) // KNL+ + /* optimization works only for 8-wide BVHs with 16-wide SIMD */ + const vint<16> id(step); + const vint<16> id2 = align_shift_right<16/2>(id, id); + permX = select(vfloat<16>(dir.x) >= 0.0f, id, id2); + permY = select(vfloat<16>(dir.y) >= 0.0f, id, id2); + permZ = select(vfloat<16>(dir.z) >= 0.0f, id, id2); +#endif + } + + template + __forceinline TravRayBase(size_t k, const Vec3vf& ray_org, const Vec3vf& ray_dir, + const Vec3vf& ray_rdir, const Vec3vi& nearXYZ, + size_t flip = sizeof(vfloat)) + { + const vfloat round_down = 1.0f-3.0f*float(ulp); + const vfloat round_up = 1.0f+3.0f*float(ulp); + org = Vec3vf(ray_org.x[k], ray_org.y[k], ray_org.z[k]); + dir = Vec3vf(ray_dir.x[k], ray_dir.y[k], ray_dir.z[k]); + rdir_near = round_down*Vec3vf(ray_rdir.x[k], ray_rdir.y[k], ray_rdir.z[k]); + rdir_far = round_up *Vec3vf(ray_rdir.x[k], ray_rdir.y[k], ray_rdir.z[k]); + + nearX = nearXYZ.x[k]; + nearY = nearXYZ.y[k]; + nearZ = nearXYZ.z[k]; + farX = nearX ^ flip; + farY = nearY ^ flip; + farZ = nearZ ^ flip; + +#if defined(__AVX512ER__) // KNL+ + /* optimization works only for 8-wide BVHs with 16-wide SIMD */ + const vint<16> id(step); + const vint<16> id2 = align_shift_right<16/2>(id, id); + permX = select(vfloat<16>(dir.x) >= 0.0f, id, id2); + permY = select(vfloat<16>(dir.y) >= 0.0f, id, id2); + permZ = select(vfloat<16>(dir.z) >= 0.0f, id, id2); +#endif + } + + Vec3fa org_xyz, dir_xyz; + Vec3vf org, dir, rdir_near, rdir_far; +#if defined(__AVX512ER__) // KNL+ + vint16 permX, permY, permZ; +#endif + + size_t nearX, nearY, nearZ; + size_t farX, farY, farZ; + }; + + /* Full (with tnear and tfar) */ + template + struct TravRay : TravRayBase + { + __forceinline TravRay() {} + + __forceinline TravRay(const Vec3fa& ray_org, const Vec3fa& ray_dir, float ray_tnear, float ray_tfar) + : TravRayBase(ray_org, ray_dir), + tnear(ray_tnear), tfar(ray_tfar) {} + + template + __forceinline TravRay(size_t k, const Vec3vf& ray_org, const Vec3vf& ray_dir, + const Vec3vf& ray_rdir, const Vec3vi& nearXYZ, + float ray_tnear, float ray_tfar, + size_t flip = sizeof(vfloat)) + : TravRayBase(k, ray_org, ray_dir, ray_rdir, nearXYZ, flip), + tnear(ray_tnear), tfar(ray_tfar) {} + + vfloat tnear; + vfloat tfar; + }; + + ////////////////////////////////////////////////////////////////////////////////////// + // Point Query structure used in single-ray traversal + ////////////////////////////////////////////////////////////////////////////////////// + + template + struct TravPointQuery + { + __forceinline TravPointQuery() {} + + __forceinline TravPointQuery(const Vec3fa& query_org, const Vec3fa& query_rad) + { + org = Vec3vf(query_org.x, query_org.y, query_org.z); + rad = Vec3vf(query_rad.x, query_rad.y, query_rad.z); + } + + __forceinline vfloat const& tfar() const { + return rad.x; + } + + Vec3vf org, rad; + }; + + ////////////////////////////////////////////////////////////////////////////////////// + // point query + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t pointQuerySphereDistAndMask( + const TravPointQuery& query, vfloat& dist, vfloat const& minX, vfloat const& maxX, + vfloat const& minY, vfloat const& maxY, vfloat const& minZ, vfloat const& maxZ) + { + const vfloat vX = min(max(query.org.x, minX), maxX) - query.org.x; + const vfloat vY = min(max(query.org.y, minY), maxY) - query.org.y; + const vfloat vZ = min(max(query.org.z, minZ), maxZ) - query.org.z; + dist = vX * vX + vY * vY + vZ * vZ; + const vbool vmask = dist <= query.tfar()*query.tfar(); + const vbool valid = minX <= maxX; + return movemask(vmask) & movemask(valid); + } + + template + __forceinline size_t pointQueryNodeSphere(const typename BVHN::AABBNode* node, const TravPointQuery& query, vfloat& dist) + { + const vfloat minX = vfloat::load((float*)((const char*)&node->lower_x)); + const vfloat minY = vfloat::load((float*)((const char*)&node->lower_y)); + const vfloat minZ = vfloat::load((float*)((const char*)&node->lower_z)); + const vfloat maxX = vfloat::load((float*)((const char*)&node->upper_x)); + const vfloat maxY = vfloat::load((float*)((const char*)&node->upper_y)); + const vfloat maxZ = vfloat::load((float*)((const char*)&node->upper_z)); + return pointQuerySphereDistAndMask(query, dist, minX, maxX, minY, maxY, minZ, maxZ); + } + + template + __forceinline size_t pointQueryNodeSphere(const typename BVHN::AABBNodeMB* node, const TravPointQuery& query, const float time, vfloat& dist) + { + const vfloat* pMinX = (const vfloat*)((const char*)&node->lower_x); + const vfloat* pMinY = (const vfloat*)((const char*)&node->lower_y); + const vfloat* pMinZ = (const vfloat*)((const char*)&node->lower_z); + const vfloat* pMaxX = (const vfloat*)((const char*)&node->upper_x); + const vfloat* pMaxY = (const vfloat*)((const char*)&node->upper_y); + const vfloat* pMaxZ = (const vfloat*)((const char*)&node->upper_z); + const vfloat minX = madd(time,pMinX[6],vfloat(pMinX[0])); + const vfloat minY = madd(time,pMinY[6],vfloat(pMinY[0])); + const vfloat minZ = madd(time,pMinZ[6],vfloat(pMinZ[0])); + const vfloat maxX = madd(time,pMaxX[6],vfloat(pMaxX[0])); + const vfloat maxY = madd(time,pMaxY[6],vfloat(pMaxY[0])); + const vfloat maxZ = madd(time,pMaxZ[6],vfloat(pMaxZ[0])); + return pointQuerySphereDistAndMask(query, dist, minX, maxX, minY, maxY, minZ, maxZ); + } + + template + __forceinline size_t pointQueryNodeSphereMB4D(const typename BVHN::NodeRef ref, const TravPointQuery& query, const float time, vfloat& dist) + { + const typename BVHN::AABBNodeMB* node = ref.getAABBNodeMB(); + size_t mask = pointQueryNodeSphere(node, query, time, dist); + + if (unlikely(ref.isAABBNodeMB4D())) { + const typename BVHN::AABBNodeMB4D* node1 = (const typename BVHN::AABBNodeMB4D*) node; + const vbool vmask = (node1->lower_t <= time) & (time < node1->upper_t); + mask &= movemask(vmask); + } + + return mask; + } + + template + __forceinline size_t pointQueryNodeSphere(const typename BVHN::QuantizedBaseNode* node, const TravPointQuery& query, vfloat& dist) + { + const vfloat start_x(node->start.x); + const vfloat scale_x(node->scale.x); + const vfloat minX = madd(node->template dequantize((0*sizeof(vfloat)) >> 2),scale_x,start_x); + const vfloat maxX = madd(node->template dequantize((1*sizeof(vfloat)) >> 2),scale_x,start_x); + const vfloat start_y(node->start.y); + const vfloat scale_y(node->scale.y); + const vfloat minY = madd(node->template dequantize((2*sizeof(vfloat)) >> 2),scale_y,start_y); + const vfloat maxY = madd(node->template dequantize((3*sizeof(vfloat)) >> 2),scale_y,start_y); + const vfloat start_z(node->start.z); + const vfloat scale_z(node->scale.z); + const vfloat minZ = madd(node->template dequantize((4*sizeof(vfloat)) >> 2),scale_z,start_z); + const vfloat maxZ = madd(node->template dequantize((5*sizeof(vfloat)) >> 2),scale_z,start_z); + return pointQuerySphereDistAndMask(query, dist, minX, maxX, minY, maxY, minZ, maxZ) & movemask(node->validMask()); + } + + template + __forceinline size_t pointQueryNodeSphere(const typename BVHN::QuantizedBaseNodeMB* node, const TravPointQuery& query, const float time, vfloat& dist) + { + const vfloat minX = node->dequantizeLowerX(time); + const vfloat maxX = node->dequantizeUpperX(time); + const vfloat minY = node->dequantizeLowerY(time); + const vfloat maxY = node->dequantizeUpperY(time); + const vfloat minZ = node->dequantizeLowerZ(time); + const vfloat maxZ = node->dequantizeUpperZ(time); + return pointQuerySphereDistAndMask(query, dist, minX, maxX, minY, maxY, minZ, maxZ) & movemask(node->validMask()); + } + + template + __forceinline size_t pointQueryNodeSphere(const typename BVHN::OBBNode* node, const TravPointQuery& query, vfloat& dist) + { + // TODO: point query - implement + const vbool vmask = vbool(true); + const size_t mask = movemask(vmask) & ((1<(0.0f); + return mask; + } + + template + __forceinline size_t pointQueryNodeSphere(const typename BVHN::OBBNodeMB* node, const TravPointQuery& query, const float time, vfloat& dist) + { + // TODO: point query - implement + const vbool vmask = vbool(true); + const size_t mask = movemask(vmask) & ((1<(0.0f); + return mask; + } + + template + __forceinline size_t pointQueryAABBDistAndMask( + const TravPointQuery& query, vfloat& dist, vfloat const& minX, vfloat const& maxX, + vfloat const& minY, vfloat const& maxY, vfloat const& minZ, vfloat const& maxZ) + { + const vfloat vX = min(max(query.org.x, minX), maxX) - query.org.x; + const vfloat vY = min(max(query.org.y, minY), maxY) - query.org.y; + const vfloat vZ = min(max(query.org.z, minZ), maxZ) - query.org.z; + dist = vX * vX + vY * vY + vZ * vZ; + const vbool valid = minX <= maxX; + const vbool vmask = !((maxX < query.org.x - query.rad.x) | (minX > query.org.x + query.rad.x) | + (maxY < query.org.y - query.rad.y) | (minY > query.org.y + query.rad.y) | + (maxZ < query.org.z - query.rad.z) | (minZ > query.org.z + query.rad.z)); + return movemask(vmask) & movemask(valid); + } + + template + __forceinline size_t pointQueryNodeAABB(const typename BVHN::AABBNode* node, const TravPointQuery& query, vfloat& dist) + { + const vfloat minX = vfloat::load((float*)((const char*)&node->lower_x)); + const vfloat minY = vfloat::load((float*)((const char*)&node->lower_y)); + const vfloat minZ = vfloat::load((float*)((const char*)&node->lower_z)); + const vfloat maxX = vfloat::load((float*)((const char*)&node->upper_x)); + const vfloat maxY = vfloat::load((float*)((const char*)&node->upper_y)); + const vfloat maxZ = vfloat::load((float*)((const char*)&node->upper_z)); + return pointQueryAABBDistAndMask(query, dist, minX, maxX, minY, maxY, minZ, maxZ); + } + + template + __forceinline size_t pointQueryNodeAABB(const typename BVHN::AABBNodeMB* node, const TravPointQuery& query, const float time, vfloat& dist) + { + const vfloat* pMinX = (const vfloat*)((const char*)&node->lower_x); + const vfloat* pMinY = (const vfloat*)((const char*)&node->lower_y); + const vfloat* pMinZ = (const vfloat*)((const char*)&node->lower_z); + const vfloat* pMaxX = (const vfloat*)((const char*)&node->upper_x); + const vfloat* pMaxY = (const vfloat*)((const char*)&node->upper_y); + const vfloat* pMaxZ = (const vfloat*)((const char*)&node->upper_z); + const vfloat minX = madd(time,pMinX[6],vfloat(pMinX[0])); + const vfloat minY = madd(time,pMinY[6],vfloat(pMinY[0])); + const vfloat minZ = madd(time,pMinZ[6],vfloat(pMinZ[0])); + const vfloat maxX = madd(time,pMaxX[6],vfloat(pMaxX[0])); + const vfloat maxY = madd(time,pMaxY[6],vfloat(pMaxY[0])); + const vfloat maxZ = madd(time,pMaxZ[6],vfloat(pMaxZ[0])); + return pointQueryAABBDistAndMask(query, dist, minX, maxX, minY, maxY, minZ, maxZ); + } + + template + __forceinline size_t pointQueryNodeAABBMB4D(const typename BVHN::NodeRef ref, const TravPointQuery& query, const float time, vfloat& dist) + { + const typename BVHN::AABBNodeMB* node = ref.getAABBNodeMB(); + size_t mask = pointQueryNodeAABB(node, query, time, dist); + + if (unlikely(ref.isAABBNodeMB4D())) { + const typename BVHN::AABBNodeMB4D* node1 = (const typename BVHN::AABBNodeMB4D*) node; + const vbool vmask = (node1->lower_t <= time) & (time < node1->upper_t); + mask &= movemask(vmask); + } + + return mask; + } + + template + __forceinline size_t pointQueryNodeAABB(const typename BVHN::QuantizedBaseNode* node, const TravPointQuery& query, vfloat& dist) + { + const size_t mvalid = movemask(node->validMask()); + const vfloat start_x(node->start.x); + const vfloat scale_x(node->scale.x); + const vfloat minX = madd(node->template dequantize((0*sizeof(vfloat)) >> 2),scale_x,start_x); + const vfloat maxX = madd(node->template dequantize((1*sizeof(vfloat)) >> 2),scale_x,start_x); + const vfloat start_y(node->start.y); + const vfloat scale_y(node->scale.y); + const vfloat minY = madd(node->template dequantize((2*sizeof(vfloat)) >> 2),scale_y,start_y); + const vfloat maxY = madd(node->template dequantize((3*sizeof(vfloat)) >> 2),scale_y,start_y); + const vfloat start_z(node->start.z); + const vfloat scale_z(node->scale.z); + const vfloat minZ = madd(node->template dequantize((4*sizeof(vfloat)) >> 2),scale_z,start_z); + const vfloat maxZ = madd(node->template dequantize((5*sizeof(vfloat)) >> 2),scale_z,start_z); + return pointQueryAABBDistAndMask(query, dist, minX, maxX, minY, maxY, minZ, maxZ) & mvalid; + } + + template + __forceinline size_t pointQueryNodeAABB(const typename BVHN::QuantizedBaseNodeMB* node, const TravPointQuery& query, const float time, vfloat& dist) + { + const size_t mvalid = movemask(node->validMask()); + const vfloat minX = node->dequantizeLowerX(time); + const vfloat maxX = node->dequantizeUpperX(time); + const vfloat minY = node->dequantizeLowerY(time); + const vfloat maxY = node->dequantizeUpperY(time); + const vfloat minZ = node->dequantizeLowerZ(time); + const vfloat maxZ = node->dequantizeUpperZ(time); + return pointQueryAABBDistAndMask(query, dist, minX, maxX, minY, maxY, minZ, maxZ) & mvalid; + } + + template + __forceinline size_t pointQueryNodeAABB(const typename BVHN::OBBNode* node, const TravPointQuery& query, vfloat& dist) + { + // TODO: point query - implement + const vbool vmask = vbool(true); + const size_t mask = movemask(vmask) & ((1<(0.0f); + return mask; + } + + template + __forceinline size_t pointQueryNodeAABB(const typename BVHN::OBBNodeMB* node, const TravPointQuery& query, const float time, vfloat& dist) + { + // TODO: point query - implement + const vbool vmask = vbool(true); + const size_t mask = movemask(vmask) & ((1<(0.0f); + return mask; + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Fast AABBNode intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t intersectNode(const typename BVHN::AABBNode* node, const TravRay& ray, vfloat& dist); + + template<> + __forceinline size_t intersectNode<4,4>(const typename BVH4::AABBNode* node, const TravRay<4,4,false>& ray, vfloat4& dist) + { +#if defined(__AVX2__) + const vfloat4 tNearX = msub(vfloat4::load((float*)((const char*)&node->lower_x+ray.nearX)), ray.rdir.x, ray.org_rdir.x); + const vfloat4 tNearY = msub(vfloat4::load((float*)((const char*)&node->lower_x+ray.nearY)), ray.rdir.y, ray.org_rdir.y); + const vfloat4 tNearZ = msub(vfloat4::load((float*)((const char*)&node->lower_x+ray.nearZ)), ray.rdir.z, ray.org_rdir.z); + const vfloat4 tFarX = msub(vfloat4::load((float*)((const char*)&node->lower_x+ray.farX )), ray.rdir.x, ray.org_rdir.x); + const vfloat4 tFarY = msub(vfloat4::load((float*)((const char*)&node->lower_x+ray.farY )), ray.rdir.y, ray.org_rdir.y); + const vfloat4 tFarZ = msub(vfloat4::load((float*)((const char*)&node->lower_x+ray.farZ )), ray.rdir.z, ray.org_rdir.z); +#else + const vfloat4 tNearX = (vfloat4::load((float*)((const char*)&node->lower_x+ray.nearX)) - ray.org.x) * ray.rdir.x; + const vfloat4 tNearY = (vfloat4::load((float*)((const char*)&node->lower_x+ray.nearY)) - ray.org.y) * ray.rdir.y; + const vfloat4 tNearZ = (vfloat4::load((float*)((const char*)&node->lower_x+ray.nearZ)) - ray.org.z) * ray.rdir.z; + const vfloat4 tFarX = (vfloat4::load((float*)((const char*)&node->lower_x+ray.farX )) - ray.org.x) * ray.rdir.x; + const vfloat4 tFarY = (vfloat4::load((float*)((const char*)&node->lower_x+ray.farY )) - ray.org.y) * ray.rdir.y; + const vfloat4 tFarZ = (vfloat4::load((float*)((const char*)&node->lower_x+ray.farZ )) - ray.org.z) * ray.rdir.z; +#endif + +#if defined(__SSE4_1__) && !defined(__AVX512F__) // up to HSW + const vfloat4 tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat4 tFar = mini(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool4 vmask = asInt(tNear) > asInt(tFar); + const size_t mask = movemask(vmask) ^ ((1<<4)-1); +#elif defined(__AVX512F__) && !defined(__AVX512ER__) // SKX + const vfloat4 tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat4 tFar = mini(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool4 vmask = asInt(tNear) <= asInt(tFar); + const size_t mask = movemask(vmask); +#else + const vfloat4 tNear = max(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat4 tFar = min(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool4 vmask = tNear <= tFar; + const size_t mask = movemask(vmask); +#endif + dist = tNear; + return mask; + } + +#if defined(__AVX__) + + template<> + __forceinline size_t intersectNode<8,8>(const typename BVH8::AABBNode* node, const TravRay<8,8,false>& ray, vfloat8& dist) + { +#if defined(__AVX2__) + const vfloat8 tNearX = msub(vfloat8::load((float*)((const char*)&node->lower_x+ray.nearX)), ray.rdir.x, ray.org_rdir.x); + const vfloat8 tNearY = msub(vfloat8::load((float*)((const char*)&node->lower_x+ray.nearY)), ray.rdir.y, ray.org_rdir.y); + const vfloat8 tNearZ = msub(vfloat8::load((float*)((const char*)&node->lower_x+ray.nearZ)), ray.rdir.z, ray.org_rdir.z); + const vfloat8 tFarX = msub(vfloat8::load((float*)((const char*)&node->lower_x+ray.farX )), ray.rdir.x, ray.org_rdir.x); + const vfloat8 tFarY = msub(vfloat8::load((float*)((const char*)&node->lower_x+ray.farY )), ray.rdir.y, ray.org_rdir.y); + const vfloat8 tFarZ = msub(vfloat8::load((float*)((const char*)&node->lower_x+ray.farZ )), ray.rdir.z, ray.org_rdir.z); +#else + const vfloat8 tNearX = (vfloat8::load((float*)((const char*)&node->lower_x+ray.nearX)) - ray.org.x) * ray.rdir.x; + const vfloat8 tNearY = (vfloat8::load((float*)((const char*)&node->lower_x+ray.nearY)) - ray.org.y) * ray.rdir.y; + const vfloat8 tNearZ = (vfloat8::load((float*)((const char*)&node->lower_x+ray.nearZ)) - ray.org.z) * ray.rdir.z; + const vfloat8 tFarX = (vfloat8::load((float*)((const char*)&node->lower_x+ray.farX )) - ray.org.x) * ray.rdir.x; + const vfloat8 tFarY = (vfloat8::load((float*)((const char*)&node->lower_x+ray.farY )) - ray.org.y) * ray.rdir.y; + const vfloat8 tFarZ = (vfloat8::load((float*)((const char*)&node->lower_x+ray.farZ )) - ray.org.z) * ray.rdir.z; +#endif + +#if defined(__AVX2__) && !defined(__AVX512F__) // HSW + const vfloat8 tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat8 tFar = mini(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool8 vmask = asInt(tNear) > asInt(tFar); + const size_t mask = movemask(vmask) ^ ((1<<8)-1); +#elif defined(__AVX512F__) && !defined(__AVX512ER__) // SKX + const vfloat8 tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat8 tFar = mini(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool8 vmask = asInt(tNear) <= asInt(tFar); + const size_t mask = movemask(vmask); +#else + const vfloat8 tNear = max(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat8 tFar = min(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool8 vmask = tNear <= tFar; + const size_t mask = movemask(vmask); +#endif + dist = tNear; + return mask; + } + +#endif + +#if defined(__AVX512F__) && !defined(__AVX512VL__) // KNL + + template<> + __forceinline size_t intersectNode<4,16>(const typename BVH4::AABBNode* node, const TravRay<4,16,false>& ray, vfloat16& dist) + { + const vfloat16 tNearX = msub(vfloat16(*(vfloat4*)((const char*)&node->lower_x+ray.nearX)), ray.rdir.x, ray.org_rdir.x); + const vfloat16 tNearY = msub(vfloat16(*(vfloat4*)((const char*)&node->lower_x+ray.nearY)), ray.rdir.y, ray.org_rdir.y); + const vfloat16 tNearZ = msub(vfloat16(*(vfloat4*)((const char*)&node->lower_x+ray.nearZ)), ray.rdir.z, ray.org_rdir.z); + const vfloat16 tFarX = msub(vfloat16(*(vfloat4*)((const char*)&node->lower_x+ray.farX )), ray.rdir.x, ray.org_rdir.x); + const vfloat16 tFarY = msub(vfloat16(*(vfloat4*)((const char*)&node->lower_x+ray.farY )), ray.rdir.y, ray.org_rdir.y); + const vfloat16 tFarZ = msub(vfloat16(*(vfloat4*)((const char*)&node->lower_x+ray.farZ )), ray.rdir.z, ray.org_rdir.z); + const vfloat16 tNear = max(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat16 tFar = min(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool16 vmask = le(vbool16(0xf),tNear,tFar); + const size_t mask = movemask(vmask); + dist = tNear; + return mask; + } + + template<> + __forceinline size_t intersectNode<8,16>(const typename BVH8::AABBNode* node, const TravRay<8,16,false>& ray, vfloat16& dist) + { + const vllong8 invalid((size_t)BVH8::emptyNode); + const vboold8 m_valid(invalid != vllong8::loadu(node->children)); + const vfloat16 bminmaxX = permute(vfloat16::load((const float*)&node->lower_x), ray.permX); + const vfloat16 bminmaxY = permute(vfloat16::load((const float*)&node->lower_y), ray.permY); + const vfloat16 bminmaxZ = permute(vfloat16::load((const float*)&node->lower_z), ray.permZ); + const vfloat16 tNearFarX = msub(bminmaxX, ray.rdir.x, ray.org_rdir.x); + const vfloat16 tNearFarY = msub(bminmaxY, ray.rdir.y, ray.org_rdir.y); + const vfloat16 tNearFarZ = msub(bminmaxZ, ray.rdir.z, ray.org_rdir.z); + const vfloat16 tNear = max(tNearFarX, tNearFarY, tNearFarZ, ray.tnear); + const vfloat16 tFar = min(tNearFarX, tNearFarY, tNearFarZ, ray.tfar); + const vbool16 vmask = le(vboolf16(m_valid),tNear,align_shift_right<8>(tFar, tFar)); + const size_t mask = movemask(vmask); + dist = tNear; + return mask; + } + +#endif + + ////////////////////////////////////////////////////////////////////////////////////// + // Robust AABBNode intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t intersectNodeRobust(const typename BVHN::AABBNode* node, const TravRay& ray, vfloat& dist) + { + const vfloat tNearX = (vfloat::load((float*)((const char*)&node->lower_x+ray.nearX)) - ray.org.x) * ray.rdir_near.x; + const vfloat tNearY = (vfloat::load((float*)((const char*)&node->lower_x+ray.nearY)) - ray.org.y) * ray.rdir_near.y; + const vfloat tNearZ = (vfloat::load((float*)((const char*)&node->lower_x+ray.nearZ)) - ray.org.z) * ray.rdir_near.z; + const vfloat tFarX = (vfloat::load((float*)((const char*)&node->lower_x+ray.farX )) - ray.org.x) * ray.rdir_far.x; + const vfloat tFarY = (vfloat::load((float*)((const char*)&node->lower_x+ray.farY )) - ray.org.y) * ray.rdir_far.y; + const vfloat tFarZ = (vfloat::load((float*)((const char*)&node->lower_x+ray.farZ )) - ray.org.z) * ray.rdir_far.z; + const vfloat tNear = max(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat tFar = min(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool vmask = tNear <= tFar; + const size_t mask = movemask(vmask); + dist = tNear; + return mask; + } + +#if defined(__AVX512F__) && !defined(__AVX512VL__) // KNL + + template<> + __forceinline size_t intersectNodeRobust<4,16>(const typename BVHN<4>::AABBNode* node, const TravRay<4,16,true>& ray, vfloat<16>& dist) + { + const vfloat16 tNearX = (vfloat16(*(vfloat<4>*)((const char*)&node->lower_x+ray.nearX)) - ray.org.x) * ray.rdir_near.x; + const vfloat16 tNearY = (vfloat16(*(vfloat<4>*)((const char*)&node->lower_x+ray.nearY)) - ray.org.y) * ray.rdir_near.y; + const vfloat16 tNearZ = (vfloat16(*(vfloat<4>*)((const char*)&node->lower_x+ray.nearZ)) - ray.org.z) * ray.rdir_near.z; + const vfloat16 tFarX = (vfloat16(*(vfloat<4>*)((const char*)&node->lower_x+ray.farX )) - ray.org.x) * ray.rdir_far.x; + const vfloat16 tFarY = (vfloat16(*(vfloat<4>*)((const char*)&node->lower_x+ray.farY )) - ray.org.y) * ray.rdir_far.y; + const vfloat16 tFarZ = (vfloat16(*(vfloat<4>*)((const char*)&node->lower_x+ray.farZ )) - ray.org.z) * ray.rdir_far.z; + const vfloat16 tNear = max(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat16 tFar = min(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool16 vmask = le((1 << 4)-1,tNear,tFar); + const size_t mask = movemask(vmask); + dist = tNear; + return mask; + } + + template<> + __forceinline size_t intersectNodeRobust<8,16>(const typename BVHN<8>::AABBNode* node, const TravRay<8,16,true>& ray, vfloat<16>& dist) + { + const vfloat16 tNearX = (vfloat16(*(vfloat<8>*)((const char*)&node->lower_x+ray.nearX)) - ray.org.x) * ray.rdir_near.x; + const vfloat16 tNearY = (vfloat16(*(vfloat<8>*)((const char*)&node->lower_x+ray.nearY)) - ray.org.y) * ray.rdir_near.y; + const vfloat16 tNearZ = (vfloat16(*(vfloat<8>*)((const char*)&node->lower_x+ray.nearZ)) - ray.org.z) * ray.rdir_near.z; + const vfloat16 tFarX = (vfloat16(*(vfloat<8>*)((const char*)&node->lower_x+ray.farX )) - ray.org.x) * ray.rdir_far.x; + const vfloat16 tFarY = (vfloat16(*(vfloat<8>*)((const char*)&node->lower_x+ray.farY )) - ray.org.y) * ray.rdir_far.y; + const vfloat16 tFarZ = (vfloat16(*(vfloat<8>*)((const char*)&node->lower_x+ray.farZ )) - ray.org.z) * ray.rdir_far.z; + const vfloat16 tNear = max(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat16 tFar = min(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool16 vmask = le((1 << 8)-1,tNear,tFar); + const size_t mask = movemask(vmask); + dist = tNear; + return mask; + } + +#endif + + ////////////////////////////////////////////////////////////////////////////////////// + // Fast AABBNodeMB intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t intersectNode(const typename BVHN::AABBNodeMB* node, const TravRay& ray, const float time, vfloat& dist) + { + const vfloat* pNearX = (const vfloat*)((const char*)&node->lower_x+ray.nearX); + const vfloat* pNearY = (const vfloat*)((const char*)&node->lower_x+ray.nearY); + const vfloat* pNearZ = (const vfloat*)((const char*)&node->lower_x+ray.nearZ); + const vfloat* pFarX = (const vfloat*)((const char*)&node->lower_x+ray.farX); + const vfloat* pFarY = (const vfloat*)((const char*)&node->lower_x+ray.farY); + const vfloat* pFarZ = (const vfloat*)((const char*)&node->lower_x+ray.farZ); +#if defined(__AVX2__) + const vfloat tNearX = msub(madd(time,pNearX[6],vfloat(pNearX[0])), ray.rdir.x, ray.org_rdir.x); + const vfloat tNearY = msub(madd(time,pNearY[6],vfloat(pNearY[0])), ray.rdir.y, ray.org_rdir.y); + const vfloat tNearZ = msub(madd(time,pNearZ[6],vfloat(pNearZ[0])), ray.rdir.z, ray.org_rdir.z); + const vfloat tFarX = msub(madd(time,pFarX [6],vfloat(pFarX [0])), ray.rdir.x, ray.org_rdir.x); + const vfloat tFarY = msub(madd(time,pFarY [6],vfloat(pFarY [0])), ray.rdir.y, ray.org_rdir.y); + const vfloat tFarZ = msub(madd(time,pFarZ [6],vfloat(pFarZ [0])), ray.rdir.z, ray.org_rdir.z); +#else + const vfloat tNearX = (madd(time,pNearX[6],vfloat(pNearX[0])) - ray.org.x) * ray.rdir.x; + const vfloat tNearY = (madd(time,pNearY[6],vfloat(pNearY[0])) - ray.org.y) * ray.rdir.y; + const vfloat tNearZ = (madd(time,pNearZ[6],vfloat(pNearZ[0])) - ray.org.z) * ray.rdir.z; + const vfloat tFarX = (madd(time,pFarX [6],vfloat(pFarX [0])) - ray.org.x) * ray.rdir.x; + const vfloat tFarY = (madd(time,pFarY [6],vfloat(pFarY [0])) - ray.org.y) * ray.rdir.y; + const vfloat tFarZ = (madd(time,pFarZ [6],vfloat(pFarZ [0])) - ray.org.z) * ray.rdir.z; +#endif +#if defined(__AVX2__) && !defined(__AVX512F__) // HSW + const vfloat tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat tFar = mini(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool vmask = asInt(tNear) > asInt(tFar); + const size_t mask = movemask(vmask) ^ ((1< tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat tFar = mini(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool vmask = asInt(tNear) <= asInt(tFar); + const size_t mask = movemask(vmask); +#else + const vfloat tNear = max(ray.tnear,tNearX,tNearY,tNearZ); + const vfloat tFar = min(ray.tfar, tFarX ,tFarY ,tFarZ ); + const vbool vmask = tNear <= tFar; + const size_t mask = movemask(vmask); +#endif + dist = tNear; + return mask; + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Robust AABBNodeMB intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t intersectNodeRobust(const typename BVHN::AABBNodeMB* node, const TravRay& ray, const float time, vfloat& dist) + { + const vfloat* pNearX = (const vfloat*)((const char*)&node->lower_x+ray.nearX); + const vfloat* pNearY = (const vfloat*)((const char*)&node->lower_x+ray.nearY); + const vfloat* pNearZ = (const vfloat*)((const char*)&node->lower_x+ray.nearZ); + const vfloat tNearX = (madd(time,pNearX[6],vfloat(pNearX[0])) - ray.org.x) * ray.rdir_near.x; + const vfloat tNearY = (madd(time,pNearY[6],vfloat(pNearY[0])) - ray.org.y) * ray.rdir_near.y; + const vfloat tNearZ = (madd(time,pNearZ[6],vfloat(pNearZ[0])) - ray.org.z) * ray.rdir_near.z; + const vfloat tNear = max(ray.tnear,tNearX,tNearY,tNearZ); + const vfloat* pFarX = (const vfloat*)((const char*)&node->lower_x+ray.farX); + const vfloat* pFarY = (const vfloat*)((const char*)&node->lower_x+ray.farY); + const vfloat* pFarZ = (const vfloat*)((const char*)&node->lower_x+ray.farZ); + const vfloat tFarX = (madd(time,pFarX[6],vfloat(pFarX[0])) - ray.org.x) * ray.rdir_far.x; + const vfloat tFarY = (madd(time,pFarY[6],vfloat(pFarY[0])) - ray.org.y) * ray.rdir_far.y; + const vfloat tFarZ = (madd(time,pFarZ[6],vfloat(pFarZ[0])) - ray.org.z) * ray.rdir_far.z; + const vfloat tFar = min(ray.tfar,tFarX,tFarY,tFarZ); + const size_t mask = movemask(tNear <= tFar); + dist = tNear; + return mask; + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Fast AABBNodeMB4D intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t intersectNodeMB4D(const typename BVHN::NodeRef ref, const TravRay& ray, const float time, vfloat& dist) + { + const typename BVHN::AABBNodeMB* node = ref.getAABBNodeMB(); + + const vfloat* pNearX = (const vfloat*)((const char*)&node->lower_x+ray.nearX); + const vfloat* pNearY = (const vfloat*)((const char*)&node->lower_x+ray.nearY); + const vfloat* pNearZ = (const vfloat*)((const char*)&node->lower_x+ray.nearZ); + const vfloat* pFarX = (const vfloat*)((const char*)&node->lower_x+ray.farX); + const vfloat* pFarY = (const vfloat*)((const char*)&node->lower_x+ray.farY); + const vfloat* pFarZ = (const vfloat*)((const char*)&node->lower_x+ray.farZ); +#if defined (__AVX2__) + const vfloat tNearX = msub(madd(time,pNearX[6],vfloat(pNearX[0])), ray.rdir.x, ray.org_rdir.x); + const vfloat tNearY = msub(madd(time,pNearY[6],vfloat(pNearY[0])), ray.rdir.y, ray.org_rdir.y); + const vfloat tNearZ = msub(madd(time,pNearZ[6],vfloat(pNearZ[0])), ray.rdir.z, ray.org_rdir.z); + const vfloat tFarX = msub(madd(time,pFarX [6],vfloat(pFarX [0])), ray.rdir.x, ray.org_rdir.x); + const vfloat tFarY = msub(madd(time,pFarY [6],vfloat(pFarY [0])), ray.rdir.y, ray.org_rdir.y); + const vfloat tFarZ = msub(madd(time,pFarZ [6],vfloat(pFarZ [0])), ray.rdir.z, ray.org_rdir.z); +#else + const vfloat tNearX = (madd(time,pNearX[6],vfloat(pNearX[0])) - ray.org.x) * ray.rdir.x; + const vfloat tNearY = (madd(time,pNearY[6],vfloat(pNearY[0])) - ray.org.y) * ray.rdir.y; + const vfloat tNearZ = (madd(time,pNearZ[6],vfloat(pNearZ[0])) - ray.org.z) * ray.rdir.z; + const vfloat tFarX = (madd(time,pFarX [6],vfloat(pFarX [0])) - ray.org.x) * ray.rdir.x; + const vfloat tFarY = (madd(time,pFarY [6],vfloat(pFarY [0])) - ray.org.y) * ray.rdir.y; + const vfloat tFarZ = (madd(time,pFarZ [6],vfloat(pFarZ [0])) - ray.org.z) * ray.rdir.z; +#endif +#if defined(__AVX2__) && !defined(__AVX512F__) + const vfloat tNear = maxi(maxi(tNearX,tNearY),maxi(tNearZ,ray.tnear)); + const vfloat tFar = mini(mini(tFarX ,tFarY ),mini(tFarZ ,ray.tfar )); +#else + const vfloat tNear = max(ray.tnear,tNearX,tNearY,tNearZ); + const vfloat tFar = min(ray.tfar, tFarX ,tFarY ,tFarZ ); +#endif + vbool vmask = tNear <= tFar; + if (unlikely(ref.isAABBNodeMB4D())) { + const typename BVHN::AABBNodeMB4D* node1 = (const typename BVHN::AABBNodeMB4D*) node; + vmask &= (node1->lower_t <= time) & (time < node1->upper_t); + } + const size_t mask = movemask(vmask); + dist = tNear; + return mask; + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Robust AABBNodeMB4D intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t intersectNodeMB4DRobust(const typename BVHN::NodeRef ref, const TravRay& ray, const float time, vfloat& dist) + { + const typename BVHN::AABBNodeMB* node = ref.getAABBNodeMB(); + + const vfloat* pNearX = (const vfloat*)((const char*)&node->lower_x+ray.nearX); + const vfloat* pNearY = (const vfloat*)((const char*)&node->lower_x+ray.nearY); + const vfloat* pNearZ = (const vfloat*)((const char*)&node->lower_x+ray.nearZ); + const vfloat tNearX = (madd(time,pNearX[6],vfloat(pNearX[0])) - ray.org.x) * ray.rdir_near.x; + const vfloat tNearY = (madd(time,pNearY[6],vfloat(pNearY[0])) - ray.org.y) * ray.rdir_near.y; + const vfloat tNearZ = (madd(time,pNearZ[6],vfloat(pNearZ[0])) - ray.org.z) * ray.rdir_near.z; + const vfloat tNear = max(ray.tnear,tNearX,tNearY,tNearZ); + const vfloat* pFarX = (const vfloat*)((const char*)&node->lower_x+ray.farX); + const vfloat* pFarY = (const vfloat*)((const char*)&node->lower_x+ray.farY); + const vfloat* pFarZ = (const vfloat*)((const char*)&node->lower_x+ray.farZ); + const vfloat tFarX = (madd(time,pFarX[6],vfloat(pFarX[0])) - ray.org.x) * ray.rdir_far.x; + const vfloat tFarY = (madd(time,pFarY[6],vfloat(pFarY[0])) - ray.org.y) * ray.rdir_far.y; + const vfloat tFarZ = (madd(time,pFarZ[6],vfloat(pFarZ[0])) - ray.org.z) * ray.rdir_far.z; + const vfloat tFar = min(ray.tfar,tFarX,tFarY,tFarZ); + vbool vmask = tNear <= tFar; + if (unlikely(ref.isAABBNodeMB4D())) { + const typename BVHN::AABBNodeMB4D* node1 = (const typename BVHN::AABBNodeMB4D*) node; + vmask &= (node1->lower_t <= time) & (time < node1->upper_t); + } + const size_t mask = movemask(vmask); + dist = tNear; + return mask; + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Fast QuantizedBaseNode intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t intersectNode(const typename BVHN::QuantizedBaseNode* node, const TravRay& ray, vfloat& dist); + + template<> + __forceinline size_t intersectNode<4,4>(const typename BVH4::QuantizedBaseNode* node, const TravRay<4,4,false>& ray, vfloat4& dist) + { + const size_t mvalid = movemask(node->validMask()); + const vfloat4 start_x(node->start.x); + const vfloat4 scale_x(node->scale.x); + const vfloat4 lower_x = madd(node->dequantize<4>(ray.nearX >> 2),scale_x,start_x); + const vfloat4 upper_x = madd(node->dequantize<4>(ray.farX >> 2),scale_x,start_x); + const vfloat4 start_y(node->start.y); + const vfloat4 scale_y(node->scale.y); + const vfloat4 lower_y = madd(node->dequantize<4>(ray.nearY >> 2),scale_y,start_y); + const vfloat4 upper_y = madd(node->dequantize<4>(ray.farY >> 2),scale_y,start_y); + const vfloat4 start_z(node->start.z); + const vfloat4 scale_z(node->scale.z); + const vfloat4 lower_z = madd(node->dequantize<4>(ray.nearZ >> 2),scale_z,start_z); + const vfloat4 upper_z = madd(node->dequantize<4>(ray.farZ >> 2),scale_z,start_z); + +#if defined(__AVX2__) + const vfloat4 tNearX = msub(lower_x, ray.rdir.x, ray.org_rdir.x); + const vfloat4 tNearY = msub(lower_y, ray.rdir.y, ray.org_rdir.y); + const vfloat4 tNearZ = msub(lower_z, ray.rdir.z, ray.org_rdir.z); + const vfloat4 tFarX = msub(upper_x, ray.rdir.x, ray.org_rdir.x); + const vfloat4 tFarY = msub(upper_y, ray.rdir.y, ray.org_rdir.y); + const vfloat4 tFarZ = msub(upper_z, ray.rdir.z, ray.org_rdir.z); +#else + const vfloat4 tNearX = (lower_x - ray.org.x) * ray.rdir.x; + const vfloat4 tNearY = (lower_y - ray.org.y) * ray.rdir.y; + const vfloat4 tNearZ = (lower_z - ray.org.z) * ray.rdir.z; + const vfloat4 tFarX = (upper_x - ray.org.x) * ray.rdir.x; + const vfloat4 tFarY = (upper_y - ray.org.y) * ray.rdir.y; + const vfloat4 tFarZ = (upper_z - ray.org.z) * ray.rdir.z; +#endif + +#if defined(__SSE4_1__) && !defined(__AVX512F__) // up to HSW + const vfloat4 tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat4 tFar = mini(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool4 vmask = asInt(tNear) > asInt(tFar); + const size_t mask = movemask(vmask) ^ ((1<<4)-1); +#elif defined(__AVX512F__) && !defined(__AVX512ER__) // SKX + const vfloat4 tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat4 tFar = mini(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool4 vmask = asInt(tNear) <= asInt(tFar); + const size_t mask = movemask(vmask); +#else + const vfloat4 tNear = max(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat4 tFar = min(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool4 vmask = tNear <= tFar; + const size_t mask = movemask(vmask); +#endif + dist = tNear; + return mask & mvalid; + } + + template<> + __forceinline size_t intersectNode<4,4>(const typename BVH4::QuantizedBaseNode* node, const TravRay<4,4,true>& ray, vfloat4& dist) + { + const size_t mvalid = movemask(node->validMask()); + const vfloat4 start_x(node->start.x); + const vfloat4 scale_x(node->scale.x); + const vfloat4 lower_x = madd(node->dequantize<4>(ray.nearX >> 2),scale_x,start_x); + const vfloat4 upper_x = madd(node->dequantize<4>(ray.farX >> 2),scale_x,start_x); + const vfloat4 start_y(node->start.y); + const vfloat4 scale_y(node->scale.y); + const vfloat4 lower_y = madd(node->dequantize<4>(ray.nearY >> 2),scale_y,start_y); + const vfloat4 upper_y = madd(node->dequantize<4>(ray.farY >> 2),scale_y,start_y); + const vfloat4 start_z(node->start.z); + const vfloat4 scale_z(node->scale.z); + const vfloat4 lower_z = madd(node->dequantize<4>(ray.nearZ >> 2),scale_z,start_z); + const vfloat4 upper_z = madd(node->dequantize<4>(ray.farZ >> 2),scale_z,start_z); + + const vfloat4 tNearX = (lower_x - ray.org.x) * ray.rdir_near.x; + const vfloat4 tNearY = (lower_y - ray.org.y) * ray.rdir_near.y; + const vfloat4 tNearZ = (lower_z - ray.org.z) * ray.rdir_near.z; + const vfloat4 tFarX = (upper_x - ray.org.x) * ray.rdir_far.x; + const vfloat4 tFarY = (upper_y - ray.org.y) * ray.rdir_far.y; + const vfloat4 tFarZ = (upper_z - ray.org.z) * ray.rdir_far.z; + + const vfloat4 tNear = max(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat4 tFar = min(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool4 vmask = tNear <= tFar; + const size_t mask = movemask(vmask); + dist = tNear; + return mask & mvalid; + } + + +#if defined(__AVX__) + + template<> + __forceinline size_t intersectNode<8,8>(const typename BVH8::QuantizedBaseNode* node, const TravRay<8,8,false>& ray, vfloat8& dist) + { + const size_t mvalid = movemask(node->validMask()); + const vfloat8 start_x(node->start.x); + const vfloat8 scale_x(node->scale.x); + const vfloat8 lower_x = madd(node->dequantize<8>(ray.nearX >> 2),scale_x,start_x); + const vfloat8 upper_x = madd(node->dequantize<8>(ray.farX >> 2),scale_x,start_x); + const vfloat8 start_y(node->start.y); + const vfloat8 scale_y(node->scale.y); + const vfloat8 lower_y = madd(node->dequantize<8>(ray.nearY >> 2),scale_y,start_y); + const vfloat8 upper_y = madd(node->dequantize<8>(ray.farY >> 2),scale_y,start_y); + const vfloat8 start_z(node->start.z); + const vfloat8 scale_z(node->scale.z); + const vfloat8 lower_z = madd(node->dequantize<8>(ray.nearZ >> 2),scale_z,start_z); + const vfloat8 upper_z = madd(node->dequantize<8>(ray.farZ >> 2),scale_z,start_z); + +#if defined(__AVX2__) + const vfloat8 tNearX = msub(lower_x, ray.rdir.x, ray.org_rdir.x); + const vfloat8 tNearY = msub(lower_y, ray.rdir.y, ray.org_rdir.y); + const vfloat8 tNearZ = msub(lower_z, ray.rdir.z, ray.org_rdir.z); + const vfloat8 tFarX = msub(upper_x, ray.rdir.x, ray.org_rdir.x); + const vfloat8 tFarY = msub(upper_y, ray.rdir.y, ray.org_rdir.y); + const vfloat8 tFarZ = msub(upper_z, ray.rdir.z, ray.org_rdir.z); +#else + const vfloat8 tNearX = (lower_x - ray.org.x) * ray.rdir.x; + const vfloat8 tNearY = (lower_y - ray.org.y) * ray.rdir.y; + const vfloat8 tNearZ = (lower_z - ray.org.z) * ray.rdir.z; + const vfloat8 tFarX = (upper_x - ray.org.x) * ray.rdir.x; + const vfloat8 tFarY = (upper_y - ray.org.y) * ray.rdir.y; + const vfloat8 tFarZ = (upper_z - ray.org.z) * ray.rdir.z; +#endif + +#if defined(__AVX2__) && !defined(__AVX512F__) // HSW + const vfloat8 tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat8 tFar = mini(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool8 vmask = asInt(tNear) > asInt(tFar); + const size_t mask = movemask(vmask) ^ ((1<<8)-1); +#elif defined(__AVX512F__) && !defined(__AVX512ER__) // SKX + const vfloat8 tNear = maxi(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat8 tFar = mini(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool8 vmask = asInt(tNear) <= asInt(tFar); + const size_t mask = movemask(vmask); +#else + const vfloat8 tNear = max(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat8 tFar = min(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool8 vmask = tNear <= tFar; + const size_t mask = movemask(vmask); +#endif + dist = tNear; + return mask & mvalid; + } + + template<> + __forceinline size_t intersectNode<8,8>(const typename BVH8::QuantizedBaseNode* node, const TravRay<8,8,true>& ray, vfloat8& dist) + { + const size_t mvalid = movemask(node->validMask()); + const vfloat8 start_x(node->start.x); + const vfloat8 scale_x(node->scale.x); + const vfloat8 lower_x = madd(node->dequantize<8>(ray.nearX >> 2),scale_x,start_x); + const vfloat8 upper_x = madd(node->dequantize<8>(ray.farX >> 2),scale_x,start_x); + const vfloat8 start_y(node->start.y); + const vfloat8 scale_y(node->scale.y); + const vfloat8 lower_y = madd(node->dequantize<8>(ray.nearY >> 2),scale_y,start_y); + const vfloat8 upper_y = madd(node->dequantize<8>(ray.farY >> 2),scale_y,start_y); + const vfloat8 start_z(node->start.z); + const vfloat8 scale_z(node->scale.z); + const vfloat8 lower_z = madd(node->dequantize<8>(ray.nearZ >> 2),scale_z,start_z); + const vfloat8 upper_z = madd(node->dequantize<8>(ray.farZ >> 2),scale_z,start_z); + + const vfloat8 tNearX = (lower_x - ray.org.x) * ray.rdir_near.x; + const vfloat8 tNearY = (lower_y - ray.org.y) * ray.rdir_near.y; + const vfloat8 tNearZ = (lower_z - ray.org.z) * ray.rdir_near.z; + const vfloat8 tFarX = (upper_x - ray.org.x) * ray.rdir_far.x; + const vfloat8 tFarY = (upper_y - ray.org.y) * ray.rdir_far.y; + const vfloat8 tFarZ = (upper_z - ray.org.z) * ray.rdir_far.z; + + const vfloat8 tNear = max(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat8 tFar = min(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool8 vmask = tNear <= tFar; + const size_t mask = movemask(vmask); + + dist = tNear; + return mask & mvalid; + } + + +#endif + +#if defined(__AVX512F__) && !defined(__AVX512VL__) // KNL + + template<> + __forceinline size_t intersectNode<4,16>(const typename BVH4::QuantizedBaseNode* node, const TravRay<4,16,false>& ray, vfloat16& dist) + { + const size_t mvalid = movemask(node->validMask()); + const vfloat16 start_x(node->start.x); + const vfloat16 scale_x(node->scale.x); + const vfloat16 lower_x = madd(vfloat16(node->dequantize<4>(ray.nearX >> 2)),scale_x,start_x); + const vfloat16 upper_x = madd(vfloat16(node->dequantize<4>(ray.farX >> 2)),scale_x,start_x); + const vfloat16 start_y(node->start.y); + const vfloat16 scale_y(node->scale.y); + const vfloat16 lower_y = madd(vfloat16(node->dequantize<4>(ray.nearY >> 2)),scale_y,start_y); + const vfloat16 upper_y = madd(vfloat16(node->dequantize<4>(ray.farY >> 2)),scale_y,start_y); + const vfloat16 start_z(node->start.z); + const vfloat16 scale_z(node->scale.z); + const vfloat16 lower_z = madd(vfloat16(node->dequantize<4>(ray.nearZ >> 2)),scale_z,start_z); + const vfloat16 upper_z = madd(vfloat16(node->dequantize<4>(ray.farZ >> 2)),scale_z,start_z); + + const vfloat16 tNearX = msub(lower_x, ray.rdir.x, ray.org_rdir.x); + const vfloat16 tNearY = msub(lower_y, ray.rdir.y, ray.org_rdir.y); + const vfloat16 tNearZ = msub(lower_z, ray.rdir.z, ray.org_rdir.z); + const vfloat16 tFarX = msub(upper_x, ray.rdir.x, ray.org_rdir.x); + const vfloat16 tFarY = msub(upper_y, ray.rdir.y, ray.org_rdir.y); + const vfloat16 tFarZ = msub(upper_z, ray.rdir.z, ray.org_rdir.z); + const vfloat16 tNear = max(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat16 tFar = min(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool16 vmask = le(vbool16(0xf),tNear,tFar); + const size_t mask = movemask(vmask) & mvalid; + dist = tNear; + return mask; + } + + template<> + __forceinline size_t intersectNode<4,16>(const typename BVH4::QuantizedBaseNode* node, const TravRay<4,16,true>& ray, vfloat16& dist) + { + const size_t mvalid = movemask(node->validMask()); + const vfloat16 start_x(node->start.x); + const vfloat16 scale_x(node->scale.x); + const vfloat16 lower_x = madd(vfloat16(node->dequantize<4>(ray.nearX >> 2)),scale_x,start_x); + const vfloat16 upper_x = madd(vfloat16(node->dequantize<4>(ray.farX >> 2)),scale_x,start_x); + const vfloat16 start_y(node->start.y); + const vfloat16 scale_y(node->scale.y); + const vfloat16 lower_y = madd(vfloat16(node->dequantize<4>(ray.nearY >> 2)),scale_y,start_y); + const vfloat16 upper_y = madd(vfloat16(node->dequantize<4>(ray.farY >> 2)),scale_y,start_y); + const vfloat16 start_z(node->start.z); + const vfloat16 scale_z(node->scale.z); + const vfloat16 lower_z = madd(vfloat16(node->dequantize<4>(ray.nearZ >> 2)),scale_z,start_z); + const vfloat16 upper_z = madd(vfloat16(node->dequantize<4>(ray.farZ >> 2)),scale_z,start_z); + + const vfloat16 tNearX = (lower_x - ray.org.x) * ray.rdir_near.x; + const vfloat16 tNearY = (lower_y - ray.org.y) * ray.rdir_near.y; + const vfloat16 tNearZ = (lower_z - ray.org.z) * ray.rdir_near.z; + const vfloat16 tFarX = (upper_x - ray.org.x) * ray.rdir_far.x; + const vfloat16 tFarY = (upper_y - ray.org.y) * ray.rdir_far.y; + const vfloat16 tFarZ = (upper_z - ray.org.z) * ray.rdir_far.z; + + const vfloat16 tNear = max(tNearX,tNearY,tNearZ,ray.tnear); + const vfloat16 tFar = min(tFarX ,tFarY ,tFarZ ,ray.tfar); + const vbool16 vmask = le(vbool16(0xf),tNear,tFar); + const size_t mask = movemask(vmask) & mvalid; + dist = tNear; + return mask; + } + + template<> + __forceinline size_t intersectNode<8,16>(const typename BVH8::QuantizedBaseNode* node, const TravRay<8,16,false>& ray, vfloat16& dist) + { + const vbool16 m_valid(node->validMask16()); + const vfloat16 bminmaxX = node->dequantizeLowerUpperX(ray.permX); + const vfloat16 bminmaxY = node->dequantizeLowerUpperY(ray.permY); + const vfloat16 bminmaxZ = node->dequantizeLowerUpperZ(ray.permZ); + const vfloat16 tNearFarX = msub(bminmaxX, ray.rdir.x, ray.org_rdir.x); + const vfloat16 tNearFarY = msub(bminmaxY, ray.rdir.y, ray.org_rdir.y); + const vfloat16 tNearFarZ = msub(bminmaxZ, ray.rdir.z, ray.org_rdir.z); + const vfloat16 tNear = max(tNearFarX, tNearFarY, tNearFarZ, ray.tnear); + const vfloat16 tFar = min(tNearFarX, tNearFarY, tNearFarZ, ray.tfar); + const vbool16 vmask = le(m_valid,tNear,align_shift_right<8>(tFar, tFar)); + const size_t mask = movemask(vmask); + dist = tNear; + return mask; + } + + template<> + __forceinline size_t intersectNode<8,16>(const typename BVH8::QuantizedBaseNode* node, const TravRay<8,16,true>& ray, vfloat16& dist) + { + const vbool16 m_valid(node->validMask16()); + const vfloat16 bminmaxX = node->dequantizeLowerUpperX(ray.permX); + const vfloat16 bminmaxY = node->dequantizeLowerUpperY(ray.permY); + const vfloat16 bminmaxZ = node->dequantizeLowerUpperZ(ray.permZ); + const vfloat16 tNearFarX = (bminmaxX - ray.org.x) * ray.rdir_far.x; // FIXME: this is not conservative !!!!!!!!! + const vfloat16 tNearFarY = (bminmaxY - ray.org.y) * ray.rdir_far.y; + const vfloat16 tNearFarZ = (bminmaxZ - ray.org.z) * ray.rdir_far.z; + const vfloat16 tNear = max(tNearFarX, tNearFarY, tNearFarZ, ray.tnear); + const vfloat16 tFar = min(tNearFarX, tNearFarY, tNearFarZ, ray.tfar); + const vbool16 vmask = le(m_valid,tNear,align_shift_right<8>(tFar, tFar)); + const size_t mask = movemask(vmask); + dist = tNear; + return mask; + } + + +#endif + + + template + __forceinline size_t intersectNode(const typename BVHN::QuantizedBaseNodeMB* node, const TravRay& ray, const float time, vfloat& dist) + { + const vboolf mvalid = node->validMask(); + const vfloat lower_x = node->dequantizeLowerX(time); + const vfloat upper_x = node->dequantizeUpperX(time); + const vfloat lower_y = node->dequantizeLowerY(time); + const vfloat upper_y = node->dequantizeUpperY(time); + const vfloat lower_z = node->dequantizeLowerZ(time); + const vfloat upper_z = node->dequantizeUpperZ(time); +#if defined(__AVX2__) + const vfloat tNearX = msub(lower_x, ray.rdir.x, ray.org_rdir.x); + const vfloat tNearY = msub(lower_y, ray.rdir.y, ray.org_rdir.y); + const vfloat tNearZ = msub(lower_z, ray.rdir.z, ray.org_rdir.z); + const vfloat tFarX = msub(upper_x, ray.rdir.x, ray.org_rdir.x); + const vfloat tFarY = msub(upper_y, ray.rdir.y, ray.org_rdir.y); + const vfloat tFarZ = msub(upper_z, ray.rdir.z, ray.org_rdir.z); +#else + const vfloat tNearX = (lower_x - ray.org.x) * ray.rdir.x; + const vfloat tNearY = (lower_y - ray.org.y) * ray.rdir.y; + const vfloat tNearZ = (lower_z - ray.org.z) * ray.rdir.z; + const vfloat tFarX = (upper_x - ray.org.x) * ray.rdir.x; + const vfloat tFarY = (upper_y - ray.org.y) * ray.rdir.y; + const vfloat tFarZ = (upper_z - ray.org.z) * ray.rdir.z; +#endif + + const vfloat tminX = mini(tNearX,tFarX); + const vfloat tmaxX = maxi(tNearX,tFarX); + const vfloat tminY = mini(tNearY,tFarY); + const vfloat tmaxY = maxi(tNearY,tFarY); + const vfloat tminZ = mini(tNearZ,tFarZ); + const vfloat tmaxZ = maxi(tNearZ,tFarZ); + const vfloat tNear = maxi(tminX,tminY,tminZ,ray.tnear); + const vfloat tFar = mini(tmaxX,tmaxY,tmaxZ,ray.tfar); +#if defined(__AVX512F__) && !defined(__AVX512ER__) // SKX + const vbool vmask = le(mvalid,asInt(tNear),asInt(tFar)); +#else + const vbool vmask = (asInt(tNear) <= asInt(tFar)) & mvalid; +#endif + const size_t mask = movemask(vmask); + dist = tNear; + return mask; + } + + template + __forceinline size_t intersectNode(const typename BVHN::QuantizedBaseNodeMB* node, const TravRay& ray, const float time, vfloat& dist) + { + const vboolf mvalid = node->validMask(); + const vfloat lower_x = node->dequantizeLowerX(time); + const vfloat upper_x = node->dequantizeUpperX(time); + const vfloat lower_y = node->dequantizeLowerY(time); + const vfloat upper_y = node->dequantizeUpperY(time); + const vfloat lower_z = node->dequantizeLowerZ(time); + const vfloat upper_z = node->dequantizeUpperZ(time); + const vfloat tNearX = (lower_x - ray.org.x) * ray.rdir_near.x; + const vfloat tNearY = (lower_y - ray.org.y) * ray.rdir_near.y; + const vfloat tNearZ = (lower_z - ray.org.z) * ray.rdir_near.z; + const vfloat tFarX = (upper_x - ray.org.x) * ray.rdir_far.x; + const vfloat tFarY = (upper_y - ray.org.y) * ray.rdir_far.y; + const vfloat tFarZ = (upper_z - ray.org.z) * ray.rdir_far.z; + + const vfloat tminX = mini(tNearX,tFarX); + const vfloat tmaxX = maxi(tNearX,tFarX); + const vfloat tminY = mini(tNearY,tFarY); + const vfloat tmaxY = maxi(tNearY,tFarY); + const vfloat tminZ = mini(tNearZ,tFarZ); + const vfloat tmaxZ = maxi(tNearZ,tFarZ); + const vfloat tNear = maxi(tminX,tminY,tminZ,ray.tnear); + const vfloat tFar = mini(tmaxX,tmaxY,tmaxZ,ray.tfar); +#if defined(__AVX512F__) && !defined(__AVX512ER__) // SKX + const vbool vmask = le(mvalid,asInt(tNear),asInt(tFar)); +#else + const vbool vmask = (asInt(tNear) <= asInt(tFar)) & mvalid; +#endif + const size_t mask = movemask(vmask); + dist = tNear; + return mask; + } + + +#if defined(__AVX512ER__) + // for KNL + template<> + __forceinline size_t intersectNode<4,16>(const typename BVHN<4>::QuantizedBaseNodeMB* node, const TravRay<4,16,false>& ray, const float time, vfloat<4>& dist) + { + const size_t mvalid = movemask(node->validMask()); + const vfloat16 lower_x = node->dequantizeLowerX(time); + const vfloat16 upper_x = node->dequantizeUpperX(time); + const vfloat16 lower_y = node->dequantizeLowerY(time); + const vfloat16 upper_y = node->dequantizeUpperY(time); + const vfloat16 lower_z = node->dequantizeLowerZ(time); + const vfloat16 upper_z = node->dequantizeUpperZ(time); + + const vfloat16 tNearX = msub(lower_x, ray.rdir.x, ray.org_rdir.x); + const vfloat16 tNearY = msub(lower_y, ray.rdir.y, ray.org_rdir.y); + const vfloat16 tNearZ = msub(lower_z, ray.rdir.z, ray.org_rdir.z); + const vfloat16 tFarX = msub(upper_x, ray.rdir.x, ray.org_rdir.x); + const vfloat16 tFarY = msub(upper_y, ray.rdir.y, ray.org_rdir.y); + const vfloat16 tFarZ = msub(upper_z, ray.rdir.z, ray.org_rdir.z); + + const vfloat16 tminX = min(tNearX,tFarX); + const vfloat16 tmaxX = max(tNearX,tFarX); + const vfloat16 tminY = min(tNearY,tFarY); + const vfloat16 tmaxY = max(tNearY,tFarY); + const vfloat16 tminZ = min(tNearZ,tFarZ); + const vfloat16 tmaxZ = max(tNearZ,tFarZ); + const vfloat16 tNear = max(tminX,tminY,tminZ,ray.tnear); + const vfloat16 tFar = min(tmaxX,tmaxY,tmaxZ,ray.tfar ); + const vbool16 vmask = tNear <= tFar; + const size_t mask = movemask(vmask) & mvalid; + dist = extractN<4,0>(tNear); + return mask; + } + + + // for KNL + template<> + __forceinline size_t intersectNode<4,16>(const typename BVHN<4>::QuantizedBaseNodeMB* node, const TravRay<4,16,true>& ray, const float time, vfloat<4>& dist) + { + const size_t mvalid = movemask(node->validMask()); + const vfloat16 lower_x = node->dequantizeLowerX(time); + const vfloat16 upper_x = node->dequantizeUpperX(time); + const vfloat16 lower_y = node->dequantizeLowerY(time); + const vfloat16 upper_y = node->dequantizeUpperY(time); + const vfloat16 lower_z = node->dequantizeLowerZ(time); + const vfloat16 upper_z = node->dequantizeUpperZ(time); + + const vfloat16 tNearX = (lower_x - ray.org.x) * ray.rdir_near.x; + const vfloat16 tNearY = (lower_y - ray.org.y) * ray.rdir_near.y; + const vfloat16 tNearZ = (lower_z - ray.org.z) * ray.rdir_near.z; + const vfloat16 tFarX = (upper_x - ray.org.x) * ray.rdir_far.x; + const vfloat16 tFarY = (upper_y - ray.org.y) * ray.rdir_far.y; + const vfloat16 tFarZ = (upper_z - ray.org.z) * ray.rdir_far.z; + + const vfloat16 tminX = min(tNearX,tFarX); + const vfloat16 tmaxX = max(tNearX,tFarX); + const vfloat16 tminY = min(tNearY,tFarY); + const vfloat16 tmaxY = max(tNearY,tFarY); + const vfloat16 tminZ = min(tNearZ,tFarZ); + const vfloat16 tmaxZ = max(tNearZ,tFarZ); + const vfloat16 tNear = max(tminX,tminY,tminZ,ray.tnear); + const vfloat16 tFar = min(tmaxX,tmaxY,tmaxZ,ray.tfar ); + const vbool16 vmask = tNear <= tFar; + const size_t mask = movemask(vmask) & mvalid; + dist = extractN<4,0>(tNear); + return mask; + } + +#endif + + ////////////////////////////////////////////////////////////////////////////////////// + // Fast OBBNode intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t intersectNode(const typename BVHN::OBBNode* node, const TravRay& ray, vfloat& dist) + { + const Vec3vf dir = xfmVector(node->naabb,ray.dir); + //const Vec3vf nrdir = Vec3vf(vfloat(-1.0f))/dir; + const Vec3vf nrdir = Vec3vf(vfloat(-1.0f))*rcp_safe(dir); + const Vec3vf org = xfmPoint(node->naabb,ray.org); + const Vec3vf tLowerXYZ = org * nrdir; // (Vec3fa(zero) - org) * rdir; + const Vec3vf tUpperXYZ = tLowerXYZ - nrdir; // (Vec3fa(one ) - org) * rdir; + + const vfloat tNearX = mini(tLowerXYZ.x,tUpperXYZ.x); + const vfloat tNearY = mini(tLowerXYZ.y,tUpperXYZ.y); + const vfloat tNearZ = mini(tLowerXYZ.z,tUpperXYZ.z); + const vfloat tFarX = maxi(tLowerXYZ.x,tUpperXYZ.x); + const vfloat tFarY = maxi(tLowerXYZ.y,tUpperXYZ.y); + const vfloat tFarZ = maxi(tLowerXYZ.z,tUpperXYZ.z); + vfloat tNear = max(ray.tnear, tNearX,tNearY,tNearZ); + vfloat tFar = min(ray.tfar, tFarX ,tFarY ,tFarZ ); + if (robust) { + tNear = tNear*vfloat(1.0f-3.0f*float(ulp)); + tFar = tFar *vfloat(1.0f+3.0f*float(ulp)); + } + const vbool vmask = tNear <= tFar; + dist = tNear; + return movemask(vmask); + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Fast OBBNodeMB intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t intersectNode(const typename BVHN::OBBNodeMB* node, const TravRay& ray, const float time, vfloat& dist) + { + const AffineSpace3vf xfm = node->space0; + const Vec3vf b0_lower = zero; + const Vec3vf b0_upper = one; + const Vec3vf lower = lerp(b0_lower,node->b1.lower,vfloat(time)); + const Vec3vf upper = lerp(b0_upper,node->b1.upper,vfloat(time)); + + const BBox3vf bounds(lower,upper); + const Vec3vf dir = xfmVector(xfm,ray.dir); + const Vec3vf rdir = rcp_safe(dir); + const Vec3vf org = xfmPoint(xfm,ray.org); + + const Vec3vf tLowerXYZ = (bounds.lower - org) * rdir; + const Vec3vf tUpperXYZ = (bounds.upper - org) * rdir; + + const vfloat tNearX = mini(tLowerXYZ.x,tUpperXYZ.x); + const vfloat tNearY = mini(tLowerXYZ.y,tUpperXYZ.y); + const vfloat tNearZ = mini(tLowerXYZ.z,tUpperXYZ.z); + const vfloat tFarX = maxi(tLowerXYZ.x,tUpperXYZ.x); + const vfloat tFarY = maxi(tLowerXYZ.y,tUpperXYZ.y); + const vfloat tFarZ = maxi(tLowerXYZ.z,tUpperXYZ.z); + vfloat tNear = max(ray.tnear, tNearX,tNearY,tNearZ); + vfloat tFar = min(ray.tfar, tFarX ,tFarY ,tFarZ ); + if (robust) { + tNear = tNear*vfloat(1.0f-3.0f*float(ulp)); + tFar = tFar *vfloat(1.0f+3.0f*float(ulp)); + } + const vbool vmask = tNear <= tFar; + dist = tNear; + return movemask(vmask); + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Node intersectors used in point query raversal + ////////////////////////////////////////////////////////////////////////////////////// + + /*! Computes traversal information for N nodes with 1 point query */ + template + struct BVHNNodePointQuerySphere1; + + template + struct BVHNNodePointQuerySphere1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = pointQueryNodeSphere(node.getAABBNode(), query, dist); + return true; + } + }; + + template + struct BVHNNodePointQuerySphere1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = pointQueryNodeSphere(node.getAABBNodeMB(), query, time, dist); + return true; + } + }; + + template + struct BVHNNodePointQuerySphere1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = pointQueryNodeSphereMB4D(node, query, time, dist); + return true; + } + }; + + template + struct BVHNNodePointQuerySphere1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (likely(node.isAABBNode())) mask = pointQueryNodeSphere(node.getAABBNode(), query, dist); + else if (unlikely(node.isOBBNode())) mask = pointQueryNodeSphere(node.ungetAABBNode(), query, dist); + else return false; + return true; + } + }; + + template + struct BVHNNodePointQuerySphere1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (likely(node.isAABBNodeMB())) mask = pointQueryNodeSphere(node.getAABBNodeMB(), query, time, dist); + else if (unlikely(node.isOBBNodeMB())) mask = pointQueryNodeSphere(node.ungetAABBNodeMB(), query, time, dist); + else return false; + return true; + } + }; + + template + struct BVHNNodePointQuerySphere1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + if (unlikely(node.isOBBNodeMB())) mask = pointQueryNodeSphere(node.ungetAABBNodeMB(), query, time, dist); + else mask = pointQueryNodeSphereMB4D(node, query, time, dist); + return true; + } + }; + + template + struct BVHNNodePointQuerySphere1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = pointQueryNodeSphere((const typename BVHN::QuantizedNode*)node.quantizedNode(), query, dist); + return true; + } + }; + + template + struct BVHNQuantizedBaseNodePointQuerySphere1 + { + static __forceinline size_t pointQuery(const typename BVHN::QuantizedBaseNode* node, const TravPointQuery& query, vfloat& dist) + { + return pointQueryNodeSphere(node,query,dist); + } + + static __forceinline size_t pointQuery(const typename BVHN::QuantizedBaseNodeMB* node, const TravPointQuery& query, const float time, vfloat& dist) + { + return pointQueryNodeSphere(node,query,time,dist); + } + }; + + /*! Computes traversal information for N nodes with 1 point query */ + template + struct BVHNNodePointQueryAABB1; + + template + struct BVHNNodePointQueryAABB1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = pointQueryNodeAABB(node.getAABBNode(), query, dist); + return true; + } + }; + + template + struct BVHNNodePointQueryAABB1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = pointQueryNodeAABB(node.getAABBNodeMB(), query, time, dist); + return true; + } + }; + + template + struct BVHNNodePointQueryAABB1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = pointQueryNodeAABBMB4D(node, query, time, dist); + return true; + } + }; + + template + struct BVHNNodePointQueryAABB1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (likely(node.isAABBNode())) mask = pointQueryNodeAABB(node.getAABBNode(), query, dist); + else if (unlikely(node.isOBBNode())) mask = pointQueryNodeAABB(node.ungetAABBNode(), query, dist); + else return false; + return true; + } + }; + + template + struct BVHNNodePointQueryAABB1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (likely(node.isAABBNodeMB())) mask = pointQueryNodeAABB(node.getAABBNodeMB(), query, time, dist); + else if (unlikely(node.isOBBNodeMB())) mask = pointQueryNodeAABB(node.ungetAABBNodeMB(), query, time, dist); + else return false; + return true; + } + }; + + template + struct BVHNNodePointQueryAABB1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + if (unlikely(node.isOBBNodeMB())) mask = pointQueryNodeAABB(node.ungetAABBNodeMB(), query, time, dist); + else mask = pointQueryNodeAABBMB4D(node, query, time, dist); + return true; + } + }; + + template + struct BVHNNodePointQueryAABB1 + { + static __forceinline bool pointQuery(const typename BVHN::NodeRef& node, const TravPointQuery& query, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = pointQueryNodeAABB((const typename BVHN::QuantizedNode*)node.quantizedNode(), query, dist); + return true; + } + }; + + template + struct BVHNQuantizedBaseNodePointQueryAABB1 + { + static __forceinline size_t pointQuery(const typename BVHN::QuantizedBaseNode* node, const TravPointQuery& query, vfloat& dist) + { + return pointQueryNodeAABB(node,query,dist); + } + + static __forceinline size_t pointQuery(const typename BVHN::QuantizedBaseNodeMB* node, const TravPointQuery& query, const float time, vfloat& dist) + { + return pointQueryNodeAABB(node,query,time,dist); + } + }; + + + ////////////////////////////////////////////////////////////////////////////////////// + // Node intersectors used in ray traversal + ////////////////////////////////////////////////////////////////////////////////////// + + /*! Intersects N nodes with 1 ray */ + template + struct BVHNNodeIntersector1; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = intersectNode(node.getAABBNode(), ray, dist); + return true; + } + }; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = intersectNodeRobust(node.getAABBNode(), ray, dist); + return true; + } + }; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = intersectNode(node.getAABBNodeMB(), ray, time, dist); + return true; + } + }; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = intersectNodeRobust(node.getAABBNodeMB(), ray, time, dist); + return true; + } + }; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = intersectNodeMB4D(node, ray, time, dist); + return true; + } + }; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = intersectNodeMB4DRobust(node, ray, time, dist); + return true; + } + }; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (likely(node.isAABBNode())) mask = intersectNode(node.getAABBNode(), ray, dist); + else if (unlikely(node.isOBBNode())) mask = intersectNode(node.ungetAABBNode(), ray, dist); + else return false; + return true; + } + }; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (likely(node.isAABBNode())) mask = intersectNodeRobust(node.getAABBNode(), ray, dist); + else if (unlikely(node.isOBBNode())) mask = intersectNode(node.ungetAABBNode(), ray, dist); + else return false; + return true; + } + }; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (likely(node.isAABBNodeMB())) mask = intersectNode(node.getAABBNodeMB(), ray, time, dist); + else if (unlikely(node.isOBBNodeMB())) mask = intersectNode(node.ungetAABBNodeMB(), ray, time, dist); + else return false; + return true; + } + }; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (likely(node.isAABBNodeMB())) mask = intersectNodeRobust(node.getAABBNodeMB(), ray, time, dist); + else if (unlikely(node.isOBBNodeMB())) mask = intersectNode(node.ungetAABBNodeMB(), ray, time, dist); + else return false; + return true; + } + }; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + if (unlikely(node.isOBBNodeMB())) mask = intersectNode(node.ungetAABBNodeMB(), ray, time, dist); + else mask = intersectNodeMB4D(node, ray, time, dist); + return true; + } + }; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + if (unlikely(node.isOBBNodeMB())) mask = intersectNode(node.ungetAABBNodeMB(), ray, time, dist); + else mask = intersectNodeMB4DRobust(node, ray, time, dist); + return true; + } + }; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = intersectNode((const typename BVHN::QuantizedNode*)node.quantizedNode(), ray, dist); + return true; + } + }; + + template + struct BVHNNodeIntersector1 + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, const TravRay& ray, float time, vfloat& dist, size_t& mask) + { + if (unlikely(node.isLeaf())) return false; + mask = intersectNodeRobust((const typename BVHN::QuantizedNode*)node.quantizedNode(), ray, dist); + return true; + } + }; + + /*! Intersects N nodes with K rays */ + template + struct BVHNQuantizedBaseNodeIntersector1; + + template + struct BVHNQuantizedBaseNodeIntersector1 + { + static __forceinline size_t intersect(const typename BVHN::QuantizedBaseNode* node, const TravRay& ray, vfloat& dist) + { + return intersectNode(node,ray,dist); + } + + static __forceinline size_t intersect(const typename BVHN::QuantizedBaseNodeMB* node, const TravRay& ray, const float time, vfloat& dist) + { + return intersectNode(node,ray,time,dist); + } + + }; + + template + struct BVHNQuantizedBaseNodeIntersector1 + { + static __forceinline size_t intersect(const typename BVHN::QuantizedBaseNode* node, const TravRay& ray, vfloat& dist) + { + return intersectNode(node,ray,dist); + } + + static __forceinline size_t intersect(const typename BVHN::QuantizedBaseNodeMB* node, const TravRay& ray, const float time, vfloat& dist) + { + return intersectNode(node,ray,time,dist); + } + + }; + + + } +} diff --git a/thirdparty/embree/kernels/bvh/node_intersector_frustum.h b/thirdparty/embree/kernels/bvh/node_intersector_frustum.h new file mode 100644 index 000000000000..dbce46932495 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/node_intersector_frustum.h @@ -0,0 +1,253 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "node_intersector.h" + +namespace embree +{ + namespace isa + { + ////////////////////////////////////////////////////////////////////////////////////// + // Frustum structure used in hybrid and stream traversal + ////////////////////////////////////////////////////////////////////////////////////// + + /* + Optimized frustum test. We calculate t=(p-org)/dir in ray/box + intersection. We assume the rays are split by octant, thus + dir intervals are either positive or negative in each + dimension. + + Case 1: dir.min >= 0 && dir.max >= 0: + t_min = (p_min - org_max) / dir_max = (p_min - org_max)*rdir_min = p_min*rdir_min - org_max*rdir_min + t_max = (p_max - org_min) / dir_min = (p_max - org_min)*rdir_max = p_max*rdir_max - org_min*rdir_max + + Case 2: dir.min < 0 && dir.max < 0: + t_min = (p_max - org_min) / dir_min = (p_max - org_min)*rdir_max = p_max*rdir_max - org_min*rdir_max + t_max = (p_min - org_max) / dir_max = (p_min - org_max)*rdir_min = p_min*rdir_min - org_max*rdir_min + */ + + template + struct Frustum; + + /* Fast variant */ + template<> + struct Frustum + { + __forceinline Frustum() {} + + template + __forceinline Frustum(const vbool& valid, const Vec3vf& org, const Vec3vf& rdir, const vfloat& ray_tnear, const vfloat& ray_tfar, int N) + { + init(valid, org, rdir, ray_tnear, ray_tfar, N); + } + + template + __forceinline void init(const vbool& valid, const Vec3vf& org, const Vec3vf& rdir, const vfloat& ray_tnear, const vfloat& ray_tfar, int N) + { + const Vec3fa reduced_min_org(reduce_min(select(valid, org.x, pos_inf)), + reduce_min(select(valid, org.y, pos_inf)), + reduce_min(select(valid, org.z, pos_inf))); + + const Vec3fa reduced_max_org(reduce_max(select(valid, org.x, neg_inf)), + reduce_max(select(valid, org.y, neg_inf)), + reduce_max(select(valid, org.z, neg_inf))); + + const Vec3fa reduced_min_rdir(reduce_min(select(valid, rdir.x, pos_inf)), + reduce_min(select(valid, rdir.y, pos_inf)), + reduce_min(select(valid, rdir.z, pos_inf))); + + const Vec3fa reduced_max_rdir(reduce_max(select(valid, rdir.x, neg_inf)), + reduce_max(select(valid, rdir.y, neg_inf)), + reduce_max(select(valid, rdir.z, neg_inf))); + + const float reduced_min_dist = reduce_min(select(valid, ray_tnear, vfloat(pos_inf))); + const float reduced_max_dist = reduce_max(select(valid, ray_tfar , vfloat(neg_inf))); + + init(reduced_min_org, reduced_max_org, reduced_min_rdir, reduced_max_rdir, reduced_min_dist, reduced_max_dist, N); + } + + __forceinline void init(const Vec3fa& reduced_min_org, + const Vec3fa& reduced_max_org, + const Vec3fa& reduced_min_rdir, + const Vec3fa& reduced_max_rdir, + float reduced_min_dist, + float reduced_max_dist, + int N) + { + const Vec3ba pos_rdir = ge_mask(reduced_min_rdir, Vec3fa(zero)); + + min_rdir = select(pos_rdir, reduced_min_rdir, reduced_max_rdir); + max_rdir = select(pos_rdir, reduced_max_rdir, reduced_min_rdir); + + min_org_rdir = min_rdir * select(pos_rdir, reduced_max_org, reduced_min_org); + max_org_rdir = max_rdir * select(pos_rdir, reduced_min_org, reduced_max_org); + + min_dist = reduced_min_dist; + max_dist = reduced_max_dist; + + nf = NearFarPrecalculations(min_rdir, N); + } + + template + __forceinline void updateMaxDist(const vfloat& ray_tfar) + { + max_dist = reduce_max(ray_tfar); + } + + NearFarPrecalculations nf; + + Vec3fa min_rdir; + Vec3fa max_rdir; + + Vec3fa min_org_rdir; + Vec3fa max_org_rdir; + + float min_dist; + float max_dist; + }; + + typedef Frustum FrustumFast; + + /* Robust variant */ + template<> + struct Frustum + { + __forceinline Frustum() {} + + template + __forceinline Frustum(const vbool& valid, const Vec3vf& org, const Vec3vf& rdir, const vfloat& ray_tnear, const vfloat& ray_tfar, int N) + { + init(valid, org, rdir, ray_tnear, ray_tfar, N); + } + + template + __forceinline void init(const vbool& valid, const Vec3vf& org, const Vec3vf& rdir, const vfloat& ray_tnear, const vfloat& ray_tfar, int N) + { + const Vec3fa reduced_min_org(reduce_min(select(valid, org.x, pos_inf)), + reduce_min(select(valid, org.y, pos_inf)), + reduce_min(select(valid, org.z, pos_inf))); + + const Vec3fa reduced_max_org(reduce_max(select(valid, org.x, neg_inf)), + reduce_max(select(valid, org.y, neg_inf)), + reduce_max(select(valid, org.z, neg_inf))); + + const Vec3fa reduced_min_rdir(reduce_min(select(valid, rdir.x, pos_inf)), + reduce_min(select(valid, rdir.y, pos_inf)), + reduce_min(select(valid, rdir.z, pos_inf))); + + const Vec3fa reduced_max_rdir(reduce_max(select(valid, rdir.x, neg_inf)), + reduce_max(select(valid, rdir.y, neg_inf)), + reduce_max(select(valid, rdir.z, neg_inf))); + + const float reduced_min_dist = reduce_min(select(valid, ray_tnear, vfloat(pos_inf))); + const float reduced_max_dist = reduce_max(select(valid, ray_tfar , vfloat(neg_inf))); + + init(reduced_min_org, reduced_max_org, reduced_min_rdir, reduced_max_rdir, reduced_min_dist, reduced_max_dist, N); + } + + __forceinline void init(const Vec3fa& reduced_min_org, + const Vec3fa& reduced_max_org, + const Vec3fa& reduced_min_rdir, + const Vec3fa& reduced_max_rdir, + float reduced_min_dist, + float reduced_max_dist, + int N) + { + const Vec3ba pos_rdir = ge_mask(reduced_min_rdir, Vec3fa(zero)); + min_rdir = select(pos_rdir, reduced_min_rdir, reduced_max_rdir); + max_rdir = select(pos_rdir, reduced_max_rdir, reduced_min_rdir); + + min_org = select(pos_rdir, reduced_max_org, reduced_min_org); + max_org = select(pos_rdir, reduced_min_org, reduced_max_org); + + min_dist = reduced_min_dist; + max_dist = reduced_max_dist; + + nf = NearFarPrecalculations(min_rdir, N); + } + + template + __forceinline void updateMaxDist(const vfloat& ray_tfar) + { + max_dist = reduce_max(ray_tfar); + } + + NearFarPrecalculations nf; + + Vec3fa min_rdir; + Vec3fa max_rdir; + + Vec3fa min_org; + Vec3fa max_org; + + float min_dist; + float max_dist; + }; + + typedef Frustum FrustumRobust; + + ////////////////////////////////////////////////////////////////////////////////////// + // Fast AABBNode intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t intersectNodeFrustum(const typename BVHN::AABBNode* __restrict__ node, + const FrustumFast& frustum, vfloat& dist) + { + const vfloat bminX = *(const vfloat*)((const char*)&node->lower_x + frustum.nf.nearX); + const vfloat bminY = *(const vfloat*)((const char*)&node->lower_x + frustum.nf.nearY); + const vfloat bminZ = *(const vfloat*)((const char*)&node->lower_x + frustum.nf.nearZ); + const vfloat bmaxX = *(const vfloat*)((const char*)&node->lower_x + frustum.nf.farX); + const vfloat bmaxY = *(const vfloat*)((const char*)&node->lower_x + frustum.nf.farY); + const vfloat bmaxZ = *(const vfloat*)((const char*)&node->lower_x + frustum.nf.farZ); + + const vfloat fminX = msub(bminX, vfloat(frustum.min_rdir.x), vfloat(frustum.min_org_rdir.x)); + const vfloat fminY = msub(bminY, vfloat(frustum.min_rdir.y), vfloat(frustum.min_org_rdir.y)); + const vfloat fminZ = msub(bminZ, vfloat(frustum.min_rdir.z), vfloat(frustum.min_org_rdir.z)); + const vfloat fmaxX = msub(bmaxX, vfloat(frustum.max_rdir.x), vfloat(frustum.max_org_rdir.x)); + const vfloat fmaxY = msub(bmaxY, vfloat(frustum.max_rdir.y), vfloat(frustum.max_org_rdir.y)); + const vfloat fmaxZ = msub(bmaxZ, vfloat(frustum.max_rdir.z), vfloat(frustum.max_org_rdir.z)); + + const vfloat fmin = maxi(fminX, fminY, fminZ, vfloat(frustum.min_dist)); + dist = fmin; + const vfloat fmax = mini(fmaxX, fmaxY, fmaxZ, vfloat(frustum.max_dist)); + const vbool vmask_node_hit = fmin <= fmax; + size_t m_node = movemask(vmask_node_hit) & (((size_t)1 << N)-1); + return m_node; + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Robust AABBNode intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t intersectNodeFrustum(const typename BVHN::AABBNode* __restrict__ node, + const FrustumRobust& frustum, vfloat& dist) + { + const vfloat bminX = *(const vfloat*)((const char*)&node->lower_x + frustum.nf.nearX); + const vfloat bminY = *(const vfloat*)((const char*)&node->lower_x + frustum.nf.nearY); + const vfloat bminZ = *(const vfloat*)((const char*)&node->lower_x + frustum.nf.nearZ); + const vfloat bmaxX = *(const vfloat*)((const char*)&node->lower_x + frustum.nf.farX); + const vfloat bmaxY = *(const vfloat*)((const char*)&node->lower_x + frustum.nf.farY); + const vfloat bmaxZ = *(const vfloat*)((const char*)&node->lower_x + frustum.nf.farZ); + + const vfloat fminX = (bminX - vfloat(frustum.min_org.x)) * vfloat(frustum.min_rdir.x); + const vfloat fminY = (bminY - vfloat(frustum.min_org.y)) * vfloat(frustum.min_rdir.y); + const vfloat fminZ = (bminZ - vfloat(frustum.min_org.z)) * vfloat(frustum.min_rdir.z); + const vfloat fmaxX = (bmaxX - vfloat(frustum.max_org.x)) * vfloat(frustum.max_rdir.x); + const vfloat fmaxY = (bmaxY - vfloat(frustum.max_org.y)) * vfloat(frustum.max_rdir.y); + const vfloat fmaxZ = (bmaxZ - vfloat(frustum.max_org.z)) * vfloat(frustum.max_rdir.z); + + const float round_down = 1.0f-2.0f*float(ulp); // FIXME: use per instruction rounding for AVX512 + const float round_up = 1.0f+2.0f*float(ulp); + const vfloat fmin = max(fminX, fminY, fminZ, vfloat(frustum.min_dist)); + dist = fmin; + const vfloat fmax = min(fmaxX, fmaxY, fmaxZ, vfloat(frustum.max_dist)); + const vbool vmask_node_hit = (round_down*fmin <= round_up*fmax); + size_t m_node = movemask(vmask_node_hit) & (((size_t)1 << N)-1); + return m_node; + } + } +} diff --git a/thirdparty/embree/kernels/bvh/node_intersector_packet.h b/thirdparty/embree/kernels/bvh/node_intersector_packet.h new file mode 100644 index 000000000000..1cc0d47fabd1 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/node_intersector_packet.h @@ -0,0 +1,805 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "node_intersector.h" + +namespace embree +{ + namespace isa + { + ////////////////////////////////////////////////////////////////////////////////////// + // Ray packet structure used in hybrid traversal + ////////////////////////////////////////////////////////////////////////////////////// + + template + struct TravRayK; + + /* Fast variant */ + template + struct TravRayK + { + __forceinline TravRayK() {} + + __forceinline TravRayK(const Vec3vf& ray_org, const Vec3vf& ray_dir, int N) + { + init(ray_org, ray_dir, N); + } + + __forceinline TravRayK(const Vec3vf& ray_org, const Vec3vf& ray_dir, const vfloat& ray_tnear, const vfloat& ray_tfar, int N) + { + init(ray_org, ray_dir, N); + tnear = ray_tnear; + tfar = ray_tfar; + } + + __forceinline void init(const Vec3vf& ray_org, const Vec3vf& ray_dir, int N) + { + org = ray_org; + dir = ray_dir; + rdir = rcp_safe(ray_dir); +#if defined(__AVX2__) + org_rdir = org * rdir; +#endif + + if (N) + { + const int size = sizeof(float)*N; + nearXYZ.x = select(rdir.x >= 0.0f, vint(0*size), vint(1*size)); + nearXYZ.y = select(rdir.y >= 0.0f, vint(2*size), vint(3*size)); + nearXYZ.z = select(rdir.z >= 0.0f, vint(4*size), vint(5*size)); + } + } + + Vec3vf org; + Vec3vf dir; + Vec3vf rdir; +#if defined(__AVX2__) + Vec3vf org_rdir; +#endif + Vec3vi nearXYZ; + vfloat tnear; + vfloat tfar; + }; + + template + using TravRayKFast = TravRayK; + + /* Robust variant */ + template + struct TravRayK + { + __forceinline TravRayK() {} + + __forceinline TravRayK(const Vec3vf& ray_org, const Vec3vf& ray_dir, int N) + { + init(ray_org, ray_dir, N); + } + + __forceinline TravRayK(const Vec3vf& ray_org, const Vec3vf& ray_dir, const vfloat& ray_tnear, const vfloat& ray_tfar, int N) + { + init(ray_org, ray_dir, N); + tnear = ray_tnear; + tfar = ray_tfar; + } + + __forceinline void init(const Vec3vf& ray_org, const Vec3vf& ray_dir, int N) + { + org = ray_org; + dir = ray_dir; + rdir = vfloat(1.0f)/(zero_fix(ray_dir)); + + if (N) + { + const int size = sizeof(float)*N; + nearXYZ.x = select(rdir.x >= 0.0f, vint(0*size), vint(1*size)); + nearXYZ.y = select(rdir.y >= 0.0f, vint(2*size), vint(3*size)); + nearXYZ.z = select(rdir.z >= 0.0f, vint(4*size), vint(5*size)); + } + } + + Vec3vf org; + Vec3vf dir; + Vec3vf rdir; + Vec3vi nearXYZ; + vfloat tnear; + vfloat tfar; + }; + + template + using TravRayKRobust = TravRayK; + + ////////////////////////////////////////////////////////////////////////////////////// + // Fast AABBNode intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vbool intersectNodeK(const typename BVHN::AABBNode* node, size_t i, + const TravRayKFast& ray, vfloat& dist) + + { + #if defined(__AVX2__) + const vfloat lclipMinX = msub(node->lower_x[i], ray.rdir.x, ray.org_rdir.x); + const vfloat lclipMinY = msub(node->lower_y[i], ray.rdir.y, ray.org_rdir.y); + const vfloat lclipMinZ = msub(node->lower_z[i], ray.rdir.z, ray.org_rdir.z); + const vfloat lclipMaxX = msub(node->upper_x[i], ray.rdir.x, ray.org_rdir.x); + const vfloat lclipMaxY = msub(node->upper_y[i], ray.rdir.y, ray.org_rdir.y); + const vfloat lclipMaxZ = msub(node->upper_z[i], ray.rdir.z, ray.org_rdir.z); + #else + const vfloat lclipMinX = (node->lower_x[i] - ray.org.x) * ray.rdir.x; + const vfloat lclipMinY = (node->lower_y[i] - ray.org.y) * ray.rdir.y; + const vfloat lclipMinZ = (node->lower_z[i] - ray.org.z) * ray.rdir.z; + const vfloat lclipMaxX = (node->upper_x[i] - ray.org.x) * ray.rdir.x; + const vfloat lclipMaxY = (node->upper_y[i] - ray.org.y) * ray.rdir.y; + const vfloat lclipMaxZ = (node->upper_z[i] - ray.org.z) * ray.rdir.z; + #endif + + #if defined(__AVX512F__) && !defined(__AVX512ER__) // SKX + if (K == 16) + { + /* use mixed float/int min/max */ + const vfloat lnearP = maxi(min(lclipMinX, lclipMaxX), min(lclipMinY, lclipMaxY), min(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = mini(max(lclipMinX, lclipMaxX), max(lclipMinY, lclipMaxY), max(lclipMinZ, lclipMaxZ)); + const vbool lhit = asInt(maxi(lnearP, ray.tnear)) <= asInt(mini(lfarP, ray.tfar)); + dist = lnearP; + return lhit; + } + else + #endif + { + const vfloat lnearP = maxi(mini(lclipMinX, lclipMaxX), mini(lclipMinY, lclipMaxY), mini(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = mini(maxi(lclipMinX, lclipMaxX), maxi(lclipMinY, lclipMaxY), maxi(lclipMinZ, lclipMaxZ)); + #if defined(__AVX512F__) && !defined(__AVX512ER__) // SKX + const vbool lhit = asInt(maxi(lnearP, ray.tnear)) <= asInt(mini(lfarP, ray.tfar)); + #else + const vbool lhit = maxi(lnearP, ray.tnear) <= mini(lfarP, ray.tfar); + #endif + dist = lnearP; + return lhit; + } + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Robust AABBNode intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vbool intersectNodeKRobust(const typename BVHN::AABBNode* node, size_t i, + const TravRayKRobust& ray, vfloat& dist) + { + // FIXME: use per instruction rounding for AVX512 + const vfloat lclipMinX = (node->lower_x[i] - ray.org.x) * ray.rdir.x; + const vfloat lclipMinY = (node->lower_y[i] - ray.org.y) * ray.rdir.y; + const vfloat lclipMinZ = (node->lower_z[i] - ray.org.z) * ray.rdir.z; + const vfloat lclipMaxX = (node->upper_x[i] - ray.org.x) * ray.rdir.x; + const vfloat lclipMaxY = (node->upper_y[i] - ray.org.y) * ray.rdir.y; + const vfloat lclipMaxZ = (node->upper_z[i] - ray.org.z) * ray.rdir.z; + const float round_up = 1.0f+3.0f*float(ulp); + const float round_down = 1.0f-3.0f*float(ulp); + const vfloat lnearP = round_down*max(max(min(lclipMinX, lclipMaxX), min(lclipMinY, lclipMaxY)), min(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = round_up *min(min(max(lclipMinX, lclipMaxX), max(lclipMinY, lclipMaxY)), max(lclipMinZ, lclipMaxZ)); + const vbool lhit = max(lnearP, ray.tnear) <= min(lfarP, ray.tfar); + dist = lnearP; + return lhit; + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Fast AABBNodeMB intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vbool intersectNodeK(const typename BVHN::AABBNodeMB* node, const size_t i, + const TravRayKFast& ray, const vfloat& time, vfloat& dist) + { + const vfloat vlower_x = madd(time, vfloat(node->lower_dx[i]), vfloat(node->lower_x[i])); + const vfloat vlower_y = madd(time, vfloat(node->lower_dy[i]), vfloat(node->lower_y[i])); + const vfloat vlower_z = madd(time, vfloat(node->lower_dz[i]), vfloat(node->lower_z[i])); + const vfloat vupper_x = madd(time, vfloat(node->upper_dx[i]), vfloat(node->upper_x[i])); + const vfloat vupper_y = madd(time, vfloat(node->upper_dy[i]), vfloat(node->upper_y[i])); + const vfloat vupper_z = madd(time, vfloat(node->upper_dz[i]), vfloat(node->upper_z[i])); + +#if defined(__AVX2__) + const vfloat lclipMinX = msub(vlower_x, ray.rdir.x, ray.org_rdir.x); + const vfloat lclipMinY = msub(vlower_y, ray.rdir.y, ray.org_rdir.y); + const vfloat lclipMinZ = msub(vlower_z, ray.rdir.z, ray.org_rdir.z); + const vfloat lclipMaxX = msub(vupper_x, ray.rdir.x, ray.org_rdir.x); + const vfloat lclipMaxY = msub(vupper_y, ray.rdir.y, ray.org_rdir.y); + const vfloat lclipMaxZ = msub(vupper_z, ray.rdir.z, ray.org_rdir.z); +#else + const vfloat lclipMinX = (vlower_x - ray.org.x) * ray.rdir.x; + const vfloat lclipMinY = (vlower_y - ray.org.y) * ray.rdir.y; + const vfloat lclipMinZ = (vlower_z - ray.org.z) * ray.rdir.z; + const vfloat lclipMaxX = (vupper_x - ray.org.x) * ray.rdir.x; + const vfloat lclipMaxY = (vupper_y - ray.org.y) * ray.rdir.y; + const vfloat lclipMaxZ = (vupper_z - ray.org.z) * ray.rdir.z; +#endif + +#if defined(__AVX512F__) && !defined(__AVX512ER__) // SKX + if (K == 16) + { + /* use mixed float/int min/max */ + const vfloat lnearP = maxi(min(lclipMinX, lclipMaxX), min(lclipMinY, lclipMaxY), min(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = mini(max(lclipMinX, lclipMaxX), max(lclipMinY, lclipMaxY), max(lclipMinZ, lclipMaxZ)); + const vbool lhit = asInt(maxi(lnearP, ray.tnear)) <= asInt(mini(lfarP, ray.tfar)); + dist = lnearP; + return lhit; + } + else +#endif + { + const vfloat lnearP = maxi(mini(lclipMinX, lclipMaxX), mini(lclipMinY, lclipMaxY), mini(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = mini(maxi(lclipMinX, lclipMaxX), maxi(lclipMinY, lclipMaxY), maxi(lclipMinZ, lclipMaxZ)); +#if defined(__AVX512F__) && !defined(__AVX512ER__) // SKX + const vbool lhit = asInt(maxi(lnearP, ray.tnear)) <= asInt(mini(lfarP, ray.tfar)); +#else + const vbool lhit = maxi(lnearP, ray.tnear) <= mini(lfarP, ray.tfar); +#endif + dist = lnearP; + return lhit; + } + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Robust AABBNodeMB intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vbool intersectNodeKRobust(const typename BVHN::AABBNodeMB* node, const size_t i, + const TravRayKRobust& ray, const vfloat& time, vfloat& dist) + { + const vfloat vlower_x = madd(time, vfloat(node->lower_dx[i]), vfloat(node->lower_x[i])); + const vfloat vlower_y = madd(time, vfloat(node->lower_dy[i]), vfloat(node->lower_y[i])); + const vfloat vlower_z = madd(time, vfloat(node->lower_dz[i]), vfloat(node->lower_z[i])); + const vfloat vupper_x = madd(time, vfloat(node->upper_dx[i]), vfloat(node->upper_x[i])); + const vfloat vupper_y = madd(time, vfloat(node->upper_dy[i]), vfloat(node->upper_y[i])); + const vfloat vupper_z = madd(time, vfloat(node->upper_dz[i]), vfloat(node->upper_z[i])); + + const vfloat lclipMinX = (vlower_x - ray.org.x) * ray.rdir.x; + const vfloat lclipMinY = (vlower_y - ray.org.y) * ray.rdir.y; + const vfloat lclipMinZ = (vlower_z - ray.org.z) * ray.rdir.z; + const vfloat lclipMaxX = (vupper_x - ray.org.x) * ray.rdir.x; + const vfloat lclipMaxY = (vupper_y - ray.org.y) * ray.rdir.y; + const vfloat lclipMaxZ = (vupper_z - ray.org.z) * ray.rdir.z; + + const float round_up = 1.0f+3.0f*float(ulp); + const float round_down = 1.0f-3.0f*float(ulp); + +#if defined(__AVX512F__) && !defined(__AVX512ER__) // SKX + if (K == 16) + { + const vfloat lnearP = round_down*maxi(min(lclipMinX, lclipMaxX), min(lclipMinY, lclipMaxY), min(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = round_up *mini(max(lclipMinX, lclipMaxX), max(lclipMinY, lclipMaxY), max(lclipMinZ, lclipMaxZ)); + const vbool lhit = maxi(lnearP, ray.tnear) <= mini(lfarP, ray.tfar); + dist = lnearP; + return lhit; + } + else +#endif + { + const vfloat lnearP = round_down*maxi(mini(lclipMinX, lclipMaxX), mini(lclipMinY, lclipMaxY), mini(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = round_up *mini(maxi(lclipMinX, lclipMaxX), maxi(lclipMinY, lclipMaxY), maxi(lclipMinZ, lclipMaxZ)); + const vbool lhit = maxi(lnearP, ray.tnear) <= mini(lfarP, ray.tfar); + dist = lnearP; + return lhit; + } + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Fast AABBNodeMB4D intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vbool intersectNodeKMB4D(const typename BVHN::NodeRef ref, const size_t i, + const TravRayKFast& ray, const vfloat& time, vfloat& dist) + { + const typename BVHN::AABBNodeMB* node = ref.getAABBNodeMB(); + + const vfloat vlower_x = madd(time, vfloat(node->lower_dx[i]), vfloat(node->lower_x[i])); + const vfloat vlower_y = madd(time, vfloat(node->lower_dy[i]), vfloat(node->lower_y[i])); + const vfloat vlower_z = madd(time, vfloat(node->lower_dz[i]), vfloat(node->lower_z[i])); + const vfloat vupper_x = madd(time, vfloat(node->upper_dx[i]), vfloat(node->upper_x[i])); + const vfloat vupper_y = madd(time, vfloat(node->upper_dy[i]), vfloat(node->upper_y[i])); + const vfloat vupper_z = madd(time, vfloat(node->upper_dz[i]), vfloat(node->upper_z[i])); + +#if defined(__AVX2__) + const vfloat lclipMinX = msub(vlower_x, ray.rdir.x, ray.org_rdir.x); + const vfloat lclipMinY = msub(vlower_y, ray.rdir.y, ray.org_rdir.y); + const vfloat lclipMinZ = msub(vlower_z, ray.rdir.z, ray.org_rdir.z); + const vfloat lclipMaxX = msub(vupper_x, ray.rdir.x, ray.org_rdir.x); + const vfloat lclipMaxY = msub(vupper_y, ray.rdir.y, ray.org_rdir.y); + const vfloat lclipMaxZ = msub(vupper_z, ray.rdir.z, ray.org_rdir.z); +#else + const vfloat lclipMinX = (vlower_x - ray.org.x) * ray.rdir.x; + const vfloat lclipMinY = (vlower_y - ray.org.y) * ray.rdir.y; + const vfloat lclipMinZ = (vlower_z - ray.org.z) * ray.rdir.z; + const vfloat lclipMaxX = (vupper_x - ray.org.x) * ray.rdir.x; + const vfloat lclipMaxY = (vupper_y - ray.org.y) * ray.rdir.y; + const vfloat lclipMaxZ = (vupper_z - ray.org.z) * ray.rdir.z; +#endif + + const vfloat lnearP = maxi(maxi(mini(lclipMinX, lclipMaxX), mini(lclipMinY, lclipMaxY)), mini(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = mini(mini(maxi(lclipMinX, lclipMaxX), maxi(lclipMinY, lclipMaxY)), maxi(lclipMinZ, lclipMaxZ)); + vbool lhit = maxi(lnearP, ray.tnear) <= mini(lfarP, ray.tfar); + if (unlikely(ref.isAABBNodeMB4D())) { + const typename BVHN::AABBNodeMB4D* node1 = (const typename BVHN::AABBNodeMB4D*) node; + lhit = lhit & (vfloat(node1->lower_t[i]) <= time) & (time < vfloat(node1->upper_t[i])); + } + dist = lnearP; + return lhit; + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Robust AABBNodeMB4D intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vbool intersectNodeKMB4DRobust(const typename BVHN::NodeRef ref, const size_t i, + const TravRayKRobust& ray, const vfloat& time, vfloat& dist) + { + const typename BVHN::AABBNodeMB* node = ref.getAABBNodeMB(); + + const vfloat vlower_x = madd(time, vfloat(node->lower_dx[i]), vfloat(node->lower_x[i])); + const vfloat vlower_y = madd(time, vfloat(node->lower_dy[i]), vfloat(node->lower_y[i])); + const vfloat vlower_z = madd(time, vfloat(node->lower_dz[i]), vfloat(node->lower_z[i])); + const vfloat vupper_x = madd(time, vfloat(node->upper_dx[i]), vfloat(node->upper_x[i])); + const vfloat vupper_y = madd(time, vfloat(node->upper_dy[i]), vfloat(node->upper_y[i])); + const vfloat vupper_z = madd(time, vfloat(node->upper_dz[i]), vfloat(node->upper_z[i])); + + const vfloat lclipMinX = (vlower_x - ray.org.x) * ray.rdir.x; + const vfloat lclipMinY = (vlower_y - ray.org.y) * ray.rdir.y; + const vfloat lclipMinZ = (vlower_z - ray.org.z) * ray.rdir.z; + const vfloat lclipMaxX = (vupper_x - ray.org.x) * ray.rdir.x; + const vfloat lclipMaxY = (vupper_y - ray.org.y) * ray.rdir.y; + const vfloat lclipMaxZ = (vupper_z - ray.org.z) * ray.rdir.z; + + const float round_up = 1.0f+3.0f*float(ulp); + const float round_down = 1.0f-3.0f*float(ulp); + const vfloat lnearP = round_down*maxi(maxi(mini(lclipMinX, lclipMaxX), mini(lclipMinY, lclipMaxY)), mini(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = round_up *mini(mini(maxi(lclipMinX, lclipMaxX), maxi(lclipMinY, lclipMaxY)), maxi(lclipMinZ, lclipMaxZ)); + vbool lhit = maxi(lnearP, ray.tnear) <= mini(lfarP, ray.tfar); + + if (unlikely(ref.isAABBNodeMB4D())) { + const typename BVHN::AABBNodeMB4D* node1 = (const typename BVHN::AABBNodeMB4D*) node; + lhit = lhit & (vfloat(node1->lower_t[i]) <= time) & (time < vfloat(node1->upper_t[i])); + } + dist = lnearP; + return lhit; + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Fast OBBNode intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vbool intersectNodeK(const typename BVHN::OBBNode* node, const size_t i, + const TravRayK& ray, vfloat& dist) + { + const AffineSpace3vf naabb(Vec3f(node->naabb.l.vx.x[i], node->naabb.l.vx.y[i], node->naabb.l.vx.z[i]), + Vec3f(node->naabb.l.vy.x[i], node->naabb.l.vy.y[i], node->naabb.l.vy.z[i]), + Vec3f(node->naabb.l.vz.x[i], node->naabb.l.vz.y[i], node->naabb.l.vz.z[i]), + Vec3f(node->naabb.p .x[i], node->naabb.p .y[i], node->naabb.p .z[i])); + + const Vec3vf dir = xfmVector(naabb, ray.dir); + const Vec3vf nrdir = Vec3vf(vfloat(-1.0f)) * rcp_safe(dir); // FIXME: negate instead of mul with -1? + const Vec3vf org = xfmPoint(naabb, ray.org); + + const vfloat lclipMinX = org.x * nrdir.x; // (Vec3fa(zero) - org) * rdir; + const vfloat lclipMinY = org.y * nrdir.y; + const vfloat lclipMinZ = org.z * nrdir.z; + const vfloat lclipMaxX = lclipMinX - nrdir.x; // (Vec3fa(one) - org) * rdir; + const vfloat lclipMaxY = lclipMinY - nrdir.y; + const vfloat lclipMaxZ = lclipMinZ - nrdir.z; + + vfloat lnearP = maxi(mini(lclipMinX, lclipMaxX), mini(lclipMinY, lclipMaxY), mini(lclipMinZ, lclipMaxZ)); + vfloat lfarP = mini(maxi(lclipMinX, lclipMaxX), maxi(lclipMinY, lclipMaxY), maxi(lclipMinZ, lclipMaxZ)); + if (robust) { + lnearP = lnearP*vfloat(1.0f-3.0f*float(ulp)); + lfarP = lfarP *vfloat(1.0f+3.0f*float(ulp)); + } + const vbool lhit = maxi(lnearP, ray.tnear) <= mini(lfarP, ray.tfar); + dist = lnearP; + return lhit; + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Fast OBBNodeMB intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vbool intersectNodeK(const typename BVHN::OBBNodeMB* node, const size_t i, + const TravRayK& ray, const vfloat& time, vfloat& dist) + { + const AffineSpace3vf xfm(Vec3f(node->space0.l.vx.x[i], node->space0.l.vx.y[i], node->space0.l.vx.z[i]), + Vec3f(node->space0.l.vy.x[i], node->space0.l.vy.y[i], node->space0.l.vy.z[i]), + Vec3f(node->space0.l.vz.x[i], node->space0.l.vz.y[i], node->space0.l.vz.z[i]), + Vec3f(node->space0.p .x[i], node->space0.p .y[i], node->space0.p .z[i])); + + const Vec3vf b0_lower = zero; + const Vec3vf b0_upper = one; + const Vec3vf b1_lower(node->b1.lower.x[i], node->b1.lower.y[i], node->b1.lower.z[i]); + const Vec3vf b1_upper(node->b1.upper.x[i], node->b1.upper.y[i], node->b1.upper.z[i]); + const Vec3vf lower = lerp(b0_lower, b1_lower, time); + const Vec3vf upper = lerp(b0_upper, b1_upper, time); + + const Vec3vf dir = xfmVector(xfm, ray.dir); + const Vec3vf rdir = rcp_safe(dir); + const Vec3vf org = xfmPoint(xfm, ray.org); + + const vfloat lclipMinX = (lower.x - org.x) * rdir.x; + const vfloat lclipMinY = (lower.y - org.y) * rdir.y; + const vfloat lclipMinZ = (lower.z - org.z) * rdir.z; + const vfloat lclipMaxX = (upper.x - org.x) * rdir.x; + const vfloat lclipMaxY = (upper.y - org.y) * rdir.y; + const vfloat lclipMaxZ = (upper.z - org.z) * rdir.z; + + vfloat lnearP = maxi(mini(lclipMinX, lclipMaxX), mini(lclipMinY, lclipMaxY), mini(lclipMinZ, lclipMaxZ)); + vfloat lfarP = mini(maxi(lclipMinX, lclipMaxX), maxi(lclipMinY, lclipMaxY), maxi(lclipMinZ, lclipMaxZ)); + if (robust) { + lnearP = lnearP*vfloat(1.0f-3.0f*float(ulp)); + lfarP = lfarP *vfloat(1.0f+3.0f*float(ulp)); + } + + const vbool lhit = maxi(lnearP, ray.tnear) <= mini(lfarP, ray.tfar); + dist = lnearP; + return lhit; + } + + + + ////////////////////////////////////////////////////////////////////////////////////// + // QuantizedBaseNode intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline vbool intersectQuantizedNodeK(const typename BVHN::QuantizedBaseNode* node, size_t i, + const TravRayK& ray, vfloat& dist) + + { + assert(movemask(node->validMask()) & ((size_t)1 << i)); + const vfloat lower_x = node->dequantizeLowerX(); + const vfloat upper_x = node->dequantizeUpperX(); + const vfloat lower_y = node->dequantizeLowerY(); + const vfloat upper_y = node->dequantizeUpperY(); + const vfloat lower_z = node->dequantizeLowerZ(); + const vfloat upper_z = node->dequantizeUpperZ(); + + #if defined(__AVX2__) + const vfloat lclipMinX = msub(lower_x[i], ray.rdir.x, ray.org_rdir.x); + const vfloat lclipMinY = msub(lower_y[i], ray.rdir.y, ray.org_rdir.y); + const vfloat lclipMinZ = msub(lower_z[i], ray.rdir.z, ray.org_rdir.z); + const vfloat lclipMaxX = msub(upper_x[i], ray.rdir.x, ray.org_rdir.x); + const vfloat lclipMaxY = msub(upper_y[i], ray.rdir.y, ray.org_rdir.y); + const vfloat lclipMaxZ = msub(upper_z[i], ray.rdir.z, ray.org_rdir.z); + #else + const vfloat lclipMinX = (lower_x[i] - ray.org.x) * ray.rdir.x; + const vfloat lclipMinY = (lower_y[i] - ray.org.y) * ray.rdir.y; + const vfloat lclipMinZ = (lower_z[i] - ray.org.z) * ray.rdir.z; + const vfloat lclipMaxX = (upper_x[i] - ray.org.x) * ray.rdir.x; + const vfloat lclipMaxY = (upper_y[i] - ray.org.y) * ray.rdir.y; + const vfloat lclipMaxZ = (upper_z[i] - ray.org.z) * ray.rdir.z; + #endif + + #if defined(__AVX512F__) && !defined(__AVX512ER__) // SKX + if (K == 16) + { + /* use mixed float/int min/max */ + const vfloat lnearP = maxi(min(lclipMinX, lclipMaxX), min(lclipMinY, lclipMaxY), min(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = mini(max(lclipMinX, lclipMaxX), max(lclipMinY, lclipMaxY), max(lclipMinZ, lclipMaxZ)); + const vbool lhit = asInt(maxi(lnearP, ray.tnear)) <= asInt(mini(lfarP, ray.tfar)); + dist = lnearP; + return lhit; + } + else + #endif + { + const vfloat lnearP = maxi(mini(lclipMinX, lclipMaxX), mini(lclipMinY, lclipMaxY), mini(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = mini(maxi(lclipMinX, lclipMaxX), maxi(lclipMinY, lclipMaxY), maxi(lclipMinZ, lclipMaxZ)); + #if defined(__AVX512F__) && !defined(__AVX512ER__) // SKX + const vbool lhit = asInt(maxi(lnearP, ray.tnear)) <= asInt(mini(lfarP, ray.tfar)); + #else + const vbool lhit = maxi(lnearP, ray.tnear) <= mini(lfarP, ray.tfar); + #endif + dist = lnearP; + return lhit; + } + } + + template + __forceinline vbool intersectQuantizedNodeK(const typename BVHN::QuantizedBaseNode* node, size_t i, + const TravRayK& ray, vfloat& dist) + + { + assert(movemask(node->validMask()) & ((size_t)1 << i)); + const vfloat lower_x = node->dequantizeLowerX(); + const vfloat upper_x = node->dequantizeUpperX(); + const vfloat lower_y = node->dequantizeLowerY(); + const vfloat upper_y = node->dequantizeUpperY(); + const vfloat lower_z = node->dequantizeLowerZ(); + const vfloat upper_z = node->dequantizeUpperZ(); + + const vfloat lclipMinX = (lower_x[i] - ray.org.x) * ray.rdir.x; + const vfloat lclipMinY = (lower_y[i] - ray.org.y) * ray.rdir.y; + const vfloat lclipMinZ = (lower_z[i] - ray.org.z) * ray.rdir.z; + const vfloat lclipMaxX = (upper_x[i] - ray.org.x) * ray.rdir.x; + const vfloat lclipMaxY = (upper_y[i] - ray.org.y) * ray.rdir.y; + const vfloat lclipMaxZ = (upper_z[i] - ray.org.z) * ray.rdir.z; + + const float round_up = 1.0f+3.0f*float(ulp); + const float round_down = 1.0f-3.0f*float(ulp); + + const vfloat lnearP = round_down*max(min(lclipMinX, lclipMaxX), min(lclipMinY, lclipMaxY), min(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = round_up *min(max(lclipMinX, lclipMaxX), max(lclipMinY, lclipMaxY), max(lclipMinZ, lclipMaxZ)); + const vbool lhit = max(lnearP, ray.tnear) <= min(lfarP, ray.tfar); + dist = lnearP; + return lhit; + } + + template + __forceinline vbool intersectQuantizedNodeMBK(const typename BVHN::QuantizedBaseNodeMB* node, const size_t i, + const TravRayK& ray, const vfloat& time, vfloat& dist) + + { + assert(movemask(node->validMask()) & ((size_t)1 << i)); + + const vfloat lower_x = node->dequantizeLowerX(i,time); + const vfloat upper_x = node->dequantizeUpperX(i,time); + const vfloat lower_y = node->dequantizeLowerY(i,time); + const vfloat upper_y = node->dequantizeUpperY(i,time); + const vfloat lower_z = node->dequantizeLowerZ(i,time); + const vfloat upper_z = node->dequantizeUpperZ(i,time); + +#if defined(__AVX2__) + const vfloat lclipMinX = msub(lower_x, ray.rdir.x, ray.org_rdir.x); + const vfloat lclipMinY = msub(lower_y, ray.rdir.y, ray.org_rdir.y); + const vfloat lclipMinZ = msub(lower_z, ray.rdir.z, ray.org_rdir.z); + const vfloat lclipMaxX = msub(upper_x, ray.rdir.x, ray.org_rdir.x); + const vfloat lclipMaxY = msub(upper_y, ray.rdir.y, ray.org_rdir.y); + const vfloat lclipMaxZ = msub(upper_z, ray.rdir.z, ray.org_rdir.z); +#else + const vfloat lclipMinX = (lower_x - ray.org.x) * ray.rdir.x; + const vfloat lclipMinY = (lower_y - ray.org.y) * ray.rdir.y; + const vfloat lclipMinZ = (lower_z - ray.org.z) * ray.rdir.z; + const vfloat lclipMaxX = (upper_x - ray.org.x) * ray.rdir.x; + const vfloat lclipMaxY = (upper_y - ray.org.y) * ray.rdir.y; + const vfloat lclipMaxZ = (upper_z - ray.org.z) * ray.rdir.z; + #endif + const vfloat lnearP = max(min(lclipMinX, lclipMaxX), min(lclipMinY, lclipMaxY), min(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = min(max(lclipMinX, lclipMaxX), max(lclipMinY, lclipMaxY), max(lclipMinZ, lclipMaxZ)); + const vbool lhit = max(lnearP, ray.tnear) <= min(lfarP, ray.tfar); + dist = lnearP; + return lhit; + } + + + template + __forceinline vbool intersectQuantizedNodeMBK(const typename BVHN::QuantizedBaseNodeMB* node, const size_t i, + const TravRayK& ray, const vfloat& time, vfloat& dist) + + { + assert(movemask(node->validMask()) & ((size_t)1 << i)); + + const vfloat lower_x = node->dequantizeLowerX(i,time); + const vfloat upper_x = node->dequantizeUpperX(i,time); + const vfloat lower_y = node->dequantizeLowerY(i,time); + const vfloat upper_y = node->dequantizeUpperY(i,time); + const vfloat lower_z = node->dequantizeLowerZ(i,time); + const vfloat upper_z = node->dequantizeUpperZ(i,time); + + const vfloat lclipMinX = (lower_x - ray.org.x) * ray.rdir.x; + const vfloat lclipMinY = (lower_y - ray.org.y) * ray.rdir.y; + const vfloat lclipMinZ = (lower_z - ray.org.z) * ray.rdir.z; + const vfloat lclipMaxX = (upper_x - ray.org.x) * ray.rdir.x; + const vfloat lclipMaxY = (upper_y - ray.org.y) * ray.rdir.y; + const vfloat lclipMaxZ = (upper_z - ray.org.z) * ray.rdir.z; + + const float round_up = 1.0f+3.0f*float(ulp); + const float round_down = 1.0f-3.0f*float(ulp); + + const vfloat lnearP = round_down*max(min(lclipMinX, lclipMaxX), min(lclipMinY, lclipMaxY), min(lclipMinZ, lclipMaxZ)); + const vfloat lfarP = round_up *min(max(lclipMinX, lclipMaxX), max(lclipMinY, lclipMaxY), max(lclipMinZ, lclipMaxZ)); + const vbool lhit = max(lnearP, ray.tnear) <= min(lfarP, ray.tfar); + dist = lnearP; + return lhit; + } + + + ////////////////////////////////////////////////////////////////////////////////////// + // Node intersectors used in hybrid traversal + ////////////////////////////////////////////////////////////////////////////////////// + + /*! Intersects N nodes with K rays */ + template + struct BVHNNodeIntersectorK; + + template + struct BVHNNodeIntersectorK + { + /* vmask is both an input and an output parameter! Its initial value should be the parent node + hit mask, which is used for correctly computing the current hit mask. The parent hit mask + is actually required only for motion blur node intersections (because different rays may + have different times), so for regular nodes vmask is simply overwritten. */ + static __forceinline bool intersect(const typename BVHN::NodeRef& node, size_t i, + const TravRayKFast& ray, const vfloat& time, vfloat& dist, vbool& vmask) + { + vmask = intersectNodeK(node.getAABBNode(), i, ray, dist); + return true; + } + }; + + template + struct BVHNNodeIntersectorK + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, size_t i, + const TravRayKRobust& ray, const vfloat& time, vfloat& dist, vbool& vmask) + { + vmask = intersectNodeKRobust(node.getAABBNode(), i, ray, dist); + return true; + } + }; + + template + struct BVHNNodeIntersectorK + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, size_t i, + const TravRayKFast& ray, const vfloat& time, vfloat& dist, vbool& vmask) + { + vmask = intersectNodeK(node.getAABBNodeMB(), i, ray, time, dist); + return true; + } + }; + + template + struct BVHNNodeIntersectorK + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, size_t i, + const TravRayKRobust& ray, const vfloat& time, vfloat& dist, vbool& vmask) + { + vmask = intersectNodeKRobust(node.getAABBNodeMB(), i, ray, time, dist); + return true; + } + }; + + template + struct BVHNNodeIntersectorK + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, size_t i, + const TravRayKFast& ray, const vfloat& time, vfloat& dist, vbool& vmask) + { + if (likely(node.isAABBNode())) vmask = intersectNodeK(node.getAABBNode(), i, ray, dist); + else /*if (unlikely(node.isOBBNode()))*/ vmask = intersectNodeK(node.ungetAABBNode(), i, ray, dist); + return true; + } + }; + + template + struct BVHNNodeIntersectorK + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, size_t i, + const TravRayKRobust& ray, const vfloat& time, vfloat& dist, vbool& vmask) + { + if (likely(node.isAABBNode())) vmask = intersectNodeKRobust(node.getAABBNode(), i, ray, dist); + else /*if (unlikely(node.isOBBNode()))*/ vmask = intersectNodeK(node.ungetAABBNode(), i, ray, dist); + return true; + } + }; + + template + struct BVHNNodeIntersectorK + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, size_t i, + const TravRayKFast& ray, const vfloat& time, vfloat& dist, vbool& vmask) + { + if (likely(node.isAABBNodeMB())) vmask = intersectNodeK(node.getAABBNodeMB(), i, ray, time, dist); + else /*if (unlikely(node.isOBBNodeMB()))*/ vmask = intersectNodeK(node.ungetAABBNodeMB(), i, ray, time, dist); + return true; + } + }; + + template + struct BVHNNodeIntersectorK + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, size_t i, + const TravRayKRobust& ray, const vfloat& time, vfloat& dist, vbool& vmask) + { + if (likely(node.isAABBNodeMB())) vmask = intersectNodeKRobust(node.getAABBNodeMB(), i, ray, time, dist); + else /*if (unlikely(node.isOBBNodeMB()))*/ vmask = intersectNodeK(node.ungetAABBNodeMB(), i, ray, time, dist); + return true; + } + }; + + template + struct BVHNNodeIntersectorK + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, size_t i, + const TravRayKFast& ray, const vfloat& time, vfloat& dist, vbool& vmask) + { + vmask &= intersectNodeKMB4D(node, i, ray, time, dist); + return true; + } + }; + + template + struct BVHNNodeIntersectorK + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, size_t i, + const TravRayKRobust& ray, const vfloat& time, vfloat& dist, vbool& vmask) + { + vmask &= intersectNodeKMB4DRobust(node, i, ray, time, dist); + return true; + } + }; + + template + struct BVHNNodeIntersectorK + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, size_t i, + const TravRayKFast& ray, const vfloat& time, vfloat& dist, vbool& vmask) + { + if (likely(node.isAABBNodeMB() || node.isAABBNodeMB4D())) { + vmask &= intersectNodeKMB4D(node, i, ray, time, dist); + } else /*if (unlikely(node.isOBBNodeMB()))*/ { + assert(node.isOBBNodeMB()); + vmask &= intersectNodeK(node.ungetAABBNodeMB(), i, ray, time, dist); + } + return true; + } + }; + + template + struct BVHNNodeIntersectorK + { + static __forceinline bool intersect(const typename BVHN::NodeRef& node, size_t i, + const TravRayKRobust& ray, const vfloat& time, vfloat& dist, vbool& vmask) + { + if (likely(node.isAABBNodeMB() || node.isAABBNodeMB4D())) { + vmask &= intersectNodeKMB4DRobust(node, i, ray, time, dist); + } else /*if (unlikely(node.isOBBNodeMB()))*/ { + assert(node.isOBBNodeMB()); + vmask &= intersectNodeK(node.ungetAABBNodeMB(), i, ray, time, dist); + } + return true; + } + }; + + + /*! Intersects N nodes with K rays */ + template + struct BVHNQuantizedBaseNodeIntersectorK; + + template + struct BVHNQuantizedBaseNodeIntersectorK + { + static __forceinline vbool intersectK(const typename BVHN::QuantizedBaseNode* node, const size_t i, + const TravRayK& ray, vfloat& dist) + { + return intersectQuantizedNodeK(node,i,ray,dist); + } + + static __forceinline vbool intersectK(const typename BVHN::QuantizedBaseNodeMB* node, const size_t i, + const TravRayK& ray, const vfloat& time, vfloat& dist) + { + return intersectQuantizedNodeMBK(node,i,ray,time,dist); + } + + }; + + template + struct BVHNQuantizedBaseNodeIntersectorK + { + static __forceinline vbool intersectK(const typename BVHN::QuantizedBaseNode* node, const size_t i, + const TravRayK& ray, vfloat& dist) + { + return intersectQuantizedNodeK(node,i,ray,dist); + } + + static __forceinline vbool intersectK(const typename BVHN::QuantizedBaseNodeMB* node, const size_t i, + const TravRayK& ray, const vfloat& time, vfloat& dist) + { + return intersectQuantizedNodeMBK(node,i,ray,time,dist); + } + }; + + + } +} diff --git a/thirdparty/embree/kernels/bvh/node_intersector_packet_stream.h b/thirdparty/embree/kernels/bvh/node_intersector_packet_stream.h new file mode 100644 index 000000000000..c2b5b0cb7a47 --- /dev/null +++ b/thirdparty/embree/kernels/bvh/node_intersector_packet_stream.h @@ -0,0 +1,189 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "node_intersector.h" + +namespace embree +{ + namespace isa + { + ////////////////////////////////////////////////////////////////////////////////////// + // Ray packet structure used in stream traversal + ////////////////////////////////////////////////////////////////////////////////////// + + template + struct TravRayKStream; + + /* Fast variant */ + template + struct TravRayKStream + { + __forceinline TravRayKStream() {} + + __forceinline TravRayKStream(const Vec3vf& ray_org, const Vec3vf& ray_dir, const vfloat& ray_tnear, const vfloat& ray_tfar) + { + init(ray_org, ray_dir); + tnear = ray_tnear; + tfar = ray_tfar; + } + + __forceinline void init(const Vec3vf& ray_org, const Vec3vf& ray_dir) + { + rdir = rcp_safe(ray_dir); + org_rdir = ray_org * rdir; + } + + Vec3vf rdir; + Vec3vf org_rdir; + vfloat tnear; + vfloat tfar; + }; + + template + using TravRayKStreamFast = TravRayKStream; + + /* Robust variant */ + template + struct TravRayKStream + { + __forceinline TravRayKStream() {} + + __forceinline TravRayKStream(const Vec3vf& ray_org, const Vec3vf& ray_dir, const vfloat& ray_tnear, const vfloat& ray_tfar) + { + init(ray_org, ray_dir); + tnear = ray_tnear; + tfar = ray_tfar; + } + + __forceinline void init(const Vec3vf& ray_org, const Vec3vf& ray_dir) + { + rdir = vfloat(1.0f)/(zero_fix(ray_dir)); + org = ray_org; + } + + Vec3vf rdir; + Vec3vf org; + vfloat tnear; + vfloat tfar; + }; + + template + using TravRayKStreamRobust = TravRayKStream; + + ////////////////////////////////////////////////////////////////////////////////////// + // Fast AABBNode intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t intersectNode1(const typename BVHN::AABBNode* __restrict__ node, + const TravRayKStreamFast& ray, size_t k, const NearFarPrecalculations& nf) + { + const vfloat bminX = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.nearX)); + const vfloat bminY = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.nearY)); + const vfloat bminZ = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.nearZ)); + const vfloat bmaxX = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.farX)); + const vfloat bmaxY = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.farY)); + const vfloat bmaxZ = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.farZ)); + + const vfloat rminX = msub(bminX, vfloat(ray.rdir.x[k]), vfloat(ray.org_rdir.x[k])); + const vfloat rminY = msub(bminY, vfloat(ray.rdir.y[k]), vfloat(ray.org_rdir.y[k])); + const vfloat rminZ = msub(bminZ, vfloat(ray.rdir.z[k]), vfloat(ray.org_rdir.z[k])); + const vfloat rmaxX = msub(bmaxX, vfloat(ray.rdir.x[k]), vfloat(ray.org_rdir.x[k])); + const vfloat rmaxY = msub(bmaxY, vfloat(ray.rdir.y[k]), vfloat(ray.org_rdir.y[k])); + const vfloat rmaxZ = msub(bmaxZ, vfloat(ray.rdir.z[k]), vfloat(ray.org_rdir.z[k])); + const vfloat rmin = maxi(rminX, rminY, rminZ, vfloat(ray.tnear[k])); + const vfloat rmax = mini(rmaxX, rmaxY, rmaxZ, vfloat(ray.tfar[k])); + + const vbool vmask_first_hit = rmin <= rmax; + + return movemask(vmask_first_hit) & (((size_t)1 << N)-1); + } + + template + __forceinline size_t intersectNodeK(const typename BVHN::AABBNode* __restrict__ node, size_t i, + const TravRayKStreamFast& ray, const NearFarPrecalculations& nf) + { + char* ptr = (char*)&node->lower_x + i*sizeof(float); + const vfloat bminX = *(const float*)(ptr + nf.nearX); + const vfloat bminY = *(const float*)(ptr + nf.nearY); + const vfloat bminZ = *(const float*)(ptr + nf.nearZ); + const vfloat bmaxX = *(const float*)(ptr + nf.farX); + const vfloat bmaxY = *(const float*)(ptr + nf.farY); + const vfloat bmaxZ = *(const float*)(ptr + nf.farZ); + + const vfloat rminX = msub(bminX, ray.rdir.x, ray.org_rdir.x); + const vfloat rminY = msub(bminY, ray.rdir.y, ray.org_rdir.y); + const vfloat rminZ = msub(bminZ, ray.rdir.z, ray.org_rdir.z); + const vfloat rmaxX = msub(bmaxX, ray.rdir.x, ray.org_rdir.x); + const vfloat rmaxY = msub(bmaxY, ray.rdir.y, ray.org_rdir.y); + const vfloat rmaxZ = msub(bmaxZ, ray.rdir.z, ray.org_rdir.z); + + const vfloat rmin = maxi(rminX, rminY, rminZ, ray.tnear); + const vfloat rmax = mini(rmaxX, rmaxY, rmaxZ, ray.tfar); + + const vbool vmask_first_hit = rmin <= rmax; + + return movemask(vmask_first_hit); + } + + ////////////////////////////////////////////////////////////////////////////////////// + // Robust AABBNode intersection + ////////////////////////////////////////////////////////////////////////////////////// + + template + __forceinline size_t intersectNode1(const typename BVHN::AABBNode* __restrict__ node, + const TravRayKStreamRobust& ray, size_t k, const NearFarPrecalculations& nf) + { + const vfloat bminX = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.nearX)); + const vfloat bminY = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.nearY)); + const vfloat bminZ = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.nearZ)); + const vfloat bmaxX = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.farX)); + const vfloat bmaxY = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.farY)); + const vfloat bmaxZ = vfloat(*(const vfloat*)((const char*)&node->lower_x + nf.farZ)); + + const vfloat rminX = (bminX - vfloat(ray.org.x[k])) * vfloat(ray.rdir.x[k]); + const vfloat rminY = (bminY - vfloat(ray.org.y[k])) * vfloat(ray.rdir.y[k]); + const vfloat rminZ = (bminZ - vfloat(ray.org.z[k])) * vfloat(ray.rdir.z[k]); + const vfloat rmaxX = (bmaxX - vfloat(ray.org.x[k])) * vfloat(ray.rdir.x[k]); + const vfloat rmaxY = (bmaxY - vfloat(ray.org.y[k])) * vfloat(ray.rdir.y[k]); + const vfloat rmaxZ = (bmaxZ - vfloat(ray.org.z[k])) * vfloat(ray.rdir.z[k]); + const float round_up = 1.0f+3.0f*float(ulp); // FIXME: use per instruction rounding for AVX512 + const vfloat rmin = max(rminX, rminY, rminZ, vfloat(ray.tnear[k])); + const vfloat rmax = round_up *min(rmaxX, rmaxY, rmaxZ, vfloat(ray.tfar[k])); + + const vbool vmask_first_hit = rmin <= rmax; + + return movemask(vmask_first_hit) & (((size_t)1 << N)-1); + } + + template + __forceinline size_t intersectNodeK(const typename BVHN::AABBNode* __restrict__ node, size_t i, + const TravRayKStreamRobust& ray, const NearFarPrecalculations& nf) + { + char *ptr = (char*)&node->lower_x + i*sizeof(float); + const vfloat bminX = *(const float*)(ptr + nf.nearX); + const vfloat bminY = *(const float*)(ptr + nf.nearY); + const vfloat bminZ = *(const float*)(ptr + nf.nearZ); + const vfloat bmaxX = *(const float*)(ptr + nf.farX); + const vfloat bmaxY = *(const float*)(ptr + nf.farY); + const vfloat bmaxZ = *(const float*)(ptr + nf.farZ); + + const vfloat rminX = (bminX - ray.org.x) * ray.rdir.x; + const vfloat rminY = (bminY - ray.org.y) * ray.rdir.y; + const vfloat rminZ = (bminZ - ray.org.z) * ray.rdir.z; + const vfloat rmaxX = (bmaxX - ray.org.x) * ray.rdir.x; + const vfloat rmaxY = (bmaxY - ray.org.y) * ray.rdir.y; + const vfloat rmaxZ = (bmaxZ - ray.org.z) * ray.rdir.z; + + const float round_up = 1.0f+3.0f*float(ulp); + const vfloat rmin = max(rminX, rminY, rminZ, vfloat(ray.tnear)); + const vfloat rmax = round_up * min(rmaxX, rmaxY, rmaxZ, vfloat(ray.tfar)); + + const vbool vmask_first_hit = rmin <= rmax; + + return movemask(vmask_first_hit); + } + } +} diff --git a/thirdparty/embree/kernels/common/accel.h b/thirdparty/embree/kernels/common/accel.h new file mode 100644 index 000000000000..f332d3655516 --- /dev/null +++ b/thirdparty/embree/kernels/common/accel.h @@ -0,0 +1,556 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" +#include "ray.h" +#include "point_query.h" +#include "context.h" + +namespace embree +{ + class Scene; + + /*! Base class for the acceleration structure data. */ + class AccelData : public RefCount + { + ALIGNED_CLASS_(16); + public: + enum Type { TY_UNKNOWN = 0, TY_ACCELN = 1, TY_ACCEL_INSTANCE = 2, TY_BVH4 = 3, TY_BVH8 = 4 }; + + public: + AccelData (const Type type) + : bounds(empty), type(type) {} + + /*! notifies the acceleration structure about the deletion of some geometry */ + virtual void deleteGeometry(size_t geomID) {}; + + /*! clears the acceleration structure data */ + virtual void clear() = 0; + + /*! returns normal bounds */ + __forceinline BBox3fa getBounds() const { + return bounds.bounds(); + } + + /*! returns bounds for some time */ + __forceinline BBox3fa getBounds(float t) const { + return bounds.interpolate(t); + } + + /*! returns linear bounds */ + __forceinline LBBox3fa getLinearBounds() const { + return bounds; + } + + /*! checks if acceleration structure is empty */ + __forceinline bool isEmpty() const { + return bounds.bounds0.lower.x == float(pos_inf); + } + + public: + LBBox3fa bounds; // linear bounds + Type type; + }; + + /*! Base class for all intersectable and buildable acceleration structures. */ + class Accel : public AccelData + { + ALIGNED_CLASS_(16); + public: + + struct Intersectors; + + /*! Type of collide function */ + typedef void (*CollideFunc)(void* bvh0, void* bvh1, RTCCollideFunc callback, void* userPtr); + + /*! Type of point query function */ + typedef bool(*PointQueryFunc)(Intersectors* This, /*!< this pointer to accel */ + PointQuery* query, /*!< point query for lookup */ + PointQueryContext* context); /*!< point query context */ + + /*! Type of intersect function pointer for single rays. */ + typedef void (*IntersectFunc)(Intersectors* This, /*!< this pointer to accel */ + RTCRayHit& ray, /*!< ray to intersect */ + IntersectContext* context); + + /*! Type of intersect function pointer for ray packets of size 4. */ + typedef void (*IntersectFunc4)(const void* valid, /*!< pointer to valid mask */ + Intersectors* This, /*!< this pointer to accel */ + RTCRayHit4& ray, /*!< ray packet to intersect */ + IntersectContext* context); + + /*! Type of intersect function pointer for ray packets of size 8. */ + typedef void (*IntersectFunc8)(const void* valid, /*!< pointer to valid mask */ + Intersectors* This, /*!< this pointer to accel */ + RTCRayHit8& ray, /*!< ray packet to intersect */ + IntersectContext* context); + + /*! Type of intersect function pointer for ray packets of size 16. */ + typedef void (*IntersectFunc16)(const void* valid, /*!< pointer to valid mask */ + Intersectors* This, /*!< this pointer to accel */ + RTCRayHit16& ray, /*!< ray packet to intersect */ + IntersectContext* context); + + /*! Type of intersect function pointer for ray packets of size N. */ + typedef void (*IntersectFuncN)(Intersectors* This, /*!< this pointer to accel */ + RTCRayHitN** ray, /*!< ray stream to intersect */ + const size_t N, /*!< number of rays in stream */ + IntersectContext* context /*!< layout flags */); + + + /*! Type of occlusion function pointer for single rays. */ + typedef void (*OccludedFunc) (Intersectors* This, /*!< this pointer to accel */ + RTCRay& ray, /*!< ray to test occlusion */ + IntersectContext* context); + + /*! Type of occlusion function pointer for ray packets of size 4. */ + typedef void (*OccludedFunc4) (const void* valid, /*!< pointer to valid mask */ + Intersectors* This, /*!< this pointer to accel */ + RTCRay4& ray, /*!< ray packet to test occlusion. */ + IntersectContext* context); + + /*! Type of occlusion function pointer for ray packets of size 8. */ + typedef void (*OccludedFunc8) (const void* valid, /*!< pointer to valid mask */ + Intersectors* This, /*!< this pointer to accel */ + RTCRay8& ray, /*!< ray packet to test occlusion. */ + IntersectContext* context); + + /*! Type of occlusion function pointer for ray packets of size 16. */ + typedef void (*OccludedFunc16) (const void* valid, /*!< pointer to valid mask */ + Intersectors* This, /*!< this pointer to accel */ + RTCRay16& ray, /*!< ray packet to test occlusion. */ + IntersectContext* context); + + /*! Type of intersect function pointer for ray packets of size N. */ + typedef void (*OccludedFuncN)(Intersectors* This, /*!< this pointer to accel */ + RTCRayN** ray, /*!< ray stream to test occlusion */ + const size_t N, /*!< number of rays in stream */ + IntersectContext* context /*!< layout flags */); + typedef void (*ErrorFunc) (); + + struct Collider + { + Collider (ErrorFunc error = nullptr) + : collide((CollideFunc)error), name(nullptr) {} + + Collider (CollideFunc collide, const char* name) + : collide(collide), name(name) {} + + operator bool() const { return name; } + + public: + CollideFunc collide; + const char* name; + }; + + struct Intersector1 + { + Intersector1 (ErrorFunc error = nullptr) + : intersect((IntersectFunc)error), occluded((OccludedFunc)error), name(nullptr) {} + + Intersector1 (IntersectFunc intersect, OccludedFunc occluded, const char* name) + : intersect(intersect), occluded(occluded), pointQuery(nullptr), name(name) {} + + Intersector1 (IntersectFunc intersect, OccludedFunc occluded, PointQueryFunc pointQuery, const char* name) + : intersect(intersect), occluded(occluded), pointQuery(pointQuery), name(name) {} + + operator bool() const { return name; } + + public: + static const char* type; + IntersectFunc intersect; + OccludedFunc occluded; + PointQueryFunc pointQuery; + const char* name; + }; + + struct Intersector4 + { + Intersector4 (ErrorFunc error = nullptr) + : intersect((IntersectFunc4)error), occluded((OccludedFunc4)error), name(nullptr) {} + + Intersector4 (IntersectFunc4 intersect, OccludedFunc4 occluded, const char* name) + : intersect(intersect), occluded(occluded), name(name) {} + + operator bool() const { return name; } + + public: + static const char* type; + IntersectFunc4 intersect; + OccludedFunc4 occluded; + const char* name; + }; + + struct Intersector8 + { + Intersector8 (ErrorFunc error = nullptr) + : intersect((IntersectFunc8)error), occluded((OccludedFunc8)error), name(nullptr) {} + + Intersector8 (IntersectFunc8 intersect, OccludedFunc8 occluded, const char* name) + : intersect(intersect), occluded(occluded), name(name) {} + + operator bool() const { return name; } + + public: + static const char* type; + IntersectFunc8 intersect; + OccludedFunc8 occluded; + const char* name; + }; + + struct Intersector16 + { + Intersector16 (ErrorFunc error = nullptr) + : intersect((IntersectFunc16)error), occluded((OccludedFunc16)error), name(nullptr) {} + + Intersector16 (IntersectFunc16 intersect, OccludedFunc16 occluded, const char* name) + : intersect(intersect), occluded(occluded), name(name) {} + + operator bool() const { return name; } + + public: + static const char* type; + IntersectFunc16 intersect; + OccludedFunc16 occluded; + const char* name; + }; + + struct IntersectorN + { + IntersectorN (ErrorFunc error = nullptr) + : intersect((IntersectFuncN)error), occluded((OccludedFuncN)error), name(nullptr) {} + + IntersectorN (IntersectFuncN intersect, OccludedFuncN occluded, const char* name) + : intersect(intersect), occluded(occluded), name(name) {} + + operator bool() const { return name; } + + public: + static const char* type; + IntersectFuncN intersect; + OccludedFuncN occluded; + const char* name; + }; + + struct Intersectors + { + Intersectors() + : ptr(nullptr), leafIntersector(nullptr), collider(nullptr), intersector1(nullptr), intersector4(nullptr), intersector8(nullptr), intersector16(nullptr), intersectorN(nullptr) {} + + Intersectors (ErrorFunc error) + : ptr(nullptr), leafIntersector(nullptr), collider(error), intersector1(error), intersector4(error), intersector8(error), intersector16(error), intersectorN(error) {} + + void print(size_t ident) + { + if (collider.name) { + for (size_t i=0; ibuild(); + bounds = accel->bounds; + } + + void deleteGeometry(size_t geomID) { + if (accel ) accel->deleteGeometry(geomID); + if (builder) builder->deleteGeometry(geomID); + } + + void clear() { + if (accel) accel->clear(); + if (builder) builder->clear(); + } + + private: + std::unique_ptr accel; + std::unique_ptr builder; + }; +} diff --git a/thirdparty/embree/kernels/common/acceln.cpp b/thirdparty/embree/kernels/common/acceln.cpp new file mode 100644 index 000000000000..c9f7e921932e --- /dev/null +++ b/thirdparty/embree/kernels/common/acceln.cpp @@ -0,0 +1,232 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "acceln.h" +#include "ray.h" +#include "../../include/embree3/rtcore_ray.h" +#include "../../common/algorithms/parallel_for.h" + +namespace embree +{ + AccelN::AccelN() + : Accel(AccelData::TY_ACCELN), accels() {} + + AccelN::~AccelN() + { + for (size_t i=0; iptr; + for (size_t i=0; iaccels.size(); i++) + if (!This->accels[i]->isEmpty()) + changed |= This->accels[i]->intersectors.pointQuery(query,context); + return changed; + } + + void AccelN::intersect (Accel::Intersectors* This_in, RTCRayHit& ray, IntersectContext* context) + { + AccelN* This = (AccelN*)This_in->ptr; + for (size_t i=0; iaccels.size(); i++) + if (!This->accels[i]->isEmpty()) + This->accels[i]->intersectors.intersect(ray,context); + } + + void AccelN::intersect4 (const void* valid, Accel::Intersectors* This_in, RTCRayHit4& ray, IntersectContext* context) + { + AccelN* This = (AccelN*)This_in->ptr; + for (size_t i=0; iaccels.size(); i++) + if (!This->accels[i]->isEmpty()) + This->accels[i]->intersectors.intersect4(valid,ray,context); + } + + void AccelN::intersect8 (const void* valid, Accel::Intersectors* This_in, RTCRayHit8& ray, IntersectContext* context) + { + AccelN* This = (AccelN*)This_in->ptr; + for (size_t i=0; iaccels.size(); i++) + if (!This->accels[i]->isEmpty()) + This->accels[i]->intersectors.intersect8(valid,ray,context); + } + + void AccelN::intersect16 (const void* valid, Accel::Intersectors* This_in, RTCRayHit16& ray, IntersectContext* context) + { + AccelN* This = (AccelN*)This_in->ptr; + for (size_t i=0; iaccels.size(); i++) + if (!This->accels[i]->isEmpty()) + This->accels[i]->intersectors.intersect16(valid,ray,context); + } + + void AccelN::intersectN (Accel::Intersectors* This_in, RTCRayHitN** ray, const size_t N, IntersectContext* context) + { + AccelN* This = (AccelN*)This_in->ptr; + for (size_t i=0; iaccels.size(); i++) + if (!This->accels[i]->isEmpty()) + This->accels[i]->intersectors.intersectN(ray,N,context); + } + + void AccelN::occluded (Accel::Intersectors* This_in, RTCRay& ray, IntersectContext* context) + { + AccelN* This = (AccelN*)This_in->ptr; + for (size_t i=0; iaccels.size(); i++) { + if (This->accels[i]->isEmpty()) continue; + This->accels[i]->intersectors.occluded(ray,context); + if (ray.tfar < 0.0f) break; + } + } + + void AccelN::occluded4 (const void* valid, Accel::Intersectors* This_in, RTCRay4& ray, IntersectContext* context) + { + AccelN* This = (AccelN*)This_in->ptr; + for (size_t i=0; iaccels.size(); i++) { + if (This->accels[i]->isEmpty()) continue; + This->accels[i]->intersectors.occluded4(valid,ray,context); +#if defined(__SSE2__) + vbool4 valid0 = asBool(((vint4*)valid)[0]); + vbool4 hit0 = ((vfloat4*)ray.tfar)[0] >= vfloat4(zero); + if (unlikely(none(valid0 & hit0))) break; +#endif + } + } + + void AccelN::occluded8 (const void* valid, Accel::Intersectors* This_in, RTCRay8& ray, IntersectContext* context) + { + AccelN* This = (AccelN*)This_in->ptr; + for (size_t i=0; iaccels.size(); i++) { + if (This->accels[i]->isEmpty()) continue; + This->accels[i]->intersectors.occluded8(valid,ray,context); +#if defined(__SSE2__) // FIXME: use higher ISA + vbool4 valid0 = asBool(((vint4*)valid)[0]); + vbool4 hit0 = ((vfloat4*)ray.tfar)[0] >= vfloat4(zero); + vbool4 valid1 = asBool(((vint4*)valid)[1]); + vbool4 hit1 = ((vfloat4*)ray.tfar)[1] >= vfloat4(zero); + if (unlikely((none((valid0 & hit0) | (valid1 & hit1))))) break; +#endif + } + } + + void AccelN::occluded16 (const void* valid, Accel::Intersectors* This_in, RTCRay16& ray, IntersectContext* context) + { + AccelN* This = (AccelN*)This_in->ptr; + for (size_t i=0; iaccels.size(); i++) { + if (This->accels[i]->isEmpty()) continue; + This->accels[i]->intersectors.occluded16(valid,ray,context); +#if defined(__SSE2__) // FIXME: use higher ISA + vbool4 valid0 = asBool(((vint4*)valid)[0]); + vbool4 hit0 = ((vfloat4*)ray.tfar)[0] >= vfloat4(zero); + vbool4 valid1 = asBool(((vint4*)valid)[1]); + vbool4 hit1 = ((vfloat4*)ray.tfar)[1] >= vfloat4(zero); + vbool4 valid2 = asBool(((vint4*)valid)[2]); + vbool4 hit2 = ((vfloat4*)ray.tfar)[2] >= vfloat4(zero); + vbool4 valid3 = asBool(((vint4*)valid)[3]); + vbool4 hit3 = ((vfloat4*)ray.tfar)[3] >= vfloat4(zero); + if (unlikely((none((valid0 & hit0) | (valid1 & hit1) | (valid2 & hit2) | (valid3 & hit3))))) break; +#endif + } + } + + void AccelN::occludedN (Accel::Intersectors* This_in, RTCRayN** ray, const size_t N, IntersectContext* context) + { + AccelN* This = (AccelN*)This_in->ptr; + size_t M = N; + for (size_t i=0; iaccels.size(); i++) + if (!This->accels[i]->isEmpty()) + This->accels[i]->intersectors.occludedN(ray,M,context); + } + + void AccelN::accels_print(size_t ident) + { + for (size_t i=0; iintersectors.print(ident+2); + } + } + + void AccelN::accels_immutable() + { + for (size_t i=0; iimmutable(); + } + + void AccelN::accels_build () + { + /* reduce memory consumption */ + accels.shrink_to_fit(); + + /* build all acceleration structures in parallel */ + parallel_for (accels.size(), [&] (size_t i) { + accels[i]->build(); + }); + + /* create list of non-empty acceleration structures */ + bool valid1 = true; + bool valid4 = true; + bool valid8 = true; + bool valid16 = true; + for (size_t i=0; iintersectors.intersector1; + valid4 &= (bool) accels[i]->intersectors.intersector4; + valid8 &= (bool) accels[i]->intersectors.intersector8; + valid16 &= (bool) accels[i]->intersectors.intersector16; + } + + if (accels.size() == 1) { + type = accels[0]->type; // FIXME: should just assign entire Accel + bounds = accels[0]->bounds; + intersectors = accels[0]->intersectors; + } + else + { + type = AccelData::TY_ACCELN; + intersectors.ptr = this; + intersectors.intersector1 = Intersector1(&intersect,&occluded,&pointQuery,valid1 ? "AccelN::intersector1": nullptr); + intersectors.intersector4 = Intersector4(&intersect4,&occluded4,valid4 ? "AccelN::intersector4" : nullptr); + intersectors.intersector8 = Intersector8(&intersect8,&occluded8,valid8 ? "AccelN::intersector8" : nullptr); + intersectors.intersector16 = Intersector16(&intersect16,&occluded16,valid16 ? "AccelN::intersector16": nullptr); + intersectors.intersectorN = IntersectorN(&intersectN,&occludedN,"AccelN::intersectorN"); + + /*! calculate bounds */ + bounds = empty; + for (size_t i=0; ibounds); + } + } + + void AccelN::accels_select(bool filter) + { + for (size_t i=0; iintersectors.select(filter); + } + + void AccelN::accels_deleteGeometry(size_t geomID) + { + for (size_t i=0; ideleteGeometry(geomID); + } + + void AccelN::accels_clear() + { + for (size_t i=0; iclear(); + } + } +} + diff --git a/thirdparty/embree/kernels/common/acceln.h b/thirdparty/embree/kernels/common/acceln.h new file mode 100644 index 000000000000..2edd98f647dd --- /dev/null +++ b/thirdparty/embree/kernels/common/acceln.h @@ -0,0 +1,49 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "accel.h" + +namespace embree +{ + /*! merges N acceleration structures together, by processing them in order */ + class AccelN : public Accel + { + public: + AccelN (); + ~AccelN(); + + public: + void accels_add(Accel* accel); + void accels_init(); + + public: + static bool pointQuery (Accel::Intersectors* This, PointQuery* query, PointQueryContext* context); + + public: + static void intersect (Accel::Intersectors* This, RTCRayHit& ray, IntersectContext* context); + static void intersect4 (const void* valid, Accel::Intersectors* This, RTCRayHit4& ray, IntersectContext* context); + static void intersect8 (const void* valid, Accel::Intersectors* This, RTCRayHit8& ray, IntersectContext* context); + static void intersect16 (const void* valid, Accel::Intersectors* This, RTCRayHit16& ray, IntersectContext* context); + static void intersectN (Accel::Intersectors* This, RTCRayHitN** ray, const size_t N, IntersectContext* context); + + public: + static void occluded (Accel::Intersectors* This, RTCRay& ray, IntersectContext* context); + static void occluded4 (const void* valid, Accel::Intersectors* This, RTCRay4& ray, IntersectContext* context); + static void occluded8 (const void* valid, Accel::Intersectors* This, RTCRay8& ray, IntersectContext* context); + static void occluded16 (const void* valid, Accel::Intersectors* This, RTCRay16& ray, IntersectContext* context); + static void occludedN (Accel::Intersectors* This, RTCRayN** ray, const size_t N, IntersectContext* context); + + public: + void accels_print(size_t ident); + void accels_immutable(); + void accels_build (); + void accels_select(bool filter); + void accels_deleteGeometry(size_t geomID); + void accels_clear (); + + public: + std::vector accels; + }; +} diff --git a/thirdparty/embree/kernels/common/accelset.cpp b/thirdparty/embree/kernels/common/accelset.cpp new file mode 100644 index 000000000000..79be1c4301bb --- /dev/null +++ b/thirdparty/embree/kernels/common/accelset.cpp @@ -0,0 +1,17 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "accelset.h" +#include "scene.h" + +namespace embree +{ + AccelSet::AccelSet (Device* device, Geometry::GType gtype, size_t numItems, size_t numTimeSteps) + : Geometry(device,gtype,(unsigned int)numItems,(unsigned int)numTimeSteps), boundsFunc(nullptr) {} + + AccelSet::IntersectorN::IntersectorN (ErrorFunc error) + : intersect((IntersectFuncN)error), occluded((OccludedFuncN)error), name(nullptr) {} + + AccelSet::IntersectorN::IntersectorN (IntersectFuncN intersect, OccludedFuncN occluded, const char* name) + : intersect(intersect), occluded(occluded), name(name) {} +} diff --git a/thirdparty/embree/kernels/common/accelset.h b/thirdparty/embree/kernels/common/accelset.h new file mode 100644 index 000000000000..3774b2accb08 --- /dev/null +++ b/thirdparty/embree/kernels/common/accelset.h @@ -0,0 +1,248 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" +#include "builder.h" +#include "geometry.h" +#include "ray.h" +#include "hit.h" + +namespace embree +{ + struct IntersectFunctionNArguments; + struct OccludedFunctionNArguments; + + typedef void (*ReportIntersectionFunc) (IntersectFunctionNArguments* args, const RTCFilterFunctionNArguments* filter_args); + typedef void (*ReportOcclusionFunc) (OccludedFunctionNArguments* args, const RTCFilterFunctionNArguments* filter_args); + + struct IntersectFunctionNArguments : public RTCIntersectFunctionNArguments + { + IntersectContext* internal_context; + Geometry* geometry; + ReportIntersectionFunc report; + }; + + struct OccludedFunctionNArguments : public RTCOccludedFunctionNArguments + { + IntersectContext* internal_context; + Geometry* geometry; + ReportOcclusionFunc report; + }; + + /*! Base class for set of acceleration structures. */ + class AccelSet : public Geometry + { + public: + typedef RTCIntersectFunctionN IntersectFuncN; + typedef RTCOccludedFunctionN OccludedFuncN; + typedef void (*ErrorFunc) (); + + struct IntersectorN + { + IntersectorN (ErrorFunc error = nullptr) ; + IntersectorN (IntersectFuncN intersect, OccludedFuncN occluded, const char* name); + + operator bool() const { return name; } + + public: + static const char* type; + IntersectFuncN intersect; + OccludedFuncN occluded; + const char* name; + }; + + public: + + /*! construction */ + AccelSet (Device* device, Geometry::GType gtype, size_t items, size_t numTimeSteps); + + /*! makes the acceleration structure immutable */ + virtual void immutable () {} + + /*! build accel */ + virtual void build () = 0; + + /*! check if the i'th primitive is valid between the specified time range */ + __forceinline bool valid(size_t i, const range& itime_range) const + { + for (size_t itime = itime_range.begin(); itime <= itime_range.end(); itime++) + if (!isvalid_non_empty(bounds(i,itime))) return false; + + return true; + } + + /*! Calculates the bounds of an item */ + __forceinline BBox3fa bounds(size_t i, size_t itime = 0) const + { + BBox3fa box; + assert(i < size()); + RTCBoundsFunctionArguments args; + args.geometryUserPtr = userPtr; + args.primID = (unsigned int)i; + args.timeStep = (unsigned int)itime; + args.bounds_o = (RTCBounds*)&box; + boundsFunc(&args); + return box; + } + + /*! calculates the linear bounds of the i'th item at the itime'th time segment */ + __forceinline LBBox3fa linearBounds(size_t i, size_t itime) const + { + BBox3fa box[2]; + assert(i < size()); + RTCBoundsFunctionArguments args; + args.geometryUserPtr = userPtr; + args.primID = (unsigned int)i; + args.timeStep = (unsigned int)(itime+0); + args.bounds_o = (RTCBounds*)&box[0]; + boundsFunc(&args); + args.timeStep = (unsigned int)(itime+1); + args.bounds_o = (RTCBounds*)&box[1]; + boundsFunc(&args); + return LBBox3fa(box[0],box[1]); + } + + /*! calculates the build bounds of the i'th item, if it's valid */ + __forceinline bool buildBounds(size_t i, BBox3fa* bbox = nullptr) const + { + const BBox3fa b = bounds(i); + if (bbox) *bbox = b; + return isvalid_non_empty(b); + } + + /*! calculates the build bounds of the i'th item at the itime'th time segment, if it's valid */ + __forceinline bool buildBounds(size_t i, size_t itime, BBox3fa& bbox) const + { + const LBBox3fa bounds = linearBounds(i,itime); + bbox = bounds.bounds0; // use bounding box of first timestep to build BVH + return isvalid_non_empty(bounds); + } + + /*! calculates the linear bounds of the i'th primitive for the specified time range */ + __forceinline LBBox3fa linearBounds(size_t primID, const BBox1f& dt) const { + return LBBox3fa([&] (size_t itime) { return bounds(primID, itime); }, dt, time_range, fnumTimeSegments); + } + + /*! calculates the linear bounds of the i'th primitive for the specified time range */ + __forceinline bool linearBounds(size_t i, const BBox1f& time_range, LBBox3fa& bbox) const { + if (!valid(i, timeSegmentRange(time_range))) return false; + bbox = linearBounds(i, time_range); + return true; + } + + /* gets version info of topology */ + unsigned int getTopologyVersion() const { + return numPrimitives; + } + + /* returns true if topology changed */ + bool topologyChanged(unsigned int otherVersion) const { + return numPrimitives != otherVersion; + } + + public: + + /*! Intersects a single ray with the scene. */ + __forceinline void intersect (RayHit& ray, unsigned int geomID, unsigned int primID, IntersectContext* context, ReportIntersectionFunc report) + { + assert(primID < size()); + assert(intersectorN.intersect); + + int mask = -1; + IntersectFunctionNArguments args; + args.valid = &mask; + args.geometryUserPtr = userPtr; + args.context = context->user; + args.rayhit = (RTCRayHitN*)&ray; + args.N = 1; + args.geomID = geomID; + args.primID = primID; + args.internal_context = context; + args.geometry = this; + args.report = report; + + intersectorN.intersect(&args); + } + + /*! Tests if single ray is occluded by the scene. */ + __forceinline void occluded (Ray& ray, unsigned int geomID, unsigned int primID, IntersectContext* context, ReportOcclusionFunc report) + { + assert(primID < size()); + assert(intersectorN.occluded); + + int mask = -1; + OccludedFunctionNArguments args; + args.valid = &mask; + args.geometryUserPtr = userPtr; + args.context = context->user; + args.ray = (RTCRayN*)&ray; + args.N = 1; + args.geomID = geomID; + args.primID = primID; + args.internal_context = context; + args.geometry = this; + args.report = report; + + intersectorN.occluded(&args); + } + + /*! Intersects a packet of K rays with the scene. */ + template + __forceinline void intersect (const vbool& valid, RayHitK& ray, unsigned int geomID, unsigned int primID, IntersectContext* context, ReportIntersectionFunc report) + { + assert(primID < size()); + assert(intersectorN.intersect); + + vint mask = valid.mask32(); + IntersectFunctionNArguments args; + args.valid = (int*)&mask; + args.geometryUserPtr = userPtr; + args.context = context->user; + args.rayhit = (RTCRayHitN*)&ray; + args.N = K; + args.geomID = geomID; + args.primID = primID; + args.internal_context = context; + args.geometry = this; + args.report = report; + + intersectorN.intersect(&args); + } + + /*! Tests if a packet of K rays is occluded by the scene. */ + template + __forceinline void occluded (const vbool& valid, RayK& ray, unsigned int geomID, unsigned int primID, IntersectContext* context, ReportOcclusionFunc report) + { + assert(primID < size()); + assert(intersectorN.occluded); + + vint mask = valid.mask32(); + OccludedFunctionNArguments args; + args.valid = (int*)&mask; + args.geometryUserPtr = userPtr; + args.context = context->user; + args.ray = (RTCRayN*)&ray; + args.N = K; + args.geomID = geomID; + args.primID = primID; + args.internal_context = context; + args.geometry = this; + args.report = report; + + intersectorN.occluded(&args); + } + + public: + RTCBoundsFunction boundsFunc; + IntersectorN intersectorN; + }; + +#define DEFINE_SET_INTERSECTORN(symbol,intersector) \ + AccelSet::IntersectorN symbol() { \ + return AccelSet::IntersectorN(intersector::intersect, \ + intersector::occluded, \ + TOSTRING(isa) "::" TOSTRING(symbol)); \ + } +} diff --git a/thirdparty/embree/kernels/common/alloc.cpp b/thirdparty/embree/kernels/common/alloc.cpp new file mode 100644 index 000000000000..f958a16f5653 --- /dev/null +++ b/thirdparty/embree/kernels/common/alloc.cpp @@ -0,0 +1,79 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "alloc.h" +#include "../../common/sys/thread.h" + +namespace embree +{ + __thread FastAllocator::ThreadLocal2* FastAllocator::thread_local_allocator2 = nullptr; + SpinLock FastAllocator::s_thread_local_allocators_lock; + std::vector> FastAllocator::s_thread_local_allocators; + + struct fast_allocator_regression_test : public RegressionTest + { + BarrierSys barrier; + std::atomic numFailed; + std::unique_ptr alloc; + + fast_allocator_regression_test() + : RegressionTest("fast_allocator_regression_test"), numFailed(0) + { + registerRegressionTest(this); + } + + static void thread_alloc(fast_allocator_regression_test* This) + { + FastAllocator::CachedAllocator threadalloc = This->alloc->getCachedAllocator(); + + size_t* ptrs[1000]; + for (size_t j=0; j<1000; j++) + { + This->barrier.wait(); + for (size_t i=0; i<1000; i++) { + ptrs[i] = (size_t*) threadalloc.malloc0(sizeof(size_t)+(i%32)); + *ptrs[i] = size_t(threadalloc.talloc0) + i; + } + for (size_t i=0; i<1000; i++) { + if (*ptrs[i] != size_t(threadalloc.talloc0) + i) + This->numFailed++; + } + This->barrier.wait(); + } + } + + bool run () + { + alloc = make_unique(new FastAllocator(nullptr,false)); + numFailed.store(0); + + size_t numThreads = getNumberOfLogicalThreads(); + barrier.init(numThreads+1); + + /* create threads */ + std::vector threads; + for (size_t i=0; ireset(); + barrier.wait(); + barrier.wait(); + } + + /* destroy threads */ + for (size_t i=0; idefaultBlockSize; + } + + /* Allocate aligned memory from the threads memory block. */ + __forceinline void* malloc(FastAllocator* alloc, size_t bytes, size_t align = 16) + { + /* bind the thread local allocator to the proper FastAllocator*/ + parent->bind(alloc); + + assert(align <= maxAlignment); + bytesUsed += bytes; + + /* try to allocate in local block */ + size_t ofs = (align - cur) & (align-1); + cur += bytes + ofs; + if (likely(cur <= end)) { bytesWasted += ofs; return &ptr[cur - bytes]; } + cur -= bytes + ofs; + + /* if allocation is too large allocate with parent allocator */ + if (4*bytes > allocBlockSize) { + return alloc->malloc(bytes,maxAlignment,false); + } + + /* get new partial block if allocation failed */ + size_t blockSize = allocBlockSize; + ptr = (char*) alloc->malloc(blockSize,maxAlignment,true); + bytesWasted += end-cur; + cur = 0; end = blockSize; + + /* retry allocation */ + ofs = (align - cur) & (align-1); + cur += bytes + ofs; + if (likely(cur <= end)) { bytesWasted += ofs; return &ptr[cur - bytes]; } + cur -= bytes + ofs; + + /* get new full block if allocation failed */ + blockSize = allocBlockSize; + ptr = (char*) alloc->malloc(blockSize,maxAlignment,false); + bytesWasted += end-cur; + cur = 0; end = blockSize; + + /* retry allocation */ + ofs = (align - cur) & (align-1); + cur += bytes + ofs; + if (likely(cur <= end)) { bytesWasted += ofs; return &ptr[cur - bytes]; } + cur -= bytes + ofs; + + /* should never happen as large allocations get handled specially above */ + assert(false); + return nullptr; + } + + + /*! returns amount of used bytes */ + __forceinline size_t getUsedBytes() const { return bytesUsed; } + + /*! returns amount of free bytes */ + __forceinline size_t getFreeBytes() const { return end-cur; } + + /*! returns amount of wasted bytes */ + __forceinline size_t getWastedBytes() const { return bytesWasted; } + + private: + ThreadLocal2* parent; + char* ptr; //!< pointer to memory block + size_t cur; //!< current location of the allocator + size_t end; //!< end of the memory block + size_t allocBlockSize; //!< block size for allocations + size_t bytesUsed; //!< number of total bytes allocated + size_t bytesWasted; //!< number of bytes wasted + }; + + /*! Two thread local structures. */ + struct __aligned(64) ThreadLocal2 + { + ALIGNED_CLASS_(64); + public: + + __forceinline ThreadLocal2() + : alloc(nullptr), alloc0(this), alloc1(this) {} + + /*! bind to fast allocator */ + __forceinline void bind(FastAllocator* alloc_i) + { + assert(alloc_i); + if (alloc.load() == alloc_i) return; + Lock lock(mutex); + //if (alloc.load() == alloc_i) return; // not required as only one thread calls bind + if (alloc.load()) { + alloc.load()->bytesUsed += alloc0.getUsedBytes() + alloc1.getUsedBytes(); + alloc.load()->bytesFree += alloc0.getFreeBytes() + alloc1.getFreeBytes(); + alloc.load()->bytesWasted += alloc0.getWastedBytes() + alloc1.getWastedBytes(); + } + alloc0.init(alloc_i); + alloc1.init(alloc_i); + alloc.store(alloc_i); + alloc_i->join(this); + } + + /*! unbind to fast allocator */ + void unbind(FastAllocator* alloc_i) + { + assert(alloc_i); + if (alloc.load() != alloc_i) return; + Lock lock(mutex); + if (alloc.load() != alloc_i) return; // required as a different thread calls unbind + alloc.load()->bytesUsed += alloc0.getUsedBytes() + alloc1.getUsedBytes(); + alloc.load()->bytesFree += alloc0.getFreeBytes() + alloc1.getFreeBytes(); + alloc.load()->bytesWasted += alloc0.getWastedBytes() + alloc1.getWastedBytes(); + alloc0.init(nullptr); + alloc1.init(nullptr); + alloc.store(nullptr); + } + + public: + SpinLock mutex; //!< required as unbind is called from other threads + std::atomic alloc; //!< parent allocator + ThreadLocal alloc0; + ThreadLocal alloc1; + }; + + FastAllocator (Device* device, bool osAllocation) + : device(device), slotMask(0), usedBlocks(nullptr), freeBlocks(nullptr), use_single_mode(false), defaultBlockSize(PAGE_SIZE), estimatedSize(0), + growSize(PAGE_SIZE), maxGrowSize(maxAllocationSize), log2_grow_size_scale(0), bytesUsed(0), bytesFree(0), bytesWasted(0), atype(osAllocation ? OS_MALLOC : ALIGNED_MALLOC), + primrefarray(device,0) + { + for (size_t i=0; i& primrefarray_i) { + primrefarray = std::move(primrefarray_i); + } + + void unshare(mvector& primrefarray_o) + { + reset(); // this removes blocks that are allocated inside the shared primref array + primrefarray_o = std::move(primrefarray); + } + + /*! returns first fast thread local allocator */ + __forceinline ThreadLocal* _threadLocal() { + return &threadLocal2()->alloc0; + } + + void setOSallocation(bool flag) + { + atype = flag ? OS_MALLOC : ALIGNED_MALLOC; + } + + private: + + /*! returns both fast thread local allocators */ + __forceinline ThreadLocal2* threadLocal2() + { + ThreadLocal2* alloc = thread_local_allocator2; + if (alloc == nullptr) { + thread_local_allocator2 = alloc = new ThreadLocal2; + Lock lock(s_thread_local_allocators_lock); + s_thread_local_allocators.push_back(make_unique(alloc)); + } + return alloc; + } + + public: + + __forceinline void join(ThreadLocal2* alloc) + { + Lock lock(thread_local_allocators_lock); + thread_local_allocators.push_back(alloc); + } + + public: + + struct CachedAllocator + { + __forceinline CachedAllocator(void* ptr) + : alloc(nullptr), talloc0(nullptr), talloc1(nullptr) + { + assert(ptr == nullptr); + } + + __forceinline CachedAllocator(FastAllocator* alloc, ThreadLocal2* talloc) + : alloc(alloc), talloc0(&talloc->alloc0), talloc1(alloc->use_single_mode ? &talloc->alloc0 : &talloc->alloc1) {} + + __forceinline operator bool () const { + return alloc != nullptr; + } + + __forceinline void* operator() (size_t bytes, size_t align = 16) const { + return talloc0->malloc(alloc,bytes,align); + } + + __forceinline void* malloc0 (size_t bytes, size_t align = 16) const { + return talloc0->malloc(alloc,bytes,align); + } + + __forceinline void* malloc1 (size_t bytes, size_t align = 16) const { + return talloc1->malloc(alloc,bytes,align); + } + + public: + FastAllocator* alloc; + ThreadLocal* talloc0; + ThreadLocal* talloc1; + }; + + __forceinline CachedAllocator getCachedAllocator() { + return CachedAllocator(this,threadLocal2()); + } + + /*! Builder interface to create thread local allocator */ + struct Create + { + public: + __forceinline Create (FastAllocator* allocator) : allocator(allocator) {} + __forceinline CachedAllocator operator() () const { return allocator->getCachedAllocator(); } + + private: + FastAllocator* allocator; + }; + + void internal_fix_used_blocks() + { + /* move thread local blocks to global block list */ + for (size_t i = 0; i < MAX_THREAD_USED_BLOCK_SLOTS; i++) + { + while (threadBlocks[i].load() != nullptr) { + Block* nextUsedBlock = threadBlocks[i].load()->next; + threadBlocks[i].load()->next = usedBlocks.load(); + usedBlocks = threadBlocks[i].load(); + threadBlocks[i] = nextUsedBlock; + } + threadBlocks[i] = nullptr; + } + } + + static const size_t threadLocalAllocOverhead = 20; //! 20 means 5% parallel allocation overhead through unfilled thread local blocks +#if defined(__AVX512ER__) // KNL + static const size_t mainAllocOverheadStatic = 15; //! 15 means 7.5% allocation overhead through unfilled main alloc blocks +#else + static const size_t mainAllocOverheadStatic = 20; //! 20 means 5% allocation overhead through unfilled main alloc blocks +#endif + static const size_t mainAllocOverheadDynamic = 8; //! 20 means 12.5% allocation overhead through unfilled main alloc blocks + + /* calculates a single threaded threshold for the builders such + * that for small scenes the overhead of partly allocated blocks + * per thread is low */ + size_t fixSingleThreadThreshold(size_t branchingFactor, size_t defaultThreshold, size_t numPrimitives, size_t bytesEstimated) + { + if (numPrimitives == 0 || bytesEstimated == 0) + return defaultThreshold; + + /* calculate block size in bytes to fulfill threadLocalAllocOverhead constraint */ + const size_t single_mode_factor = use_single_mode ? 1 : 2; + const size_t threadCount = TaskScheduler::threadCount(); + const size_t singleThreadBytes = single_mode_factor*threadLocalAllocOverhead*defaultBlockSize; + + /* if we do not have to limit number of threads use optimal thresdhold */ + if ( (bytesEstimated+(singleThreadBytes-1))/singleThreadBytes >= threadCount) + return defaultThreshold; + + /* otherwise limit number of threads by calculating proper single thread threshold */ + else { + double bytesPerPrimitive = double(bytesEstimated)/double(numPrimitives); + return size_t(ceil(branchingFactor*singleThreadBytes/bytesPerPrimitive)); + } + } + + __forceinline size_t alignSize(size_t i) { + return (i+127)/128*128; + } + + /*! initializes the grow size */ + __forceinline void initGrowSizeAndNumSlots(size_t bytesEstimated, bool fast) + { + /* we do not need single thread local allocator mode */ + use_single_mode = false; + + /* calculate growSize such that at most mainAllocationOverhead gets wasted when a block stays unused */ + size_t mainAllocOverhead = fast ? mainAllocOverheadDynamic : mainAllocOverheadStatic; + size_t blockSize = alignSize(bytesEstimated/mainAllocOverhead); + growSize = maxGrowSize = clamp(blockSize,size_t(1024),maxAllocationSize); + + /* if we reached the maxAllocationSize for growSize, we can + * increase the number of allocation slots by still guaranteeing + * the mainAllocationOverhead */ + slotMask = 0x0; + + if (MAX_THREAD_USED_BLOCK_SLOTS >= 2 && bytesEstimated > 2*mainAllocOverhead*growSize) slotMask = 0x1; + if (MAX_THREAD_USED_BLOCK_SLOTS >= 4 && bytesEstimated > 4*mainAllocOverhead*growSize) slotMask = 0x3; + if (MAX_THREAD_USED_BLOCK_SLOTS >= 8 && bytesEstimated > 8*mainAllocOverhead*growSize) slotMask = 0x7; + if (MAX_THREAD_USED_BLOCK_SLOTS >= 8 && bytesEstimated > 16*mainAllocOverhead*growSize) { growSize *= 2; } /* if the overhead is tiny, double the growSize */ + + /* set the thread local alloc block size */ + size_t defaultBlockSizeSwitch = PAGE_SIZE+maxAlignment; + + /* for sufficiently large scene we can increase the defaultBlockSize over the defaultBlockSizeSwitch size */ +#if 0 // we do not do this as a block size of 4160 if for some reason best for KNL + const size_t threadCount = TaskScheduler::threadCount(); + const size_t single_mode_factor = use_single_mode ? 1 : 2; + const size_t singleThreadBytes = single_mode_factor*threadLocalAllocOverhead*defaultBlockSizeSwitch; + if (bytesEstimated+(singleThreadBytes-1))/singleThreadBytes >= threadCount) + defaultBlockSize = min(max(defaultBlockSizeSwitch,bytesEstimated/(single_mode_factor*threadLocalAllocOverhead*threadCount)),growSize); + + /* otherwise we grow the defaultBlockSize up to defaultBlockSizeSwitch */ + else +#endif + defaultBlockSize = clamp(blockSize,size_t(1024),defaultBlockSizeSwitch); + + if (bytesEstimated == 0) { + maxGrowSize = maxAllocationSize; // special mode if builder cannot estimate tree size + defaultBlockSize = defaultBlockSizeSwitch; + } + log2_grow_size_scale = 0; + + if (device->alloc_main_block_size != 0) growSize = device->alloc_main_block_size; + if (device->alloc_num_main_slots >= 1 ) slotMask = 0x0; + if (device->alloc_num_main_slots >= 2 ) slotMask = 0x1; + if (device->alloc_num_main_slots >= 4 ) slotMask = 0x3; + if (device->alloc_num_main_slots >= 8 ) slotMask = 0x7; + if (device->alloc_thread_block_size != 0) defaultBlockSize = device->alloc_thread_block_size; + if (device->alloc_single_thread_alloc != -1) use_single_mode = device->alloc_single_thread_alloc; + } + + /*! initializes the allocator */ + void init(size_t bytesAllocate, size_t bytesReserve, size_t bytesEstimate) + { + internal_fix_used_blocks(); + /* distribute the allocation to multiple thread block slots */ + slotMask = MAX_THREAD_USED_BLOCK_SLOTS-1; // FIXME: remove + if (usedBlocks.load() || freeBlocks.load()) { reset(); return; } + if (bytesReserve == 0) bytesReserve = bytesAllocate; + freeBlocks = Block::create(device,bytesAllocate,bytesReserve,nullptr,atype); + estimatedSize = bytesEstimate; + initGrowSizeAndNumSlots(bytesEstimate,true); + } + + /*! initializes the allocator */ + void init_estimate(size_t bytesEstimate) + { + internal_fix_used_blocks(); + if (usedBlocks.load() || freeBlocks.load()) { reset(); return; } + /* single allocator mode ? */ + estimatedSize = bytesEstimate; + //initGrowSizeAndNumSlots(bytesEstimate,false); + initGrowSizeAndNumSlots(bytesEstimate,false); + + } + + /*! frees state not required after build */ + __forceinline void cleanup() + { + internal_fix_used_blocks(); + + /* unbind all thread local allocators */ + for (auto alloc : thread_local_allocators) alloc->unbind(this); + thread_local_allocators.clear(); + } + + /*! resets the allocator, memory blocks get reused */ + void reset () + { + internal_fix_used_blocks(); + + bytesUsed.store(0); + bytesFree.store(0); + bytesWasted.store(0); + + /* reset all used blocks and move them to begin of free block list */ + while (usedBlocks.load() != nullptr) { + usedBlocks.load()->reset_block(); + Block* nextUsedBlock = usedBlocks.load()->next; + usedBlocks.load()->next = freeBlocks.load(); + freeBlocks = usedBlocks.load(); + usedBlocks = nextUsedBlock; + } + + /* remove all shared blocks as they are re-added during build */ + freeBlocks.store(Block::remove_shared_blocks(freeBlocks.load())); + + for (size_t i=0; iunbind(this); + thread_local_allocators.clear(); + } + + /*! frees all allocated memory */ + __forceinline void clear() + { + cleanup(); + bytesUsed.store(0); + bytesFree.store(0); + bytesWasted.store(0); + if (usedBlocks.load() != nullptr) usedBlocks.load()->clear_list(device); usedBlocks = nullptr; + if (freeBlocks.load() != nullptr) freeBlocks.load()->clear_list(device); freeBlocks = nullptr; + for (size_t i=0; imalloc(device,bytes,align,partial); + if (ptr) return ptr; + } + + /* throw error if allocation is too large */ + if (bytes > maxAllocationSize) + throw_RTCError(RTC_ERROR_UNKNOWN,"allocation is too large"); + + /* parallel block creation in case of no freeBlocks, avoids single global mutex */ + if (likely(freeBlocks.load() == nullptr)) + { + Lock lock(slotMutex[slot]); + if (myUsedBlocks == threadUsedBlocks[slot]) { + const size_t alignedBytes = (bytes+(align-1)) & ~(align-1); + const size_t allocSize = max(min(growSize,maxGrowSize),alignedBytes); + assert(allocSize >= bytes); + threadBlocks[slot] = threadUsedBlocks[slot] = Block::create(device,allocSize,allocSize,threadBlocks[slot],atype); // FIXME: a large allocation might throw away a block here! + // FIXME: a direct allocation should allocate inside the block here, and not in the next loop! a different thread could do some allocation and make the large allocation fail. + } + continue; + } + + /* if this fails allocate new block */ + { + Lock lock(mutex); + if (myUsedBlocks == threadUsedBlocks[slot]) + { + if (freeBlocks.load() != nullptr) { + Block* nextFreeBlock = freeBlocks.load()->next; + freeBlocks.load()->next = usedBlocks; + __memory_barrier(); + usedBlocks = freeBlocks.load(); + threadUsedBlocks[slot] = freeBlocks.load(); + freeBlocks = nextFreeBlock; + } else { + const size_t allocSize = min(growSize*incGrowSizeScale(),maxGrowSize); + usedBlocks = threadUsedBlocks[slot] = Block::create(device,allocSize,allocSize,usedBlocks,atype); // FIXME: a large allocation should get delivered directly, like above! + } + } + } + } + } + + /*! add new block */ + void addBlock(void* ptr, ssize_t bytes) + { + Lock lock(mutex); + const size_t sizeof_Header = offsetof(Block,data[0]); + void* aptr = (void*) ((((size_t)ptr)+maxAlignment-1) & ~(maxAlignment-1)); + size_t ofs = (size_t) aptr - (size_t) ptr; + bytes -= ofs; + if (bytes < 4096) return; // ignore empty or very small blocks + freeBlocks = new (aptr) Block(SHARED,bytes-sizeof_Header,bytes-sizeof_Header,freeBlocks,ofs); + } + + /* special allocation only used from morton builder only a single time for each build */ + void* specialAlloc(size_t bytes) + { + assert(freeBlocks.load() != nullptr && freeBlocks.load()->getBlockAllocatedBytes() >= bytes); + return freeBlocks.load()->ptr(); + } + + struct Statistics + { + Statistics () + : bytesUsed(0), bytesFree(0), bytesWasted(0) {} + + Statistics (size_t bytesUsed, size_t bytesFree, size_t bytesWasted) + : bytesUsed(bytesUsed), bytesFree(bytesFree), bytesWasted(bytesWasted) {} + + Statistics (FastAllocator* alloc, AllocationType atype, bool huge_pages = false) + : bytesUsed(0), bytesFree(0), bytesWasted(0) + { + Block* usedBlocks = alloc->usedBlocks.load(); + Block* freeBlocks = alloc->freeBlocks.load(); + if (usedBlocks) bytesUsed += usedBlocks->getUsedBytes(atype,huge_pages); + if (freeBlocks) bytesFree += freeBlocks->getAllocatedBytes(atype,huge_pages); + if (usedBlocks) bytesFree += usedBlocks->getFreeBytes(atype,huge_pages); + if (freeBlocks) bytesWasted += freeBlocks->getWastedBytes(atype,huge_pages); + if (usedBlocks) bytesWasted += usedBlocks->getWastedBytes(atype,huge_pages); + } + + std::string str(size_t numPrimitives) + { + std::stringstream str; + str.setf(std::ios::fixed, std::ios::floatfield); + str << "used = " << std::setw(7) << std::setprecision(3) << 1E-6f*bytesUsed << " MB, " + << "free = " << std::setw(7) << std::setprecision(3) << 1E-6f*bytesFree << " MB, " + << "wasted = " << std::setw(7) << std::setprecision(3) << 1E-6f*bytesWasted << " MB, " + << "total = " << std::setw(7) << std::setprecision(3) << 1E-6f*bytesAllocatedTotal() << " MB, " + << "#bytes/prim = " << std::setw(6) << std::setprecision(2) << double(bytesAllocatedTotal())/double(numPrimitives); + return str.str(); + } + + friend Statistics operator+ ( const Statistics& a, const Statistics& b) + { + return Statistics(a.bytesUsed+b.bytesUsed, + a.bytesFree+b.bytesFree, + a.bytesWasted+b.bytesWasted); + } + + size_t bytesAllocatedTotal() const { + return bytesUsed + bytesFree + bytesWasted; + } + + public: + size_t bytesUsed; + size_t bytesFree; + size_t bytesWasted; + }; + + Statistics getStatistics(AllocationType atype, bool huge_pages = false) { + return Statistics(this,atype,huge_pages); + } + + size_t getUsedBytes() { + return bytesUsed; + } + + size_t getWastedBytes() { + return bytesWasted; + } + + struct AllStatistics + { + AllStatistics (FastAllocator* alloc) + + : bytesUsed(alloc->bytesUsed), + bytesFree(alloc->bytesFree), + bytesWasted(alloc->bytesWasted), + stat_all(alloc,ANY_TYPE), + stat_malloc(alloc,ALIGNED_MALLOC), + stat_4K(alloc,OS_MALLOC,false), + stat_2M(alloc,OS_MALLOC,true), + stat_shared(alloc,SHARED) {} + + AllStatistics (size_t bytesUsed, + size_t bytesFree, + size_t bytesWasted, + Statistics stat_all, + Statistics stat_malloc, + Statistics stat_4K, + Statistics stat_2M, + Statistics stat_shared) + + : bytesUsed(bytesUsed), + bytesFree(bytesFree), + bytesWasted(bytesWasted), + stat_all(stat_all), + stat_malloc(stat_malloc), + stat_4K(stat_4K), + stat_2M(stat_2M), + stat_shared(stat_shared) {} + + friend AllStatistics operator+ (const AllStatistics& a, const AllStatistics& b) + { + return AllStatistics(a.bytesUsed+b.bytesUsed, + a.bytesFree+b.bytesFree, + a.bytesWasted+b.bytesWasted, + a.stat_all + b.stat_all, + a.stat_malloc + b.stat_malloc, + a.stat_4K + b.stat_4K, + a.stat_2M + b.stat_2M, + a.stat_shared + b.stat_shared); + } + + void print(size_t numPrimitives) + { + std::stringstream str0; + str0.setf(std::ios::fixed, std::ios::floatfield); + str0 << " alloc : " + << "used = " << std::setw(7) << std::setprecision(3) << 1E-6f*bytesUsed << " MB, " + << " " + << "#bytes/prim = " << std::setw(6) << std::setprecision(2) << double(bytesUsed)/double(numPrimitives); + std::cout << str0.str() << std::endl; + + std::stringstream str1; + str1.setf(std::ios::fixed, std::ios::floatfield); + str1 << " alloc : " + << "used = " << std::setw(7) << std::setprecision(3) << 1E-6f*bytesUsed << " MB, " + << "free = " << std::setw(7) << std::setprecision(3) << 1E-6f*bytesFree << " MB, " + << "wasted = " << std::setw(7) << std::setprecision(3) << 1E-6f*bytesWasted << " MB, " + << "total = " << std::setw(7) << std::setprecision(3) << 1E-6f*(bytesUsed+bytesFree+bytesWasted) << " MB, " + << "#bytes/prim = " << std::setw(6) << std::setprecision(2) << double(bytesUsed+bytesFree+bytesWasted)/double(numPrimitives); + std::cout << str1.str() << std::endl; + + std::cout << " total : " << stat_all.str(numPrimitives) << std::endl; + std::cout << " 4K : " << stat_4K.str(numPrimitives) << std::endl; + std::cout << " 2M : " << stat_2M.str(numPrimitives) << std::endl; + std::cout << " malloc: " << stat_malloc.str(numPrimitives) << std::endl; + std::cout << " shared: " << stat_shared.str(numPrimitives) << std::endl; + } + + private: + size_t bytesUsed; + size_t bytesFree; + size_t bytesWasted; + Statistics stat_all; + Statistics stat_malloc; + Statistics stat_4K; + Statistics stat_2M; + Statistics stat_shared; + }; + + void print_blocks() + { + std::cout << " estimatedSize = " << estimatedSize << ", slotMask = " << slotMask << ", use_single_mode = " << use_single_mode << ", maxGrowSize = " << maxGrowSize << ", defaultBlockSize = " << defaultBlockSize << std::endl; + + std::cout << " used blocks = "; + if (usedBlocks.load() != nullptr) usedBlocks.load()->print_list(); + std::cout << "[END]" << std::endl; + + std::cout << " free blocks = "; + if (freeBlocks.load() != nullptr) freeBlocks.load()->print_list(); + std::cout << "[END]" << std::endl; + } + + private: + + struct Block + { + static Block* create(MemoryMonitorInterface* device, size_t bytesAllocate, size_t bytesReserve, Block* next, AllocationType atype) + { + /* We avoid using os_malloc for small blocks as this could + * cause a risk of fragmenting the virtual address space and + * reach the limit of vm.max_map_count = 65k under Linux. */ + if (atype == OS_MALLOC && bytesAllocate < maxAllocationSize) + atype = ALIGNED_MALLOC; + + /* we need to additionally allocate some header */ + const size_t sizeof_Header = offsetof(Block,data[0]); + bytesAllocate = sizeof_Header+bytesAllocate; + bytesReserve = sizeof_Header+bytesReserve; + + /* consume full 4k pages with using os_malloc */ + if (atype == OS_MALLOC) { + bytesAllocate = ((bytesAllocate+PAGE_SIZE-1) & ~(PAGE_SIZE-1)); + bytesReserve = ((bytesReserve +PAGE_SIZE-1) & ~(PAGE_SIZE-1)); + } + + /* either use alignedMalloc or os_malloc */ + void *ptr = nullptr; + if (atype == ALIGNED_MALLOC) + { + /* special handling for default block size */ + if (bytesAllocate == (2*PAGE_SIZE_2M)) + { + const size_t alignment = maxAlignment; + if (device) device->memoryMonitor(bytesAllocate+alignment,false); + ptr = alignedMalloc(bytesAllocate,alignment); + + /* give hint to transparently convert these pages to 2MB pages */ + const size_t ptr_aligned_begin = ((size_t)ptr) & ~size_t(PAGE_SIZE_2M-1); + os_advise((void*)(ptr_aligned_begin + 0),PAGE_SIZE_2M); // may fail if no memory mapped before block + os_advise((void*)(ptr_aligned_begin + 1*PAGE_SIZE_2M),PAGE_SIZE_2M); + os_advise((void*)(ptr_aligned_begin + 2*PAGE_SIZE_2M),PAGE_SIZE_2M); // may fail if no memory mapped after block + + return new (ptr) Block(ALIGNED_MALLOC,bytesAllocate-sizeof_Header,bytesAllocate-sizeof_Header,next,alignment); + } + else + { + const size_t alignment = maxAlignment; + if (device) device->memoryMonitor(bytesAllocate+alignment,false); + ptr = alignedMalloc(bytesAllocate,alignment); + return new (ptr) Block(ALIGNED_MALLOC,bytesAllocate-sizeof_Header,bytesAllocate-sizeof_Header,next,alignment); + } + } + else if (atype == OS_MALLOC) + { + if (device) device->memoryMonitor(bytesAllocate,false); + bool huge_pages; ptr = os_malloc(bytesReserve,huge_pages); + return new (ptr) Block(OS_MALLOC,bytesAllocate-sizeof_Header,bytesReserve-sizeof_Header,next,0,huge_pages); + } + else + assert(false); + + return NULL; + } + + Block (AllocationType atype, size_t bytesAllocate, size_t bytesReserve, Block* next, size_t wasted, bool huge_pages = false) + : cur(0), allocEnd(bytesAllocate), reserveEnd(bytesReserve), next(next), wasted(wasted), atype(atype), huge_pages(huge_pages) + { + assert((((size_t)&data[0]) & (maxAlignment-1)) == 0); + } + + static Block* remove_shared_blocks(Block* head) + { + Block** prev_next = &head; + for (Block* block = head; block; block = block->next) { + if (block->atype == SHARED) *prev_next = block->next; + else prev_next = &block->next; + } + return head; + } + + void clear_list(MemoryMonitorInterface* device) + { + Block* block = this; + while (block) { + Block* next = block->next; + block->clear_block(device); + block = next; + } + } + + void clear_block (MemoryMonitorInterface* device) + { + const size_t sizeof_Header = offsetof(Block,data[0]); + const ssize_t sizeof_Alloced = wasted+sizeof_Header+getBlockAllocatedBytes(); + + if (atype == ALIGNED_MALLOC) { + alignedFree(this); + if (device) device->memoryMonitor(-sizeof_Alloced,true); + } + + else if (atype == OS_MALLOC) { + size_t sizeof_This = sizeof_Header+reserveEnd; + os_free(this,sizeof_This,huge_pages); + if (device) device->memoryMonitor(-sizeof_Alloced,true); + } + + else /* if (atype == SHARED) */ { + } + } + + void* malloc(MemoryMonitorInterface* device, size_t& bytes_in, size_t align, bool partial) + { + size_t bytes = bytes_in; + assert(align <= maxAlignment); + bytes = (bytes+(align-1)) & ~(align-1); + if (unlikely(cur+bytes > reserveEnd && !partial)) return nullptr; + const size_t i = cur.fetch_add(bytes); + if (unlikely(i+bytes > reserveEnd && !partial)) return nullptr; + if (unlikely(i > reserveEnd)) return nullptr; + bytes_in = bytes = min(bytes,reserveEnd-i); + + if (i+bytes > allocEnd) { + if (device) device->memoryMonitor(i+bytes-max(i,allocEnd),true); + } + return &data[i]; + } + + void* ptr() { + return &data[cur]; + } + + void reset_block () + { + allocEnd = max(allocEnd,(size_t)cur); + cur = 0; + } + + size_t getBlockUsedBytes() const { + return min(size_t(cur),reserveEnd); + } + + size_t getBlockFreeBytes() const { + return getBlockAllocatedBytes() - getBlockUsedBytes(); + } + + size_t getBlockAllocatedBytes() const { + return min(max(allocEnd,size_t(cur)),reserveEnd); + } + + size_t getBlockWastedBytes() const { + const size_t sizeof_Header = offsetof(Block,data[0]); + return sizeof_Header + wasted; + } + + size_t getBlockReservedBytes() const { + return reserveEnd; + } + + bool hasType(AllocationType atype_i, bool huge_pages_i) const + { + if (atype_i == ANY_TYPE ) return true; + else if (atype == OS_MALLOC) return atype_i == atype && huge_pages_i == huge_pages; + else return atype_i == atype; + } + + size_t getUsedBytes(AllocationType atype, bool huge_pages = false) const { + size_t bytes = 0; + for (const Block* block = this; block; block = block->next) { + if (!block->hasType(atype,huge_pages)) continue; + bytes += block->getBlockUsedBytes(); + } + return bytes; + } + + size_t getFreeBytes(AllocationType atype, bool huge_pages = false) const { + size_t bytes = 0; + for (const Block* block = this; block; block = block->next) { + if (!block->hasType(atype,huge_pages)) continue; + bytes += block->getBlockFreeBytes(); + } + return bytes; + } + + size_t getWastedBytes(AllocationType atype, bool huge_pages = false) const { + size_t bytes = 0; + for (const Block* block = this; block; block = block->next) { + if (!block->hasType(atype,huge_pages)) continue; + bytes += block->getBlockWastedBytes(); + } + return bytes; + } + + size_t getAllocatedBytes(AllocationType atype, bool huge_pages = false) const { + size_t bytes = 0; + for (const Block* block = this; block; block = block->next) { + if (!block->hasType(atype,huge_pages)) continue; + bytes += block->getBlockAllocatedBytes(); + } + return bytes; + } + + void print_list () + { + for (const Block* block = this; block; block = block->next) + block->print_block(); + } + + void print_block() const + { + if (atype == ALIGNED_MALLOC) std::cout << "A"; + else if (atype == OS_MALLOC) std::cout << "O"; + else if (atype == SHARED) std::cout << "S"; + if (huge_pages) std::cout << "H"; + size_t bytesUsed = getBlockUsedBytes(); + size_t bytesFree = getBlockFreeBytes(); + size_t bytesWasted = getBlockWastedBytes(); + std::cout << "[" << bytesUsed << ", " << bytesFree << ", " << bytesWasted << "] "; + } + + public: + std::atomic cur; //!< current location of the allocator + std::atomic allocEnd; //!< end of the allocated memory region + std::atomic reserveEnd; //!< end of the reserved memory region + Block* next; //!< pointer to next block in list + size_t wasted; //!< amount of memory wasted through block alignment + AllocationType atype; //!< allocation mode of the block + bool huge_pages; //!< whether the block uses huge pages + char align[maxAlignment-5*sizeof(size_t)-sizeof(AllocationType)-sizeof(bool)]; //!< align data to maxAlignment + char data[1]; //!< here starts memory to use for allocations + }; + + private: + Device* device; + SpinLock mutex; + size_t slotMask; + std::atomic threadUsedBlocks[MAX_THREAD_USED_BLOCK_SLOTS]; + std::atomic usedBlocks; + std::atomic freeBlocks; + + std::atomic threadBlocks[MAX_THREAD_USED_BLOCK_SLOTS]; + SpinLock slotMutex[MAX_THREAD_USED_BLOCK_SLOTS]; + + bool use_single_mode; + size_t defaultBlockSize; + size_t estimatedSize; + size_t growSize; + size_t maxGrowSize; + std::atomic log2_grow_size_scale; //!< log2 of scaling factor for grow size // FIXME: remove + std::atomic bytesUsed; + std::atomic bytesFree; + std::atomic bytesWasted; + static __thread ThreadLocal2* thread_local_allocator2; + static SpinLock s_thread_local_allocators_lock; + static std::vector> s_thread_local_allocators; + SpinLock thread_local_allocators_lock; + std::vector thread_local_allocators; + AllocationType atype; + mvector primrefarray; //!< primrefarray used to allocate nodes + }; +} diff --git a/thirdparty/embree/kernels/common/buffer.h b/thirdparty/embree/kernels/common/buffer.h new file mode 100644 index 000000000000..02d319c59dd7 --- /dev/null +++ b/thirdparty/embree/kernels/common/buffer.h @@ -0,0 +1,263 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" +#include "device.h" + +namespace embree +{ + /*! Implements an API data buffer object. This class may or may not own the data. */ + class Buffer : public RefCount + { + public: + /*! Buffer construction */ + Buffer() + : device(nullptr), ptr(nullptr), numBytes(0), shared(false) {} + + /*! Buffer construction */ + Buffer(Device* device, size_t numBytes_in, void* ptr_in = nullptr) + : device(device), numBytes(numBytes_in) + { + device->refInc(); + + if (ptr_in) + { + shared = true; + ptr = (char*)ptr_in; + } + else + { + shared = false; + alloc(); + } + } + + /*! Buffer destruction */ + ~Buffer() { + free(); + device->refDec(); + } + + /*! this class is not copyable */ + private: + Buffer(const Buffer& other) DELETED; // do not implement + Buffer& operator =(const Buffer& other) DELETED; // do not implement + + public: + /* inits and allocates the buffer */ + void create(Device* device_in, size_t numBytes_in) + { + init(device_in, numBytes_in); + alloc(); + } + + /* inits the buffer */ + void init(Device* device_in, size_t numBytes_in) + { + free(); + device = device_in; + ptr = nullptr; + numBytes = numBytes_in; + shared = false; + } + + /*! sets shared buffer */ + void set(Device* device_in, void* ptr_in, size_t numBytes_in) + { + free(); + device = device_in; + ptr = (char*)ptr_in; + if (numBytes_in != (size_t)-1) + numBytes = numBytes_in; + shared = true; + } + + /*! allocated buffer */ + void alloc() + { + if (device) + device->memoryMonitor(this->bytes(), false); + size_t b = (this->bytes()+15) & ssize_t(-16); + ptr = (char*)alignedMalloc(b,16); + } + + /*! frees the buffer */ + void free() + { + if (shared) return; + alignedFree(ptr); + if (device) + device->memoryMonitor(-ssize_t(this->bytes()), true); + ptr = nullptr; + } + + /*! gets buffer pointer */ + void* data() + { + /* report error if buffer is not existing */ + if (!device) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "invalid buffer specified"); + + /* return buffer */ + return ptr; + } + + /*! returns pointer to first element */ + __forceinline char* getPtr() const { + return ptr; + } + + /*! returns the number of bytes of the buffer */ + __forceinline size_t bytes() const { + return numBytes; + } + + /*! returns true of the buffer is not empty */ + __forceinline operator bool() const { + return ptr; + } + + public: + Device* device; //!< device to report memory usage to + char* ptr; //!< pointer to buffer data + size_t numBytes; //!< number of bytes in the buffer + bool shared; //!< set if memory is shared with application + }; + + /*! An untyped contiguous range of a buffer. This class does not own the buffer content. */ + class RawBufferView + { + public: + /*! Buffer construction */ + RawBufferView() + : ptr_ofs(nullptr), stride(0), num(0), format(RTC_FORMAT_UNDEFINED), modCounter(1), modified(true), userData(0) {} + + public: + /*! sets the buffer view */ + void set(const Ref& buffer_in, size_t offset_in, size_t stride_in, size_t num_in, RTCFormat format_in) + { + if ((offset_in + stride_in * num_in) > (stride_in * buffer_in->numBytes)) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "buffer range out of bounds"); + + ptr_ofs = buffer_in->ptr + offset_in; + stride = stride_in; + num = num_in; + format = format_in; + modCounter++; + modified = true; + buffer = buffer_in; + } + + /*! returns pointer to the first element */ + __forceinline char* getPtr() const { + return ptr_ofs; + } + + /*! returns pointer to the i'th element */ + __forceinline char* getPtr(size_t i) const + { + assert(i otherModCounter; + } + + /*! mark buffer as modified or unmodified */ + __forceinline bool isLocalModified() const { + return modified; + } + + /*! clear local modified flag */ + __forceinline void clearLocalModified() { + modified = false; + } + + /*! returns true of the buffer is not empty */ + __forceinline operator bool() const { + return ptr_ofs; + } + + /*! checks padding to 16 byte check, fails hard */ + __forceinline void checkPadding16() const + { + if (ptr_ofs && num) + volatile int MAYBE_UNUSED w = *((int*)getPtr(size()-1)+3); // FIXME: is failing hard avoidable? + } + + public: + char* ptr_ofs; //!< base pointer plus offset + size_t stride; //!< stride of the buffer in bytes + size_t num; //!< number of elements in the buffer + RTCFormat format; //!< format of the buffer + unsigned int modCounter; //!< version ID of this buffer + bool modified; //!< local modified data + int userData; //!< special data + Ref buffer; //!< reference to the parent buffer + }; + + /*! A typed contiguous range of a buffer. This class does not own the buffer content. */ + template + class BufferView : public RawBufferView + { + public: + typedef T value_type; + + /*! access to the ith element of the buffer */ + __forceinline T& operator [](size_t i) { assert(i + class BufferView : public RawBufferView + { + public: + typedef Vec3fa value_type; + + /*! access to the ith element of the buffer */ + __forceinline const Vec3fa operator [](size_t i) const + { + assert(i + struct ProgressMonitorClosure : BuildProgressMonitor + { + public: + ProgressMonitorClosure (const Closure& closure) : closure(closure) {} + void operator() (size_t dn) const { closure(dn); } + private: + const Closure closure; + }; + template __forceinline const ProgressMonitorClosure BuildProgressMonitorFromClosure(const Closure& closure) { + return ProgressMonitorClosure(closure); + } + + struct LineSegments; + struct TriangleMesh; + struct QuadMesh; + struct UserGeometry; + + class Scene; + + typedef void (*createLineSegmentsAccelTy)(Scene* scene, LineSegments* mesh, AccelData*& accel, Builder*& builder); + typedef void (*createTriangleMeshAccelTy)(Scene* scene, unsigned int geomID, AccelData*& accel, Builder*& builder); + typedef void (*createQuadMeshAccelTy)(Scene* scene, unsigned int geomID, AccelData*& accel, Builder*& builder); + typedef void (*createUserGeometryAccelTy)(Scene* scene, unsigned int geomID, AccelData*& accel, Builder*& builder); + +} diff --git a/thirdparty/embree/kernels/common/context.h b/thirdparty/embree/kernels/common/context.h new file mode 100644 index 000000000000..d0185a74f2e5 --- /dev/null +++ b/thirdparty/embree/kernels/common/context.h @@ -0,0 +1,131 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" +#include "rtcore.h" +#include "point_query.h" + +namespace embree +{ + class Scene; + + struct IntersectContext + { + public: + __forceinline IntersectContext(Scene* scene, RTCIntersectContext* user_context) + : scene(scene), user(user_context) {} + + __forceinline bool hasContextFilter() const { + return user->filter != nullptr; + } + + __forceinline bool isCoherent() const { + return embree::isCoherent(user->flags); + } + + __forceinline bool isIncoherent() const { + return embree::isIncoherent(user->flags); + } + + public: + Scene* scene; + RTCIntersectContext* user; + }; + + template + __forceinline Vec4vf enlargeRadiusToMinWidth(const IntersectContext* context, const Geometry* geom, const Vec3vf& ray_org, const Vec4vf& v) + { +#if RTC_MIN_WIDTH + const vfloat d = length(Vec3vf(v) - ray_org); + const vfloat r = clamp(context->user->minWidthDistanceFactor*d, v.w, geom->maxRadiusScale*v.w); + return Vec4vf(v.x,v.y,v.z,r); +#else + return v; +#endif + } + + template + __forceinline Vec3ff enlargeRadiusToMinWidth(const IntersectContext* context, const Geometry* geom, const Vec3fa& ray_org, const Vec3ff& v) + { +#if RTC_MIN_WIDTH + const float d = length(Vec3fa(v) - ray_org); + const float r = clamp(context->user->minWidthDistanceFactor*d, v.w, geom->maxRadiusScale*v.w); + return Vec3ff(v.x,v.y,v.z,r); +#else + return v; +#endif + } + + enum PointQueryType + { + POINT_QUERY_TYPE_UNDEFINED = 0, + POINT_QUERY_TYPE_SPHERE = 1, + POINT_QUERY_TYPE_AABB = 2, + }; + + typedef bool (*PointQueryFunction)(struct RTCPointQueryFunctionArguments* args); + + struct PointQueryContext + { + public: + __forceinline PointQueryContext(Scene* scene, + PointQuery* query_ws, + PointQueryType query_type, + PointQueryFunction func, + RTCPointQueryContext* userContext, + float similarityScale, + void* userPtr) + : scene(scene) + , query_ws(query_ws) + , query_type(query_type) + , func(func) + , userContext(userContext) + , similarityScale(similarityScale) + , userPtr(userPtr) + , primID(RTC_INVALID_GEOMETRY_ID) + , geomID(RTC_INVALID_GEOMETRY_ID) + , query_radius(query_ws->radius) + { + if (query_type == POINT_QUERY_TYPE_AABB) { + assert(similarityScale == 0.f); + updateAABB(); + } + if (userContext->instStackSize == 0) { + assert(similarityScale == 1.f); + } + } + + public: + __forceinline void updateAABB() + { + if (likely(query_ws->radius == (float)inf || userContext->instStackSize == 0)) { + query_radius = Vec3fa(query_ws->radius); + return; + } + + const AffineSpace3fa m = AffineSpace3fa_load_unaligned((AffineSpace3fa*)userContext->world2inst[userContext->instStackSize-1]); + BBox3fa bbox(Vec3fa(-query_ws->radius), Vec3fa(query_ws->radius)); + bbox = xfmBounds(m, bbox); + query_radius = 0.5f * (bbox.upper - bbox.lower); + } + +public: + Scene* scene; + + PointQuery* query_ws; // the original world space point query + PointQueryType query_type; + PointQueryFunction func; + RTCPointQueryContext* userContext; + const float similarityScale; + + void* userPtr; + + unsigned int primID; + unsigned int geomID; + + Vec3fa query_radius; // used if the query is converted to an AABB internally + }; +} + diff --git a/thirdparty/embree/kernels/common/default.h b/thirdparty/embree/kernels/common/default.h new file mode 100644 index 000000000000..3db53413bc57 --- /dev/null +++ b/thirdparty/embree/kernels/common/default.h @@ -0,0 +1,268 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../../common/sys/platform.h" +#include "../../common/sys/sysinfo.h" +#include "../../common/sys/thread.h" +#include "../../common/sys/alloc.h" +#include "../../common/sys/ref.h" +#include "../../common/sys/intrinsics.h" +#include "../../common/sys/atomic.h" +#include "../../common/sys/mutex.h" +#include "../../common/sys/vector.h" +#include "../../common/sys/array.h" +#include "../../common/sys/string.h" +#include "../../common/sys/regression.h" +#include "../../common/sys/vector.h" + +#include "../../common/math/math.h" +#include "../../common/math/transcendental.h" +#include "../../common/simd/simd.h" +#include "../../common/math/vec2.h" +#include "../../common/math/vec3.h" +#include "../../common/math/vec4.h" +#include "../../common/math/vec2fa.h" +#include "../../common/math/vec3fa.h" +#include "../../common/math/interval.h" +#include "../../common/math/bbox.h" +#include "../../common/math/obbox.h" +#include "../../common/math/lbbox.h" +#include "../../common/math/linearspace2.h" +#include "../../common/math/linearspace3.h" +#include "../../common/math/affinespace.h" +#include "../../common/math/range.h" +#include "../../common/lexers/tokenstream.h" + +#include "../../common/tasking/taskscheduler.h" + +#define COMMA , + +#include "../config.h" +#include "isa.h" +#include "stat.h" +#include "profile.h" +#include "rtcore.h" +#include "vector.h" +#include "state.h" +#include "instance_stack.h" + +#include +#include +#include +#include +#include +#include + +namespace embree +{ + //////////////////////////////////////////////////////////////////////////////// + /// Vec2 shortcuts + //////////////////////////////////////////////////////////////////////////////// + + template using Vec2vf = Vec2>; + template using Vec2vd = Vec2>; + template using Vec2vr = Vec2>; + template using Vec2vi = Vec2>; + template using Vec2vl = Vec2>; + template using Vec2vb = Vec2>; + template using Vec2vbf = Vec2>; + template using Vec2vbd = Vec2>; + + typedef Vec2 Vec2vf4; + typedef Vec2 Vec2vd4; + typedef Vec2 Vec2vr4; + typedef Vec2 Vec2vi4; + typedef Vec2 Vec2vl4; + typedef Vec2 Vec2vb4; + typedef Vec2 Vec2vbf4; + typedef Vec2 Vec2vbd4; + + typedef Vec2 Vec2vf8; + typedef Vec2 Vec2vd8; + typedef Vec2 Vec2vr8; + typedef Vec2 Vec2vi8; + typedef Vec2 Vec2vl8; + typedef Vec2 Vec2vb8; + typedef Vec2 Vec2vbf8; + typedef Vec2 Vec2vbd8; + + typedef Vec2 Vec2vf16; + typedef Vec2 Vec2vd16; + typedef Vec2 Vec2vr16; + typedef Vec2 Vec2vi16; + typedef Vec2 Vec2vl16; + typedef Vec2 Vec2vb16; + typedef Vec2 Vec2vbf16; + typedef Vec2 Vec2vbd16; + + typedef Vec2 Vec2vfx; + typedef Vec2 Vec2vdx; + typedef Vec2 Vec2vrx; + typedef Vec2 Vec2vix; + typedef Vec2 Vec2vlx; + typedef Vec2 Vec2vbx; + typedef Vec2 Vec2vbfx; + typedef Vec2 Vec2vbdx; + + //////////////////////////////////////////////////////////////////////////////// + /// Vec3 shortcuts + //////////////////////////////////////////////////////////////////////////////// + + template using Vec3vf = Vec3>; + template using Vec3vd = Vec3>; + template using Vec3vr = Vec3>; + template using Vec3vi = Vec3>; + template using Vec3vl = Vec3>; + template using Vec3vb = Vec3>; + template using Vec3vbf = Vec3>; + template using Vec3vbd = Vec3>; + + typedef Vec3 Vec3vf4; + typedef Vec3 Vec3vd4; + typedef Vec3 Vec3vr4; + typedef Vec3 Vec3vi4; + typedef Vec3 Vec3vl4; + typedef Vec3 Vec3vb4; + typedef Vec3 Vec3vbf4; + typedef Vec3 Vec3vbd4; + + typedef Vec3 Vec3vf8; + typedef Vec3 Vec3vd8; + typedef Vec3 Vec3vr8; + typedef Vec3 Vec3vi8; + typedef Vec3 Vec3vl8; + typedef Vec3 Vec3vb8; + typedef Vec3 Vec3vbf8; + typedef Vec3 Vec3vbd8; + + typedef Vec3 Vec3vf16; + typedef Vec3 Vec3vd16; + typedef Vec3 Vec3vr16; + typedef Vec3 Vec3vi16; + typedef Vec3 Vec3vl16; + typedef Vec3 Vec3vb16; + typedef Vec3 Vec3vbf16; + typedef Vec3 Vec3vbd16; + + typedef Vec3 Vec3vfx; + typedef Vec3 Vec3vdx; + typedef Vec3 Vec3vrx; + typedef Vec3 Vec3vix; + typedef Vec3 Vec3vlx; + typedef Vec3 Vec3vbx; + typedef Vec3 Vec3vbfx; + typedef Vec3 Vec3vbdx; + + //////////////////////////////////////////////////////////////////////////////// + /// Vec4 shortcuts + //////////////////////////////////////////////////////////////////////////////// + + template using Vec4vf = Vec4>; + template using Vec4vd = Vec4>; + template using Vec4vr = Vec4>; + template using Vec4vi = Vec4>; + template using Vec4vl = Vec4>; + template using Vec4vb = Vec4>; + template using Vec4vbf = Vec4>; + template using Vec4vbd = Vec4>; + + typedef Vec4 Vec4vf4; + typedef Vec4 Vec4vd4; + typedef Vec4 Vec4vr4; + typedef Vec4 Vec4vi4; + typedef Vec4 Vec4vl4; + typedef Vec4 Vec4vb4; + typedef Vec4 Vec4vbf4; + typedef Vec4 Vec4vbd4; + + typedef Vec4 Vec4vf8; + typedef Vec4 Vec4vd8; + typedef Vec4 Vec4vr8; + typedef Vec4 Vec4vi8; + typedef Vec4 Vec4vl8; + typedef Vec4 Vec4vb8; + typedef Vec4 Vec4vbf8; + typedef Vec4 Vec4vbd8; + + typedef Vec4 Vec4vf16; + typedef Vec4 Vec4vd16; + typedef Vec4 Vec4vr16; + typedef Vec4 Vec4vi16; + typedef Vec4 Vec4vl16; + typedef Vec4 Vec4vb16; + typedef Vec4 Vec4vbf16; + typedef Vec4 Vec4vbd16; + + typedef Vec4 Vec4vfx; + typedef Vec4 Vec4vdx; + typedef Vec4 Vec4vrx; + typedef Vec4 Vec4vix; + typedef Vec4 Vec4vlx; + typedef Vec4 Vec4vbx; + typedef Vec4 Vec4vbfx; + typedef Vec4 Vec4vbdx; + + //////////////////////////////////////////////////////////////////////////////// + /// Other shortcuts + //////////////////////////////////////////////////////////////////////////////// + + template using BBox3vf = BBox>; + typedef BBox BBox3vf4; + typedef BBox BBox3vf8; + typedef BBox BBox3vf16; + + /* calculate time segment itime and fractional time ftime */ + __forceinline int getTimeSegment(float time, float numTimeSegments, float& ftime) + { + const float timeScaled = time * numTimeSegments; + const float itimef = clamp(floorf(timeScaled), 0.0f, numTimeSegments-1.0f); + ftime = timeScaled - itimef; + return int(itimef); + } + + __forceinline int getTimeSegment(float time, float start_time, float end_time, float numTimeSegments, float& ftime) + { + const float timeScaled = (time-start_time)/(end_time-start_time) * numTimeSegments; + const float itimef = clamp(floorf(timeScaled), 0.0f, numTimeSegments-1.0f); + ftime = timeScaled - itimef; + return int(itimef); + } + + template + __forceinline vint getTimeSegment(const vfloat& time, const vfloat& numTimeSegments, vfloat& ftime) + { + const vfloat timeScaled = time * numTimeSegments; + const vfloat itimef = clamp(floor(timeScaled), vfloat(zero), numTimeSegments-1.0f); + ftime = timeScaled - itimef; + return vint(itimef); + } + + template + __forceinline vint getTimeSegment(const vfloat& time, const vfloat& start_time, const vfloat& end_time, const vfloat& numTimeSegments, vfloat& ftime) + { + const vfloat timeScaled = (time-start_time)/(end_time-start_time) * numTimeSegments; + const vfloat itimef = clamp(floor(timeScaled), vfloat(zero), numTimeSegments-1.0f); + ftime = timeScaled - itimef; + return vint(itimef); + } + + /* calculate overlapping time segment range */ + __forceinline range getTimeSegmentRange(const BBox1f& time_range, float numTimeSegments) + { + const float round_up = 1.0f+2.0f*float(ulp); // corrects inaccuracies to precisely match time step + const float round_down = 1.0f-2.0f*float(ulp); + const int itime_lower = (int)max(floor(round_up *time_range.lower*numTimeSegments), 0.0f); + const int itime_upper = (int)min(ceil (round_down*time_range.upper*numTimeSegments), numTimeSegments); + return make_range(itime_lower, itime_upper); + } + + /* calculate overlapping time segment range */ + __forceinline range getTimeSegmentRange(const BBox1f& range, BBox1f time_range, float numTimeSegments) + { + const float lower = (range.lower-time_range.lower)/time_range.size(); + const float upper = (range.upper-time_range.lower)/time_range.size(); + return getTimeSegmentRange(BBox1f(lower,upper),numTimeSegments); + } +} diff --git a/thirdparty/embree/kernels/common/device.cpp b/thirdparty/embree/kernels/common/device.cpp new file mode 100644 index 000000000000..d8c3d9c748f5 --- /dev/null +++ b/thirdparty/embree/kernels/common/device.cpp @@ -0,0 +1,560 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "device.h" +#include "../hash.h" +#include "scene_triangle_mesh.h" +#include "scene_user_geometry.h" +#include "scene_instance.h" +#include "scene_curves.h" +#include "scene_subdiv_mesh.h" + +#include "../subdiv/tessellation_cache.h" + +#include "acceln.h" +#include "geometry.h" + +#include "../geometry/cylinder.h" + +#include "../bvh/bvh4_factory.h" +#include "../bvh/bvh8_factory.h" + +#include "../../common/tasking/taskscheduler.h" +#include "../../common/sys/alloc.h" + +namespace embree +{ + /*! some global variables that can be set via rtcSetParameter1i for debugging purposes */ + ssize_t Device::debug_int0 = 0; + ssize_t Device::debug_int1 = 0; + ssize_t Device::debug_int2 = 0; + ssize_t Device::debug_int3 = 0; + + DECLARE_SYMBOL2(RayStreamFilterFuncs,rayStreamFilterFuncs); + + static MutexSys g_mutex; + static std::map g_cache_size_map; + static std::map g_num_threads_map; + + Device::Device (const char* cfg) + { + /* check that CPU supports lowest ISA */ + if (!hasISA(ISA)) { + throw_RTCError(RTC_ERROR_UNSUPPORTED_CPU,"CPU does not support " ISA_STR); + } + + /* set default frequency level for detected CPU */ + switch (getCPUModel()) { + case CPU::UNKNOWN: frequency_level = FREQUENCY_SIMD256; break; + case CPU::XEON_ICE_LAKE: frequency_level = FREQUENCY_SIMD256; break; + case CPU::CORE_ICE_LAKE: frequency_level = FREQUENCY_SIMD256; break; + case CPU::CORE_TIGER_LAKE: frequency_level = FREQUENCY_SIMD128; break; + case CPU::CORE_COMET_LAKE: frequency_level = FREQUENCY_SIMD128; break; + case CPU::CORE_CANNON_LAKE:frequency_level = FREQUENCY_SIMD128; break; + case CPU::CORE_KABY_LAKE: frequency_level = FREQUENCY_SIMD128; break; + case CPU::XEON_SKY_LAKE: frequency_level = FREQUENCY_SIMD128; break; + case CPU::CORE_SKY_LAKE: frequency_level = FREQUENCY_SIMD128; break; + case CPU::XEON_BROADWELL: frequency_level = FREQUENCY_SIMD256; break; + case CPU::CORE_BROADWELL: frequency_level = FREQUENCY_SIMD256; break; + case CPU::XEON_HASWELL: frequency_level = FREQUENCY_SIMD256; break; + case CPU::CORE_HASWELL: frequency_level = FREQUENCY_SIMD256; break; + case CPU::XEON_IVY_BRIDGE: frequency_level = FREQUENCY_SIMD256; break; + case CPU::CORE_IVY_BRIDGE: frequency_level = FREQUENCY_SIMD256; break; + case CPU::SANDY_BRIDGE: frequency_level = FREQUENCY_SIMD256; break; + case CPU::NEHALEM: frequency_level = FREQUENCY_SIMD128; break; + case CPU::CORE2: frequency_level = FREQUENCY_SIMD128; break; + case CPU::CORE1: frequency_level = FREQUENCY_SIMD128; break; + } + + /* initialize global state */ +#if defined(EMBREE_CONFIG) + State::parseString(EMBREE_CONFIG); +#endif + State::parseString(cfg); + if (!ignore_config_files && FileName::executableFolder() != FileName("")) + State::parseFile(FileName::executableFolder()+FileName(".embree" TOSTRING(RTC_VERSION_MAJOR))); + if (!ignore_config_files && FileName::homeFolder() != FileName("")) + State::parseFile(FileName::homeFolder()+FileName(".embree" TOSTRING(RTC_VERSION_MAJOR))); + State::verify(); + + /* check whether selected ISA is supported by the HW, as the user could have forced an unsupported ISA */ + if (!checkISASupport()) { + throw_RTCError(RTC_ERROR_UNSUPPORTED_CPU,"CPU does not support selected ISA"); + } + + /*! do some internal tests */ + assert(isa::Cylinder::verify()); + + /*! enable huge page support if desired */ +#if defined(__WIN32__) + if (State::enable_selockmemoryprivilege) + State::hugepages_success &= win_enable_selockmemoryprivilege(State::verbosity(3)); +#endif + State::hugepages_success &= os_init(State::hugepages,State::verbosity(3)); + + /*! set tessellation cache size */ + setCacheSize( State::tessellation_cache_size ); + + /*! enable some floating point exceptions to catch bugs */ + if (State::float_exceptions) + { + int exceptions = _MM_MASK_MASK; + //exceptions &= ~_MM_MASK_INVALID; + exceptions &= ~_MM_MASK_DENORM; + exceptions &= ~_MM_MASK_DIV_ZERO; + //exceptions &= ~_MM_MASK_OVERFLOW; + //exceptions &= ~_MM_MASK_UNDERFLOW; + //exceptions &= ~_MM_MASK_INEXACT; + _MM_SET_EXCEPTION_MASK(exceptions); + } + + /* print info header */ + if (State::verbosity(1)) + print(); + if (State::verbosity(2)) + State::print(); + + /* register all algorithms */ + bvh4_factory = make_unique(new BVH4Factory(enabled_builder_cpu_features, enabled_cpu_features)); + +#if defined(EMBREE_TARGET_SIMD8) + bvh8_factory = make_unique(new BVH8Factory(enabled_builder_cpu_features, enabled_cpu_features)); +#endif + + /* setup tasking system */ + initTaskingSystem(numThreads); + + /* ray stream SOA to AOS conversion */ +#if defined(EMBREE_RAY_PACKETS) + RayStreamFilterFuncsType rayStreamFilterFuncs; + SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(enabled_cpu_features,rayStreamFilterFuncs); + rayStreamFilters = rayStreamFilterFuncs(); +#endif + } + + Device::~Device () + { + setCacheSize(0); + exitTaskingSystem(); + } + + std::string getEnabledTargets() + { + std::string v; +#if defined(EMBREE_TARGET_SSE2) + v += "SSE2 "; +#endif +#if defined(EMBREE_TARGET_SSE42) + v += "SSE4.2 "; +#endif +#if defined(EMBREE_TARGET_AVX) + v += "AVX "; +#endif +#if defined(EMBREE_TARGET_AVX2) + v += "AVX2 "; +#endif +#if defined(EMBREE_TARGET_AVX512KNL) + v += "AVX512KNL "; +#endif +#if defined(EMBREE_TARGET_AVX512SKX) + v += "AVX512SKX "; +#endif + return v; + } + + std::string getEmbreeFeatures() + { + std::string v; +#if defined(EMBREE_RAY_MASK) + v += "raymasks "; +#endif +#if defined (EMBREE_BACKFACE_CULLING) + v += "backfaceculling "; +#endif +#if defined (EMBREE_BACKFACE_CULLING_CURVES) + v += "backfacecullingcurves "; +#endif +#if defined(EMBREE_FILTER_FUNCTION) + v += "intersection_filter "; +#endif +#if defined (EMBREE_COMPACT_POLYS) + v += "compact_polys "; +#endif + return v; + } + + void Device::print() + { + const int cpu_features = getCPUFeatures(); + std::cout << std::endl; + std::cout << "Embree Ray Tracing Kernels " << RTC_VERSION_STRING << " (" << RTC_HASH << ")" << std::endl; + std::cout << " Compiler : " << getCompilerName() << std::endl; + std::cout << " Build : "; +#if defined(DEBUG) + std::cout << "Debug " << std::endl; +#else + std::cout << "Release " << std::endl; +#endif + std::cout << " Platform : " << getPlatformName() << std::endl; + std::cout << " CPU : " << stringOfCPUModel(getCPUModel()) << " (" << getCPUVendor() << ")" << std::endl; + std::cout << " Threads : " << getNumberOfLogicalThreads() << std::endl; + std::cout << " ISA : " << stringOfCPUFeatures(cpu_features) << std::endl; + std::cout << " Targets : " << supportedTargetList(cpu_features) << std::endl; + const bool hasFTZ = _mm_getcsr() & _MM_FLUSH_ZERO_ON; + const bool hasDAZ = _mm_getcsr() & _MM_DENORMALS_ZERO_ON; + std::cout << " MXCSR : " << "FTZ=" << hasFTZ << ", DAZ=" << hasDAZ << std::endl; + std::cout << " Config" << std::endl; + std::cout << " Threads : " << (numThreads ? toString(numThreads) : std::string("default")) << std::endl; + std::cout << " ISA : " << stringOfCPUFeatures(enabled_cpu_features) << std::endl; + std::cout << " Targets : " << supportedTargetList(enabled_cpu_features) << " (supported)" << std::endl; + std::cout << " " << getEnabledTargets() << " (compile time enabled)" << std::endl; + std::cout << " Features: " << getEmbreeFeatures() << std::endl; + std::cout << " Tasking : "; +#if defined(TASKING_TBB) + std::cout << "TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR << " "; + #if TBB_INTERFACE_VERSION >= 12002 + std::cout << "TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << TBB_runtime_interface_version() << " "; + #else + std::cout << "TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version() << " "; + #endif +#endif +#if defined(TASKING_INTERNAL) + std::cout << "internal_tasking_system "; +#endif +#if defined(TASKING_PPL) + std::cout << "PPL "; +#endif + std::cout << std::endl; + + /* check of FTZ and DAZ flags are set in CSR */ + if (!hasFTZ || !hasDAZ) + { +#if !defined(_DEBUG) + if (State::verbosity(1)) +#endif + { + std::cout << std::endl; + std::cout << "================================================================================" << std::endl; + std::cout << " WARNING: \"Flush to Zero\" or \"Denormals are Zero\" mode not enabled " << std::endl + << " in the MXCSR control and status register. This can have a severe " << std::endl + << " performance impact. Please enable these modes for each application " << std::endl + << " thread the following way:" << std::endl + << std::endl + << " #include \"xmmintrin.h\"" << std::endl + << " #include \"pmmintrin.h\"" << std::endl + << std::endl + << " _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);" << std::endl + << " _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);" << std::endl; + std::cout << "================================================================================" << std::endl; + std::cout << std::endl; + } + } + std::cout << std::endl; + } + + void Device::setDeviceErrorCode(RTCError error) + { + RTCError* stored_error = errorHandler.error(); + if (*stored_error == RTC_ERROR_NONE) + *stored_error = error; + } + + RTCError Device::getDeviceErrorCode() + { + RTCError* stored_error = errorHandler.error(); + RTCError error = *stored_error; + *stored_error = RTC_ERROR_NONE; + return error; + } + + void Device::setThreadErrorCode(RTCError error) + { + RTCError* stored_error = g_errorHandler.error(); + if (*stored_error == RTC_ERROR_NONE) + *stored_error = error; + } + + RTCError Device::getThreadErrorCode() + { + RTCError* stored_error = g_errorHandler.error(); + RTCError error = *stored_error; + *stored_error = RTC_ERROR_NONE; + return error; + } + + void Device::process_error(Device* device, RTCError error, const char* str) + { + /* store global error code when device construction failed */ + if (!device) + return setThreadErrorCode(error); + + /* print error when in verbose mode */ + if (device->verbosity(1)) + { + switch (error) { + case RTC_ERROR_NONE : std::cerr << "Embree: No error"; break; + case RTC_ERROR_UNKNOWN : std::cerr << "Embree: Unknown error"; break; + case RTC_ERROR_INVALID_ARGUMENT : std::cerr << "Embree: Invalid argument"; break; + case RTC_ERROR_INVALID_OPERATION: std::cerr << "Embree: Invalid operation"; break; + case RTC_ERROR_OUT_OF_MEMORY : std::cerr << "Embree: Out of memory"; break; + case RTC_ERROR_UNSUPPORTED_CPU : std::cerr << "Embree: Unsupported CPU"; break; + default : std::cerr << "Embree: Invalid error code"; break; + }; + if (str) std::cerr << ", (" << str << ")"; + std::cerr << std::endl; + } + + /* call user specified error callback */ + if (device->error_function) + device->error_function(device->error_function_userptr,error,str); + + /* record error code */ + device->setDeviceErrorCode(error); + } + + void Device::memoryMonitor(ssize_t bytes, bool post) + { + if (State::memory_monitor_function && bytes != 0) { + if (!State::memory_monitor_function(State::memory_monitor_userptr,bytes,post)) { + if (bytes > 0) { // only throw exception when we allocate memory to never throw inside a destructor + throw_RTCError(RTC_ERROR_OUT_OF_MEMORY,"memory monitor forced termination"); + } + } + } + } + + size_t getMaxNumThreads() + { + size_t maxNumThreads = 0; + for (std::map::iterator i=g_num_threads_map.begin(); i != g_num_threads_map.end(); i++) + maxNumThreads = max(maxNumThreads, (*i).second); + if (maxNumThreads == 0) + maxNumThreads = std::numeric_limits::max(); + return maxNumThreads; + } + + size_t getMaxCacheSize() + { + size_t maxCacheSize = 0; + for (std::map::iterator i=g_cache_size_map.begin(); i!= g_cache_size_map.end(); i++) + maxCacheSize = max(maxCacheSize, (*i).second); + return maxCacheSize; + } + + void Device::setCacheSize(size_t bytes) + { +#if defined(EMBREE_GEOMETRY_SUBDIVISION) + Lock lock(g_mutex); + if (bytes == 0) g_cache_size_map.erase(this); + else g_cache_size_map[this] = bytes; + + size_t maxCacheSize = getMaxCacheSize(); + resizeTessellationCache(maxCacheSize); +#endif + } + + void Device::initTaskingSystem(size_t numThreads) + { + Lock lock(g_mutex); + if (numThreads == 0) + g_num_threads_map[this] = std::numeric_limits::max(); + else + g_num_threads_map[this] = numThreads; + + /* create task scheduler */ + size_t maxNumThreads = getMaxNumThreads(); + TaskScheduler::create(maxNumThreads,State::set_affinity,State::start_threads); +#if USE_TASK_ARENA + const size_t nThreads = min(maxNumThreads,TaskScheduler::threadCount()); + const size_t uThreads = min(max(numUserThreads,(size_t)1),nThreads); + arena = make_unique(new tbb::task_arena((int)nThreads,(unsigned int)uThreads)); +#endif + } + + void Device::exitTaskingSystem() + { + Lock lock(g_mutex); + g_num_threads_map.erase(this); + + /* terminate tasking system */ + if (g_num_threads_map.size() == 0) { + TaskScheduler::destroy(); + } + /* or configure new number of threads */ + else { + size_t maxNumThreads = getMaxNumThreads(); + TaskScheduler::create(maxNumThreads,State::set_affinity,State::start_threads); + } +#if USE_TASK_ARENA + arena.reset(); +#endif + } + + void Device::setProperty(const RTCDeviceProperty prop, ssize_t val) + { + /* hidden internal properties */ + switch ((size_t)prop) + { + case 1000000: debug_int0 = val; return; + case 1000001: debug_int1 = val; return; + case 1000002: debug_int2 = val; return; + case 1000003: debug_int3 = val; return; + } + + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "unknown writable property"); + } + + ssize_t Device::getProperty(const RTCDeviceProperty prop) + { + size_t iprop = (size_t)prop; + + /* get name of internal regression test */ + if (iprop >= 2000000 && iprop < 3000000) + { + RegressionTest* test = getRegressionTest(iprop-2000000); + if (test) return (ssize_t) test->name.c_str(); + else return 0; + } + + /* run internal regression test */ + if (iprop >= 3000000 && iprop < 4000000) + { + RegressionTest* test = getRegressionTest(iprop-3000000); + if (test) return test->run(); + else return 0; + } + + /* documented properties */ + switch (prop) + { + case RTC_DEVICE_PROPERTY_VERSION_MAJOR: return RTC_VERSION_MAJOR; + case RTC_DEVICE_PROPERTY_VERSION_MINOR: return RTC_VERSION_MINOR; + case RTC_DEVICE_PROPERTY_VERSION_PATCH: return RTC_VERSION_PATCH; + case RTC_DEVICE_PROPERTY_VERSION : return RTC_VERSION; + +#if defined(EMBREE_TARGET_SIMD4) && defined(EMBREE_RAY_PACKETS) + case RTC_DEVICE_PROPERTY_NATIVE_RAY4_SUPPORTED: return hasISA(SSE2); +#else + case RTC_DEVICE_PROPERTY_NATIVE_RAY4_SUPPORTED: return 0; +#endif + +#if defined(EMBREE_TARGET_SIMD8) && defined(EMBREE_RAY_PACKETS) + case RTC_DEVICE_PROPERTY_NATIVE_RAY8_SUPPORTED: return hasISA(AVX); +#else + case RTC_DEVICE_PROPERTY_NATIVE_RAY8_SUPPORTED: return 0; +#endif + +#if defined(EMBREE_TARGET_SIMD16) && defined(EMBREE_RAY_PACKETS) + case RTC_DEVICE_PROPERTY_NATIVE_RAY16_SUPPORTED: return hasISA(AVX512KNL) | hasISA(AVX512SKX); +#else + case RTC_DEVICE_PROPERTY_NATIVE_RAY16_SUPPORTED: return 0; +#endif + +#if defined(EMBREE_RAY_PACKETS) + case RTC_DEVICE_PROPERTY_RAY_STREAM_SUPPORTED: return 1; +#else + case RTC_DEVICE_PROPERTY_RAY_STREAM_SUPPORTED: return 0; +#endif + +#if defined(EMBREE_RAY_MASK) + case RTC_DEVICE_PROPERTY_RAY_MASK_SUPPORTED: return 1; +#else + case RTC_DEVICE_PROPERTY_RAY_MASK_SUPPORTED: return 0; +#endif + +#if defined(EMBREE_BACKFACE_CULLING) + case RTC_DEVICE_PROPERTY_BACKFACE_CULLING_ENABLED: return 1; +#else + case RTC_DEVICE_PROPERTY_BACKFACE_CULLING_ENABLED: return 0; +#endif + +#if defined(EMBREE_BACKFACE_CULLING_CURVES) + case RTC_DEVICE_PROPERTY_BACKFACE_CULLING_CURVES_ENABLED: return 1; +#else + case RTC_DEVICE_PROPERTY_BACKFACE_CULLING_CURVES_ENABLED: return 0; +#endif + +#if defined(EMBREE_COMPACT_POLYS) + case RTC_DEVICE_PROPERTY_COMPACT_POLYS_ENABLED: return 1; +#else + case RTC_DEVICE_PROPERTY_COMPACT_POLYS_ENABLED: return 0; +#endif + +#if defined(EMBREE_FILTER_FUNCTION) + case RTC_DEVICE_PROPERTY_FILTER_FUNCTION_SUPPORTED: return 1; +#else + case RTC_DEVICE_PROPERTY_FILTER_FUNCTION_SUPPORTED: return 0; +#endif + +#if defined(EMBREE_IGNORE_INVALID_RAYS) + case RTC_DEVICE_PROPERTY_IGNORE_INVALID_RAYS_ENABLED: return 1; +#else + case RTC_DEVICE_PROPERTY_IGNORE_INVALID_RAYS_ENABLED: return 0; +#endif + +#if defined(TASKING_INTERNAL) + case RTC_DEVICE_PROPERTY_TASKING_SYSTEM: return 0; +#endif + +#if defined(TASKING_TBB) + case RTC_DEVICE_PROPERTY_TASKING_SYSTEM: return 1; +#endif + +#if defined(TASKING_PPL) + case RTC_DEVICE_PROPERTY_TASKING_SYSTEM: return 2; +#endif + +#if defined(EMBREE_GEOMETRY_TRIANGLE) + case RTC_DEVICE_PROPERTY_TRIANGLE_GEOMETRY_SUPPORTED: return 1; +#else + case RTC_DEVICE_PROPERTY_TRIANGLE_GEOMETRY_SUPPORTED: return 0; +#endif + +#if defined(EMBREE_GEOMETRY_QUAD) + case RTC_DEVICE_PROPERTY_QUAD_GEOMETRY_SUPPORTED: return 1; +#else + case RTC_DEVICE_PROPERTY_QUAD_GEOMETRY_SUPPORTED: return 0; +#endif + +#if defined(EMBREE_GEOMETRY_CURVE) + case RTC_DEVICE_PROPERTY_CURVE_GEOMETRY_SUPPORTED: return 1; +#else + case RTC_DEVICE_PROPERTY_CURVE_GEOMETRY_SUPPORTED: return 0; +#endif + +#if defined(EMBREE_GEOMETRY_SUBDIVISION) + case RTC_DEVICE_PROPERTY_SUBDIVISION_GEOMETRY_SUPPORTED: return 1; +#else + case RTC_DEVICE_PROPERTY_SUBDIVISION_GEOMETRY_SUPPORTED: return 0; +#endif + +#if defined(EMBREE_GEOMETRY_USER) + case RTC_DEVICE_PROPERTY_USER_GEOMETRY_SUPPORTED: return 1; +#else + case RTC_DEVICE_PROPERTY_USER_GEOMETRY_SUPPORTED: return 0; +#endif + +#if defined(EMBREE_GEOMETRY_POINT) + case RTC_DEVICE_PROPERTY_POINT_GEOMETRY_SUPPORTED: return 1; +#else + case RTC_DEVICE_PROPERTY_POINT_GEOMETRY_SUPPORTED: return 0; +#endif + +#if defined(TASKING_PPL) + case RTC_DEVICE_PROPERTY_JOIN_COMMIT_SUPPORTED: return 0; +#elif defined(TASKING_TBB) && (TBB_INTERFACE_VERSION_MAJOR < 8) + case RTC_DEVICE_PROPERTY_JOIN_COMMIT_SUPPORTED: return 0; +#else + case RTC_DEVICE_PROPERTY_JOIN_COMMIT_SUPPORTED: return 1; +#endif + +#if defined(TASKING_TBB) && TASKING_TBB_USE_TASK_ISOLATION + case RTC_DEVICE_PROPERTY_PARALLEL_COMMIT_SUPPORTED: return 1; +#else + case RTC_DEVICE_PROPERTY_PARALLEL_COMMIT_SUPPORTED: return 0; +#endif + + default: throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "unknown readable property"); break; + }; + } +} diff --git a/thirdparty/embree/kernels/common/device.h b/thirdparty/embree/kernels/common/device.h new file mode 100644 index 000000000000..e9a81bb1091d --- /dev/null +++ b/thirdparty/embree/kernels/common/device.h @@ -0,0 +1,85 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" +#include "state.h" +#include "accel.h" + +namespace embree +{ + class BVH4Factory; + class BVH8Factory; + + class Device : public State, public MemoryMonitorInterface + { + ALIGNED_CLASS_(16); + + public: + + /*! Device construction */ + Device (const char* cfg); + + /*! Device destruction */ + virtual ~Device (); + + /*! prints info about the device */ + void print(); + + /*! sets the error code */ + void setDeviceErrorCode(RTCError error); + + /*! returns and clears the error code */ + RTCError getDeviceErrorCode(); + + /*! sets the error code */ + static void setThreadErrorCode(RTCError error); + + /*! returns and clears the error code */ + static RTCError getThreadErrorCode(); + + /*! processes error codes, do not call directly */ + static void process_error(Device* device, RTCError error, const char* str); + + /*! invokes the memory monitor callback */ + void memoryMonitor(ssize_t bytes, bool post); + + /*! sets the size of the software cache. */ + void setCacheSize(size_t bytes); + + /*! sets a property */ + void setProperty(const RTCDeviceProperty prop, ssize_t val); + + /*! gets a property */ + ssize_t getProperty(const RTCDeviceProperty prop); + + private: + + /*! initializes the tasking system */ + void initTaskingSystem(size_t numThreads); + + /*! shuts down the tasking system */ + void exitTaskingSystem(); + + /*! some variables that can be set via rtcSetParameter1i for debugging purposes */ + public: + static ssize_t debug_int0; + static ssize_t debug_int1; + static ssize_t debug_int2; + static ssize_t debug_int3; + + public: + std::unique_ptr bvh4_factory; +#if defined(EMBREE_TARGET_SIMD8) + std::unique_ptr bvh8_factory; +#endif + +#if USE_TASK_ARENA + std::unique_ptr arena; +#endif + + /* ray streams filter */ + RayStreamFilterFuncs rayStreamFilters; + }; +} diff --git a/thirdparty/embree/kernels/common/geometry.cpp b/thirdparty/embree/kernels/common/geometry.cpp new file mode 100644 index 000000000000..b3aa8e33960d --- /dev/null +++ b/thirdparty/embree/kernels/common/geometry.cpp @@ -0,0 +1,259 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "geometry.h" +#include "scene.h" + +namespace embree +{ + const char* Geometry::gtype_names[Geometry::GTY_END] = + { + "flat_linear_curve", + "round_linear_curve", + "oriented_linear_curve", + "", + "flat_bezier_curve", + "round_bezier_curve", + "oriented_bezier_curve", + "", + "flat_bspline_curve", + "round_bspline_curve", + "oriented_bspline_curve", + "", + "flat_hermite_curve", + "round_hermite_curve", + "oriented_hermite_curve", + "", + "flat_catmull_rom_curve", + "round_catmull_rom_curve", + "oriented_catmull_rom_curve", + "", + "triangles", + "quads", + "grid", + "subdivs", + "", + "sphere", + "disc", + "oriented_disc", + "", + "usergeom", + "instance_cheap", + "instance_expensive", + }; + + Geometry::Geometry (Device* device, GType gtype, unsigned int numPrimitives, unsigned int numTimeSteps) + : device(device), userPtr(nullptr), + numPrimitives(numPrimitives), numTimeSteps(unsigned(numTimeSteps)), fnumTimeSegments(float(numTimeSteps-1)), time_range(0.0f,1.0f), + mask(-1), + gtype(gtype), + gsubtype(GTY_SUBTYPE_DEFAULT), + quality(RTC_BUILD_QUALITY_MEDIUM), + state((unsigned)State::MODIFIED), + enabled(true), + intersectionFilterN(nullptr), occlusionFilterN(nullptr), pointQueryFunc(nullptr) + { + device->refInc(); + } + + Geometry::~Geometry() + { + device->refDec(); + } + + void Geometry::setNumPrimitives(unsigned int numPrimitives_in) + { + if (numPrimitives_in == numPrimitives) return; + + numPrimitives = numPrimitives_in; + + Geometry::update(); + } + + void Geometry::setNumTimeSteps (unsigned int numTimeSteps_in) + { + if (numTimeSteps_in == numTimeSteps) { + return; + } + + numTimeSteps = numTimeSteps_in; + fnumTimeSegments = float(numTimeSteps_in-1); + + Geometry::update(); + } + + void Geometry::setTimeRange (const BBox1f range) + { + time_range = range; + Geometry::update(); + } + + void Geometry::update() + { + ++modCounter_; // FIXME: required? + state = (unsigned)State::MODIFIED; + } + + void Geometry::commit() + { + ++modCounter_; + state = (unsigned)State::COMMITTED; + } + + void Geometry::preCommit() + { + if (State::MODIFIED == (State)state) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"geometry not committed"); + } + + void Geometry::postCommit() + { + } + + void Geometry::enable () + { + if (isEnabled()) + return; + + enabled = true; + ++modCounter_; + } + + void Geometry::disable () + { + if (isDisabled()) + return; + + enabled = false; + ++modCounter_; + } + + void Geometry::setUserData (void* ptr) + { + userPtr = ptr; + } + + void Geometry::setIntersectionFilterFunctionN (RTCFilterFunctionN filter) + { + if (!(getTypeMask() & (MTY_TRIANGLE_MESH | MTY_QUAD_MESH | MTY_CURVES | MTY_SUBDIV_MESH | MTY_USER_GEOMETRY | MTY_GRID_MESH))) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"filter functions not supported for this geometry"); + + intersectionFilterN = filter; + } + + void Geometry::setOcclusionFilterFunctionN (RTCFilterFunctionN filter) + { + if (!(getTypeMask() & (MTY_TRIANGLE_MESH | MTY_QUAD_MESH | MTY_CURVES | MTY_SUBDIV_MESH | MTY_USER_GEOMETRY | MTY_GRID_MESH))) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"filter functions not supported for this geometry"); + + occlusionFilterN = filter; + } + + void Geometry::setPointQueryFunction (RTCPointQueryFunction func) + { + pointQueryFunc = func; + } + + void Geometry::interpolateN(const RTCInterpolateNArguments* const args) + { + const void* valid_i = args->valid; + const unsigned* primIDs = args->primIDs; + const float* u = args->u; + const float* v = args->v; + unsigned int N = args->N; + RTCBufferType bufferType = args->bufferType; + unsigned int bufferSlot = args->bufferSlot; + float* P = args->P; + float* dPdu = args->dPdu; + float* dPdv = args->dPdv; + float* ddPdudu = args->ddPdudu; + float* ddPdvdv = args->ddPdvdv; + float* ddPdudv = args->ddPdudv; + unsigned int valueCount = args->valueCount; + + if (valueCount > 256) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"maximally 256 floating point values can be interpolated per vertex"); + const int* valid = (const int*) valid_i; + + __aligned(64) float P_tmp[256]; + __aligned(64) float dPdu_tmp[256]; + __aligned(64) float dPdv_tmp[256]; + __aligned(64) float ddPdudu_tmp[256]; + __aligned(64) float ddPdvdv_tmp[256]; + __aligned(64) float ddPdudv_tmp[256]; + + float* Pt = P ? P_tmp : nullptr; + float* dPdut = nullptr, *dPdvt = nullptr; + if (dPdu) { dPdut = dPdu_tmp; dPdvt = dPdv_tmp; } + float* ddPdudut = nullptr, *ddPdvdvt = nullptr, *ddPdudvt = nullptr; + if (ddPdudu) { ddPdudut = ddPdudu_tmp; ddPdvdvt = ddPdvdv_tmp; ddPdudvt = ddPdudv_tmp; } + + for (unsigned int i=0; iprimID < size()); + + RTCPointQueryFunctionArguments args; + args.query = (RTCPointQuery*)context->query_ws; + args.userPtr = context->userPtr; + args.primID = context->primID; + args.geomID = context->geomID; + args.context = context->userContext; + args.similarityScale = context->similarityScale; + + bool update = false; + if(context->func) update |= context->func(&args); + if(pointQueryFunc) update |= pointQueryFunc(&args); + + if (update && context->userContext->instStackSize > 0) + { + // update point query + if (context->query_type == POINT_QUERY_TYPE_AABB) { + context->updateAABB(); + } else { + assert(context->similarityScale > 0.f); + query->radius = context->query_ws->radius * context->similarityScale; + } + } + return update; + } +} diff --git a/thirdparty/embree/kernels/common/geometry.h b/thirdparty/embree/kernels/common/geometry.h new file mode 100644 index 000000000000..953974bfd29c --- /dev/null +++ b/thirdparty/embree/kernels/common/geometry.h @@ -0,0 +1,582 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" +#include "device.h" +#include "buffer.h" +#include "../common/point_query.h" +#include "../builders/priminfo.h" + +namespace embree +{ + class Scene; + class Geometry; + + struct GeometryCounts + { + __forceinline GeometryCounts() + : numFilterFunctions(0), + numTriangles(0), numMBTriangles(0), + numQuads(0), numMBQuads(0), + numBezierCurves(0), numMBBezierCurves(0), + numLineSegments(0), numMBLineSegments(0), + numSubdivPatches(0), numMBSubdivPatches(0), + numUserGeometries(0), numMBUserGeometries(0), + numInstancesCheap(0), numMBInstancesCheap(0), + numInstancesExpensive(0), numMBInstancesExpensive(0), + numGrids(0), numMBGrids(0), + numPoints(0), numMBPoints(0) {} + + __forceinline size_t size() const { + return numTriangles + numQuads + numBezierCurves + numLineSegments + numSubdivPatches + numUserGeometries + numInstancesCheap + numInstancesExpensive + numGrids + numPoints + + numMBTriangles + numMBQuads + numMBBezierCurves + numMBLineSegments + numMBSubdivPatches + numMBUserGeometries + numMBInstancesCheap + numMBInstancesExpensive + numMBGrids + numMBPoints; + } + + __forceinline unsigned int enabledGeometryTypesMask() const + { + unsigned int mask = 0; + if (numTriangles) mask |= 1 << 0; + if (numQuads) mask |= 1 << 1; + if (numBezierCurves+numLineSegments) mask |= 1 << 2; + if (numSubdivPatches) mask |= 1 << 3; + if (numUserGeometries) mask |= 1 << 4; + if (numInstancesCheap) mask |= 1 << 5; + if (numInstancesExpensive) mask |= 1 << 6; + if (numGrids) mask |= 1 << 7; + if (numPoints) mask |= 1 << 8; + + unsigned int maskMB = 0; + if (numMBTriangles) maskMB |= 1 << 0; + if (numMBQuads) maskMB |= 1 << 1; + if (numMBBezierCurves+numMBLineSegments) maskMB |= 1 << 2; + if (numMBSubdivPatches) maskMB |= 1 << 3; + if (numMBUserGeometries) maskMB |= 1 << 4; + if (numMBInstancesCheap) maskMB |= 1 << 5; + if (numMBInstancesExpensive) maskMB |= 1 << 6; + if (numMBGrids) maskMB |= 1 << 7; + if (numMBPoints) maskMB |= 1 << 8; + + return (mask<<8) + maskMB; + } + + __forceinline GeometryCounts operator+ (GeometryCounts const & rhs) const + { + GeometryCounts ret; + ret.numFilterFunctions = numFilterFunctions + rhs.numFilterFunctions; + ret.numTriangles = numTriangles + rhs.numTriangles; + ret.numMBTriangles = numMBTriangles + rhs.numMBTriangles; + ret.numQuads = numQuads + rhs.numQuads; + ret.numMBQuads = numMBQuads + rhs.numMBQuads; + ret.numBezierCurves = numBezierCurves + rhs.numBezierCurves; + ret.numMBBezierCurves = numMBBezierCurves + rhs.numMBBezierCurves; + ret.numLineSegments = numLineSegments + rhs.numLineSegments; + ret.numMBLineSegments = numMBLineSegments + rhs.numMBLineSegments; + ret.numSubdivPatches = numSubdivPatches + rhs.numSubdivPatches; + ret.numMBSubdivPatches = numMBSubdivPatches + rhs.numMBSubdivPatches; + ret.numUserGeometries = numUserGeometries + rhs.numUserGeometries; + ret.numMBUserGeometries = numMBUserGeometries + rhs.numMBUserGeometries; + ret.numInstancesCheap = numInstancesCheap + rhs.numInstancesCheap; + ret.numMBInstancesCheap = numMBInstancesCheap + rhs.numMBInstancesCheap; + ret.numInstancesExpensive = numInstancesExpensive + rhs.numInstancesExpensive; + ret.numMBInstancesExpensive = numMBInstancesExpensive + rhs.numMBInstancesExpensive; + ret.numGrids = numGrids + rhs.numGrids; + ret.numMBGrids = numMBGrids + rhs.numMBGrids; + ret.numPoints = numPoints + rhs.numPoints; + ret.numMBPoints = numMBPoints + rhs.numMBPoints; + + return ret; + } + + size_t numFilterFunctions; //!< number of geometries with filter functions enabled + size_t numTriangles; //!< number of enabled triangles + size_t numMBTriangles; //!< number of enabled motion blured triangles + size_t numQuads; //!< number of enabled quads + size_t numMBQuads; //!< number of enabled motion blurred quads + size_t numBezierCurves; //!< number of enabled curves + size_t numMBBezierCurves; //!< number of enabled motion blurred curves + size_t numLineSegments; //!< number of enabled line segments + size_t numMBLineSegments; //!< number of enabled line motion blurred segments + size_t numSubdivPatches; //!< number of enabled subdivision patches + size_t numMBSubdivPatches; //!< number of enabled motion blured subdivision patches + size_t numUserGeometries; //!< number of enabled user geometries + size_t numMBUserGeometries; //!< number of enabled motion blurred user geometries + size_t numInstancesCheap; //!< number of enabled cheap instances + size_t numMBInstancesCheap; //!< number of enabled motion blurred cheap instances + size_t numInstancesExpensive; //!< number of enabled expensive instances + size_t numMBInstancesExpensive; //!< number of enabled motion blurred expensive instances + size_t numGrids; //!< number of enabled grid geometries + size_t numMBGrids; //!< number of enabled motion blurred grid geometries + size_t numPoints; //!< number of enabled points + size_t numMBPoints; //!< number of enabled motion blurred points + }; + + /*! Base class all geometries are derived from */ + class Geometry : public RefCount + { + friend class Scene; + public: + + /*! type of geometry */ + enum GType + { + GTY_FLAT_LINEAR_CURVE = 0, + GTY_ROUND_LINEAR_CURVE = 1, + GTY_ORIENTED_LINEAR_CURVE = 2, + GTY_CONE_LINEAR_CURVE = 3, + + GTY_FLAT_BEZIER_CURVE = 4, + GTY_ROUND_BEZIER_CURVE = 5, + GTY_ORIENTED_BEZIER_CURVE = 6, + + GTY_FLAT_BSPLINE_CURVE = 8, + GTY_ROUND_BSPLINE_CURVE = 9, + GTY_ORIENTED_BSPLINE_CURVE = 10, + + GTY_FLAT_HERMITE_CURVE = 12, + GTY_ROUND_HERMITE_CURVE = 13, + GTY_ORIENTED_HERMITE_CURVE = 14, + + GTY_FLAT_CATMULL_ROM_CURVE = 16, + GTY_ROUND_CATMULL_ROM_CURVE = 17, + GTY_ORIENTED_CATMULL_ROM_CURVE = 18, + + GTY_TRIANGLE_MESH = 20, + GTY_QUAD_MESH = 21, + GTY_GRID_MESH = 22, + GTY_SUBDIV_MESH = 23, + + GTY_SPHERE_POINT = 25, + GTY_DISC_POINT = 26, + GTY_ORIENTED_DISC_POINT = 27, + + GTY_USER_GEOMETRY = 29, + GTY_INSTANCE_CHEAP = 30, + GTY_INSTANCE_EXPENSIVE = 31, + GTY_END = 32, + + GTY_BASIS_LINEAR = 0, + GTY_BASIS_BEZIER = 4, + GTY_BASIS_BSPLINE = 8, + GTY_BASIS_HERMITE = 12, + GTY_BASIS_CATMULL_ROM = 16, + GTY_BASIS_MASK = 28, + + GTY_SUBTYPE_FLAT_CURVE = 0, + GTY_SUBTYPE_ROUND_CURVE = 1, + GTY_SUBTYPE_ORIENTED_CURVE = 2, + GTY_SUBTYPE_MASK = 3, + }; + + enum GSubType + { + GTY_SUBTYPE_DEFAULT= 0, + GTY_SUBTYPE_INSTANCE_LINEAR = 0, + GTY_SUBTYPE_INSTANCE_QUATERNION = 1 + }; + + enum GTypeMask + { + MTY_FLAT_LINEAR_CURVE = 1ul << GTY_FLAT_LINEAR_CURVE, + MTY_ROUND_LINEAR_CURVE = 1ul << GTY_ROUND_LINEAR_CURVE, + MTY_CONE_LINEAR_CURVE = 1ul << GTY_CONE_LINEAR_CURVE, + MTY_ORIENTED_LINEAR_CURVE = 1ul << GTY_ORIENTED_LINEAR_CURVE, + + MTY_FLAT_BEZIER_CURVE = 1ul << GTY_FLAT_BEZIER_CURVE, + MTY_ROUND_BEZIER_CURVE = 1ul << GTY_ROUND_BEZIER_CURVE, + MTY_ORIENTED_BEZIER_CURVE = 1ul << GTY_ORIENTED_BEZIER_CURVE, + + MTY_FLAT_BSPLINE_CURVE = 1ul << GTY_FLAT_BSPLINE_CURVE, + MTY_ROUND_BSPLINE_CURVE = 1ul << GTY_ROUND_BSPLINE_CURVE, + MTY_ORIENTED_BSPLINE_CURVE = 1ul << GTY_ORIENTED_BSPLINE_CURVE, + + MTY_FLAT_HERMITE_CURVE = 1ul << GTY_FLAT_HERMITE_CURVE, + MTY_ROUND_HERMITE_CURVE = 1ul << GTY_ROUND_HERMITE_CURVE, + MTY_ORIENTED_HERMITE_CURVE = 1ul << GTY_ORIENTED_HERMITE_CURVE, + + MTY_FLAT_CATMULL_ROM_CURVE = 1ul << GTY_FLAT_CATMULL_ROM_CURVE, + MTY_ROUND_CATMULL_ROM_CURVE = 1ul << GTY_ROUND_CATMULL_ROM_CURVE, + MTY_ORIENTED_CATMULL_ROM_CURVE = 1ul << GTY_ORIENTED_CATMULL_ROM_CURVE, + + MTY_CURVE2 = MTY_FLAT_LINEAR_CURVE | MTY_ROUND_LINEAR_CURVE | MTY_CONE_LINEAR_CURVE | MTY_ORIENTED_LINEAR_CURVE, + + MTY_CURVE4 = MTY_FLAT_BEZIER_CURVE | MTY_ROUND_BEZIER_CURVE | MTY_ORIENTED_BEZIER_CURVE | + MTY_FLAT_BSPLINE_CURVE | MTY_ROUND_BSPLINE_CURVE | MTY_ORIENTED_BSPLINE_CURVE | + MTY_FLAT_HERMITE_CURVE | MTY_ROUND_HERMITE_CURVE | MTY_ORIENTED_HERMITE_CURVE | + MTY_FLAT_CATMULL_ROM_CURVE | MTY_ROUND_CATMULL_ROM_CURVE | MTY_ORIENTED_CATMULL_ROM_CURVE, + + MTY_SPHERE_POINT = 1ul << GTY_SPHERE_POINT, + MTY_DISC_POINT = 1ul << GTY_DISC_POINT, + MTY_ORIENTED_DISC_POINT = 1ul << GTY_ORIENTED_DISC_POINT, + + MTY_POINTS = MTY_SPHERE_POINT | MTY_DISC_POINT | MTY_ORIENTED_DISC_POINT, + + MTY_CURVES = MTY_CURVE2 | MTY_CURVE4 | MTY_POINTS, + + MTY_TRIANGLE_MESH = 1ul << GTY_TRIANGLE_MESH, + MTY_QUAD_MESH = 1ul << GTY_QUAD_MESH, + MTY_GRID_MESH = 1ul << GTY_GRID_MESH, + MTY_SUBDIV_MESH = 1ul << GTY_SUBDIV_MESH, + MTY_USER_GEOMETRY = 1ul << GTY_USER_GEOMETRY, + + MTY_INSTANCE_CHEAP = 1ul << GTY_INSTANCE_CHEAP, + MTY_INSTANCE_EXPENSIVE = 1ul << GTY_INSTANCE_EXPENSIVE, + MTY_INSTANCE = MTY_INSTANCE_CHEAP | MTY_INSTANCE_EXPENSIVE + }; + + static const char* gtype_names[GTY_END]; + + enum class State : unsigned { + MODIFIED = 0, + COMMITTED = 1, + }; + + public: + + /*! Geometry constructor */ + Geometry (Device* device, GType gtype, unsigned int numPrimitives, unsigned int numTimeSteps); + + /*! Geometry destructor */ + virtual ~Geometry(); + + public: + + /*! tests if geometry is enabled */ + __forceinline bool isEnabled() const { return enabled; } + + /*! tests if geometry is disabled */ + __forceinline bool isDisabled() const { return !isEnabled(); } + + /*! tests if that geometry has some filter function set */ + __forceinline bool hasFilterFunctions () const { + return (intersectionFilterN != nullptr) || (occlusionFilterN != nullptr); + } + + /*! returns geometry type */ + __forceinline GType getType() const { return gtype; } + + /*! returns curve type */ + __forceinline GType getCurveType() const { return (GType)(gtype & GTY_SUBTYPE_MASK); } + + /*! returns curve basis */ + __forceinline GType getCurveBasis() const { return (GType)(gtype & GTY_BASIS_MASK); } + + /*! returns geometry type mask */ + __forceinline GTypeMask getTypeMask() const { return (GTypeMask)(1 << gtype); } + + /*! returns number of primitives */ + __forceinline size_t size() const { return numPrimitives; } + + /*! sets the number of primitives */ + virtual void setNumPrimitives(unsigned int numPrimitives_in); + + /*! sets number of time steps */ + virtual void setNumTimeSteps (unsigned int numTimeSteps_in); + + /*! sets motion blur time range */ + void setTimeRange (const BBox1f range); + + /*! sets number of vertex attributes */ + virtual void setVertexAttributeCount (unsigned int N) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! sets number of topologies */ + virtual void setTopologyCount (unsigned int N) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! sets the build quality */ + void setBuildQuality(RTCBuildQuality quality_in) + { + this->quality = quality_in; + Geometry::update(); + } + + /* calculate time segment itime and fractional time ftime */ + __forceinline int timeSegment(float time, float& ftime) const { + return getTimeSegment(time,time_range.lower,time_range.upper,fnumTimeSegments,ftime); + } + + template + __forceinline vint timeSegment(const vfloat& time, vfloat& ftime) const { + return getTimeSegment(time,vfloat(time_range.lower),vfloat(time_range.upper),vfloat(fnumTimeSegments),ftime); + } + + /* calculate overlapping time segment range */ + __forceinline range timeSegmentRange(const BBox1f& range) const { + return getTimeSegmentRange(range,time_range,fnumTimeSegments); + } + + /* returns time that corresponds to time step */ + __forceinline float timeStep(const int i) const { + assert(i>=0 && i<(int)numTimeSteps); + return time_range.lower + time_range.size()*float(i)/fnumTimeSegments; + } + + /*! for all geometries */ + public: + + /*! Enable geometry. */ + virtual void enable(); + + /*! Update geometry. */ + void update(); + + /*! commit of geometry */ + virtual void commit(); + + /*! Update geometry buffer. */ + virtual void updateBuffer(RTCBufferType type, unsigned int slot) { + update(); // update everything for geometries not supporting this call + } + + /*! Disable geometry. */ + virtual void disable(); + + /*! Verify the geometry */ + virtual bool verify() { return true; } + + /*! called before every build */ + virtual void preCommit(); + + /*! called after every build */ + virtual void postCommit(); + + virtual void addElementsToCount (GeometryCounts & counts) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + }; + + /*! sets constant tessellation rate for the geometry */ + virtual void setTessellationRate(float N) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! Sets the maximal curve radius scale allowed by min-width feature. */ + virtual void setMaxRadiusScale(float s) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! Set user data pointer. */ + virtual void setUserData(void* ptr); + + /*! Get user data pointer. */ + __forceinline void* getUserData() const { + return userPtr; + } + + /*! interpolates user data to the specified u/v location */ + virtual void interpolate(const RTCInterpolateArguments* const args) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! interpolates user data to the specified u/v locations */ + virtual void interpolateN(const RTCInterpolateNArguments* const args); + + /* point query api */ + bool pointQuery(PointQuery* query, PointQueryContext* context); + + /*! for subdivision surfaces only */ + public: + virtual void setSubdivisionMode (unsigned topologyID, RTCSubdivisionMode mode) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + virtual void setVertexAttributeTopology(unsigned int vertexBufferSlot, unsigned int indexBufferSlot) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! Set displacement function. */ + virtual void setDisplacementFunction (RTCDisplacementFunctionN filter) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + virtual unsigned int getFirstHalfEdge(unsigned int faceID) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + virtual unsigned int getFace(unsigned int edgeID) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + virtual unsigned int getNextHalfEdge(unsigned int edgeID) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + virtual unsigned int getPreviousHalfEdge(unsigned int edgeID) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + virtual unsigned int getOppositeHalfEdge(unsigned int topologyID, unsigned int edgeID) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! get fast access to first vertex buffer if applicable */ + virtual float * getCompactVertexArray () const { + return nullptr; + } + + /*! Returns the modified counter - how many times the geo has been modified */ + __forceinline unsigned int getModCounter () const { + return modCounter_; + } + + /*! for triangle meshes and bezier curves only */ + public: + + + /*! Sets ray mask. */ + virtual void setMask(unsigned mask) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! Sets specified buffer. */ + virtual void setBuffer(RTCBufferType type, unsigned int slot, RTCFormat format, const Ref& buffer, size_t offset, size_t stride, unsigned int num) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! Gets specified buffer. */ + virtual void* getBuffer(RTCBufferType type, unsigned int slot) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! Set intersection filter function for ray packets of size N. */ + virtual void setIntersectionFilterFunctionN (RTCFilterFunctionN filterN); + + /*! Set occlusion filter function for ray packets of size N. */ + virtual void setOcclusionFilterFunctionN (RTCFilterFunctionN filterN); + + /*! for instances only */ + public: + + /*! Sets the instanced scene */ + virtual void setInstancedScene(const Ref& scene) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! Sets transformation of the instance */ + virtual void setTransform(const AffineSpace3fa& transform, unsigned int timeStep) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! Sets transformation of the instance */ + virtual void setQuaternionDecomposition(const AffineSpace3ff& qd, unsigned int timeStep) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! Returns the transformation of the instance */ + virtual AffineSpace3fa getTransform(float time) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! for user geometries only */ + public: + + /*! Set bounds function. */ + virtual void setBoundsFunction (RTCBoundsFunction bounds, void* userPtr) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! Set intersect function for ray packets of size N. */ + virtual void setIntersectFunctionN (RTCIntersectFunctionN intersect) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! Set occlusion function for ray packets of size N. */ + virtual void setOccludedFunctionN (RTCOccludedFunctionN occluded) { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation not supported for this geometry"); + } + + /*! Set point query function. */ + void setPointQueryFunction(RTCPointQueryFunction func); + + /*! returns number of time segments */ + __forceinline unsigned numTimeSegments () const { + return numTimeSteps-1; + } + + public: + + virtual PrimInfo createPrimRefArray(mvector& prims, const range& r, size_t k, unsigned int geomID) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"createPrimRefArray not implemented for this geometry"); + } + + virtual PrimInfo createPrimRefArrayMB(mvector& prims, size_t itime, const range& r, size_t k, unsigned int geomID) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"createPrimRefMBArray not implemented for this geometry"); + } + + virtual PrimInfoMB createPrimRefMBArray(mvector& prims, const BBox1f& t0t1, const range& r, size_t k, unsigned int geomID) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"createPrimRefMBArray not implemented for this geometry"); + } + + virtual LinearSpace3fa computeAlignedSpace(const size_t primID) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"computeAlignedSpace not implemented for this geometry"); + } + + virtual LinearSpace3fa computeAlignedSpaceMB(const size_t primID, const BBox1f time_range) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"computeAlignedSpace not implemented for this geometry"); + } + + virtual Vec3fa computeDirection(unsigned int primID) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"computeDirection not implemented for this geometry"); + } + + virtual Vec3fa computeDirection(unsigned int primID, size_t time) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"computeDirection not implemented for this geometry"); + } + + virtual BBox3fa vbounds(size_t primID) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"vbounds not implemented for this geometry"); + } + + virtual BBox3fa vbounds(const LinearSpace3fa& space, size_t primID) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"vbounds not implemented for this geometry"); + } + + virtual BBox3fa vbounds(const Vec3fa& ofs, const float scale, const float r_scale0, const LinearSpace3fa& space, size_t i, size_t itime = 0) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"vbounds not implemented for this geometry"); + } + + virtual LBBox3fa vlinearBounds(size_t primID, const BBox1f& time_range) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"vlinearBounds not implemented for this geometry"); + } + + virtual LBBox3fa vlinearBounds(const LinearSpace3fa& space, size_t primID, const BBox1f& time_range) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"vlinearBounds not implemented for this geometry"); + } + + virtual LBBox3fa vlinearBounds(const Vec3fa& ofs, const float scale, const float r_scale0, const LinearSpace3fa& space, size_t primID, const BBox1f& time_range) const { + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"vlinearBounds not implemented for this geometry"); + } + + public: + __forceinline bool hasIntersectionFilter() const { return intersectionFilterN != nullptr; } + __forceinline bool hasOcclusionFilter() const { return occlusionFilterN != nullptr; } + + public: + Device* device; //!< device this geometry belongs to + + void* userPtr; //!< user pointer + unsigned int numPrimitives; //!< number of primitives of this geometry + + unsigned int numTimeSteps; //!< number of time steps + float fnumTimeSegments; //!< number of time segments (precalculation) + BBox1f time_range; //!< motion blur time range + + unsigned int mask; //!< for masking out geometry + unsigned int modCounter_ = 1; //!< counter for every modification - used to rebuild scenes when geo is modified + + struct { + GType gtype : 8; //!< geometry type + GSubType gsubtype : 8; //!< geometry subtype + RTCBuildQuality quality : 3; //!< build quality for geometry + unsigned state : 2; + bool enabled : 1; //!< true if geometry is enabled + }; + + RTCFilterFunctionN intersectionFilterN; + RTCFilterFunctionN occlusionFilterN; + RTCPointQueryFunction pointQueryFunc; + }; +} diff --git a/thirdparty/embree/kernels/common/hit.h b/thirdparty/embree/kernels/common/hit.h new file mode 100644 index 000000000000..32a198cdfe4f --- /dev/null +++ b/thirdparty/embree/kernels/common/hit.h @@ -0,0 +1,114 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" +#include "ray.h" +#include "instance_stack.h" + +namespace embree +{ + /* Hit structure for K hits */ + template + struct HitK + { + /* Default construction does nothing */ + __forceinline HitK() {} + + /* Constructs a hit */ + __forceinline HitK(const RTCIntersectContext* context, const vuint& geomID, const vuint& primID, const vfloat& u, const vfloat& v, const Vec3vf& Ng) + : Ng(Ng), u(u), v(v), primID(primID), geomID(geomID) + { + for (unsigned l = 0; l < RTC_MAX_INSTANCE_LEVEL_COUNT; ++l) + instID[l] = RTC_INVALID_GEOMETRY_ID; + instance_id_stack::copy(context->instID, instID); + } + + /* Returns the size of the hit */ + static __forceinline size_t size() { return K; } + + public: + Vec3vf Ng; // geometry normal + vfloat u; // barycentric u coordinate of hit + vfloat v; // barycentric v coordinate of hit + vuint primID; // primitive ID + vuint geomID; // geometry ID + vuint instID[RTC_MAX_INSTANCE_LEVEL_COUNT]; // instance ID + }; + + /* Specialization for a single hit */ + template<> + struct __aligned(16) HitK<1> + { + /* Default construction does nothing */ + __forceinline HitK() {} + + /* Constructs a hit */ + __forceinline HitK(const RTCIntersectContext* context, unsigned int geomID, unsigned int primID, float u, float v, const Vec3fa& Ng) + : Ng(Ng.x,Ng.y,Ng.z), u(u), v(v), primID(primID), geomID(geomID) + { + instance_id_stack::copy(context->instID, instID); + } + + /* Returns the size of the hit */ + static __forceinline size_t size() { return 1; } + + public: + Vec3 Ng; // geometry normal + float u; // barycentric u coordinate of hit + float v; // barycentric v coordinate of hit + unsigned int primID; // primitive ID + unsigned int geomID; // geometry ID + unsigned int instID[RTC_MAX_INSTANCE_LEVEL_COUNT]; // instance ID + }; + + /* Shortcuts */ + typedef HitK<1> Hit; + typedef HitK<4> Hit4; + typedef HitK<8> Hit8; + typedef HitK<16> Hit16; + + /* Outputs hit to stream */ + template + __forceinline embree_ostream operator<<(embree_ostream cout, const HitK& ray) + { + cout << "{ " << embree_endl + << " Ng = " << ray.Ng << embree_endl + << " u = " << ray.u << embree_endl + << " v = " << ray.v << embree_endl + << " primID = " << ray.primID << embree_endl + << " geomID = " << ray.geomID << embree_endl + << " instID ="; + for (unsigned l = 0; l < RTC_MAX_INSTANCE_LEVEL_COUNT; ++l) + { + cout << " " << ray.instID[l]; + } + cout << embree_endl; + return cout << "}"; + } + + template + __forceinline void copyHitToRay(RayHit& ray, const Hit& hit) + { + ray.Ng = hit.Ng; + ray.u = hit.u; + ray.v = hit.v; + ray.primID = hit.primID; + ray.geomID = hit.geomID; + instance_id_stack::copy(hit.instID, ray.instID); + } + + template + __forceinline void copyHitToRay(const vbool &mask, RayHitK &ray, const HitK &hit) + { + vfloat::storeu(mask,&ray.Ng.x, hit.Ng.x); + vfloat::storeu(mask,&ray.Ng.y, hit.Ng.y); + vfloat::storeu(mask,&ray.Ng.z, hit.Ng.z); + vfloat::storeu(mask,&ray.u, hit.u); + vfloat::storeu(mask,&ray.v, hit.v); + vuint::storeu(mask,&ray.primID, hit.primID); + vuint::storeu(mask,&ray.geomID, hit.geomID); + instance_id_stack::copy(hit.instID, ray.instID, mask); + } +} diff --git a/thirdparty/embree/kernels/common/instance_stack.h b/thirdparty/embree/kernels/common/instance_stack.h new file mode 100644 index 000000000000..d7e3637f7b2b --- /dev/null +++ b/thirdparty/embree/kernels/common/instance_stack.h @@ -0,0 +1,199 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "rtcore.h" + +namespace embree { +namespace instance_id_stack { + +static_assert(RTC_MAX_INSTANCE_LEVEL_COUNT > 0, + "RTC_MAX_INSTANCE_LEVEL_COUNT must be greater than 0."); + +/******************************************************************************* + * Instance ID stack manipulation. + * This is used from the instance intersector. + ******************************************************************************/ + +/* + * Push an instance to the stack. + */ +RTC_FORCEINLINE bool push(RTCIntersectContext* context, + unsigned instanceId) +{ +#if RTC_MAX_INSTANCE_LEVEL_COUNT > 1 + const bool spaceAvailable = context->instStackSize < RTC_MAX_INSTANCE_LEVEL_COUNT; + /* We assert here because instances are silently dropped when the stack is full. + This might be quite hard to find in production. */ + assert(spaceAvailable); + if (likely(spaceAvailable)) + context->instID[context->instStackSize++] = instanceId; + return spaceAvailable; +#else + const bool spaceAvailable = (context->instID[0] == RTC_INVALID_GEOMETRY_ID); + assert(spaceAvailable); + if (likely(spaceAvailable)) + context->instID[0] = instanceId; + return spaceAvailable; +#endif +} + + +/* + * Pop the last instance pushed to the stack. + * Do not call on an empty stack. + */ +RTC_FORCEINLINE void pop(RTCIntersectContext* context) +{ + assert(context); +#if RTC_MAX_INSTANCE_LEVEL_COUNT > 1 + assert(context->instStackSize > 0); + context->instID[--context->instStackSize] = RTC_INVALID_GEOMETRY_ID; +#else + assert(context->instID[0] != RTC_INVALID_GEOMETRY_ID); + context->instID[0] = RTC_INVALID_GEOMETRY_ID; +#endif +} + +/******************************************************************************* + * Optimized instance id stack copy. + * The copy() function at the bottom of this block will either copy full + * stacks or copy only until the last valid element has been copied, depending + * on RTC_MAX_INSTANCE_LEVEL_COUNT. + ******************************************************************************/ + +/* + * Plain array assignment. This works for scalar->scalar, + * scalar->vector, and vector->vector. + */ +template +RTC_FORCEINLINE void level_copy(unsigned level, Src* src, Tgt* tgt) +{ + tgt[level] = src[level]; +} + +/* + * Masked SIMD vector->vector store. + */ +template +RTC_FORCEINLINE void level_copy(unsigned level, const vuint* src, vuint* tgt, const vbool& mask) +{ + vuint::storeu(mask, tgt + level, src[level]); +} + +/* + * Masked scalar->SIMD vector store. + */ +template +RTC_FORCEINLINE void level_copy(unsigned level, const unsigned* src, vuint* tgt, const vbool& mask) +{ + vuint::store(mask, tgt + level, src[level]); +} + +/* + * Indexed assign from vector to scalar. + */ +template +RTC_FORCEINLINE void level_copy(unsigned level, const vuint* src, unsigned* tgt, const size_t& idx) +{ + tgt[level] = src[level][idx]; +} + +/* + * Indexed assign from scalar to vector. + */ +template +RTC_FORCEINLINE void level_copy(unsigned level, const unsigned* src, vuint* tgt, const size_t& idx) +{ + tgt[level][idx] = src[level]; +} + +/* + * Indexed assign from vector to vector. + */ +template +RTC_FORCEINLINE void level_copy(unsigned level, const vuint* src, vuint* tgt, const size_t& i, const size_t& j) +{ + tgt[level][j] = src[level][i]; +} + +/* + * Check if the given stack level is valid. + * These are only used for large max stack sizes. + */ +RTC_FORCEINLINE bool level_valid(unsigned level, const unsigned* stack) +{ + return stack[level] != RTC_INVALID_GEOMETRY_ID; +} +RTC_FORCEINLINE bool level_valid(unsigned level, const unsigned* stack, const size_t& /*i*/) +{ + return stack[level] != RTC_INVALID_GEOMETRY_ID; +} +template +RTC_FORCEINLINE bool level_valid(unsigned level, const unsigned* stack, const vbool& /*mask*/) +{ + return stack[level] != RTC_INVALID_GEOMETRY_ID; +} + +template +RTC_FORCEINLINE bool level_valid(unsigned level, const vuint* stack) +{ + return any(stack[level] != RTC_INVALID_GEOMETRY_ID); +} +template +RTC_FORCEINLINE bool level_valid(unsigned level, const vuint* stack, const vbool& mask) +{ + return any(mask & (stack[level] != RTC_INVALID_GEOMETRY_ID)); +} + +template +RTC_FORCEINLINE bool level_valid(unsigned level, const vuint* stack, const size_t& i) +{ + return stack[level][i] != RTC_INVALID_GEOMETRY_ID; +} +template +RTC_FORCEINLINE bool level_valid(unsigned level, const vuint* stack, const size_t& i, const size_t& /*j*/) +{ + return stack[level][i] != RTC_INVALID_GEOMETRY_ID; +} + +/* + * Copy an instance ID stack. + * + * This function automatically selects a LevelFunctor from the above Assign + * structs. + */ +template +RTC_FORCEINLINE void copy(Src src, Tgt tgt, Args&&... args) +{ +#if (RTC_MAX_INSTANCE_LEVEL_COUNT == 1) + /* + * Avoid all loops for only one level. + */ + level_copy(0, src, tgt, std::forward(args)...); + +#elif (RTC_MAX_INSTANCE_LEVEL_COUNT <= 4) + /* + * It is faster to avoid the valid test for low level counts. + * Just copy the whole stack. + */ + for (unsigned l = 0; l < RTC_MAX_INSTANCE_LEVEL_COUNT; ++l) + level_copy(l, src, tgt, std::forward(args)...); + +#else + /* + * For general stack sizes, it pays off to test for validity. + */ + bool valid = true; + for (unsigned l = 0; l < RTC_MAX_INSTANCE_LEVEL_COUNT && valid; ++l) + { + level_copy(l, src, tgt, std::forward(args)...); + valid = level_valid(l, src, std::forward(args)...); + } +#endif +} + +} // namespace instance_id_stack +} // namespace embree + diff --git a/thirdparty/embree/kernels/common/isa.h b/thirdparty/embree/kernels/common/isa.h new file mode 100644 index 000000000000..9fd1ea58b749 --- /dev/null +++ b/thirdparty/embree/kernels/common/isa.h @@ -0,0 +1,271 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../../common/sys/platform.h" +#include "../../common/sys/sysinfo.h" + +namespace embree +{ +#define DEFINE_SYMBOL2(type,name) \ + typedef type (*name##Func)(); \ + name##Func name; + +#define DECLARE_SYMBOL2(type,name) \ + namespace sse2 { extern type name(); } \ + namespace sse42 { extern type name(); } \ + namespace avx { extern type name(); } \ + namespace avx2 { extern type name(); } \ + namespace avx512knl { extern type name(); } \ + namespace avx512skx { extern type name(); } \ + void name##_error2() { throw_RTCError(RTC_ERROR_UNKNOWN,"internal error in ISA selection for " TOSTRING(name)); } \ + type name##_error() { return type(name##_error2); } \ + type name##_zero() { return type(nullptr); } + +#define DECLARE_ISA_FUNCTION(type,symbol,args) \ + namespace sse2 { extern type symbol(args); } \ + namespace sse42 { extern type symbol(args); } \ + namespace avx { extern type symbol(args); } \ + namespace avx2 { extern type symbol(args); } \ + namespace avx512knl { extern type symbol(args); } \ + namespace avx512skx { extern type symbol(args); } \ + inline type symbol##_error(args) { throw_RTCError(RTC_ERROR_UNSUPPORTED_CPU,"function " TOSTRING(symbol) " not supported by your CPU"); } \ + typedef type (*symbol##Ty)(args); \ + +#define DEFINE_ISA_FUNCTION(type,symbol,args) \ + typedef type (*symbol##Func)(args); \ + symbol##Func symbol; + +#define ZERO_SYMBOL(features,intersector) \ + intersector = intersector##_zero; + +#define INIT_SYMBOL(features,intersector) \ + intersector = decltype(intersector)(intersector##_error); + +#define SELECT_SYMBOL_DEFAULT(features,intersector) \ + intersector = isa::intersector; + +#if defined(__SSE__) +#if !defined(EMBREE_TARGET_SIMD4) +#define EMBREE_TARGET_SIMD4 +#endif +#endif + +#if defined(EMBREE_TARGET_SSE42) +#define SELECT_SYMBOL_SSE42(features,intersector) \ + if ((features & SSE42) == SSE42) intersector = sse42::intersector; +#else +#define SELECT_SYMBOL_SSE42(features,intersector) +#endif + +#if defined(EMBREE_TARGET_AVX) || defined(__AVX__) +#if !defined(EMBREE_TARGET_SIMD8) +#define EMBREE_TARGET_SIMD8 +#endif +#if defined(__AVX__) // if default ISA is >= AVX we treat AVX target as default target +#define SELECT_SYMBOL_AVX(features,intersector) \ + if ((features & ISA) == ISA) intersector = isa::intersector; +#else +#define SELECT_SYMBOL_AVX(features,intersector) \ + if ((features & AVX) == AVX) intersector = avx::intersector; +#endif +#else +#define SELECT_SYMBOL_AVX(features,intersector) +#endif + +#if defined(EMBREE_TARGET_AVX2) +#if !defined(EMBREE_TARGET_SIMD8) +#define EMBREE_TARGET_SIMD8 +#endif +#define SELECT_SYMBOL_AVX2(features,intersector) \ + if ((features & AVX2) == AVX2) intersector = avx2::intersector; +#else +#define SELECT_SYMBOL_AVX2(features,intersector) +#endif + +#if defined(EMBREE_TARGET_AVX512KNL) +#if !defined(EMBREE_TARGET_SIMD16) +#define EMBREE_TARGET_SIMD16 +#endif +#define SELECT_SYMBOL_AVX512KNL(features,intersector) \ + if ((features & AVX512KNL) == AVX512KNL) intersector = avx512knl::intersector; +#else +#define SELECT_SYMBOL_AVX512KNL(features,intersector) +#endif + +#if defined(EMBREE_TARGET_AVX512SKX) +#if !defined(EMBREE_TARGET_SIMD16) +#define EMBREE_TARGET_SIMD16 +#endif +#define SELECT_SYMBOL_AVX512SKX(features,intersector) \ + if ((features & AVX512SKX) == AVX512SKX) intersector = avx512skx::intersector; +#else +#define SELECT_SYMBOL_AVX512SKX(features,intersector) +#endif + +#define SELECT_SYMBOL_DEFAULT_SSE42(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_SSE42(features,intersector); + +#define SELECT_SYMBOL_DEFAULT_SSE42_AVX(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_SSE42(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); + +#define SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_SSE42(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); + +#define SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX512SKX(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_SSE42(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); \ + SELECT_SYMBOL_AVX512KNL(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512SKX(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_SSE42(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); \ + SELECT_SYMBOL_AVX512KNL(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_DEFAULT_SSE42_AVX_AVX2_AVX512SKX(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_SSE42(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_DEFAULT_AVX(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); + +#define SELECT_SYMBOL_DEFAULT_AVX_AVX2(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); + +#define SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX512KNL(features,intersector); + +#define SELECT_SYMBOL_DEFAULT_AVX_AVX512KNL_AVX512SKX(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX512KNL(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_DEFAULT_AVX_AVX512SKX(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_INIT_AVX(features,intersector) \ + INIT_SYMBOL(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); + +#define SELECT_SYMBOL_INIT_AVX_AVX2(features,intersector) \ + INIT_SYMBOL(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); + +#define SELECT_SYMBOL_INIT_AVX_AVX2_AVX512SKX(features,intersector) \ + INIT_SYMBOL(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_INIT_SSE42_AVX_AVX2(features,intersector) \ + INIT_SYMBOL(features,intersector); \ + SELECT_SYMBOL_SSE42(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); + +#define SELECT_SYMBOL_INIT_AVX_AVX512KNL(features,intersector) \ + INIT_SYMBOL(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX512KNL(features,intersector); + +#define SELECT_SYMBOL_INIT_AVX_AVX512KNL_AVX512SKX(features,intersector) \ + INIT_SYMBOL(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX512KNL(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL(features,intersector) \ + INIT_SYMBOL(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); \ + SELECT_SYMBOL_AVX512KNL(features,intersector); + +#define SELECT_SYMBOL_INIT_AVX_AVX2_AVX512KNL_AVX512SKX(features,intersector) \ + INIT_SYMBOL(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); \ + SELECT_SYMBOL_AVX512KNL(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_INIT_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,intersector) \ + INIT_SYMBOL(features,intersector); \ + SELECT_SYMBOL_SSE42(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); \ + SELECT_SYMBOL_AVX512KNL(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_ZERO_SSE42_AVX_AVX2_AVX512KNL_AVX512SKX(features,intersector) \ + ZERO_SYMBOL(features,intersector); \ + SELECT_SYMBOL_SSE42(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); \ + SELECT_SYMBOL_AVX512KNL(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(features,intersector) \ + SELECT_SYMBOL_DEFAULT(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); \ + SELECT_SYMBOL_AVX512KNL(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_INIT_AVX512KNL_AVX512SKX(features,intersector) \ + INIT_SYMBOL(features,intersector); \ + SELECT_SYMBOL_AVX512KNL(features,intersector); \ + SELECT_SYMBOL_AVX512SKX(features,intersector); + +#define SELECT_SYMBOL_SSE42_AVX_AVX2(features,intersector) \ + SELECT_SYMBOL_SSE42(features,intersector); \ + SELECT_SYMBOL_AVX(features,intersector); \ + SELECT_SYMBOL_AVX2(features,intersector); + + struct VerifyMultiTargetLinking { + static __noinline int getISA(int depth = 5) { + if (depth == 0) return ISA; + else return getISA(depth-1); + } + }; + namespace sse2 { int getISA(); }; + namespace sse42 { int getISA(); }; + namespace avx { int getISA(); }; + namespace avx2 { int getISA(); }; + namespace avx512knl { int getISA(); }; + namespace avx512skx { int getISA(); }; +} diff --git a/thirdparty/embree/kernels/common/motion_derivative.h b/thirdparty/embree/kernels/common/motion_derivative.h new file mode 100644 index 000000000000..82953f0e89a4 --- /dev/null +++ b/thirdparty/embree/kernels/common/motion_derivative.h @@ -0,0 +1,325 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../../common/math/affinespace.h" +#include "../../common/math/interval.h" + +#include + +namespace embree { + +#define MOTION_DERIVATIVE_ROOT_EPSILON 1e-4f + +static void motion_derivative_coefficients(const float *p, float *coeff); + +struct MotionDerivativeCoefficients +{ + float theta; + float coeffs[3*8*7]; + + MotionDerivativeCoefficients() {} + + // xfm0 and xfm1 are interpret as quaternion decomposition + MotionDerivativeCoefficients(AffineSpace3ff const& xfm0, AffineSpace3ff const& xfm1) + { + // cosTheta of the two quaternions + const float cosTheta = min(1.f, max(-1.f, + xfm0.l.vx.w * xfm1.l.vx.w + + xfm0.l.vy.w * xfm1.l.vy.w + + xfm0.l.vz.w * xfm1.l.vz.w + + xfm0.p.w * xfm1.p.w)); + + theta = std::acos(cosTheta); + Vec4f qperp(xfm1.p.w, xfm1.l.vx.w, xfm1.l.vy.w, xfm1.l.vz.w); + if (cosTheta < 0.995f) { + // compute perpendicular quaternion + qperp.x = xfm1.p.w - cosTheta * xfm0.p.w; + qperp.y = xfm1.l.vx.w - cosTheta * xfm0.l.vx.w; + qperp.z = xfm1.l.vy.w - cosTheta * xfm0.l.vy.w; + qperp.w = xfm1.l.vz.w - cosTheta * xfm0.l.vz.w; + qperp = normalize(qperp); + } + const float p[33] = { + theta, + xfm0.l.vx.y, xfm0.l.vx.z, xfm0.l.vy.z, // translation component of xfm0 + xfm1.l.vx.y, xfm1.l.vx.z, xfm1.l.vy.z, // translation component of xfm1 + xfm0.p.w, xfm0.l.vx.w, xfm0.l.vy.w, xfm0.l.vz.w, // quaternion of xfm0 + qperp.x, qperp.y, qperp.z, qperp.w, + xfm0.l.vx.x, xfm0.l.vy.x, xfm0.l.vz.x, xfm0.p.x, // scale/skew component of xfm0 + xfm0.l.vy.y, xfm0.l.vz.y, xfm0.p.y, + xfm0.l.vz.z, xfm0.p.z, + xfm1.l.vx.x, xfm1.l.vy.x, xfm1.l.vz.x, xfm1.p.x, // scale/skew component of xfm1 + xfm1.l.vy.y, xfm1.l.vz.y, xfm1.p.y, + xfm1.l.vz.z, xfm1.p.z + }; + motion_derivative_coefficients(p, coeffs); + } +}; + +struct MotionDerivative +{ + float twoTheta; + float c[8]; + + MotionDerivative(MotionDerivativeCoefficients const& mdc, + int dim, Vec3fa const& p0, Vec3fa const& p1) + : twoTheta(2.f*mdc.theta) + { + const float p[7] = { 1, p0.x, p0.y, p0.z, p1.x, p1.y, p1.z }; + for (int i = 0; i < 8; ++i) { + c[i] = 0; + for (int j = 0; j < 7; ++j) { + c[i] += mdc.coeffs[8*7*dim + i*7 + j] * p[j]; + } + } + } + + template + struct EvalMotionDerivative + { + MotionDerivative const& md; + float offset; + + EvalMotionDerivative(MotionDerivative const& md, float offset) : md(md), offset(offset) {} + + T operator()(T const& time) const { + return md.c[0] + md.c[1] * time + + (md.c[2] + md.c[3] * time + md.c[4] * time * time) * cos(md.twoTheta * time) + + (md.c[5] + md.c[6] * time + md.c[7] * time * time) * sin(md.twoTheta * time) + + offset; + } + }; + + unsigned int findRoots( + Interval1f const& interval, + float offset, + float* roots, + unsigned int maxNumRoots) + { + unsigned int numRoots = 0; + EvalMotionDerivative eval(*this, offset); + findRoots(eval, interval, numRoots, roots, maxNumRoots); + return numRoots; + } + + template + static void findRoots( + + Eval const& eval, + Interval1f const& interval, + unsigned int& numRoots, + float* roots, + unsigned int maxNumRoots) + { + Interval1f range = eval(interval); + if (range.lower > 0 || range.upper < 0 || range.lower >= range.upper) return; + + const float split = 0.5f * (interval.upper + interval.lower); + if (interval.upper-interval.lower < 1e-7f || abs(split-interval.lower) < 1e-7f || abs(split-interval.upper) < 1e-7f) + { + // check if the root already exists + for (unsigned int k = 0; k < numRoots && k < maxNumRoots; ++k) { + if (abs(roots[k]-split) < MOTION_DERIVATIVE_ROOT_EPSILON) + return; + } + if (numRoots < maxNumRoots) { + roots[numRoots++] = split; + } + if (numRoots > maxNumRoots) { + printf("error: more roots than expected\n"); // FIXME: workaround for ICC2019.4 compiler bug under macOS + return; + } + return; + } + + findRoots(eval, Interval1f(interval.lower, split), numRoots, roots, maxNumRoots); + findRoots(eval, Interval1f(split, interval.upper), numRoots, roots, maxNumRoots); + } +}; + +/****************************************************************************** + * Code generated with sympy 1.4 * + * See http://www.sympy.org/ for more information. * + * * + * see * + * * + * scripts/generate_motion_derivative_coefficients.py * + * * + * for how this code is generated * + * * + ******************************************************************************/ +static void motion_derivative_coefficients(const float *p, float *coeff) +{ + coeff[0] = -p[1] + p[4] - p[7]*p[9]*p[23] + p[7]*p[9]*p[32] + p[7]*p[10]*p[21] - p[7]*p[10]*p[30] - p[8]*p[9]*p[21] + p[8]*p[9]*p[30] - p[8]*p[10]*p[23] + p[8]*p[10]*p[32] + p[9]*p[9]*p[18] - p[9]*p[9]*p[27] + p[10]*p[10]*p[18] - p[10]*p[10]*p[27] - p[11]*p[13]*p[23] + p[11]*p[13]*p[32] + p[11]*p[14]*p[21] - p[11]*p[14]*p[30] - p[12]*p[13]*p[21] + p[12]*p[13]*p[30] - p[12]*p[14]*p[23] + p[12]*p[14]*p[32] + p[13]*p[13]*p[18] - p[13]*p[13]*p[27] + p[14]*p[14]*p[18] - p[14]*p[14]*p[27] - p[18] + p[27]; + coeff[1] = 2*p[9]*p[9]*p[15] - p[9]*p[9]*p[24] + 2*p[10]*p[10]*p[15] - p[10]*p[10]*p[24] + 2*p[13]*p[13]*p[15] - p[13]*p[13]*p[24] + 2*p[14]*p[14]*p[15] - p[14]*p[14]*p[24] - 2*p[15] + p[24]; + coeff[2] = 2*p[7]*p[10]*p[19] - p[7]*p[10]*p[28] - 2*p[8]*p[9]*p[19] + p[8]*p[9]*p[28] + 2*p[9]*p[9]*p[16] - p[9]*p[9]*p[25] + 2*p[10]*p[10]*p[16] - p[10]*p[10]*p[25] + 2*p[11]*p[14]*p[19] - p[11]*p[14]*p[28] - 2*p[12]*p[13]*p[19] + p[12]*p[13]*p[28] + 2*p[13]*p[13]*p[16] - p[13]*p[13]*p[25] + 2*p[14]*p[14]*p[16] - p[14]*p[14]*p[25] - 2*p[16] + p[25]; + coeff[3] = -2*p[7]*p[9]*p[22] + p[7]*p[9]*p[31] + 2*p[7]*p[10]*p[20] - p[7]*p[10]*p[29] - 2*p[8]*p[9]*p[20] + p[8]*p[9]*p[29] - 2*p[8]*p[10]*p[22] + p[8]*p[10]*p[31] + 2*p[9]*p[9]*p[17] - p[9]*p[9]*p[26] + 2*p[10]*p[10]*p[17] - p[10]*p[10]*p[26] - 2*p[11]*p[13]*p[22] + p[11]*p[13]*p[31] + 2*p[11]*p[14]*p[20] - p[11]*p[14]*p[29] - 2*p[12]*p[13]*p[20] + p[12]*p[13]*p[29] - 2*p[12]*p[14]*p[22] + p[12]*p[14]*p[31] + 2*p[13]*p[13]*p[17] - p[13]*p[13]*p[26] + 2*p[14]*p[14]*p[17] - p[14]*p[14]*p[26] - 2*p[17] + p[26]; + coeff[4] = (-p[9]*p[9] - p[10]*p[10] - p[13]*p[13] - p[14]*p[14] + 1)*p[15]; + coeff[5] = -p[7]*p[10]*p[19] + p[8]*p[9]*p[19] - p[9]*p[9]*p[16] - p[10]*p[10]*p[16] - p[11]*p[14]*p[19] + p[12]*p[13]*p[19] - p[13]*p[13]*p[16] - p[14]*p[14]*p[16] + p[16]; + coeff[6] = p[7]*p[9]*p[22] - p[7]*p[10]*p[20] + p[8]*p[9]*p[20] + p[8]*p[10]*p[22] - p[9]*p[9]*p[17] - p[10]*p[10]*p[17] + p[11]*p[13]*p[22] - p[11]*p[14]*p[20] + p[12]*p[13]*p[20] + p[12]*p[14]*p[22] - p[13]*p[13]*p[17] - p[14]*p[14]*p[17] + p[17]; + coeff[7] = 0; + coeff[8] = -2*p[9]*p[9]*p[15] + 2*p[9]*p[9]*p[24] - 2*p[10]*p[10]*p[15] + 2*p[10]*p[10]*p[24] - 2*p[13]*p[13]*p[15] + 2*p[13]*p[13]*p[24] - 2*p[14]*p[14]*p[15] + 2*p[14]*p[14]*p[24] + 2*p[15] - 2*p[24]; + coeff[9] = -2*p[7]*p[10]*p[19] + 2*p[7]*p[10]*p[28] + 2*p[8]*p[9]*p[19] - 2*p[8]*p[9]*p[28] - 2*p[9]*p[9]*p[16] + 2*p[9]*p[9]*p[25] - 2*p[10]*p[10]*p[16] + 2*p[10]*p[10]*p[25] - 2*p[11]*p[14]*p[19] + 2*p[11]*p[14]*p[28] + 2*p[12]*p[13]*p[19] - 2*p[12]*p[13]*p[28] - 2*p[13]*p[13]*p[16] + 2*p[13]*p[13]*p[25] - 2*p[14]*p[14]*p[16] + 2*p[14]*p[14]*p[25] + 2*p[16] - 2*p[25]; + coeff[10] = 2*p[7]*p[9]*p[22] - 2*p[7]*p[9]*p[31] - 2*p[7]*p[10]*p[20] + 2*p[7]*p[10]*p[29] + 2*p[8]*p[9]*p[20] - 2*p[8]*p[9]*p[29] + 2*p[8]*p[10]*p[22] - 2*p[8]*p[10]*p[31] - 2*p[9]*p[9]*p[17] + 2*p[9]*p[9]*p[26] - 2*p[10]*p[10]*p[17] + 2*p[10]*p[10]*p[26] + 2*p[11]*p[13]*p[22] - 2*p[11]*p[13]*p[31] - 2*p[11]*p[14]*p[20] + 2*p[11]*p[14]*p[29] + 2*p[12]*p[13]*p[20] - 2*p[12]*p[13]*p[29] + 2*p[12]*p[14]*p[22] - 2*p[12]*p[14]*p[31] - 2*p[13]*p[13]*p[17] + 2*p[13]*p[13]*p[26] - 2*p[14]*p[14]*p[17] + 2*p[14]*p[14]*p[26] + 2*p[17] - 2*p[26]; + coeff[11] = 2*p[9]*p[9]*p[15] - 2*p[9]*p[9]*p[24] + 2*p[10]*p[10]*p[15] - 2*p[10]*p[10]*p[24] + 2*p[13]*p[13]*p[15] - 2*p[13]*p[13]*p[24] + 2*p[14]*p[14]*p[15] - 2*p[14]*p[14]*p[24] - 2*p[15] + 2*p[24]; + coeff[12] = 2*p[7]*p[10]*p[19] - 2*p[7]*p[10]*p[28] - 2*p[8]*p[9]*p[19] + 2*p[8]*p[9]*p[28] + 2*p[9]*p[9]*p[16] - 2*p[9]*p[9]*p[25] + 2*p[10]*p[10]*p[16] - 2*p[10]*p[10]*p[25] + 2*p[11]*p[14]*p[19] - 2*p[11]*p[14]*p[28] - 2*p[12]*p[13]*p[19] + 2*p[12]*p[13]*p[28] + 2*p[13]*p[13]*p[16] - 2*p[13]*p[13]*p[25] + 2*p[14]*p[14]*p[16] - 2*p[14]*p[14]*p[25] - 2*p[16] + 2*p[25]; + coeff[13] = -2*p[7]*p[9]*p[22] + 2*p[7]*p[9]*p[31] + 2*p[7]*p[10]*p[20] - 2*p[7]*p[10]*p[29] - 2*p[8]*p[9]*p[20] + 2*p[8]*p[9]*p[29] - 2*p[8]*p[10]*p[22] + 2*p[8]*p[10]*p[31] + 2*p[9]*p[9]*p[17] - 2*p[9]*p[9]*p[26] + 2*p[10]*p[10]*p[17] - 2*p[10]*p[10]*p[26] - 2*p[11]*p[13]*p[22] + 2*p[11]*p[13]*p[31] + 2*p[11]*p[14]*p[20] - 2*p[11]*p[14]*p[29] - 2*p[12]*p[13]*p[20] + 2*p[12]*p[13]*p[29] - 2*p[12]*p[14]*p[22] + 2*p[12]*p[14]*p[31] + 2*p[13]*p[13]*p[17] - 2*p[13]*p[13]*p[26] + 2*p[14]*p[14]*p[17] - 2*p[14]*p[14]*p[26] - 2*p[17] + 2*p[26]; + coeff[14] = 2*p[0]*p[7]*p[11]*p[18] + 2*p[0]*p[7]*p[13]*p[23] - 2*p[0]*p[7]*p[14]*p[21] + 2*p[0]*p[8]*p[12]*p[18] + 2*p[0]*p[8]*p[13]*p[21] + 2*p[0]*p[8]*p[14]*p[23] + 2*p[0]*p[9]*p[11]*p[23] + 2*p[0]*p[9]*p[12]*p[21] - 2*p[0]*p[9]*p[13]*p[18] - 2*p[0]*p[10]*p[11]*p[21] + 2*p[0]*p[10]*p[12]*p[23] - 2*p[0]*p[10]*p[14]*p[18] - p[7]*p[9]*p[23] + p[7]*p[9]*p[32] + p[7]*p[10]*p[21] - p[7]*p[10]*p[30] - p[8]*p[9]*p[21] + p[8]*p[9]*p[30] - p[8]*p[10]*p[23] + p[8]*p[10]*p[32] + p[9]*p[9]*p[18] - p[9]*p[9]*p[27] + p[10]*p[10]*p[18] - p[10]*p[10]*p[27] + p[11]*p[13]*p[23] - p[11]*p[13]*p[32] - p[11]*p[14]*p[21] + p[11]*p[14]*p[30] + p[12]*p[13]*p[21] - p[12]*p[13]*p[30] + p[12]*p[14]*p[23] - p[12]*p[14]*p[32] - p[13]*p[13]*p[18] + p[13]*p[13]*p[27] - p[14]*p[14]*p[18] + p[14]*p[14]*p[27]; + coeff[15] = 2*p[0]*p[7]*p[11]*p[15] + 2*p[0]*p[8]*p[12]*p[15] - 2*p[0]*p[9]*p[13]*p[15] - 2*p[0]*p[10]*p[14]*p[15] + 2*p[9]*p[9]*p[15] - p[9]*p[9]*p[24] + 2*p[10]*p[10]*p[15] - p[10]*p[10]*p[24] - 2*p[13]*p[13]*p[15] + p[13]*p[13]*p[24] - 2*p[14]*p[14]*p[15] + p[14]*p[14]*p[24]; + coeff[16] = 2*p[0]*p[7]*p[11]*p[16] - 2*p[0]*p[7]*p[14]*p[19] + 2*p[0]*p[8]*p[12]*p[16] + 2*p[0]*p[8]*p[13]*p[19] + 2*p[0]*p[9]*p[12]*p[19] - 2*p[0]*p[9]*p[13]*p[16] - 2*p[0]*p[10]*p[11]*p[19] - 2*p[0]*p[10]*p[14]*p[16] + 2*p[7]*p[10]*p[19] - p[7]*p[10]*p[28] - 2*p[8]*p[9]*p[19] + p[8]*p[9]*p[28] + 2*p[9]*p[9]*p[16] - p[9]*p[9]*p[25] + 2*p[10]*p[10]*p[16] - p[10]*p[10]*p[25] - 2*p[11]*p[14]*p[19] + p[11]*p[14]*p[28] + 2*p[12]*p[13]*p[19] - p[12]*p[13]*p[28] - 2*p[13]*p[13]*p[16] + p[13]*p[13]*p[25] - 2*p[14]*p[14]*p[16] + p[14]*p[14]*p[25]; + coeff[17] = 2*p[0]*p[7]*p[11]*p[17] + 2*p[0]*p[7]*p[13]*p[22] - 2*p[0]*p[7]*p[14]*p[20] + 2*p[0]*p[8]*p[12]*p[17] + 2*p[0]*p[8]*p[13]*p[20] + 2*p[0]*p[8]*p[14]*p[22] + 2*p[0]*p[9]*p[11]*p[22] + 2*p[0]*p[9]*p[12]*p[20] - 2*p[0]*p[9]*p[13]*p[17] - 2*p[0]*p[10]*p[11]*p[20] + 2*p[0]*p[10]*p[12]*p[22] - 2*p[0]*p[10]*p[14]*p[17] - 2*p[7]*p[9]*p[22] + p[7]*p[9]*p[31] + 2*p[7]*p[10]*p[20] - p[7]*p[10]*p[29] - 2*p[8]*p[9]*p[20] + p[8]*p[9]*p[29] - 2*p[8]*p[10]*p[22] + p[8]*p[10]*p[31] + 2*p[9]*p[9]*p[17] - p[9]*p[9]*p[26] + 2*p[10]*p[10]*p[17] - p[10]*p[10]*p[26] + 2*p[11]*p[13]*p[22] - p[11]*p[13]*p[31] - 2*p[11]*p[14]*p[20] + p[11]*p[14]*p[29] + 2*p[12]*p[13]*p[20] - p[12]*p[13]*p[29] + 2*p[12]*p[14]*p[22] - p[12]*p[14]*p[31] - 2*p[13]*p[13]*p[17] + p[13]*p[13]*p[26] - 2*p[14]*p[14]*p[17] + p[14]*p[14]*p[26]; + coeff[18] = (-p[9]*p[9] - p[10]*p[10] + p[13]*p[13] + p[14]*p[14])*p[15]; + coeff[19] = -p[7]*p[10]*p[19] + p[8]*p[9]*p[19] - p[9]*p[9]*p[16] - p[10]*p[10]*p[16] + p[11]*p[14]*p[19] - p[12]*p[13]*p[19] + p[13]*p[13]*p[16] + p[14]*p[14]*p[16]; + coeff[20] = p[7]*p[9]*p[22] - p[7]*p[10]*p[20] + p[8]*p[9]*p[20] + p[8]*p[10]*p[22] - p[9]*p[9]*p[17] - p[10]*p[10]*p[17] - p[11]*p[13]*p[22] + p[11]*p[14]*p[20] - p[12]*p[13]*p[20] - p[12]*p[14]*p[22] + p[13]*p[13]*p[17] + p[14]*p[14]*p[17]; + coeff[21] = 2*(-p[7]*p[11]*p[18] + p[7]*p[11]*p[27] - p[7]*p[13]*p[23] + p[7]*p[13]*p[32] + p[7]*p[14]*p[21] - p[7]*p[14]*p[30] - p[8]*p[12]*p[18] + p[8]*p[12]*p[27] - p[8]*p[13]*p[21] + p[8]*p[13]*p[30] - p[8]*p[14]*p[23] + p[8]*p[14]*p[32] - p[9]*p[11]*p[23] + p[9]*p[11]*p[32] - p[9]*p[12]*p[21] + p[9]*p[12]*p[30] + p[9]*p[13]*p[18] - p[9]*p[13]*p[27] + p[10]*p[11]*p[21] - p[10]*p[11]*p[30] - p[10]*p[12]*p[23] + p[10]*p[12]*p[32] + p[10]*p[14]*p[18] - p[10]*p[14]*p[27])*p[0]; + coeff[22] = -4*p[0]*p[7]*p[11]*p[15] + 2*p[0]*p[7]*p[11]*p[24] - 4*p[0]*p[8]*p[12]*p[15] + 2*p[0]*p[8]*p[12]*p[24] + 4*p[0]*p[9]*p[13]*p[15] - 2*p[0]*p[9]*p[13]*p[24] + 4*p[0]*p[10]*p[14]*p[15] - 2*p[0]*p[10]*p[14]*p[24] - 2*p[9]*p[9]*p[15] + 2*p[9]*p[9]*p[24] - 2*p[10]*p[10]*p[15] + 2*p[10]*p[10]*p[24] + 2*p[13]*p[13]*p[15] - 2*p[13]*p[13]*p[24] + 2*p[14]*p[14]*p[15] - 2*p[14]*p[14]*p[24]; + coeff[23] = -4*p[0]*p[7]*p[11]*p[16] + 2*p[0]*p[7]*p[11]*p[25] + 4*p[0]*p[7]*p[14]*p[19] - 2*p[0]*p[7]*p[14]*p[28] - 4*p[0]*p[8]*p[12]*p[16] + 2*p[0]*p[8]*p[12]*p[25] - 4*p[0]*p[8]*p[13]*p[19] + 2*p[0]*p[8]*p[13]*p[28] - 4*p[0]*p[9]*p[12]*p[19] + 2*p[0]*p[9]*p[12]*p[28] + 4*p[0]*p[9]*p[13]*p[16] - 2*p[0]*p[9]*p[13]*p[25] + 4*p[0]*p[10]*p[11]*p[19] - 2*p[0]*p[10]*p[11]*p[28] + 4*p[0]*p[10]*p[14]*p[16] - 2*p[0]*p[10]*p[14]*p[25] - 2*p[7]*p[10]*p[19] + 2*p[7]*p[10]*p[28] + 2*p[8]*p[9]*p[19] - 2*p[8]*p[9]*p[28] - 2*p[9]*p[9]*p[16] + 2*p[9]*p[9]*p[25] - 2*p[10]*p[10]*p[16] + 2*p[10]*p[10]*p[25] + 2*p[11]*p[14]*p[19] - 2*p[11]*p[14]*p[28] - 2*p[12]*p[13]*p[19] + 2*p[12]*p[13]*p[28] + 2*p[13]*p[13]*p[16] - 2*p[13]*p[13]*p[25] + 2*p[14]*p[14]*p[16] - 2*p[14]*p[14]*p[25]; + coeff[24] = -4*p[0]*p[7]*p[11]*p[17] + 2*p[0]*p[7]*p[11]*p[26] - 4*p[0]*p[7]*p[13]*p[22] + 2*p[0]*p[7]*p[13]*p[31] + 4*p[0]*p[7]*p[14]*p[20] - 2*p[0]*p[7]*p[14]*p[29] - 4*p[0]*p[8]*p[12]*p[17] + 2*p[0]*p[8]*p[12]*p[26] - 4*p[0]*p[8]*p[13]*p[20] + 2*p[0]*p[8]*p[13]*p[29] - 4*p[0]*p[8]*p[14]*p[22] + 2*p[0]*p[8]*p[14]*p[31] - 4*p[0]*p[9]*p[11]*p[22] + 2*p[0]*p[9]*p[11]*p[31] - 4*p[0]*p[9]*p[12]*p[20] + 2*p[0]*p[9]*p[12]*p[29] + 4*p[0]*p[9]*p[13]*p[17] - 2*p[0]*p[9]*p[13]*p[26] + 4*p[0]*p[10]*p[11]*p[20] - 2*p[0]*p[10]*p[11]*p[29] - 4*p[0]*p[10]*p[12]*p[22] + 2*p[0]*p[10]*p[12]*p[31] + 4*p[0]*p[10]*p[14]*p[17] - 2*p[0]*p[10]*p[14]*p[26] + 2*p[7]*p[9]*p[22] - 2*p[7]*p[9]*p[31] - 2*p[7]*p[10]*p[20] + 2*p[7]*p[10]*p[29] + 2*p[8]*p[9]*p[20] - 2*p[8]*p[9]*p[29] + 2*p[8]*p[10]*p[22] - 2*p[8]*p[10]*p[31] - 2*p[9]*p[9]*p[17] + 2*p[9]*p[9]*p[26] - 2*p[10]*p[10]*p[17] + 2*p[10]*p[10]*p[26] - 2*p[11]*p[13]*p[22] + 2*p[11]*p[13]*p[31] + 2*p[11]*p[14]*p[20] - 2*p[11]*p[14]*p[29] - 2*p[12]*p[13]*p[20] + 2*p[12]*p[13]*p[29] - 2*p[12]*p[14]*p[22] + 2*p[12]*p[14]*p[31] + 2*p[13]*p[13]*p[17] - 2*p[13]*p[13]*p[26] + 2*p[14]*p[14]*p[17] - 2*p[14]*p[14]*p[26]; + coeff[25] = 2*p[0]*p[7]*p[11]*p[15] + 2*p[0]*p[8]*p[12]*p[15] - 2*p[0]*p[9]*p[13]*p[15] - 2*p[0]*p[10]*p[14]*p[15] + 2*p[9]*p[9]*p[15] - 2*p[9]*p[9]*p[24] + 2*p[10]*p[10]*p[15] - 2*p[10]*p[10]*p[24] - 2*p[13]*p[13]*p[15] + 2*p[13]*p[13]*p[24] - 2*p[14]*p[14]*p[15] + 2*p[14]*p[14]*p[24]; + coeff[26] = 2*p[0]*p[7]*p[11]*p[16] - 2*p[0]*p[7]*p[14]*p[19] + 2*p[0]*p[8]*p[12]*p[16] + 2*p[0]*p[8]*p[13]*p[19] + 2*p[0]*p[9]*p[12]*p[19] - 2*p[0]*p[9]*p[13]*p[16] - 2*p[0]*p[10]*p[11]*p[19] - 2*p[0]*p[10]*p[14]*p[16] + 2*p[7]*p[10]*p[19] - 2*p[7]*p[10]*p[28] - 2*p[8]*p[9]*p[19] + 2*p[8]*p[9]*p[28] + 2*p[9]*p[9]*p[16] - 2*p[9]*p[9]*p[25] + 2*p[10]*p[10]*p[16] - 2*p[10]*p[10]*p[25] - 2*p[11]*p[14]*p[19] + 2*p[11]*p[14]*p[28] + 2*p[12]*p[13]*p[19] - 2*p[12]*p[13]*p[28] - 2*p[13]*p[13]*p[16] + 2*p[13]*p[13]*p[25] - 2*p[14]*p[14]*p[16] + 2*p[14]*p[14]*p[25]; + coeff[27] = 2*p[0]*p[7]*p[11]*p[17] + 2*p[0]*p[7]*p[13]*p[22] - 2*p[0]*p[7]*p[14]*p[20] + 2*p[0]*p[8]*p[12]*p[17] + 2*p[0]*p[8]*p[13]*p[20] + 2*p[0]*p[8]*p[14]*p[22] + 2*p[0]*p[9]*p[11]*p[22] + 2*p[0]*p[9]*p[12]*p[20] - 2*p[0]*p[9]*p[13]*p[17] - 2*p[0]*p[10]*p[11]*p[20] + 2*p[0]*p[10]*p[12]*p[22] - 2*p[0]*p[10]*p[14]*p[17] - 2*p[7]*p[9]*p[22] + 2*p[7]*p[9]*p[31] + 2*p[7]*p[10]*p[20] - 2*p[7]*p[10]*p[29] - 2*p[8]*p[9]*p[20] + 2*p[8]*p[9]*p[29] - 2*p[8]*p[10]*p[22] + 2*p[8]*p[10]*p[31] + 2*p[9]*p[9]*p[17] - 2*p[9]*p[9]*p[26] + 2*p[10]*p[10]*p[17] - 2*p[10]*p[10]*p[26] + 2*p[11]*p[13]*p[22] - 2*p[11]*p[13]*p[31] - 2*p[11]*p[14]*p[20] + 2*p[11]*p[14]*p[29] + 2*p[12]*p[13]*p[20] - 2*p[12]*p[13]*p[29] + 2*p[12]*p[14]*p[22] - 2*p[12]*p[14]*p[31] - 2*p[13]*p[13]*p[17] + 2*p[13]*p[13]*p[26] - 2*p[14]*p[14]*p[17] + 2*p[14]*p[14]*p[26]; + coeff[28] = 0; + coeff[29] = 2*(p[7]*p[11]*p[15] - p[7]*p[11]*p[24] + p[8]*p[12]*p[15] - p[8]*p[12]*p[24] - p[9]*p[13]*p[15] + p[9]*p[13]*p[24] - p[10]*p[14]*p[15] + p[10]*p[14]*p[24])*p[0]; + coeff[30] = 2*(p[7]*p[11]*p[16] - p[7]*p[11]*p[25] - p[7]*p[14]*p[19] + p[7]*p[14]*p[28] + p[8]*p[12]*p[16] - p[8]*p[12]*p[25] + p[8]*p[13]*p[19] - p[8]*p[13]*p[28] + p[9]*p[12]*p[19] - p[9]*p[12]*p[28] - p[9]*p[13]*p[16] + p[9]*p[13]*p[25] - p[10]*p[11]*p[19] + p[10]*p[11]*p[28] - p[10]*p[14]*p[16] + p[10]*p[14]*p[25])*p[0]; + coeff[31] = 2*(p[7]*p[11]*p[17] - p[7]*p[11]*p[26] + p[7]*p[13]*p[22] - p[7]*p[13]*p[31] - p[7]*p[14]*p[20] + p[7]*p[14]*p[29] + p[8]*p[12]*p[17] - p[8]*p[12]*p[26] + p[8]*p[13]*p[20] - p[8]*p[13]*p[29] + p[8]*p[14]*p[22] - p[8]*p[14]*p[31] + p[9]*p[11]*p[22] - p[9]*p[11]*p[31] + p[9]*p[12]*p[20] - p[9]*p[12]*p[29] - p[9]*p[13]*p[17] + p[9]*p[13]*p[26] - p[10]*p[11]*p[20] + p[10]*p[11]*p[29] + p[10]*p[12]*p[22] - p[10]*p[12]*p[31] - p[10]*p[14]*p[17] + p[10]*p[14]*p[26])*p[0]; + coeff[32] = 2*(-p[7]*p[11]*p[15] + p[7]*p[11]*p[24] - p[8]*p[12]*p[15] + p[8]*p[12]*p[24] + p[9]*p[13]*p[15] - p[9]*p[13]*p[24] + p[10]*p[14]*p[15] - p[10]*p[14]*p[24])*p[0]; + coeff[33] = 2*(-p[7]*p[11]*p[16] + p[7]*p[11]*p[25] + p[7]*p[14]*p[19] - p[7]*p[14]*p[28] - p[8]*p[12]*p[16] + p[8]*p[12]*p[25] - p[8]*p[13]*p[19] + p[8]*p[13]*p[28] - p[9]*p[12]*p[19] + p[9]*p[12]*p[28] + p[9]*p[13]*p[16] - p[9]*p[13]*p[25] + p[10]*p[11]*p[19] - p[10]*p[11]*p[28] + p[10]*p[14]*p[16] - p[10]*p[14]*p[25])*p[0]; + coeff[34] = 2*(-p[7]*p[11]*p[17] + p[7]*p[11]*p[26] - p[7]*p[13]*p[22] + p[7]*p[13]*p[31] + p[7]*p[14]*p[20] - p[7]*p[14]*p[29] - p[8]*p[12]*p[17] + p[8]*p[12]*p[26] - p[8]*p[13]*p[20] + p[8]*p[13]*p[29] - p[8]*p[14]*p[22] + p[8]*p[14]*p[31] - p[9]*p[11]*p[22] + p[9]*p[11]*p[31] - p[9]*p[12]*p[20] + p[9]*p[12]*p[29] + p[9]*p[13]*p[17] - p[9]*p[13]*p[26] + p[10]*p[11]*p[20] - p[10]*p[11]*p[29] - p[10]*p[12]*p[22] + p[10]*p[12]*p[31] + p[10]*p[14]*p[17] - p[10]*p[14]*p[26])*p[0]; + coeff[35] = -2*p[0]*p[7]*p[9]*p[23] + 2*p[0]*p[7]*p[10]*p[21] - 2*p[0]*p[8]*p[9]*p[21] - 2*p[0]*p[8]*p[10]*p[23] + 2*p[0]*p[9]*p[9]*p[18] + 2*p[0]*p[10]*p[10]*p[18] + 2*p[0]*p[11]*p[13]*p[23] - 2*p[0]*p[11]*p[14]*p[21] + 2*p[0]*p[12]*p[13]*p[21] + 2*p[0]*p[12]*p[14]*p[23] - 2*p[0]*p[13]*p[13]*p[18] - 2*p[0]*p[14]*p[14]*p[18] - p[7]*p[11]*p[18] + p[7]*p[11]*p[27] - p[7]*p[13]*p[23] + p[7]*p[13]*p[32] + p[7]*p[14]*p[21] - p[7]*p[14]*p[30] - p[8]*p[12]*p[18] + p[8]*p[12]*p[27] - p[8]*p[13]*p[21] + p[8]*p[13]*p[30] - p[8]*p[14]*p[23] + p[8]*p[14]*p[32] - p[9]*p[11]*p[23] + p[9]*p[11]*p[32] - p[9]*p[12]*p[21] + p[9]*p[12]*p[30] + p[9]*p[13]*p[18] - p[9]*p[13]*p[27] + p[10]*p[11]*p[21] - p[10]*p[11]*p[30] - p[10]*p[12]*p[23] + p[10]*p[12]*p[32] + p[10]*p[14]*p[18] - p[10]*p[14]*p[27]; + coeff[36] = 2*p[0]*p[9]*p[9]*p[15] + 2*p[0]*p[10]*p[10]*p[15] - 2*p[0]*p[13]*p[13]*p[15] - 2*p[0]*p[14]*p[14]*p[15] - 2*p[7]*p[11]*p[15] + p[7]*p[11]*p[24] - 2*p[8]*p[12]*p[15] + p[8]*p[12]*p[24] + 2*p[9]*p[13]*p[15] - p[9]*p[13]*p[24] + 2*p[10]*p[14]*p[15] - p[10]*p[14]*p[24]; + coeff[37] = 2*p[0]*p[7]*p[10]*p[19] - 2*p[0]*p[8]*p[9]*p[19] + 2*p[0]*p[9]*p[9]*p[16] + 2*p[0]*p[10]*p[10]*p[16] - 2*p[0]*p[11]*p[14]*p[19] + 2*p[0]*p[12]*p[13]*p[19] - 2*p[0]*p[13]*p[13]*p[16] - 2*p[0]*p[14]*p[14]*p[16] - 2*p[7]*p[11]*p[16] + p[7]*p[11]*p[25] + 2*p[7]*p[14]*p[19] - p[7]*p[14]*p[28] - 2*p[8]*p[12]*p[16] + p[8]*p[12]*p[25] - 2*p[8]*p[13]*p[19] + p[8]*p[13]*p[28] - 2*p[9]*p[12]*p[19] + p[9]*p[12]*p[28] + 2*p[9]*p[13]*p[16] - p[9]*p[13]*p[25] + 2*p[10]*p[11]*p[19] - p[10]*p[11]*p[28] + 2*p[10]*p[14]*p[16] - p[10]*p[14]*p[25]; + coeff[38] = -2*p[0]*p[7]*p[9]*p[22] + 2*p[0]*p[7]*p[10]*p[20] - 2*p[0]*p[8]*p[9]*p[20] - 2*p[0]*p[8]*p[10]*p[22] + 2*p[0]*p[9]*p[9]*p[17] + 2*p[0]*p[10]*p[10]*p[17] + 2*p[0]*p[11]*p[13]*p[22] - 2*p[0]*p[11]*p[14]*p[20] + 2*p[0]*p[12]*p[13]*p[20] + 2*p[0]*p[12]*p[14]*p[22] - 2*p[0]*p[13]*p[13]*p[17] - 2*p[0]*p[14]*p[14]*p[17] - 2*p[7]*p[11]*p[17] + p[7]*p[11]*p[26] - 2*p[7]*p[13]*p[22] + p[7]*p[13]*p[31] + 2*p[7]*p[14]*p[20] - p[7]*p[14]*p[29] - 2*p[8]*p[12]*p[17] + p[8]*p[12]*p[26] - 2*p[8]*p[13]*p[20] + p[8]*p[13]*p[29] - 2*p[8]*p[14]*p[22] + p[8]*p[14]*p[31] - 2*p[9]*p[11]*p[22] + p[9]*p[11]*p[31] - 2*p[9]*p[12]*p[20] + p[9]*p[12]*p[29] + 2*p[9]*p[13]*p[17] - p[9]*p[13]*p[26] + 2*p[10]*p[11]*p[20] - p[10]*p[11]*p[29] - 2*p[10]*p[12]*p[22] + p[10]*p[12]*p[31] + 2*p[10]*p[14]*p[17] - p[10]*p[14]*p[26]; + coeff[39] = (p[7]*p[11] + p[8]*p[12] - p[9]*p[13] - p[10]*p[14])*p[15]; + coeff[40] = p[7]*p[11]*p[16] - p[7]*p[14]*p[19] + p[8]*p[12]*p[16] + p[8]*p[13]*p[19] + p[9]*p[12]*p[19] - p[9]*p[13]*p[16] - p[10]*p[11]*p[19] - p[10]*p[14]*p[16]; + coeff[41] = p[7]*p[11]*p[17] + p[7]*p[13]*p[22] - p[7]*p[14]*p[20] + p[8]*p[12]*p[17] + p[8]*p[13]*p[20] + p[8]*p[14]*p[22] + p[9]*p[11]*p[22] + p[9]*p[12]*p[20] - p[9]*p[13]*p[17] - p[10]*p[11]*p[20] + p[10]*p[12]*p[22] - p[10]*p[14]*p[17]; + coeff[42] = 2*(p[7]*p[9]*p[23] - p[7]*p[9]*p[32] - p[7]*p[10]*p[21] + p[7]*p[10]*p[30] + p[8]*p[9]*p[21] - p[8]*p[9]*p[30] + p[8]*p[10]*p[23] - p[8]*p[10]*p[32] - p[9]*p[9]*p[18] + p[9]*p[9]*p[27] - p[10]*p[10]*p[18] + p[10]*p[10]*p[27] - p[11]*p[13]*p[23] + p[11]*p[13]*p[32] + p[11]*p[14]*p[21] - p[11]*p[14]*p[30] - p[12]*p[13]*p[21] + p[12]*p[13]*p[30] - p[12]*p[14]*p[23] + p[12]*p[14]*p[32] + p[13]*p[13]*p[18] - p[13]*p[13]*p[27] + p[14]*p[14]*p[18] - p[14]*p[14]*p[27])*p[0]; + coeff[43] = -4*p[0]*p[9]*p[9]*p[15] + 2*p[0]*p[9]*p[9]*p[24] - 4*p[0]*p[10]*p[10]*p[15] + 2*p[0]*p[10]*p[10]*p[24] + 4*p[0]*p[13]*p[13]*p[15] - 2*p[0]*p[13]*p[13]*p[24] + 4*p[0]*p[14]*p[14]*p[15] - 2*p[0]*p[14]*p[14]*p[24] + 2*p[7]*p[11]*p[15] - 2*p[7]*p[11]*p[24] + 2*p[8]*p[12]*p[15] - 2*p[8]*p[12]*p[24] - 2*p[9]*p[13]*p[15] + 2*p[9]*p[13]*p[24] - 2*p[10]*p[14]*p[15] + 2*p[10]*p[14]*p[24]; + coeff[44] = -4*p[0]*p[7]*p[10]*p[19] + 2*p[0]*p[7]*p[10]*p[28] + 4*p[0]*p[8]*p[9]*p[19] - 2*p[0]*p[8]*p[9]*p[28] - 4*p[0]*p[9]*p[9]*p[16] + 2*p[0]*p[9]*p[9]*p[25] - 4*p[0]*p[10]*p[10]*p[16] + 2*p[0]*p[10]*p[10]*p[25] + 4*p[0]*p[11]*p[14]*p[19] - 2*p[0]*p[11]*p[14]*p[28] - 4*p[0]*p[12]*p[13]*p[19] + 2*p[0]*p[12]*p[13]*p[28] + 4*p[0]*p[13]*p[13]*p[16] - 2*p[0]*p[13]*p[13]*p[25] + 4*p[0]*p[14]*p[14]*p[16] - 2*p[0]*p[14]*p[14]*p[25] + 2*p[7]*p[11]*p[16] - 2*p[7]*p[11]*p[25] - 2*p[7]*p[14]*p[19] + 2*p[7]*p[14]*p[28] + 2*p[8]*p[12]*p[16] - 2*p[8]*p[12]*p[25] + 2*p[8]*p[13]*p[19] - 2*p[8]*p[13]*p[28] + 2*p[9]*p[12]*p[19] - 2*p[9]*p[12]*p[28] - 2*p[9]*p[13]*p[16] + 2*p[9]*p[13]*p[25] - 2*p[10]*p[11]*p[19] + 2*p[10]*p[11]*p[28] - 2*p[10]*p[14]*p[16] + 2*p[10]*p[14]*p[25]; + coeff[45] = 4*p[0]*p[7]*p[9]*p[22] - 2*p[0]*p[7]*p[9]*p[31] - 4*p[0]*p[7]*p[10]*p[20] + 2*p[0]*p[7]*p[10]*p[29] + 4*p[0]*p[8]*p[9]*p[20] - 2*p[0]*p[8]*p[9]*p[29] + 4*p[0]*p[8]*p[10]*p[22] - 2*p[0]*p[8]*p[10]*p[31] - 4*p[0]*p[9]*p[9]*p[17] + 2*p[0]*p[9]*p[9]*p[26] - 4*p[0]*p[10]*p[10]*p[17] + 2*p[0]*p[10]*p[10]*p[26] - 4*p[0]*p[11]*p[13]*p[22] + 2*p[0]*p[11]*p[13]*p[31] + 4*p[0]*p[11]*p[14]*p[20] - 2*p[0]*p[11]*p[14]*p[29] - 4*p[0]*p[12]*p[13]*p[20] + 2*p[0]*p[12]*p[13]*p[29] - 4*p[0]*p[12]*p[14]*p[22] + 2*p[0]*p[12]*p[14]*p[31] + 4*p[0]*p[13]*p[13]*p[17] - 2*p[0]*p[13]*p[13]*p[26] + 4*p[0]*p[14]*p[14]*p[17] - 2*p[0]*p[14]*p[14]*p[26] + 2*p[7]*p[11]*p[17] - 2*p[7]*p[11]*p[26] + 2*p[7]*p[13]*p[22] - 2*p[7]*p[13]*p[31] - 2*p[7]*p[14]*p[20] + 2*p[7]*p[14]*p[29] + 2*p[8]*p[12]*p[17] - 2*p[8]*p[12]*p[26] + 2*p[8]*p[13]*p[20] - 2*p[8]*p[13]*p[29] + 2*p[8]*p[14]*p[22] - 2*p[8]*p[14]*p[31] + 2*p[9]*p[11]*p[22] - 2*p[9]*p[11]*p[31] + 2*p[9]*p[12]*p[20] - 2*p[9]*p[12]*p[29] - 2*p[9]*p[13]*p[17] + 2*p[9]*p[13]*p[26] - 2*p[10]*p[11]*p[20] + 2*p[10]*p[11]*p[29] + 2*p[10]*p[12]*p[22] - 2*p[10]*p[12]*p[31] - 2*p[10]*p[14]*p[17] + 2*p[10]*p[14]*p[26]; + coeff[46] = 2*p[0]*p[9]*p[9]*p[15] + 2*p[0]*p[10]*p[10]*p[15] - 2*p[0]*p[13]*p[13]*p[15] - 2*p[0]*p[14]*p[14]*p[15] - 2*p[7]*p[11]*p[15] + 2*p[7]*p[11]*p[24] - 2*p[8]*p[12]*p[15] + 2*p[8]*p[12]*p[24] + 2*p[9]*p[13]*p[15] - 2*p[9]*p[13]*p[24] + 2*p[10]*p[14]*p[15] - 2*p[10]*p[14]*p[24]; + coeff[47] = 2*p[0]*p[7]*p[10]*p[19] - 2*p[0]*p[8]*p[9]*p[19] + 2*p[0]*p[9]*p[9]*p[16] + 2*p[0]*p[10]*p[10]*p[16] - 2*p[0]*p[11]*p[14]*p[19] + 2*p[0]*p[12]*p[13]*p[19] - 2*p[0]*p[13]*p[13]*p[16] - 2*p[0]*p[14]*p[14]*p[16] - 2*p[7]*p[11]*p[16] + 2*p[7]*p[11]*p[25] + 2*p[7]*p[14]*p[19] - 2*p[7]*p[14]*p[28] - 2*p[8]*p[12]*p[16] + 2*p[8]*p[12]*p[25] - 2*p[8]*p[13]*p[19] + 2*p[8]*p[13]*p[28] - 2*p[9]*p[12]*p[19] + 2*p[9]*p[12]*p[28] + 2*p[9]*p[13]*p[16] - 2*p[9]*p[13]*p[25] + 2*p[10]*p[11]*p[19] - 2*p[10]*p[11]*p[28] + 2*p[10]*p[14]*p[16] - 2*p[10]*p[14]*p[25]; + coeff[48] = -2*p[0]*p[7]*p[9]*p[22] + 2*p[0]*p[7]*p[10]*p[20] - 2*p[0]*p[8]*p[9]*p[20] - 2*p[0]*p[8]*p[10]*p[22] + 2*p[0]*p[9]*p[9]*p[17] + 2*p[0]*p[10]*p[10]*p[17] + 2*p[0]*p[11]*p[13]*p[22] - 2*p[0]*p[11]*p[14]*p[20] + 2*p[0]*p[12]*p[13]*p[20] + 2*p[0]*p[12]*p[14]*p[22] - 2*p[0]*p[13]*p[13]*p[17] - 2*p[0]*p[14]*p[14]*p[17] - 2*p[7]*p[11]*p[17] + 2*p[7]*p[11]*p[26] - 2*p[7]*p[13]*p[22] + 2*p[7]*p[13]*p[31] + 2*p[7]*p[14]*p[20] - 2*p[7]*p[14]*p[29] - 2*p[8]*p[12]*p[17] + 2*p[8]*p[12]*p[26] - 2*p[8]*p[13]*p[20] + 2*p[8]*p[13]*p[29] - 2*p[8]*p[14]*p[22] + 2*p[8]*p[14]*p[31] - 2*p[9]*p[11]*p[22] + 2*p[9]*p[11]*p[31] - 2*p[9]*p[12]*p[20] + 2*p[9]*p[12]*p[29] + 2*p[9]*p[13]*p[17] - 2*p[9]*p[13]*p[26] + 2*p[10]*p[11]*p[20] - 2*p[10]*p[11]*p[29] - 2*p[10]*p[12]*p[22] + 2*p[10]*p[12]*p[31] + 2*p[10]*p[14]*p[17] - 2*p[10]*p[14]*p[26]; + coeff[49] = 0; + coeff[50] = 2*(p[9]*p[9]*p[15] - p[9]*p[9]*p[24] + p[10]*p[10]*p[15] - p[10]*p[10]*p[24] - p[13]*p[13]*p[15] + p[13]*p[13]*p[24] - p[14]*p[14]*p[15] + p[14]*p[14]*p[24])*p[0]; + coeff[51] = 2*(p[7]*p[10]*p[19] - p[7]*p[10]*p[28] - p[8]*p[9]*p[19] + p[8]*p[9]*p[28] + p[9]*p[9]*p[16] - p[9]*p[9]*p[25] + p[10]*p[10]*p[16] - p[10]*p[10]*p[25] - p[11]*p[14]*p[19] + p[11]*p[14]*p[28] + p[12]*p[13]*p[19] - p[12]*p[13]*p[28] - p[13]*p[13]*p[16] + p[13]*p[13]*p[25] - p[14]*p[14]*p[16] + p[14]*p[14]*p[25])*p[0]; + coeff[52] = 2*(-p[7]*p[9]*p[22] + p[7]*p[9]*p[31] + p[7]*p[10]*p[20] - p[7]*p[10]*p[29] - p[8]*p[9]*p[20] + p[8]*p[9]*p[29] - p[8]*p[10]*p[22] + p[8]*p[10]*p[31] + p[9]*p[9]*p[17] - p[9]*p[9]*p[26] + p[10]*p[10]*p[17] - p[10]*p[10]*p[26] + p[11]*p[13]*p[22] - p[11]*p[13]*p[31] - p[11]*p[14]*p[20] + p[11]*p[14]*p[29] + p[12]*p[13]*p[20] - p[12]*p[13]*p[29] + p[12]*p[14]*p[22] - p[12]*p[14]*p[31] - p[13]*p[13]*p[17] + p[13]*p[13]*p[26] - p[14]*p[14]*p[17] + p[14]*p[14]*p[26])*p[0]; + coeff[53] = 2*(-p[9]*p[9]*p[15] + p[9]*p[9]*p[24] - p[10]*p[10]*p[15] + p[10]*p[10]*p[24] + p[13]*p[13]*p[15] - p[13]*p[13]*p[24] + p[14]*p[14]*p[15] - p[14]*p[14]*p[24])*p[0]; + coeff[54] = 2*(-p[7]*p[10]*p[19] + p[7]*p[10]*p[28] + p[8]*p[9]*p[19] - p[8]*p[9]*p[28] - p[9]*p[9]*p[16] + p[9]*p[9]*p[25] - p[10]*p[10]*p[16] + p[10]*p[10]*p[25] + p[11]*p[14]*p[19] - p[11]*p[14]*p[28] - p[12]*p[13]*p[19] + p[12]*p[13]*p[28] + p[13]*p[13]*p[16] - p[13]*p[13]*p[25] + p[14]*p[14]*p[16] - p[14]*p[14]*p[25])*p[0]; + coeff[55] = 2*(p[7]*p[9]*p[22] - p[7]*p[9]*p[31] - p[7]*p[10]*p[20] + p[7]*p[10]*p[29] + p[8]*p[9]*p[20] - p[8]*p[9]*p[29] + p[8]*p[10]*p[22] - p[8]*p[10]*p[31] - p[9]*p[9]*p[17] + p[9]*p[9]*p[26] - p[10]*p[10]*p[17] + p[10]*p[10]*p[26] - p[11]*p[13]*p[22] + p[11]*p[13]*p[31] + p[11]*p[14]*p[20] - p[11]*p[14]*p[29] - p[12]*p[13]*p[20] + p[12]*p[13]*p[29] - p[12]*p[14]*p[22] + p[12]*p[14]*p[31] + p[13]*p[13]*p[17] - p[13]*p[13]*p[26] + p[14]*p[14]*p[17] - p[14]*p[14]*p[26])*p[0]; + coeff[56] = -p[2] + p[5] + p[7]*p[8]*p[23] - p[7]*p[8]*p[32] - p[7]*p[10]*p[18] + p[7]*p[10]*p[27] + p[8]*p[8]*p[21] - p[8]*p[8]*p[30] - p[8]*p[9]*p[18] + p[8]*p[9]*p[27] - p[9]*p[10]*p[23] + p[9]*p[10]*p[32] + p[10]*p[10]*p[21] - p[10]*p[10]*p[30] + p[11]*p[12]*p[23] - p[11]*p[12]*p[32] - p[11]*p[14]*p[18] + p[11]*p[14]*p[27] + p[12]*p[12]*p[21] - p[12]*p[12]*p[30] - p[12]*p[13]*p[18] + p[12]*p[13]*p[27] - p[13]*p[14]*p[23] + p[13]*p[14]*p[32] + p[14]*p[14]*p[21] - p[14]*p[14]*p[30] - p[21] + p[30]; + coeff[57] = -2*p[7]*p[10]*p[15] + p[7]*p[10]*p[24] - 2*p[8]*p[9]*p[15] + p[8]*p[9]*p[24] - 2*p[11]*p[14]*p[15] + p[11]*p[14]*p[24] - 2*p[12]*p[13]*p[15] + p[12]*p[13]*p[24]; + coeff[58] = -2*p[7]*p[10]*p[16] + p[7]*p[10]*p[25] + 2*p[8]*p[8]*p[19] - p[8]*p[8]*p[28] - 2*p[8]*p[9]*p[16] + p[8]*p[9]*p[25] + 2*p[10]*p[10]*p[19] - p[10]*p[10]*p[28] - 2*p[11]*p[14]*p[16] + p[11]*p[14]*p[25] + 2*p[12]*p[12]*p[19] - p[12]*p[12]*p[28] - 2*p[12]*p[13]*p[16] + p[12]*p[13]*p[25] + 2*p[14]*p[14]*p[19] - p[14]*p[14]*p[28] - 2*p[19] + p[28]; + coeff[59] = 2*p[7]*p[8]*p[22] - p[7]*p[8]*p[31] - 2*p[7]*p[10]*p[17] + p[7]*p[10]*p[26] + 2*p[8]*p[8]*p[20] - p[8]*p[8]*p[29] - 2*p[8]*p[9]*p[17] + p[8]*p[9]*p[26] - 2*p[9]*p[10]*p[22] + p[9]*p[10]*p[31] + 2*p[10]*p[10]*p[20] - p[10]*p[10]*p[29] + 2*p[11]*p[12]*p[22] - p[11]*p[12]*p[31] - 2*p[11]*p[14]*p[17] + p[11]*p[14]*p[26] + 2*p[12]*p[12]*p[20] - p[12]*p[12]*p[29] - 2*p[12]*p[13]*p[17] + p[12]*p[13]*p[26] - 2*p[13]*p[14]*p[22] + p[13]*p[14]*p[31] + 2*p[14]*p[14]*p[20] - p[14]*p[14]*p[29] - 2*p[20] + p[29]; + coeff[60] = (p[7]*p[10] + p[8]*p[9] + p[11]*p[14] + p[12]*p[13])*p[15]; + coeff[61] = p[7]*p[10]*p[16] - p[8]*p[8]*p[19] + p[8]*p[9]*p[16] - p[10]*p[10]*p[19] + p[11]*p[14]*p[16] - p[12]*p[12]*p[19] + p[12]*p[13]*p[16] - p[14]*p[14]*p[19] + p[19]; + coeff[62] = -p[7]*p[8]*p[22] + p[7]*p[10]*p[17] - p[8]*p[8]*p[20] + p[8]*p[9]*p[17] + p[9]*p[10]*p[22] - p[10]*p[10]*p[20] - p[11]*p[12]*p[22] + p[11]*p[14]*p[17] - p[12]*p[12]*p[20] + p[12]*p[13]*p[17] + p[13]*p[14]*p[22] - p[14]*p[14]*p[20] + p[20]; + coeff[63] = 0; + coeff[64] = 2*p[7]*p[10]*p[15] - 2*p[7]*p[10]*p[24] + 2*p[8]*p[9]*p[15] - 2*p[8]*p[9]*p[24] + 2*p[11]*p[14]*p[15] - 2*p[11]*p[14]*p[24] + 2*p[12]*p[13]*p[15] - 2*p[12]*p[13]*p[24]; + coeff[65] = 2*p[7]*p[10]*p[16] - 2*p[7]*p[10]*p[25] - 2*p[8]*p[8]*p[19] + 2*p[8]*p[8]*p[28] + 2*p[8]*p[9]*p[16] - 2*p[8]*p[9]*p[25] - 2*p[10]*p[10]*p[19] + 2*p[10]*p[10]*p[28] + 2*p[11]*p[14]*p[16] - 2*p[11]*p[14]*p[25] - 2*p[12]*p[12]*p[19] + 2*p[12]*p[12]*p[28] + 2*p[12]*p[13]*p[16] - 2*p[12]*p[13]*p[25] - 2*p[14]*p[14]*p[19] + 2*p[14]*p[14]*p[28] + 2*p[19] - 2*p[28]; + coeff[66] = -2*p[7]*p[8]*p[22] + 2*p[7]*p[8]*p[31] + 2*p[7]*p[10]*p[17] - 2*p[7]*p[10]*p[26] - 2*p[8]*p[8]*p[20] + 2*p[8]*p[8]*p[29] + 2*p[8]*p[9]*p[17] - 2*p[8]*p[9]*p[26] + 2*p[9]*p[10]*p[22] - 2*p[9]*p[10]*p[31] - 2*p[10]*p[10]*p[20] + 2*p[10]*p[10]*p[29] - 2*p[11]*p[12]*p[22] + 2*p[11]*p[12]*p[31] + 2*p[11]*p[14]*p[17] - 2*p[11]*p[14]*p[26] - 2*p[12]*p[12]*p[20] + 2*p[12]*p[12]*p[29] + 2*p[12]*p[13]*p[17] - 2*p[12]*p[13]*p[26] + 2*p[13]*p[14]*p[22] - 2*p[13]*p[14]*p[31] - 2*p[14]*p[14]*p[20] + 2*p[14]*p[14]*p[29] + 2*p[20] - 2*p[29]; + coeff[67] = -2*p[7]*p[10]*p[15] + 2*p[7]*p[10]*p[24] - 2*p[8]*p[9]*p[15] + 2*p[8]*p[9]*p[24] - 2*p[11]*p[14]*p[15] + 2*p[11]*p[14]*p[24] - 2*p[12]*p[13]*p[15] + 2*p[12]*p[13]*p[24]; + coeff[68] = -2*p[7]*p[10]*p[16] + 2*p[7]*p[10]*p[25] + 2*p[8]*p[8]*p[19] - 2*p[8]*p[8]*p[28] - 2*p[8]*p[9]*p[16] + 2*p[8]*p[9]*p[25] + 2*p[10]*p[10]*p[19] - 2*p[10]*p[10]*p[28] - 2*p[11]*p[14]*p[16] + 2*p[11]*p[14]*p[25] + 2*p[12]*p[12]*p[19] - 2*p[12]*p[12]*p[28] - 2*p[12]*p[13]*p[16] + 2*p[12]*p[13]*p[25] + 2*p[14]*p[14]*p[19] - 2*p[14]*p[14]*p[28] - 2*p[19] + 2*p[28]; + coeff[69] = 2*p[7]*p[8]*p[22] - 2*p[7]*p[8]*p[31] - 2*p[7]*p[10]*p[17] + 2*p[7]*p[10]*p[26] + 2*p[8]*p[8]*p[20] - 2*p[8]*p[8]*p[29] - 2*p[8]*p[9]*p[17] + 2*p[8]*p[9]*p[26] - 2*p[9]*p[10]*p[22] + 2*p[9]*p[10]*p[31] + 2*p[10]*p[10]*p[20] - 2*p[10]*p[10]*p[29] + 2*p[11]*p[12]*p[22] - 2*p[11]*p[12]*p[31] - 2*p[11]*p[14]*p[17] + 2*p[11]*p[14]*p[26] + 2*p[12]*p[12]*p[20] - 2*p[12]*p[12]*p[29] - 2*p[12]*p[13]*p[17] + 2*p[12]*p[13]*p[26] - 2*p[13]*p[14]*p[22] + 2*p[13]*p[14]*p[31] + 2*p[14]*p[14]*p[20] - 2*p[14]*p[14]*p[29] - 2*p[20] + 2*p[29]; + coeff[70] = 2*p[0]*p[7]*p[11]*p[21] - 2*p[0]*p[7]*p[12]*p[23] + 2*p[0]*p[7]*p[14]*p[18] - 2*p[0]*p[8]*p[11]*p[23] - 2*p[0]*p[8]*p[12]*p[21] + 2*p[0]*p[8]*p[13]*p[18] + 2*p[0]*p[9]*p[12]*p[18] + 2*p[0]*p[9]*p[13]*p[21] + 2*p[0]*p[9]*p[14]*p[23] + 2*p[0]*p[10]*p[11]*p[18] + 2*p[0]*p[10]*p[13]*p[23] - 2*p[0]*p[10]*p[14]*p[21] + p[7]*p[8]*p[23] - p[7]*p[8]*p[32] - p[7]*p[10]*p[18] + p[7]*p[10]*p[27] + p[8]*p[8]*p[21] - p[8]*p[8]*p[30] - p[8]*p[9]*p[18] + p[8]*p[9]*p[27] - p[9]*p[10]*p[23] + p[9]*p[10]*p[32] + p[10]*p[10]*p[21] - p[10]*p[10]*p[30] - p[11]*p[12]*p[23] + p[11]*p[12]*p[32] + p[11]*p[14]*p[18] - p[11]*p[14]*p[27] - p[12]*p[12]*p[21] + p[12]*p[12]*p[30] + p[12]*p[13]*p[18] - p[12]*p[13]*p[27] + p[13]*p[14]*p[23] - p[13]*p[14]*p[32] - p[14]*p[14]*p[21] + p[14]*p[14]*p[30]; + coeff[71] = 2*p[0]*p[7]*p[14]*p[15] + 2*p[0]*p[8]*p[13]*p[15] + 2*p[0]*p[9]*p[12]*p[15] + 2*p[0]*p[10]*p[11]*p[15] - 2*p[7]*p[10]*p[15] + p[7]*p[10]*p[24] - 2*p[8]*p[9]*p[15] + p[8]*p[9]*p[24] + 2*p[11]*p[14]*p[15] - p[11]*p[14]*p[24] + 2*p[12]*p[13]*p[15] - p[12]*p[13]*p[24]; + coeff[72] = 2*p[0]*p[7]*p[11]*p[19] + 2*p[0]*p[7]*p[14]*p[16] - 2*p[0]*p[8]*p[12]*p[19] + 2*p[0]*p[8]*p[13]*p[16] + 2*p[0]*p[9]*p[12]*p[16] + 2*p[0]*p[9]*p[13]*p[19] + 2*p[0]*p[10]*p[11]*p[16] - 2*p[0]*p[10]*p[14]*p[19] - 2*p[7]*p[10]*p[16] + p[7]*p[10]*p[25] + 2*p[8]*p[8]*p[19] - p[8]*p[8]*p[28] - 2*p[8]*p[9]*p[16] + p[8]*p[9]*p[25] + 2*p[10]*p[10]*p[19] - p[10]*p[10]*p[28] + 2*p[11]*p[14]*p[16] - p[11]*p[14]*p[25] - 2*p[12]*p[12]*p[19] + p[12]*p[12]*p[28] + 2*p[12]*p[13]*p[16] - p[12]*p[13]*p[25] - 2*p[14]*p[14]*p[19] + p[14]*p[14]*p[28]; + coeff[73] = 2*p[0]*p[7]*p[11]*p[20] - 2*p[0]*p[7]*p[12]*p[22] + 2*p[0]*p[7]*p[14]*p[17] - 2*p[0]*p[8]*p[11]*p[22] - 2*p[0]*p[8]*p[12]*p[20] + 2*p[0]*p[8]*p[13]*p[17] + 2*p[0]*p[9]*p[12]*p[17] + 2*p[0]*p[9]*p[13]*p[20] + 2*p[0]*p[9]*p[14]*p[22] + 2*p[0]*p[10]*p[11]*p[17] + 2*p[0]*p[10]*p[13]*p[22] - 2*p[0]*p[10]*p[14]*p[20] + 2*p[7]*p[8]*p[22] - p[7]*p[8]*p[31] - 2*p[7]*p[10]*p[17] + p[7]*p[10]*p[26] + 2*p[8]*p[8]*p[20] - p[8]*p[8]*p[29] - 2*p[8]*p[9]*p[17] + p[8]*p[9]*p[26] - 2*p[9]*p[10]*p[22] + p[9]*p[10]*p[31] + 2*p[10]*p[10]*p[20] - p[10]*p[10]*p[29] - 2*p[11]*p[12]*p[22] + p[11]*p[12]*p[31] + 2*p[11]*p[14]*p[17] - p[11]*p[14]*p[26] - 2*p[12]*p[12]*p[20] + p[12]*p[12]*p[29] + 2*p[12]*p[13]*p[17] - p[12]*p[13]*p[26] + 2*p[13]*p[14]*p[22] - p[13]*p[14]*p[31] - 2*p[14]*p[14]*p[20] + p[14]*p[14]*p[29]; + coeff[74] = (p[7]*p[10] + p[8]*p[9] - p[11]*p[14] - p[12]*p[13])*p[15]; + coeff[75] = p[7]*p[10]*p[16] - p[8]*p[8]*p[19] + p[8]*p[9]*p[16] - p[10]*p[10]*p[19] - p[11]*p[14]*p[16] + p[12]*p[12]*p[19] - p[12]*p[13]*p[16] + p[14]*p[14]*p[19]; + coeff[76] = -p[7]*p[8]*p[22] + p[7]*p[10]*p[17] - p[8]*p[8]*p[20] + p[8]*p[9]*p[17] + p[9]*p[10]*p[22] - p[10]*p[10]*p[20] + p[11]*p[12]*p[22] - p[11]*p[14]*p[17] + p[12]*p[12]*p[20] - p[12]*p[13]*p[17] - p[13]*p[14]*p[22] + p[14]*p[14]*p[20]; + coeff[77] = 2*(-p[7]*p[11]*p[21] + p[7]*p[11]*p[30] + p[7]*p[12]*p[23] - p[7]*p[12]*p[32] - p[7]*p[14]*p[18] + p[7]*p[14]*p[27] + p[8]*p[11]*p[23] - p[8]*p[11]*p[32] + p[8]*p[12]*p[21] - p[8]*p[12]*p[30] - p[8]*p[13]*p[18] + p[8]*p[13]*p[27] - p[9]*p[12]*p[18] + p[9]*p[12]*p[27] - p[9]*p[13]*p[21] + p[9]*p[13]*p[30] - p[9]*p[14]*p[23] + p[9]*p[14]*p[32] - p[10]*p[11]*p[18] + p[10]*p[11]*p[27] - p[10]*p[13]*p[23] + p[10]*p[13]*p[32] + p[10]*p[14]*p[21] - p[10]*p[14]*p[30])*p[0]; + coeff[78] = -4*p[0]*p[7]*p[14]*p[15] + 2*p[0]*p[7]*p[14]*p[24] - 4*p[0]*p[8]*p[13]*p[15] + 2*p[0]*p[8]*p[13]*p[24] - 4*p[0]*p[9]*p[12]*p[15] + 2*p[0]*p[9]*p[12]*p[24] - 4*p[0]*p[10]*p[11]*p[15] + 2*p[0]*p[10]*p[11]*p[24] + 2*p[7]*p[10]*p[15] - 2*p[7]*p[10]*p[24] + 2*p[8]*p[9]*p[15] - 2*p[8]*p[9]*p[24] - 2*p[11]*p[14]*p[15] + 2*p[11]*p[14]*p[24] - 2*p[12]*p[13]*p[15] + 2*p[12]*p[13]*p[24]; + coeff[79] = -4*p[0]*p[7]*p[11]*p[19] + 2*p[0]*p[7]*p[11]*p[28] - 4*p[0]*p[7]*p[14]*p[16] + 2*p[0]*p[7]*p[14]*p[25] + 4*p[0]*p[8]*p[12]*p[19] - 2*p[0]*p[8]*p[12]*p[28] - 4*p[0]*p[8]*p[13]*p[16] + 2*p[0]*p[8]*p[13]*p[25] - 4*p[0]*p[9]*p[12]*p[16] + 2*p[0]*p[9]*p[12]*p[25] - 4*p[0]*p[9]*p[13]*p[19] + 2*p[0]*p[9]*p[13]*p[28] - 4*p[0]*p[10]*p[11]*p[16] + 2*p[0]*p[10]*p[11]*p[25] + 4*p[0]*p[10]*p[14]*p[19] - 2*p[0]*p[10]*p[14]*p[28] + 2*p[7]*p[10]*p[16] - 2*p[7]*p[10]*p[25] - 2*p[8]*p[8]*p[19] + 2*p[8]*p[8]*p[28] + 2*p[8]*p[9]*p[16] - 2*p[8]*p[9]*p[25] - 2*p[10]*p[10]*p[19] + 2*p[10]*p[10]*p[28] - 2*p[11]*p[14]*p[16] + 2*p[11]*p[14]*p[25] + 2*p[12]*p[12]*p[19] - 2*p[12]*p[12]*p[28] - 2*p[12]*p[13]*p[16] + 2*p[12]*p[13]*p[25] + 2*p[14]*p[14]*p[19] - 2*p[14]*p[14]*p[28]; + coeff[80] = -4*p[0]*p[7]*p[11]*p[20] + 2*p[0]*p[7]*p[11]*p[29] + 4*p[0]*p[7]*p[12]*p[22] - 2*p[0]*p[7]*p[12]*p[31] - 4*p[0]*p[7]*p[14]*p[17] + 2*p[0]*p[7]*p[14]*p[26] + 4*p[0]*p[8]*p[11]*p[22] - 2*p[0]*p[8]*p[11]*p[31] + 4*p[0]*p[8]*p[12]*p[20] - 2*p[0]*p[8]*p[12]*p[29] - 4*p[0]*p[8]*p[13]*p[17] + 2*p[0]*p[8]*p[13]*p[26] - 4*p[0]*p[9]*p[12]*p[17] + 2*p[0]*p[9]*p[12]*p[26] - 4*p[0]*p[9]*p[13]*p[20] + 2*p[0]*p[9]*p[13]*p[29] - 4*p[0]*p[9]*p[14]*p[22] + 2*p[0]*p[9]*p[14]*p[31] - 4*p[0]*p[10]*p[11]*p[17] + 2*p[0]*p[10]*p[11]*p[26] - 4*p[0]*p[10]*p[13]*p[22] + 2*p[0]*p[10]*p[13]*p[31] + 4*p[0]*p[10]*p[14]*p[20] - 2*p[0]*p[10]*p[14]*p[29] - 2*p[7]*p[8]*p[22] + 2*p[7]*p[8]*p[31] + 2*p[7]*p[10]*p[17] - 2*p[7]*p[10]*p[26] - 2*p[8]*p[8]*p[20] + 2*p[8]*p[8]*p[29] + 2*p[8]*p[9]*p[17] - 2*p[8]*p[9]*p[26] + 2*p[9]*p[10]*p[22] - 2*p[9]*p[10]*p[31] - 2*p[10]*p[10]*p[20] + 2*p[10]*p[10]*p[29] + 2*p[11]*p[12]*p[22] - 2*p[11]*p[12]*p[31] - 2*p[11]*p[14]*p[17] + 2*p[11]*p[14]*p[26] + 2*p[12]*p[12]*p[20] - 2*p[12]*p[12]*p[29] - 2*p[12]*p[13]*p[17] + 2*p[12]*p[13]*p[26] - 2*p[13]*p[14]*p[22] + 2*p[13]*p[14]*p[31] + 2*p[14]*p[14]*p[20] - 2*p[14]*p[14]*p[29]; + coeff[81] = 2*p[0]*p[7]*p[14]*p[15] + 2*p[0]*p[8]*p[13]*p[15] + 2*p[0]*p[9]*p[12]*p[15] + 2*p[0]*p[10]*p[11]*p[15] - 2*p[7]*p[10]*p[15] + 2*p[7]*p[10]*p[24] - 2*p[8]*p[9]*p[15] + 2*p[8]*p[9]*p[24] + 2*p[11]*p[14]*p[15] - 2*p[11]*p[14]*p[24] + 2*p[12]*p[13]*p[15] - 2*p[12]*p[13]*p[24]; + coeff[82] = 2*p[0]*p[7]*p[11]*p[19] + 2*p[0]*p[7]*p[14]*p[16] - 2*p[0]*p[8]*p[12]*p[19] + 2*p[0]*p[8]*p[13]*p[16] + 2*p[0]*p[9]*p[12]*p[16] + 2*p[0]*p[9]*p[13]*p[19] + 2*p[0]*p[10]*p[11]*p[16] - 2*p[0]*p[10]*p[14]*p[19] - 2*p[7]*p[10]*p[16] + 2*p[7]*p[10]*p[25] + 2*p[8]*p[8]*p[19] - 2*p[8]*p[8]*p[28] - 2*p[8]*p[9]*p[16] + 2*p[8]*p[9]*p[25] + 2*p[10]*p[10]*p[19] - 2*p[10]*p[10]*p[28] + 2*p[11]*p[14]*p[16] - 2*p[11]*p[14]*p[25] - 2*p[12]*p[12]*p[19] + 2*p[12]*p[12]*p[28] + 2*p[12]*p[13]*p[16] - 2*p[12]*p[13]*p[25] - 2*p[14]*p[14]*p[19] + 2*p[14]*p[14]*p[28]; + coeff[83] = 2*p[0]*p[7]*p[11]*p[20] - 2*p[0]*p[7]*p[12]*p[22] + 2*p[0]*p[7]*p[14]*p[17] - 2*p[0]*p[8]*p[11]*p[22] - 2*p[0]*p[8]*p[12]*p[20] + 2*p[0]*p[8]*p[13]*p[17] + 2*p[0]*p[9]*p[12]*p[17] + 2*p[0]*p[9]*p[13]*p[20] + 2*p[0]*p[9]*p[14]*p[22] + 2*p[0]*p[10]*p[11]*p[17] + 2*p[0]*p[10]*p[13]*p[22] - 2*p[0]*p[10]*p[14]*p[20] + 2*p[7]*p[8]*p[22] - 2*p[7]*p[8]*p[31] - 2*p[7]*p[10]*p[17] + 2*p[7]*p[10]*p[26] + 2*p[8]*p[8]*p[20] - 2*p[8]*p[8]*p[29] - 2*p[8]*p[9]*p[17] + 2*p[8]*p[9]*p[26] - 2*p[9]*p[10]*p[22] + 2*p[9]*p[10]*p[31] + 2*p[10]*p[10]*p[20] - 2*p[10]*p[10]*p[29] - 2*p[11]*p[12]*p[22] + 2*p[11]*p[12]*p[31] + 2*p[11]*p[14]*p[17] - 2*p[11]*p[14]*p[26] - 2*p[12]*p[12]*p[20] + 2*p[12]*p[12]*p[29] + 2*p[12]*p[13]*p[17] - 2*p[12]*p[13]*p[26] + 2*p[13]*p[14]*p[22] - 2*p[13]*p[14]*p[31] - 2*p[14]*p[14]*p[20] + 2*p[14]*p[14]*p[29]; + coeff[84] = 0; + coeff[85] = 2*(p[7]*p[14]*p[15] - p[7]*p[14]*p[24] + p[8]*p[13]*p[15] - p[8]*p[13]*p[24] + p[9]*p[12]*p[15] - p[9]*p[12]*p[24] + p[10]*p[11]*p[15] - p[10]*p[11]*p[24])*p[0]; + coeff[86] = 2*(p[7]*p[11]*p[19] - p[7]*p[11]*p[28] + p[7]*p[14]*p[16] - p[7]*p[14]*p[25] - p[8]*p[12]*p[19] + p[8]*p[12]*p[28] + p[8]*p[13]*p[16] - p[8]*p[13]*p[25] + p[9]*p[12]*p[16] - p[9]*p[12]*p[25] + p[9]*p[13]*p[19] - p[9]*p[13]*p[28] + p[10]*p[11]*p[16] - p[10]*p[11]*p[25] - p[10]*p[14]*p[19] + p[10]*p[14]*p[28])*p[0]; + coeff[87] = 2*(p[7]*p[11]*p[20] - p[7]*p[11]*p[29] - p[7]*p[12]*p[22] + p[7]*p[12]*p[31] + p[7]*p[14]*p[17] - p[7]*p[14]*p[26] - p[8]*p[11]*p[22] + p[8]*p[11]*p[31] - p[8]*p[12]*p[20] + p[8]*p[12]*p[29] + p[8]*p[13]*p[17] - p[8]*p[13]*p[26] + p[9]*p[12]*p[17] - p[9]*p[12]*p[26] + p[9]*p[13]*p[20] - p[9]*p[13]*p[29] + p[9]*p[14]*p[22] - p[9]*p[14]*p[31] + p[10]*p[11]*p[17] - p[10]*p[11]*p[26] + p[10]*p[13]*p[22] - p[10]*p[13]*p[31] - p[10]*p[14]*p[20] + p[10]*p[14]*p[29])*p[0]; + coeff[88] = 2*(-p[7]*p[14]*p[15] + p[7]*p[14]*p[24] - p[8]*p[13]*p[15] + p[8]*p[13]*p[24] - p[9]*p[12]*p[15] + p[9]*p[12]*p[24] - p[10]*p[11]*p[15] + p[10]*p[11]*p[24])*p[0]; + coeff[89] = 2*(-p[7]*p[11]*p[19] + p[7]*p[11]*p[28] - p[7]*p[14]*p[16] + p[7]*p[14]*p[25] + p[8]*p[12]*p[19] - p[8]*p[12]*p[28] - p[8]*p[13]*p[16] + p[8]*p[13]*p[25] - p[9]*p[12]*p[16] + p[9]*p[12]*p[25] - p[9]*p[13]*p[19] + p[9]*p[13]*p[28] - p[10]*p[11]*p[16] + p[10]*p[11]*p[25] + p[10]*p[14]*p[19] - p[10]*p[14]*p[28])*p[0]; + coeff[90] = 2*(-p[7]*p[11]*p[20] + p[7]*p[11]*p[29] + p[7]*p[12]*p[22] - p[7]*p[12]*p[31] - p[7]*p[14]*p[17] + p[7]*p[14]*p[26] + p[8]*p[11]*p[22] - p[8]*p[11]*p[31] + p[8]*p[12]*p[20] - p[8]*p[12]*p[29] - p[8]*p[13]*p[17] + p[8]*p[13]*p[26] - p[9]*p[12]*p[17] + p[9]*p[12]*p[26] - p[9]*p[13]*p[20] + p[9]*p[13]*p[29] - p[9]*p[14]*p[22] + p[9]*p[14]*p[31] - p[10]*p[11]*p[17] + p[10]*p[11]*p[26] - p[10]*p[13]*p[22] + p[10]*p[13]*p[31] + p[10]*p[14]*p[20] - p[10]*p[14]*p[29])*p[0]; + coeff[91] = 2*p[0]*p[7]*p[8]*p[23] - 2*p[0]*p[7]*p[10]*p[18] + 2*p[0]*p[8]*p[8]*p[21] - 2*p[0]*p[8]*p[9]*p[18] - 2*p[0]*p[9]*p[10]*p[23] + 2*p[0]*p[10]*p[10]*p[21] - 2*p[0]*p[11]*p[12]*p[23] + 2*p[0]*p[11]*p[14]*p[18] - 2*p[0]*p[12]*p[12]*p[21] + 2*p[0]*p[12]*p[13]*p[18] + 2*p[0]*p[13]*p[14]*p[23] - 2*p[0]*p[14]*p[14]*p[21] - p[7]*p[11]*p[21] + p[7]*p[11]*p[30] + p[7]*p[12]*p[23] - p[7]*p[12]*p[32] - p[7]*p[14]*p[18] + p[7]*p[14]*p[27] + p[8]*p[11]*p[23] - p[8]*p[11]*p[32] + p[8]*p[12]*p[21] - p[8]*p[12]*p[30] - p[8]*p[13]*p[18] + p[8]*p[13]*p[27] - p[9]*p[12]*p[18] + p[9]*p[12]*p[27] - p[9]*p[13]*p[21] + p[9]*p[13]*p[30] - p[9]*p[14]*p[23] + p[9]*p[14]*p[32] - p[10]*p[11]*p[18] + p[10]*p[11]*p[27] - p[10]*p[13]*p[23] + p[10]*p[13]*p[32] + p[10]*p[14]*p[21] - p[10]*p[14]*p[30]; + coeff[92] = -2*p[0]*p[7]*p[10]*p[15] - 2*p[0]*p[8]*p[9]*p[15] + 2*p[0]*p[11]*p[14]*p[15] + 2*p[0]*p[12]*p[13]*p[15] - 2*p[7]*p[14]*p[15] + p[7]*p[14]*p[24] - 2*p[8]*p[13]*p[15] + p[8]*p[13]*p[24] - 2*p[9]*p[12]*p[15] + p[9]*p[12]*p[24] - 2*p[10]*p[11]*p[15] + p[10]*p[11]*p[24]; + coeff[93] = -2*p[0]*p[7]*p[10]*p[16] + 2*p[0]*p[8]*p[8]*p[19] - 2*p[0]*p[8]*p[9]*p[16] + 2*p[0]*p[10]*p[10]*p[19] + 2*p[0]*p[11]*p[14]*p[16] - 2*p[0]*p[12]*p[12]*p[19] + 2*p[0]*p[12]*p[13]*p[16] - 2*p[0]*p[14]*p[14]*p[19] - 2*p[7]*p[11]*p[19] + p[7]*p[11]*p[28] - 2*p[7]*p[14]*p[16] + p[7]*p[14]*p[25] + 2*p[8]*p[12]*p[19] - p[8]*p[12]*p[28] - 2*p[8]*p[13]*p[16] + p[8]*p[13]*p[25] - 2*p[9]*p[12]*p[16] + p[9]*p[12]*p[25] - 2*p[9]*p[13]*p[19] + p[9]*p[13]*p[28] - 2*p[10]*p[11]*p[16] + p[10]*p[11]*p[25] + 2*p[10]*p[14]*p[19] - p[10]*p[14]*p[28]; + coeff[94] = 2*p[0]*p[7]*p[8]*p[22] - 2*p[0]*p[7]*p[10]*p[17] + 2*p[0]*p[8]*p[8]*p[20] - 2*p[0]*p[8]*p[9]*p[17] - 2*p[0]*p[9]*p[10]*p[22] + 2*p[0]*p[10]*p[10]*p[20] - 2*p[0]*p[11]*p[12]*p[22] + 2*p[0]*p[11]*p[14]*p[17] - 2*p[0]*p[12]*p[12]*p[20] + 2*p[0]*p[12]*p[13]*p[17] + 2*p[0]*p[13]*p[14]*p[22] - 2*p[0]*p[14]*p[14]*p[20] - 2*p[7]*p[11]*p[20] + p[7]*p[11]*p[29] + 2*p[7]*p[12]*p[22] - p[7]*p[12]*p[31] - 2*p[7]*p[14]*p[17] + p[7]*p[14]*p[26] + 2*p[8]*p[11]*p[22] - p[8]*p[11]*p[31] + 2*p[8]*p[12]*p[20] - p[8]*p[12]*p[29] - 2*p[8]*p[13]*p[17] + p[8]*p[13]*p[26] - 2*p[9]*p[12]*p[17] + p[9]*p[12]*p[26] - 2*p[9]*p[13]*p[20] + p[9]*p[13]*p[29] - 2*p[9]*p[14]*p[22] + p[9]*p[14]*p[31] - 2*p[10]*p[11]*p[17] + p[10]*p[11]*p[26] - 2*p[10]*p[13]*p[22] + p[10]*p[13]*p[31] + 2*p[10]*p[14]*p[20] - p[10]*p[14]*p[29]; + coeff[95] = (p[7]*p[14] + p[8]*p[13] + p[9]*p[12] + p[10]*p[11])*p[15]; + coeff[96] = p[7]*p[11]*p[19] + p[7]*p[14]*p[16] - p[8]*p[12]*p[19] + p[8]*p[13]*p[16] + p[9]*p[12]*p[16] + p[9]*p[13]*p[19] + p[10]*p[11]*p[16] - p[10]*p[14]*p[19]; + coeff[97] = p[7]*p[11]*p[20] - p[7]*p[12]*p[22] + p[7]*p[14]*p[17] - p[8]*p[11]*p[22] - p[8]*p[12]*p[20] + p[8]*p[13]*p[17] + p[9]*p[12]*p[17] + p[9]*p[13]*p[20] + p[9]*p[14]*p[22] + p[10]*p[11]*p[17] + p[10]*p[13]*p[22] - p[10]*p[14]*p[20]; + coeff[98] = 2*(-p[7]*p[8]*p[23] + p[7]*p[8]*p[32] + p[7]*p[10]*p[18] - p[7]*p[10]*p[27] - p[8]*p[8]*p[21] + p[8]*p[8]*p[30] + p[8]*p[9]*p[18] - p[8]*p[9]*p[27] + p[9]*p[10]*p[23] - p[9]*p[10]*p[32] - p[10]*p[10]*p[21] + p[10]*p[10]*p[30] + p[11]*p[12]*p[23] - p[11]*p[12]*p[32] - p[11]*p[14]*p[18] + p[11]*p[14]*p[27] + p[12]*p[12]*p[21] - p[12]*p[12]*p[30] - p[12]*p[13]*p[18] + p[12]*p[13]*p[27] - p[13]*p[14]*p[23] + p[13]*p[14]*p[32] + p[14]*p[14]*p[21] - p[14]*p[14]*p[30])*p[0]; + coeff[99] = 4*p[0]*p[7]*p[10]*p[15] - 2*p[0]*p[7]*p[10]*p[24] + 4*p[0]*p[8]*p[9]*p[15] - 2*p[0]*p[8]*p[9]*p[24] - 4*p[0]*p[11]*p[14]*p[15] + 2*p[0]*p[11]*p[14]*p[24] - 4*p[0]*p[12]*p[13]*p[15] + 2*p[0]*p[12]*p[13]*p[24] + 2*p[7]*p[14]*p[15] - 2*p[7]*p[14]*p[24] + 2*p[8]*p[13]*p[15] - 2*p[8]*p[13]*p[24] + 2*p[9]*p[12]*p[15] - 2*p[9]*p[12]*p[24] + 2*p[10]*p[11]*p[15] - 2*p[10]*p[11]*p[24]; + coeff[100] = 4*p[0]*p[7]*p[10]*p[16] - 2*p[0]*p[7]*p[10]*p[25] - 4*p[0]*p[8]*p[8]*p[19] + 2*p[0]*p[8]*p[8]*p[28] + 4*p[0]*p[8]*p[9]*p[16] - 2*p[0]*p[8]*p[9]*p[25] - 4*p[0]*p[10]*p[10]*p[19] + 2*p[0]*p[10]*p[10]*p[28] - 4*p[0]*p[11]*p[14]*p[16] + 2*p[0]*p[11]*p[14]*p[25] + 4*p[0]*p[12]*p[12]*p[19] - 2*p[0]*p[12]*p[12]*p[28] - 4*p[0]*p[12]*p[13]*p[16] + 2*p[0]*p[12]*p[13]*p[25] + 4*p[0]*p[14]*p[14]*p[19] - 2*p[0]*p[14]*p[14]*p[28] + 2*p[7]*p[11]*p[19] - 2*p[7]*p[11]*p[28] + 2*p[7]*p[14]*p[16] - 2*p[7]*p[14]*p[25] - 2*p[8]*p[12]*p[19] + 2*p[8]*p[12]*p[28] + 2*p[8]*p[13]*p[16] - 2*p[8]*p[13]*p[25] + 2*p[9]*p[12]*p[16] - 2*p[9]*p[12]*p[25] + 2*p[9]*p[13]*p[19] - 2*p[9]*p[13]*p[28] + 2*p[10]*p[11]*p[16] - 2*p[10]*p[11]*p[25] - 2*p[10]*p[14]*p[19] + 2*p[10]*p[14]*p[28]; + coeff[101] = -4*p[0]*p[7]*p[8]*p[22] + 2*p[0]*p[7]*p[8]*p[31] + 4*p[0]*p[7]*p[10]*p[17] - 2*p[0]*p[7]*p[10]*p[26] - 4*p[0]*p[8]*p[8]*p[20] + 2*p[0]*p[8]*p[8]*p[29] + 4*p[0]*p[8]*p[9]*p[17] - 2*p[0]*p[8]*p[9]*p[26] + 4*p[0]*p[9]*p[10]*p[22] - 2*p[0]*p[9]*p[10]*p[31] - 4*p[0]*p[10]*p[10]*p[20] + 2*p[0]*p[10]*p[10]*p[29] + 4*p[0]*p[11]*p[12]*p[22] - 2*p[0]*p[11]*p[12]*p[31] - 4*p[0]*p[11]*p[14]*p[17] + 2*p[0]*p[11]*p[14]*p[26] + 4*p[0]*p[12]*p[12]*p[20] - 2*p[0]*p[12]*p[12]*p[29] - 4*p[0]*p[12]*p[13]*p[17] + 2*p[0]*p[12]*p[13]*p[26] - 4*p[0]*p[13]*p[14]*p[22] + 2*p[0]*p[13]*p[14]*p[31] + 4*p[0]*p[14]*p[14]*p[20] - 2*p[0]*p[14]*p[14]*p[29] + 2*p[7]*p[11]*p[20] - 2*p[7]*p[11]*p[29] - 2*p[7]*p[12]*p[22] + 2*p[7]*p[12]*p[31] + 2*p[7]*p[14]*p[17] - 2*p[7]*p[14]*p[26] - 2*p[8]*p[11]*p[22] + 2*p[8]*p[11]*p[31] - 2*p[8]*p[12]*p[20] + 2*p[8]*p[12]*p[29] + 2*p[8]*p[13]*p[17] - 2*p[8]*p[13]*p[26] + 2*p[9]*p[12]*p[17] - 2*p[9]*p[12]*p[26] + 2*p[9]*p[13]*p[20] - 2*p[9]*p[13]*p[29] + 2*p[9]*p[14]*p[22] - 2*p[9]*p[14]*p[31] + 2*p[10]*p[11]*p[17] - 2*p[10]*p[11]*p[26] + 2*p[10]*p[13]*p[22] - 2*p[10]*p[13]*p[31] - 2*p[10]*p[14]*p[20] + 2*p[10]*p[14]*p[29]; + coeff[102] = -2*p[0]*p[7]*p[10]*p[15] - 2*p[0]*p[8]*p[9]*p[15] + 2*p[0]*p[11]*p[14]*p[15] + 2*p[0]*p[12]*p[13]*p[15] - 2*p[7]*p[14]*p[15] + 2*p[7]*p[14]*p[24] - 2*p[8]*p[13]*p[15] + 2*p[8]*p[13]*p[24] - 2*p[9]*p[12]*p[15] + 2*p[9]*p[12]*p[24] - 2*p[10]*p[11]*p[15] + 2*p[10]*p[11]*p[24]; + coeff[103] = -2*p[0]*p[7]*p[10]*p[16] + 2*p[0]*p[8]*p[8]*p[19] - 2*p[0]*p[8]*p[9]*p[16] + 2*p[0]*p[10]*p[10]*p[19] + 2*p[0]*p[11]*p[14]*p[16] - 2*p[0]*p[12]*p[12]*p[19] + 2*p[0]*p[12]*p[13]*p[16] - 2*p[0]*p[14]*p[14]*p[19] - 2*p[7]*p[11]*p[19] + 2*p[7]*p[11]*p[28] - 2*p[7]*p[14]*p[16] + 2*p[7]*p[14]*p[25] + 2*p[8]*p[12]*p[19] - 2*p[8]*p[12]*p[28] - 2*p[8]*p[13]*p[16] + 2*p[8]*p[13]*p[25] - 2*p[9]*p[12]*p[16] + 2*p[9]*p[12]*p[25] - 2*p[9]*p[13]*p[19] + 2*p[9]*p[13]*p[28] - 2*p[10]*p[11]*p[16] + 2*p[10]*p[11]*p[25] + 2*p[10]*p[14]*p[19] - 2*p[10]*p[14]*p[28]; + coeff[104] = 2*p[0]*p[7]*p[8]*p[22] - 2*p[0]*p[7]*p[10]*p[17] + 2*p[0]*p[8]*p[8]*p[20] - 2*p[0]*p[8]*p[9]*p[17] - 2*p[0]*p[9]*p[10]*p[22] + 2*p[0]*p[10]*p[10]*p[20] - 2*p[0]*p[11]*p[12]*p[22] + 2*p[0]*p[11]*p[14]*p[17] - 2*p[0]*p[12]*p[12]*p[20] + 2*p[0]*p[12]*p[13]*p[17] + 2*p[0]*p[13]*p[14]*p[22] - 2*p[0]*p[14]*p[14]*p[20] - 2*p[7]*p[11]*p[20] + 2*p[7]*p[11]*p[29] + 2*p[7]*p[12]*p[22] - 2*p[7]*p[12]*p[31] - 2*p[7]*p[14]*p[17] + 2*p[7]*p[14]*p[26] + 2*p[8]*p[11]*p[22] - 2*p[8]*p[11]*p[31] + 2*p[8]*p[12]*p[20] - 2*p[8]*p[12]*p[29] - 2*p[8]*p[13]*p[17] + 2*p[8]*p[13]*p[26] - 2*p[9]*p[12]*p[17] + 2*p[9]*p[12]*p[26] - 2*p[9]*p[13]*p[20] + 2*p[9]*p[13]*p[29] - 2*p[9]*p[14]*p[22] + 2*p[9]*p[14]*p[31] - 2*p[10]*p[11]*p[17] + 2*p[10]*p[11]*p[26] - 2*p[10]*p[13]*p[22] + 2*p[10]*p[13]*p[31] + 2*p[10]*p[14]*p[20] - 2*p[10]*p[14]*p[29]; + coeff[105] = 0; + coeff[106] = 2*(-p[7]*p[10]*p[15] + p[7]*p[10]*p[24] - p[8]*p[9]*p[15] + p[8]*p[9]*p[24] + p[11]*p[14]*p[15] - p[11]*p[14]*p[24] + p[12]*p[13]*p[15] - p[12]*p[13]*p[24])*p[0]; + coeff[107] = 2*(-p[7]*p[10]*p[16] + p[7]*p[10]*p[25] + p[8]*p[8]*p[19] - p[8]*p[8]*p[28] - p[8]*p[9]*p[16] + p[8]*p[9]*p[25] + p[10]*p[10]*p[19] - p[10]*p[10]*p[28] + p[11]*p[14]*p[16] - p[11]*p[14]*p[25] - p[12]*p[12]*p[19] + p[12]*p[12]*p[28] + p[12]*p[13]*p[16] - p[12]*p[13]*p[25] - p[14]*p[14]*p[19] + p[14]*p[14]*p[28])*p[0]; + coeff[108] = 2*(p[7]*p[8]*p[22] - p[7]*p[8]*p[31] - p[7]*p[10]*p[17] + p[7]*p[10]*p[26] + p[8]*p[8]*p[20] - p[8]*p[8]*p[29] - p[8]*p[9]*p[17] + p[8]*p[9]*p[26] - p[9]*p[10]*p[22] + p[9]*p[10]*p[31] + p[10]*p[10]*p[20] - p[10]*p[10]*p[29] - p[11]*p[12]*p[22] + p[11]*p[12]*p[31] + p[11]*p[14]*p[17] - p[11]*p[14]*p[26] - p[12]*p[12]*p[20] + p[12]*p[12]*p[29] + p[12]*p[13]*p[17] - p[12]*p[13]*p[26] + p[13]*p[14]*p[22] - p[13]*p[14]*p[31] - p[14]*p[14]*p[20] + p[14]*p[14]*p[29])*p[0]; + coeff[109] = 2*(p[7]*p[10]*p[15] - p[7]*p[10]*p[24] + p[8]*p[9]*p[15] - p[8]*p[9]*p[24] - p[11]*p[14]*p[15] + p[11]*p[14]*p[24] - p[12]*p[13]*p[15] + p[12]*p[13]*p[24])*p[0]; + coeff[110] = 2*(p[7]*p[10]*p[16] - p[7]*p[10]*p[25] - p[8]*p[8]*p[19] + p[8]*p[8]*p[28] + p[8]*p[9]*p[16] - p[8]*p[9]*p[25] - p[10]*p[10]*p[19] + p[10]*p[10]*p[28] - p[11]*p[14]*p[16] + p[11]*p[14]*p[25] + p[12]*p[12]*p[19] - p[12]*p[12]*p[28] - p[12]*p[13]*p[16] + p[12]*p[13]*p[25] + p[14]*p[14]*p[19] - p[14]*p[14]*p[28])*p[0]; + coeff[111] = 2*(-p[7]*p[8]*p[22] + p[7]*p[8]*p[31] + p[7]*p[10]*p[17] - p[7]*p[10]*p[26] - p[8]*p[8]*p[20] + p[8]*p[8]*p[29] + p[8]*p[9]*p[17] - p[8]*p[9]*p[26] + p[9]*p[10]*p[22] - p[9]*p[10]*p[31] - p[10]*p[10]*p[20] + p[10]*p[10]*p[29] + p[11]*p[12]*p[22] - p[11]*p[12]*p[31] - p[11]*p[14]*p[17] + p[11]*p[14]*p[26] + p[12]*p[12]*p[20] - p[12]*p[12]*p[29] - p[12]*p[13]*p[17] + p[12]*p[13]*p[26] - p[13]*p[14]*p[22] + p[13]*p[14]*p[31] + p[14]*p[14]*p[20] - p[14]*p[14]*p[29])*p[0]; + coeff[112] = -p[3] + p[6] - p[7]*p[8]*p[21] + p[7]*p[8]*p[30] + p[7]*p[9]*p[18] - p[7]*p[9]*p[27] + p[8]*p[8]*p[23] - p[8]*p[8]*p[32] - p[8]*p[10]*p[18] + p[8]*p[10]*p[27] + p[9]*p[9]*p[23] - p[9]*p[9]*p[32] - p[9]*p[10]*p[21] + p[9]*p[10]*p[30] - p[11]*p[12]*p[21] + p[11]*p[12]*p[30] + p[11]*p[13]*p[18] - p[11]*p[13]*p[27] + p[12]*p[12]*p[23] - p[12]*p[12]*p[32] - p[12]*p[14]*p[18] + p[12]*p[14]*p[27] + p[13]*p[13]*p[23] - p[13]*p[13]*p[32] - p[13]*p[14]*p[21] + p[13]*p[14]*p[30] - p[23] + p[32]; + coeff[113] = 2*p[7]*p[9]*p[15] - p[7]*p[9]*p[24] - 2*p[8]*p[10]*p[15] + p[8]*p[10]*p[24] + 2*p[11]*p[13]*p[15] - p[11]*p[13]*p[24] - 2*p[12]*p[14]*p[15] + p[12]*p[14]*p[24]; + coeff[114] = -2*p[7]*p[8]*p[19] + p[7]*p[8]*p[28] + 2*p[7]*p[9]*p[16] - p[7]*p[9]*p[25] - 2*p[8]*p[10]*p[16] + p[8]*p[10]*p[25] - 2*p[9]*p[10]*p[19] + p[9]*p[10]*p[28] - 2*p[11]*p[12]*p[19] + p[11]*p[12]*p[28] + 2*p[11]*p[13]*p[16] - p[11]*p[13]*p[25] - 2*p[12]*p[14]*p[16] + p[12]*p[14]*p[25] - 2*p[13]*p[14]*p[19] + p[13]*p[14]*p[28]; + coeff[115] = -2*p[7]*p[8]*p[20] + p[7]*p[8]*p[29] + 2*p[7]*p[9]*p[17] - p[7]*p[9]*p[26] + 2*p[8]*p[8]*p[22] - p[8]*p[8]*p[31] - 2*p[8]*p[10]*p[17] + p[8]*p[10]*p[26] + 2*p[9]*p[9]*p[22] - p[9]*p[9]*p[31] - 2*p[9]*p[10]*p[20] + p[9]*p[10]*p[29] - 2*p[11]*p[12]*p[20] + p[11]*p[12]*p[29] + 2*p[11]*p[13]*p[17] - p[11]*p[13]*p[26] + 2*p[12]*p[12]*p[22] - p[12]*p[12]*p[31] - 2*p[12]*p[14]*p[17] + p[12]*p[14]*p[26] + 2*p[13]*p[13]*p[22] - p[13]*p[13]*p[31] - 2*p[13]*p[14]*p[20] + p[13]*p[14]*p[29] - 2*p[22] + p[31]; + coeff[116] = (-p[7]*p[9] + p[8]*p[10] - p[11]*p[13] + p[12]*p[14])*p[15]; + coeff[117] = p[7]*p[8]*p[19] - p[7]*p[9]*p[16] + p[8]*p[10]*p[16] + p[9]*p[10]*p[19] + p[11]*p[12]*p[19] - p[11]*p[13]*p[16] + p[12]*p[14]*p[16] + p[13]*p[14]*p[19]; + coeff[118] = p[7]*p[8]*p[20] - p[7]*p[9]*p[17] - p[8]*p[8]*p[22] + p[8]*p[10]*p[17] - p[9]*p[9]*p[22] + p[9]*p[10]*p[20] + p[11]*p[12]*p[20] - p[11]*p[13]*p[17] - p[12]*p[12]*p[22] + p[12]*p[14]*p[17] - p[13]*p[13]*p[22] + p[13]*p[14]*p[20] + p[22]; + coeff[119] = 0; + coeff[120] = -2*p[7]*p[9]*p[15] + 2*p[7]*p[9]*p[24] + 2*p[8]*p[10]*p[15] - 2*p[8]*p[10]*p[24] - 2*p[11]*p[13]*p[15] + 2*p[11]*p[13]*p[24] + 2*p[12]*p[14]*p[15] - 2*p[12]*p[14]*p[24]; + coeff[121] = 2*p[7]*p[8]*p[19] - 2*p[7]*p[8]*p[28] - 2*p[7]*p[9]*p[16] + 2*p[7]*p[9]*p[25] + 2*p[8]*p[10]*p[16] - 2*p[8]*p[10]*p[25] + 2*p[9]*p[10]*p[19] - 2*p[9]*p[10]*p[28] + 2*p[11]*p[12]*p[19] - 2*p[11]*p[12]*p[28] - 2*p[11]*p[13]*p[16] + 2*p[11]*p[13]*p[25] + 2*p[12]*p[14]*p[16] - 2*p[12]*p[14]*p[25] + 2*p[13]*p[14]*p[19] - 2*p[13]*p[14]*p[28]; + coeff[122] = 2*p[7]*p[8]*p[20] - 2*p[7]*p[8]*p[29] - 2*p[7]*p[9]*p[17] + 2*p[7]*p[9]*p[26] - 2*p[8]*p[8]*p[22] + 2*p[8]*p[8]*p[31] + 2*p[8]*p[10]*p[17] - 2*p[8]*p[10]*p[26] - 2*p[9]*p[9]*p[22] + 2*p[9]*p[9]*p[31] + 2*p[9]*p[10]*p[20] - 2*p[9]*p[10]*p[29] + 2*p[11]*p[12]*p[20] - 2*p[11]*p[12]*p[29] - 2*p[11]*p[13]*p[17] + 2*p[11]*p[13]*p[26] - 2*p[12]*p[12]*p[22] + 2*p[12]*p[12]*p[31] + 2*p[12]*p[14]*p[17] - 2*p[12]*p[14]*p[26] - 2*p[13]*p[13]*p[22] + 2*p[13]*p[13]*p[31] + 2*p[13]*p[14]*p[20] - 2*p[13]*p[14]*p[29] + 2*p[22] - 2*p[31]; + coeff[123] = 2*p[7]*p[9]*p[15] - 2*p[7]*p[9]*p[24] - 2*p[8]*p[10]*p[15] + 2*p[8]*p[10]*p[24] + 2*p[11]*p[13]*p[15] - 2*p[11]*p[13]*p[24] - 2*p[12]*p[14]*p[15] + 2*p[12]*p[14]*p[24]; + coeff[124] = -2*p[7]*p[8]*p[19] + 2*p[7]*p[8]*p[28] + 2*p[7]*p[9]*p[16] - 2*p[7]*p[9]*p[25] - 2*p[8]*p[10]*p[16] + 2*p[8]*p[10]*p[25] - 2*p[9]*p[10]*p[19] + 2*p[9]*p[10]*p[28] - 2*p[11]*p[12]*p[19] + 2*p[11]*p[12]*p[28] + 2*p[11]*p[13]*p[16] - 2*p[11]*p[13]*p[25] - 2*p[12]*p[14]*p[16] + 2*p[12]*p[14]*p[25] - 2*p[13]*p[14]*p[19] + 2*p[13]*p[14]*p[28]; + coeff[125] = -2*p[7]*p[8]*p[20] + 2*p[7]*p[8]*p[29] + 2*p[7]*p[9]*p[17] - 2*p[7]*p[9]*p[26] + 2*p[8]*p[8]*p[22] - 2*p[8]*p[8]*p[31] - 2*p[8]*p[10]*p[17] + 2*p[8]*p[10]*p[26] + 2*p[9]*p[9]*p[22] - 2*p[9]*p[9]*p[31] - 2*p[9]*p[10]*p[20] + 2*p[9]*p[10]*p[29] - 2*p[11]*p[12]*p[20] + 2*p[11]*p[12]*p[29] + 2*p[11]*p[13]*p[17] - 2*p[11]*p[13]*p[26] + 2*p[12]*p[12]*p[22] - 2*p[12]*p[12]*p[31] - 2*p[12]*p[14]*p[17] + 2*p[12]*p[14]*p[26] + 2*p[13]*p[13]*p[22] - 2*p[13]*p[13]*p[31] - 2*p[13]*p[14]*p[20] + 2*p[13]*p[14]*p[29] - 2*p[22] + 2*p[31]; + coeff[126] = 2*p[0]*p[7]*p[11]*p[23] + 2*p[0]*p[7]*p[12]*p[21] - 2*p[0]*p[7]*p[13]*p[18] + 2*p[0]*p[8]*p[11]*p[21] - 2*p[0]*p[8]*p[12]*p[23] + 2*p[0]*p[8]*p[14]*p[18] - 2*p[0]*p[9]*p[11]*p[18] - 2*p[0]*p[9]*p[13]*p[23] + 2*p[0]*p[9]*p[14]*p[21] + 2*p[0]*p[10]*p[12]*p[18] + 2*p[0]*p[10]*p[13]*p[21] + 2*p[0]*p[10]*p[14]*p[23] - p[7]*p[8]*p[21] + p[7]*p[8]*p[30] + p[7]*p[9]*p[18] - p[7]*p[9]*p[27] + p[8]*p[8]*p[23] - p[8]*p[8]*p[32] - p[8]*p[10]*p[18] + p[8]*p[10]*p[27] + p[9]*p[9]*p[23] - p[9]*p[9]*p[32] - p[9]*p[10]*p[21] + p[9]*p[10]*p[30] + p[11]*p[12]*p[21] - p[11]*p[12]*p[30] - p[11]*p[13]*p[18] + p[11]*p[13]*p[27] - p[12]*p[12]*p[23] + p[12]*p[12]*p[32] + p[12]*p[14]*p[18] - p[12]*p[14]*p[27] - p[13]*p[13]*p[23] + p[13]*p[13]*p[32] + p[13]*p[14]*p[21] - p[13]*p[14]*p[30]; + coeff[127] = -2*p[0]*p[7]*p[13]*p[15] + 2*p[0]*p[8]*p[14]*p[15] - 2*p[0]*p[9]*p[11]*p[15] + 2*p[0]*p[10]*p[12]*p[15] + 2*p[7]*p[9]*p[15] - p[7]*p[9]*p[24] - 2*p[8]*p[10]*p[15] + p[8]*p[10]*p[24] - 2*p[11]*p[13]*p[15] + p[11]*p[13]*p[24] + 2*p[12]*p[14]*p[15] - p[12]*p[14]*p[24]; + coeff[128] = 2*p[0]*p[7]*p[12]*p[19] - 2*p[0]*p[7]*p[13]*p[16] + 2*p[0]*p[8]*p[11]*p[19] + 2*p[0]*p[8]*p[14]*p[16] - 2*p[0]*p[9]*p[11]*p[16] + 2*p[0]*p[9]*p[14]*p[19] + 2*p[0]*p[10]*p[12]*p[16] + 2*p[0]*p[10]*p[13]*p[19] - 2*p[7]*p[8]*p[19] + p[7]*p[8]*p[28] + 2*p[7]*p[9]*p[16] - p[7]*p[9]*p[25] - 2*p[8]*p[10]*p[16] + p[8]*p[10]*p[25] - 2*p[9]*p[10]*p[19] + p[9]*p[10]*p[28] + 2*p[11]*p[12]*p[19] - p[11]*p[12]*p[28] - 2*p[11]*p[13]*p[16] + p[11]*p[13]*p[25] + 2*p[12]*p[14]*p[16] - p[12]*p[14]*p[25] + 2*p[13]*p[14]*p[19] - p[13]*p[14]*p[28]; + coeff[129] = 2*p[0]*p[7]*p[11]*p[22] + 2*p[0]*p[7]*p[12]*p[20] - 2*p[0]*p[7]*p[13]*p[17] + 2*p[0]*p[8]*p[11]*p[20] - 2*p[0]*p[8]*p[12]*p[22] + 2*p[0]*p[8]*p[14]*p[17] - 2*p[0]*p[9]*p[11]*p[17] - 2*p[0]*p[9]*p[13]*p[22] + 2*p[0]*p[9]*p[14]*p[20] + 2*p[0]*p[10]*p[12]*p[17] + 2*p[0]*p[10]*p[13]*p[20] + 2*p[0]*p[10]*p[14]*p[22] - 2*p[7]*p[8]*p[20] + p[7]*p[8]*p[29] + 2*p[7]*p[9]*p[17] - p[7]*p[9]*p[26] + 2*p[8]*p[8]*p[22] - p[8]*p[8]*p[31] - 2*p[8]*p[10]*p[17] + p[8]*p[10]*p[26] + 2*p[9]*p[9]*p[22] - p[9]*p[9]*p[31] - 2*p[9]*p[10]*p[20] + p[9]*p[10]*p[29] + 2*p[11]*p[12]*p[20] - p[11]*p[12]*p[29] - 2*p[11]*p[13]*p[17] + p[11]*p[13]*p[26] - 2*p[12]*p[12]*p[22] + p[12]*p[12]*p[31] + 2*p[12]*p[14]*p[17] - p[12]*p[14]*p[26] - 2*p[13]*p[13]*p[22] + p[13]*p[13]*p[31] + 2*p[13]*p[14]*p[20] - p[13]*p[14]*p[29]; + coeff[130] = (-p[7]*p[9] + p[8]*p[10] + p[11]*p[13] - p[12]*p[14])*p[15]; + coeff[131] = p[7]*p[8]*p[19] - p[7]*p[9]*p[16] + p[8]*p[10]*p[16] + p[9]*p[10]*p[19] - p[11]*p[12]*p[19] + p[11]*p[13]*p[16] - p[12]*p[14]*p[16] - p[13]*p[14]*p[19]; + coeff[132] = p[7]*p[8]*p[20] - p[7]*p[9]*p[17] - p[8]*p[8]*p[22] + p[8]*p[10]*p[17] - p[9]*p[9]*p[22] + p[9]*p[10]*p[20] - p[11]*p[12]*p[20] + p[11]*p[13]*p[17] + p[12]*p[12]*p[22] - p[12]*p[14]*p[17] + p[13]*p[13]*p[22] - p[13]*p[14]*p[20]; + coeff[133] = 2*(-p[7]*p[11]*p[23] + p[7]*p[11]*p[32] - p[7]*p[12]*p[21] + p[7]*p[12]*p[30] + p[7]*p[13]*p[18] - p[7]*p[13]*p[27] - p[8]*p[11]*p[21] + p[8]*p[11]*p[30] + p[8]*p[12]*p[23] - p[8]*p[12]*p[32] - p[8]*p[14]*p[18] + p[8]*p[14]*p[27] + p[9]*p[11]*p[18] - p[9]*p[11]*p[27] + p[9]*p[13]*p[23] - p[9]*p[13]*p[32] - p[9]*p[14]*p[21] + p[9]*p[14]*p[30] - p[10]*p[12]*p[18] + p[10]*p[12]*p[27] - p[10]*p[13]*p[21] + p[10]*p[13]*p[30] - p[10]*p[14]*p[23] + p[10]*p[14]*p[32])*p[0]; + coeff[134] = 4*p[0]*p[7]*p[13]*p[15] - 2*p[0]*p[7]*p[13]*p[24] - 4*p[0]*p[8]*p[14]*p[15] + 2*p[0]*p[8]*p[14]*p[24] + 4*p[0]*p[9]*p[11]*p[15] - 2*p[0]*p[9]*p[11]*p[24] - 4*p[0]*p[10]*p[12]*p[15] + 2*p[0]*p[10]*p[12]*p[24] - 2*p[7]*p[9]*p[15] + 2*p[7]*p[9]*p[24] + 2*p[8]*p[10]*p[15] - 2*p[8]*p[10]*p[24] + 2*p[11]*p[13]*p[15] - 2*p[11]*p[13]*p[24] - 2*p[12]*p[14]*p[15] + 2*p[12]*p[14]*p[24]; + coeff[135] = -4*p[0]*p[7]*p[12]*p[19] + 2*p[0]*p[7]*p[12]*p[28] + 4*p[0]*p[7]*p[13]*p[16] - 2*p[0]*p[7]*p[13]*p[25] - 4*p[0]*p[8]*p[11]*p[19] + 2*p[0]*p[8]*p[11]*p[28] - 4*p[0]*p[8]*p[14]*p[16] + 2*p[0]*p[8]*p[14]*p[25] + 4*p[0]*p[9]*p[11]*p[16] - 2*p[0]*p[9]*p[11]*p[25] - 4*p[0]*p[9]*p[14]*p[19] + 2*p[0]*p[9]*p[14]*p[28] - 4*p[0]*p[10]*p[12]*p[16] + 2*p[0]*p[10]*p[12]*p[25] - 4*p[0]*p[10]*p[13]*p[19] + 2*p[0]*p[10]*p[13]*p[28] + 2*p[7]*p[8]*p[19] - 2*p[7]*p[8]*p[28] - 2*p[7]*p[9]*p[16] + 2*p[7]*p[9]*p[25] + 2*p[8]*p[10]*p[16] - 2*p[8]*p[10]*p[25] + 2*p[9]*p[10]*p[19] - 2*p[9]*p[10]*p[28] - 2*p[11]*p[12]*p[19] + 2*p[11]*p[12]*p[28] + 2*p[11]*p[13]*p[16] - 2*p[11]*p[13]*p[25] - 2*p[12]*p[14]*p[16] + 2*p[12]*p[14]*p[25] - 2*p[13]*p[14]*p[19] + 2*p[13]*p[14]*p[28]; + coeff[136] = -4*p[0]*p[7]*p[11]*p[22] + 2*p[0]*p[7]*p[11]*p[31] - 4*p[0]*p[7]*p[12]*p[20] + 2*p[0]*p[7]*p[12]*p[29] + 4*p[0]*p[7]*p[13]*p[17] - 2*p[0]*p[7]*p[13]*p[26] - 4*p[0]*p[8]*p[11]*p[20] + 2*p[0]*p[8]*p[11]*p[29] + 4*p[0]*p[8]*p[12]*p[22] - 2*p[0]*p[8]*p[12]*p[31] - 4*p[0]*p[8]*p[14]*p[17] + 2*p[0]*p[8]*p[14]*p[26] + 4*p[0]*p[9]*p[11]*p[17] - 2*p[0]*p[9]*p[11]*p[26] + 4*p[0]*p[9]*p[13]*p[22] - 2*p[0]*p[9]*p[13]*p[31] - 4*p[0]*p[9]*p[14]*p[20] + 2*p[0]*p[9]*p[14]*p[29] - 4*p[0]*p[10]*p[12]*p[17] + 2*p[0]*p[10]*p[12]*p[26] - 4*p[0]*p[10]*p[13]*p[20] + 2*p[0]*p[10]*p[13]*p[29] - 4*p[0]*p[10]*p[14]*p[22] + 2*p[0]*p[10]*p[14]*p[31] + 2*p[7]*p[8]*p[20] - 2*p[7]*p[8]*p[29] - 2*p[7]*p[9]*p[17] + 2*p[7]*p[9]*p[26] - 2*p[8]*p[8]*p[22] + 2*p[8]*p[8]*p[31] + 2*p[8]*p[10]*p[17] - 2*p[8]*p[10]*p[26] - 2*p[9]*p[9]*p[22] + 2*p[9]*p[9]*p[31] + 2*p[9]*p[10]*p[20] - 2*p[9]*p[10]*p[29] - 2*p[11]*p[12]*p[20] + 2*p[11]*p[12]*p[29] + 2*p[11]*p[13]*p[17] - 2*p[11]*p[13]*p[26] + 2*p[12]*p[12]*p[22] - 2*p[12]*p[12]*p[31] - 2*p[12]*p[14]*p[17] + 2*p[12]*p[14]*p[26] + 2*p[13]*p[13]*p[22] - 2*p[13]*p[13]*p[31] - 2*p[13]*p[14]*p[20] + 2*p[13]*p[14]*p[29]; + coeff[137] = -2*p[0]*p[7]*p[13]*p[15] + 2*p[0]*p[8]*p[14]*p[15] - 2*p[0]*p[9]*p[11]*p[15] + 2*p[0]*p[10]*p[12]*p[15] + 2*p[7]*p[9]*p[15] - 2*p[7]*p[9]*p[24] - 2*p[8]*p[10]*p[15] + 2*p[8]*p[10]*p[24] - 2*p[11]*p[13]*p[15] + 2*p[11]*p[13]*p[24] + 2*p[12]*p[14]*p[15] - 2*p[12]*p[14]*p[24]; + coeff[138] = 2*p[0]*p[7]*p[12]*p[19] - 2*p[0]*p[7]*p[13]*p[16] + 2*p[0]*p[8]*p[11]*p[19] + 2*p[0]*p[8]*p[14]*p[16] - 2*p[0]*p[9]*p[11]*p[16] + 2*p[0]*p[9]*p[14]*p[19] + 2*p[0]*p[10]*p[12]*p[16] + 2*p[0]*p[10]*p[13]*p[19] - 2*p[7]*p[8]*p[19] + 2*p[7]*p[8]*p[28] + 2*p[7]*p[9]*p[16] - 2*p[7]*p[9]*p[25] - 2*p[8]*p[10]*p[16] + 2*p[8]*p[10]*p[25] - 2*p[9]*p[10]*p[19] + 2*p[9]*p[10]*p[28] + 2*p[11]*p[12]*p[19] - 2*p[11]*p[12]*p[28] - 2*p[11]*p[13]*p[16] + 2*p[11]*p[13]*p[25] + 2*p[12]*p[14]*p[16] - 2*p[12]*p[14]*p[25] + 2*p[13]*p[14]*p[19] - 2*p[13]*p[14]*p[28]; + coeff[139] = 2*p[0]*p[7]*p[11]*p[22] + 2*p[0]*p[7]*p[12]*p[20] - 2*p[0]*p[7]*p[13]*p[17] + 2*p[0]*p[8]*p[11]*p[20] - 2*p[0]*p[8]*p[12]*p[22] + 2*p[0]*p[8]*p[14]*p[17] - 2*p[0]*p[9]*p[11]*p[17] - 2*p[0]*p[9]*p[13]*p[22] + 2*p[0]*p[9]*p[14]*p[20] + 2*p[0]*p[10]*p[12]*p[17] + 2*p[0]*p[10]*p[13]*p[20] + 2*p[0]*p[10]*p[14]*p[22] - 2*p[7]*p[8]*p[20] + 2*p[7]*p[8]*p[29] + 2*p[7]*p[9]*p[17] - 2*p[7]*p[9]*p[26] + 2*p[8]*p[8]*p[22] - 2*p[8]*p[8]*p[31] - 2*p[8]*p[10]*p[17] + 2*p[8]*p[10]*p[26] + 2*p[9]*p[9]*p[22] - 2*p[9]*p[9]*p[31] - 2*p[9]*p[10]*p[20] + 2*p[9]*p[10]*p[29] + 2*p[11]*p[12]*p[20] - 2*p[11]*p[12]*p[29] - 2*p[11]*p[13]*p[17] + 2*p[11]*p[13]*p[26] - 2*p[12]*p[12]*p[22] + 2*p[12]*p[12]*p[31] + 2*p[12]*p[14]*p[17] - 2*p[12]*p[14]*p[26] - 2*p[13]*p[13]*p[22] + 2*p[13]*p[13]*p[31] + 2*p[13]*p[14]*p[20] - 2*p[13]*p[14]*p[29]; + coeff[140] = 0; + coeff[141] = 2*(-p[7]*p[13]*p[15] + p[7]*p[13]*p[24] + p[8]*p[14]*p[15] - p[8]*p[14]*p[24] - p[9]*p[11]*p[15] + p[9]*p[11]*p[24] + p[10]*p[12]*p[15] - p[10]*p[12]*p[24])*p[0]; + coeff[142] = 2*(p[7]*p[12]*p[19] - p[7]*p[12]*p[28] - p[7]*p[13]*p[16] + p[7]*p[13]*p[25] + p[8]*p[11]*p[19] - p[8]*p[11]*p[28] + p[8]*p[14]*p[16] - p[8]*p[14]*p[25] - p[9]*p[11]*p[16] + p[9]*p[11]*p[25] + p[9]*p[14]*p[19] - p[9]*p[14]*p[28] + p[10]*p[12]*p[16] - p[10]*p[12]*p[25] + p[10]*p[13]*p[19] - p[10]*p[13]*p[28])*p[0]; + coeff[143] = 2*(p[7]*p[11]*p[22] - p[7]*p[11]*p[31] + p[7]*p[12]*p[20] - p[7]*p[12]*p[29] - p[7]*p[13]*p[17] + p[7]*p[13]*p[26] + p[8]*p[11]*p[20] - p[8]*p[11]*p[29] - p[8]*p[12]*p[22] + p[8]*p[12]*p[31] + p[8]*p[14]*p[17] - p[8]*p[14]*p[26] - p[9]*p[11]*p[17] + p[9]*p[11]*p[26] - p[9]*p[13]*p[22] + p[9]*p[13]*p[31] + p[9]*p[14]*p[20] - p[9]*p[14]*p[29] + p[10]*p[12]*p[17] - p[10]*p[12]*p[26] + p[10]*p[13]*p[20] - p[10]*p[13]*p[29] + p[10]*p[14]*p[22] - p[10]*p[14]*p[31])*p[0]; + coeff[144] = 2*(p[7]*p[13]*p[15] - p[7]*p[13]*p[24] - p[8]*p[14]*p[15] + p[8]*p[14]*p[24] + p[9]*p[11]*p[15] - p[9]*p[11]*p[24] - p[10]*p[12]*p[15] + p[10]*p[12]*p[24])*p[0]; + coeff[145] = 2*(-p[7]*p[12]*p[19] + p[7]*p[12]*p[28] + p[7]*p[13]*p[16] - p[7]*p[13]*p[25] - p[8]*p[11]*p[19] + p[8]*p[11]*p[28] - p[8]*p[14]*p[16] + p[8]*p[14]*p[25] + p[9]*p[11]*p[16] - p[9]*p[11]*p[25] - p[9]*p[14]*p[19] + p[9]*p[14]*p[28] - p[10]*p[12]*p[16] + p[10]*p[12]*p[25] - p[10]*p[13]*p[19] + p[10]*p[13]*p[28])*p[0]; + coeff[146] = 2*(-p[7]*p[11]*p[22] + p[7]*p[11]*p[31] - p[7]*p[12]*p[20] + p[7]*p[12]*p[29] + p[7]*p[13]*p[17] - p[7]*p[13]*p[26] - p[8]*p[11]*p[20] + p[8]*p[11]*p[29] + p[8]*p[12]*p[22] - p[8]*p[12]*p[31] - p[8]*p[14]*p[17] + p[8]*p[14]*p[26] + p[9]*p[11]*p[17] - p[9]*p[11]*p[26] + p[9]*p[13]*p[22] - p[9]*p[13]*p[31] - p[9]*p[14]*p[20] + p[9]*p[14]*p[29] - p[10]*p[12]*p[17] + p[10]*p[12]*p[26] - p[10]*p[13]*p[20] + p[10]*p[13]*p[29] - p[10]*p[14]*p[22] + p[10]*p[14]*p[31])*p[0]; + coeff[147] = -2*p[0]*p[7]*p[8]*p[21] + 2*p[0]*p[7]*p[9]*p[18] + 2*p[0]*p[8]*p[8]*p[23] - 2*p[0]*p[8]*p[10]*p[18] + 2*p[0]*p[9]*p[9]*p[23] - 2*p[0]*p[9]*p[10]*p[21] + 2*p[0]*p[11]*p[12]*p[21] - 2*p[0]*p[11]*p[13]*p[18] - 2*p[0]*p[12]*p[12]*p[23] + 2*p[0]*p[12]*p[14]*p[18] - 2*p[0]*p[13]*p[13]*p[23] + 2*p[0]*p[13]*p[14]*p[21] - p[7]*p[11]*p[23] + p[7]*p[11]*p[32] - p[7]*p[12]*p[21] + p[7]*p[12]*p[30] + p[7]*p[13]*p[18] - p[7]*p[13]*p[27] - p[8]*p[11]*p[21] + p[8]*p[11]*p[30] + p[8]*p[12]*p[23] - p[8]*p[12]*p[32] - p[8]*p[14]*p[18] + p[8]*p[14]*p[27] + p[9]*p[11]*p[18] - p[9]*p[11]*p[27] + p[9]*p[13]*p[23] - p[9]*p[13]*p[32] - p[9]*p[14]*p[21] + p[9]*p[14]*p[30] - p[10]*p[12]*p[18] + p[10]*p[12]*p[27] - p[10]*p[13]*p[21] + p[10]*p[13]*p[30] - p[10]*p[14]*p[23] + p[10]*p[14]*p[32]; + coeff[148] = 2*p[0]*p[7]*p[9]*p[15] - 2*p[0]*p[8]*p[10]*p[15] - 2*p[0]*p[11]*p[13]*p[15] + 2*p[0]*p[12]*p[14]*p[15] + 2*p[7]*p[13]*p[15] - p[7]*p[13]*p[24] - 2*p[8]*p[14]*p[15] + p[8]*p[14]*p[24] + 2*p[9]*p[11]*p[15] - p[9]*p[11]*p[24] - 2*p[10]*p[12]*p[15] + p[10]*p[12]*p[24]; + coeff[149] = -2*p[0]*p[7]*p[8]*p[19] + 2*p[0]*p[7]*p[9]*p[16] - 2*p[0]*p[8]*p[10]*p[16] - 2*p[0]*p[9]*p[10]*p[19] + 2*p[0]*p[11]*p[12]*p[19] - 2*p[0]*p[11]*p[13]*p[16] + 2*p[0]*p[12]*p[14]*p[16] + 2*p[0]*p[13]*p[14]*p[19] - 2*p[7]*p[12]*p[19] + p[7]*p[12]*p[28] + 2*p[7]*p[13]*p[16] - p[7]*p[13]*p[25] - 2*p[8]*p[11]*p[19] + p[8]*p[11]*p[28] - 2*p[8]*p[14]*p[16] + p[8]*p[14]*p[25] + 2*p[9]*p[11]*p[16] - p[9]*p[11]*p[25] - 2*p[9]*p[14]*p[19] + p[9]*p[14]*p[28] - 2*p[10]*p[12]*p[16] + p[10]*p[12]*p[25] - 2*p[10]*p[13]*p[19] + p[10]*p[13]*p[28]; + coeff[150] = -2*p[0]*p[7]*p[8]*p[20] + 2*p[0]*p[7]*p[9]*p[17] + 2*p[0]*p[8]*p[8]*p[22] - 2*p[0]*p[8]*p[10]*p[17] + 2*p[0]*p[9]*p[9]*p[22] - 2*p[0]*p[9]*p[10]*p[20] + 2*p[0]*p[11]*p[12]*p[20] - 2*p[0]*p[11]*p[13]*p[17] - 2*p[0]*p[12]*p[12]*p[22] + 2*p[0]*p[12]*p[14]*p[17] - 2*p[0]*p[13]*p[13]*p[22] + 2*p[0]*p[13]*p[14]*p[20] - 2*p[7]*p[11]*p[22] + p[7]*p[11]*p[31] - 2*p[7]*p[12]*p[20] + p[7]*p[12]*p[29] + 2*p[7]*p[13]*p[17] - p[7]*p[13]*p[26] - 2*p[8]*p[11]*p[20] + p[8]*p[11]*p[29] + 2*p[8]*p[12]*p[22] - p[8]*p[12]*p[31] - 2*p[8]*p[14]*p[17] + p[8]*p[14]*p[26] + 2*p[9]*p[11]*p[17] - p[9]*p[11]*p[26] + 2*p[9]*p[13]*p[22] - p[9]*p[13]*p[31] - 2*p[9]*p[14]*p[20] + p[9]*p[14]*p[29] - 2*p[10]*p[12]*p[17] + p[10]*p[12]*p[26] - 2*p[10]*p[13]*p[20] + p[10]*p[13]*p[29] - 2*p[10]*p[14]*p[22] + p[10]*p[14]*p[31]; + coeff[151] = (-p[7]*p[13] + p[8]*p[14] - p[9]*p[11] + p[10]*p[12])*p[15]; + coeff[152] = p[7]*p[12]*p[19] - p[7]*p[13]*p[16] + p[8]*p[11]*p[19] + p[8]*p[14]*p[16] - p[9]*p[11]*p[16] + p[9]*p[14]*p[19] + p[10]*p[12]*p[16] + p[10]*p[13]*p[19]; + coeff[153] = p[7]*p[11]*p[22] + p[7]*p[12]*p[20] - p[7]*p[13]*p[17] + p[8]*p[11]*p[20] - p[8]*p[12]*p[22] + p[8]*p[14]*p[17] - p[9]*p[11]*p[17] - p[9]*p[13]*p[22] + p[9]*p[14]*p[20] + p[10]*p[12]*p[17] + p[10]*p[13]*p[20] + p[10]*p[14]*p[22]; + coeff[154] = 2*(p[7]*p[8]*p[21] - p[7]*p[8]*p[30] - p[7]*p[9]*p[18] + p[7]*p[9]*p[27] - p[8]*p[8]*p[23] + p[8]*p[8]*p[32] + p[8]*p[10]*p[18] - p[8]*p[10]*p[27] - p[9]*p[9]*p[23] + p[9]*p[9]*p[32] + p[9]*p[10]*p[21] - p[9]*p[10]*p[30] - p[11]*p[12]*p[21] + p[11]*p[12]*p[30] + p[11]*p[13]*p[18] - p[11]*p[13]*p[27] + p[12]*p[12]*p[23] - p[12]*p[12]*p[32] - p[12]*p[14]*p[18] + p[12]*p[14]*p[27] + p[13]*p[13]*p[23] - p[13]*p[13]*p[32] - p[13]*p[14]*p[21] + p[13]*p[14]*p[30])*p[0]; + coeff[155] = -4*p[0]*p[7]*p[9]*p[15] + 2*p[0]*p[7]*p[9]*p[24] + 4*p[0]*p[8]*p[10]*p[15] - 2*p[0]*p[8]*p[10]*p[24] + 4*p[0]*p[11]*p[13]*p[15] - 2*p[0]*p[11]*p[13]*p[24] - 4*p[0]*p[12]*p[14]*p[15] + 2*p[0]*p[12]*p[14]*p[24] - 2*p[7]*p[13]*p[15] + 2*p[7]*p[13]*p[24] + 2*p[8]*p[14]*p[15] - 2*p[8]*p[14]*p[24] - 2*p[9]*p[11]*p[15] + 2*p[9]*p[11]*p[24] + 2*p[10]*p[12]*p[15] - 2*p[10]*p[12]*p[24]; + coeff[156] = 4*p[0]*p[7]*p[8]*p[19] - 2*p[0]*p[7]*p[8]*p[28] - 4*p[0]*p[7]*p[9]*p[16] + 2*p[0]*p[7]*p[9]*p[25] + 4*p[0]*p[8]*p[10]*p[16] - 2*p[0]*p[8]*p[10]*p[25] + 4*p[0]*p[9]*p[10]*p[19] - 2*p[0]*p[9]*p[10]*p[28] - 4*p[0]*p[11]*p[12]*p[19] + 2*p[0]*p[11]*p[12]*p[28] + 4*p[0]*p[11]*p[13]*p[16] - 2*p[0]*p[11]*p[13]*p[25] - 4*p[0]*p[12]*p[14]*p[16] + 2*p[0]*p[12]*p[14]*p[25] - 4*p[0]*p[13]*p[14]*p[19] + 2*p[0]*p[13]*p[14]*p[28] + 2*p[7]*p[12]*p[19] - 2*p[7]*p[12]*p[28] - 2*p[7]*p[13]*p[16] + 2*p[7]*p[13]*p[25] + 2*p[8]*p[11]*p[19] - 2*p[8]*p[11]*p[28] + 2*p[8]*p[14]*p[16] - 2*p[8]*p[14]*p[25] - 2*p[9]*p[11]*p[16] + 2*p[9]*p[11]*p[25] + 2*p[9]*p[14]*p[19] - 2*p[9]*p[14]*p[28] + 2*p[10]*p[12]*p[16] - 2*p[10]*p[12]*p[25] + 2*p[10]*p[13]*p[19] - 2*p[10]*p[13]*p[28]; + coeff[157] = 4*p[0]*p[7]*p[8]*p[20] - 2*p[0]*p[7]*p[8]*p[29] - 4*p[0]*p[7]*p[9]*p[17] + 2*p[0]*p[7]*p[9]*p[26] - 4*p[0]*p[8]*p[8]*p[22] + 2*p[0]*p[8]*p[8]*p[31] + 4*p[0]*p[8]*p[10]*p[17] - 2*p[0]*p[8]*p[10]*p[26] - 4*p[0]*p[9]*p[9]*p[22] + 2*p[0]*p[9]*p[9]*p[31] + 4*p[0]*p[9]*p[10]*p[20] - 2*p[0]*p[9]*p[10]*p[29] - 4*p[0]*p[11]*p[12]*p[20] + 2*p[0]*p[11]*p[12]*p[29] + 4*p[0]*p[11]*p[13]*p[17] - 2*p[0]*p[11]*p[13]*p[26] + 4*p[0]*p[12]*p[12]*p[22] - 2*p[0]*p[12]*p[12]*p[31] - 4*p[0]*p[12]*p[14]*p[17] + 2*p[0]*p[12]*p[14]*p[26] + 4*p[0]*p[13]*p[13]*p[22] - 2*p[0]*p[13]*p[13]*p[31] - 4*p[0]*p[13]*p[14]*p[20] + 2*p[0]*p[13]*p[14]*p[29] + 2*p[7]*p[11]*p[22] - 2*p[7]*p[11]*p[31] + 2*p[7]*p[12]*p[20] - 2*p[7]*p[12]*p[29] - 2*p[7]*p[13]*p[17] + 2*p[7]*p[13]*p[26] + 2*p[8]*p[11]*p[20] - 2*p[8]*p[11]*p[29] - 2*p[8]*p[12]*p[22] + 2*p[8]*p[12]*p[31] + 2*p[8]*p[14]*p[17] - 2*p[8]*p[14]*p[26] - 2*p[9]*p[11]*p[17] + 2*p[9]*p[11]*p[26] - 2*p[9]*p[13]*p[22] + 2*p[9]*p[13]*p[31] + 2*p[9]*p[14]*p[20] - 2*p[9]*p[14]*p[29] + 2*p[10]*p[12]*p[17] - 2*p[10]*p[12]*p[26] + 2*p[10]*p[13]*p[20] - 2*p[10]*p[13]*p[29] + 2*p[10]*p[14]*p[22] - 2*p[10]*p[14]*p[31]; + coeff[158] = 2*p[0]*p[7]*p[9]*p[15] - 2*p[0]*p[8]*p[10]*p[15] - 2*p[0]*p[11]*p[13]*p[15] + 2*p[0]*p[12]*p[14]*p[15] + 2*p[7]*p[13]*p[15] - 2*p[7]*p[13]*p[24] - 2*p[8]*p[14]*p[15] + 2*p[8]*p[14]*p[24] + 2*p[9]*p[11]*p[15] - 2*p[9]*p[11]*p[24] - 2*p[10]*p[12]*p[15] + 2*p[10]*p[12]*p[24]; + coeff[159] = -2*p[0]*p[7]*p[8]*p[19] + 2*p[0]*p[7]*p[9]*p[16] - 2*p[0]*p[8]*p[10]*p[16] - 2*p[0]*p[9]*p[10]*p[19] + 2*p[0]*p[11]*p[12]*p[19] - 2*p[0]*p[11]*p[13]*p[16] + 2*p[0]*p[12]*p[14]*p[16] + 2*p[0]*p[13]*p[14]*p[19] - 2*p[7]*p[12]*p[19] + 2*p[7]*p[12]*p[28] + 2*p[7]*p[13]*p[16] - 2*p[7]*p[13]*p[25] - 2*p[8]*p[11]*p[19] + 2*p[8]*p[11]*p[28] - 2*p[8]*p[14]*p[16] + 2*p[8]*p[14]*p[25] + 2*p[9]*p[11]*p[16] - 2*p[9]*p[11]*p[25] - 2*p[9]*p[14]*p[19] + 2*p[9]*p[14]*p[28] - 2*p[10]*p[12]*p[16] + 2*p[10]*p[12]*p[25] - 2*p[10]*p[13]*p[19] + 2*p[10]*p[13]*p[28]; + coeff[160] = -2*p[0]*p[7]*p[8]*p[20] + 2*p[0]*p[7]*p[9]*p[17] + 2*p[0]*p[8]*p[8]*p[22] - 2*p[0]*p[8]*p[10]*p[17] + 2*p[0]*p[9]*p[9]*p[22] - 2*p[0]*p[9]*p[10]*p[20] + 2*p[0]*p[11]*p[12]*p[20] - 2*p[0]*p[11]*p[13]*p[17] - 2*p[0]*p[12]*p[12]*p[22] + 2*p[0]*p[12]*p[14]*p[17] - 2*p[0]*p[13]*p[13]*p[22] + 2*p[0]*p[13]*p[14]*p[20] - 2*p[7]*p[11]*p[22] + 2*p[7]*p[11]*p[31] - 2*p[7]*p[12]*p[20] + 2*p[7]*p[12]*p[29] + 2*p[7]*p[13]*p[17] - 2*p[7]*p[13]*p[26] - 2*p[8]*p[11]*p[20] + 2*p[8]*p[11]*p[29] + 2*p[8]*p[12]*p[22] - 2*p[8]*p[12]*p[31] - 2*p[8]*p[14]*p[17] + 2*p[8]*p[14]*p[26] + 2*p[9]*p[11]*p[17] - 2*p[9]*p[11]*p[26] + 2*p[9]*p[13]*p[22] - 2*p[9]*p[13]*p[31] - 2*p[9]*p[14]*p[20] + 2*p[9]*p[14]*p[29] - 2*p[10]*p[12]*p[17] + 2*p[10]*p[12]*p[26] - 2*p[10]*p[13]*p[20] + 2*p[10]*p[13]*p[29] - 2*p[10]*p[14]*p[22] + 2*p[10]*p[14]*p[31]; + coeff[161] = 0; + coeff[162] = 2*(p[7]*p[9]*p[15] - p[7]*p[9]*p[24] - p[8]*p[10]*p[15] + p[8]*p[10]*p[24] - p[11]*p[13]*p[15] + p[11]*p[13]*p[24] + p[12]*p[14]*p[15] - p[12]*p[14]*p[24])*p[0]; + coeff[163] = 2*(-p[7]*p[8]*p[19] + p[7]*p[8]*p[28] + p[7]*p[9]*p[16] - p[7]*p[9]*p[25] - p[8]*p[10]*p[16] + p[8]*p[10]*p[25] - p[9]*p[10]*p[19] + p[9]*p[10]*p[28] + p[11]*p[12]*p[19] - p[11]*p[12]*p[28] - p[11]*p[13]*p[16] + p[11]*p[13]*p[25] + p[12]*p[14]*p[16] - p[12]*p[14]*p[25] + p[13]*p[14]*p[19] - p[13]*p[14]*p[28])*p[0]; + coeff[164] = 2*(-p[7]*p[8]*p[20] + p[7]*p[8]*p[29] + p[7]*p[9]*p[17] - p[7]*p[9]*p[26] + p[8]*p[8]*p[22] - p[8]*p[8]*p[31] - p[8]*p[10]*p[17] + p[8]*p[10]*p[26] + p[9]*p[9]*p[22] - p[9]*p[9]*p[31] - p[9]*p[10]*p[20] + p[9]*p[10]*p[29] + p[11]*p[12]*p[20] - p[11]*p[12]*p[29] - p[11]*p[13]*p[17] + p[11]*p[13]*p[26] - p[12]*p[12]*p[22] + p[12]*p[12]*p[31] + p[12]*p[14]*p[17] - p[12]*p[14]*p[26] - p[13]*p[13]*p[22] + p[13]*p[13]*p[31] + p[13]*p[14]*p[20] - p[13]*p[14]*p[29])*p[0]; + coeff[165] = 2*(-p[7]*p[9]*p[15] + p[7]*p[9]*p[24] + p[8]*p[10]*p[15] - p[8]*p[10]*p[24] + p[11]*p[13]*p[15] - p[11]*p[13]*p[24] - p[12]*p[14]*p[15] + p[12]*p[14]*p[24])*p[0]; + coeff[166] = 2*(p[7]*p[8]*p[19] - p[7]*p[8]*p[28] - p[7]*p[9]*p[16] + p[7]*p[9]*p[25] + p[8]*p[10]*p[16] - p[8]*p[10]*p[25] + p[9]*p[10]*p[19] - p[9]*p[10]*p[28] - p[11]*p[12]*p[19] + p[11]*p[12]*p[28] + p[11]*p[13]*p[16] - p[11]*p[13]*p[25] - p[12]*p[14]*p[16] + p[12]*p[14]*p[25] - p[13]*p[14]*p[19] + p[13]*p[14]*p[28])*p[0]; + coeff[167] = 2*(p[7]*p[8]*p[20] - p[7]*p[8]*p[29] - p[7]*p[9]*p[17] + p[7]*p[9]*p[26] - p[8]*p[8]*p[22] + p[8]*p[8]*p[31] + p[8]*p[10]*p[17] - p[8]*p[10]*p[26] - p[9]*p[9]*p[22] + p[9]*p[9]*p[31] + p[9]*p[10]*p[20] - p[9]*p[10]*p[29] - p[11]*p[12]*p[20] + p[11]*p[12]*p[29] + p[11]*p[13]*p[17] - p[11]*p[13]*p[26] + p[12]*p[12]*p[22] - p[12]*p[12]*p[31] - p[12]*p[14]*p[17] + p[12]*p[14]*p[26] + p[13]*p[13]*p[22] - p[13]*p[13]*p[31] - p[13]*p[14]*p[20] + p[13]*p[14]*p[29])*p[0]; +} + +} // namespace embree diff --git a/thirdparty/embree/kernels/common/point_query.h b/thirdparty/embree/kernels/common/point_query.h new file mode 100644 index 000000000000..27d158ca3ab0 --- /dev/null +++ b/thirdparty/embree/kernels/common/point_query.h @@ -0,0 +1,136 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" + +namespace embree +{ + /* Point query structure for closest point query */ + template + struct RTC_ALIGN(16) PointQueryK + { + /* Default construction does nothing */ + __forceinline PointQueryK() {} + + /* Constructs a ray from origin, direction, and ray segment. Near + * has to be smaller than far */ + __forceinline PointQueryK(const Vec3vf& p, const vfloat& radius = inf, const vfloat& time = zero) + : p(p), time(time), radius(radius) {} + + /* Returns the size of the ray */ + static __forceinline size_t size() { return K; } + + /* Calculates if this is a valid ray that does not cause issues during traversal */ + __forceinline vbool valid() const + { + const vbool vx = (abs(p.x) <= vfloat(FLT_LARGE)); + const vbool vy = (abs(p.y) <= vfloat(FLT_LARGE)); + const vbool vz = (abs(p.z) <= vfloat(FLT_LARGE)); + const vbool vn = radius >= vfloat(0); + const vbool vf = abs(time) < vfloat(inf); + return vx & vy & vz & vn & vf; + } + + __forceinline void get(PointQueryK<1>* ray) const; + __forceinline void get(size_t i, PointQueryK<1>& ray) const; + __forceinline void set(const PointQueryK<1>* ray); + __forceinline void set(size_t i, const PointQueryK<1>& ray); + + Vec3vf p; // location of the query point + vfloat time; // time for motion blur + vfloat radius; // radius for the point query + }; + + /* Specialization for a single point query */ + template<> + struct RTC_ALIGN(16) PointQueryK<1> + { + /* Default construction does nothing */ + __forceinline PointQueryK() {} + + /* Constructs a ray from origin, direction, and ray segment. Near + * has to be smaller than far */ + __forceinline PointQueryK(const Vec3fa& p, float radius = inf, float time = zero) + : p(p), time(time), radius(radius) {} + + /* Calculates if this is a valid ray that does not cause issues during traversal */ + __forceinline bool valid() const { + return all(le_mask(abs(Vec3fa(p)), Vec3fa(FLT_LARGE)) & le_mask(Vec3fa(0.f), Vec3fa(radius))) && abs(time) < float(inf); + } + + Vec3f p; + float time; + float radius; + }; + + /* Converts point query packet to single point query */ + template + __forceinline void PointQueryK::get(PointQueryK<1>* query) const + { + for (size_t i = 0; i < K; i++) // FIXME: use SIMD transpose + { + query[i].p.x = p.x[i]; + query[i].p.y = p.y[i]; + query[i].p.z = p.z[i]; + query[i].time = time[i]; + query[i].radius = radius[i]; + } + } + + /* Extracts a single point query out of a point query packet*/ + template + __forceinline void PointQueryK::get(size_t i, PointQueryK<1>& query) const + { + query.p.x = p.x[i]; + query.p.y = p.y[i]; + query.p.z = p.z[i]; + query.radius = radius[i]; + query.time = time[i]; + } + + /* Converts single point query to point query packet */ + template + __forceinline void PointQueryK::set(const PointQueryK<1>* query) + { + for (size_t i = 0; i < K; i++) + { + p.x[i] = query[i].p.x; + p.y[i] = query[i].p.y; + p.z[i] = query[i].p.z; + radius[i] = query[i].radius; + time[i] = query[i].time; + } + } + + /* inserts a single point query into a point query packet element */ + template + __forceinline void PointQueryK::set(size_t i, const PointQueryK<1>& query) + { + p.x[i] = query.p.x; + p.y[i] = query.p.y; + p.z[i] = query.p.z; + radius[i] = query.radius; + time[i] = query.time; + } + + /* Shortcuts */ + typedef PointQueryK<1> PointQuery; + typedef PointQueryK<4> PointQuery4; + typedef PointQueryK<8> PointQuery8; + typedef PointQueryK<16> PointQuery16; + struct PointQueryN; + + /* Outputs point query to stream */ + template + __forceinline embree_ostream operator <<(embree_ostream cout, const PointQueryK& query) + { + cout << "{ " << embree_endl + << " p = " << query.p << embree_endl + << " r = " << query.radius << embree_endl + << " time = " << query.time << embree_endl + << "}"; + return cout; + } +} diff --git a/thirdparty/embree/kernels/common/primref.h b/thirdparty/embree/kernels/common/primref.h new file mode 100644 index 000000000000..3d4f9c0d44d8 --- /dev/null +++ b/thirdparty/embree/kernels/common/primref.h @@ -0,0 +1,138 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" + +namespace embree +{ + /*! A primitive reference stores the bounds of the primitive and its ID. */ + struct __aligned(32) PrimRef + { + __forceinline PrimRef () {} + +#if defined(__AVX__) + __forceinline PrimRef(const PrimRef& v) { + vfloat8::store((float*)this,vfloat8::load((float*)&v)); + } + __forceinline PrimRef& operator=(const PrimRef& v) { + vfloat8::store((float*)this,vfloat8::load((float*)&v)); return *this; + } +#endif + + __forceinline PrimRef (const BBox3fa& bounds, unsigned int geomID, unsigned int primID) + { + lower = Vec3fx(bounds.lower, geomID); + upper = Vec3fx(bounds.upper, primID); + } + + __forceinline PrimRef (const BBox3fa& bounds, size_t id) + { +#if defined(__X86_64__) + lower = Vec3fx(bounds.lower, (unsigned)(id & 0xFFFFFFFF)); + upper = Vec3fx(bounds.upper, (unsigned)((id >> 32) & 0xFFFFFFFF)); +#else + lower = Vec3fx(bounds.lower, (unsigned)id); + upper = Vec3fx(bounds.upper, (unsigned)0); +#endif + } + + /*! calculates twice the center of the primitive */ + __forceinline const Vec3fa center2() const { + return lower+upper; + } + + /*! return the bounding box of the primitive */ + __forceinline const BBox3fa bounds() const { + return BBox3fa(lower,upper); + } + + /*! size for bin heuristic is 1 */ + __forceinline unsigned size() const { + return 1; + } + + /*! returns bounds and centroid used for binning */ + __forceinline void binBoundsAndCenter(BBox3fa& bounds_o, Vec3fa& center_o) const + { + bounds_o = bounds(); + center_o = embree::center2(bounds_o); + } + + __forceinline unsigned& geomIDref() { // FIXME: remove !!!!!!! + return lower.u; + } + __forceinline unsigned& primIDref() { // FIXME: remove !!!!!!! + return upper.u; + } + + /*! returns the geometry ID */ + __forceinline unsigned geomID() const { + return lower.a; + } + + /*! returns the primitive ID */ + __forceinline unsigned primID() const { + return upper.a; + } + + /*! returns an size_t sized ID */ + __forceinline size_t ID() const { +#if defined(__X86_64__) + return size_t(lower.u) + (size_t(upper.u) << 32); +#else + return size_t(lower.u); +#endif + } + + /*! special function for operator< */ + __forceinline uint64_t ID64() const { + return (((uint64_t)primID()) << 32) + (uint64_t)geomID(); + } + + /*! allows sorting the primrefs by ID */ + friend __forceinline bool operator<(const PrimRef& p0, const PrimRef& p1) { + return p0.ID64() < p1.ID64(); + } + + /*! Outputs primitive reference to a stream. */ + friend __forceinline embree_ostream operator<<(embree_ostream cout, const PrimRef& ref) { + return cout << "{ lower = " << ref.lower << ", upper = " << ref.upper << ", geomID = " << ref.geomID() << ", primID = " << ref.primID() << " }"; + } + + public: + Vec3fx lower; //!< lower bounds and geomID + Vec3fx upper; //!< upper bounds and primID + }; + + /*! fast exchange for PrimRefs */ + __forceinline void xchg(PrimRef& a, PrimRef& b) + { +#if defined(__AVX__) + const vfloat8 aa = vfloat8::load((float*)&a); + const vfloat8 bb = vfloat8::load((float*)&b); + vfloat8::store((float*)&a,bb); + vfloat8::store((float*)&b,aa); +#else + std::swap(a,b); +#endif + } + + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + /************************************************************************************/ + + struct SubGridBuildData { + unsigned short sx,sy; + unsigned int primID; + + __forceinline SubGridBuildData() {}; + __forceinline SubGridBuildData(const unsigned int sx, const unsigned int sy, const unsigned int primID) : sx(sx), sy(sy), primID(primID) {}; + + __forceinline size_t x() const { return (size_t)sx & 0x7fff; } + __forceinline size_t y() const { return (size_t)sy & 0x7fff; } + + }; +} diff --git a/thirdparty/embree/kernels/common/primref_mb.h b/thirdparty/embree/kernels/common/primref_mb.h new file mode 100644 index 000000000000..97a3fbe4e6b7 --- /dev/null +++ b/thirdparty/embree/kernels/common/primref_mb.h @@ -0,0 +1,262 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" + +#define MBLUR_BIN_LBBOX 1 + +namespace embree +{ +#if MBLUR_BIN_LBBOX + + /*! A primitive reference stores the bounds of the primitive and its ID. */ + struct PrimRefMB + { + typedef LBBox3fa BBox; + + __forceinline PrimRefMB () {} + + __forceinline PrimRefMB (const LBBox3fa& lbounds_i, unsigned int activeTimeSegments, BBox1f time_range, unsigned int totalTimeSegments, unsigned int geomID, unsigned int primID) + : lbounds((LBBox3fx)lbounds_i), time_range(time_range) + { + assert(activeTimeSegments > 0); + lbounds.bounds0.lower.a = geomID; + lbounds.bounds0.upper.a = primID; + lbounds.bounds1.lower.a = activeTimeSegments; + lbounds.bounds1.upper.a = totalTimeSegments; + } + + __forceinline PrimRefMB (EmptyTy empty, const LBBox3fa& lbounds_i, unsigned int activeTimeSegments, BBox1f time_range, unsigned int totalTimeSegments, size_t id) + : lbounds((LBBox3fx)lbounds_i), time_range(time_range) + { + assert(activeTimeSegments > 0); +#if defined(__X86_64__) + lbounds.bounds0.lower.a = id & 0xFFFFFFFF; + lbounds.bounds0.upper.a = (id >> 32) & 0xFFFFFFFF; +#else + lbounds.bounds0.lower.a = id; + lbounds.bounds0.upper.a = 0; +#endif + lbounds.bounds1.lower.a = activeTimeSegments; + lbounds.bounds1.upper.a = totalTimeSegments; + } + + __forceinline PrimRefMB (const LBBox3fa& lbounds_i, unsigned int activeTimeSegments, BBox1f time_range, unsigned int totalTimeSegments, size_t id) + : lbounds((LBBox3fx)lbounds_i), time_range(time_range) + { + assert(activeTimeSegments > 0); +#if defined(__X86_64__) + lbounds.bounds0.lower.u = id & 0xFFFFFFFF; + lbounds.bounds0.upper.u = (id >> 32) & 0xFFFFFFFF; +#else + lbounds.bounds0.lower.u = id; + lbounds.bounds0.upper.u = 0; +#endif + lbounds.bounds1.lower.a = activeTimeSegments; + lbounds.bounds1.upper.a = totalTimeSegments; + } + + /*! returns bounds for binning */ + __forceinline LBBox3fa bounds() const { + return lbounds; + } + + /*! returns the number of time segments of this primref */ + __forceinline unsigned size() const { + return lbounds.bounds1.lower.a; + } + + __forceinline unsigned totalTimeSegments() const { + return lbounds.bounds1.upper.a; + } + + /* calculate overlapping time segment range */ + __forceinline range timeSegmentRange(const BBox1f& range) const { + return getTimeSegmentRange(range,time_range,float(totalTimeSegments())); + } + + /* returns time that corresponds to time step */ + __forceinline float timeStep(const int i) const { + assert(i>=0 && i<=(int)totalTimeSegments()); + return time_range.lower + time_range.size()*float(i)/float(totalTimeSegments()); + } + + /*! checks if time range overlaps */ + __forceinline bool time_range_overlap(const BBox1f& range) const + { + if (0.9999f*time_range.upper <= range.lower) return false; + if (1.0001f*time_range.lower >= range.upper) return false; + return true; + } + + /*! returns center for binning */ + __forceinline Vec3fa binCenter() const { + return center2(lbounds.interpolate(0.5f)); + } + + /*! returns bounds and centroid used for binning */ + __forceinline void binBoundsAndCenter(LBBox3fa& bounds_o, Vec3fa& center_o) const + { + bounds_o = bounds(); + center_o = binCenter(); + } + + /*! returns the geometry ID */ + __forceinline unsigned geomID() const { + return lbounds.bounds0.lower.a; + } + + /*! returns the primitive ID */ + __forceinline unsigned primID() const { + return lbounds.bounds0.upper.a; + } + + /*! returns an size_t sized ID */ + __forceinline size_t ID() const { +#if defined(__X86_64__) + return size_t(lbounds.bounds0.lower.u) + (size_t(lbounds.bounds0.upper.u) << 32); +#else + return size_t(lbounds.bounds0.lower.u); +#endif + } + + /*! special function for operator< */ + __forceinline uint64_t ID64() const { + return (((uint64_t)primID()) << 32) + (uint64_t)geomID(); + } + + /*! allows sorting the primrefs by ID */ + friend __forceinline bool operator<(const PrimRefMB& p0, const PrimRefMB& p1) { + return p0.ID64() < p1.ID64(); + } + + /*! Outputs primitive reference to a stream. */ + friend __forceinline embree_ostream operator<<(embree_ostream cout, const PrimRefMB& ref) { + return cout << "{ time_range = " << ref.time_range << ", bounds = " << ref.bounds() << ", geomID = " << ref.geomID() << ", primID = " << ref.primID() << ", active_segments = " << ref.size() << ", total_segments = " << ref.totalTimeSegments() << " }"; + } + + public: + LBBox3fx lbounds; + BBox1f time_range; // entire geometry time range + }; + +#else + + /*! A primitive reference stores the bounds of the primitive and its ID. */ + struct __aligned(16) PrimRefMB + { + typedef BBox3fa BBox; + + __forceinline PrimRefMB () {} + + __forceinline PrimRefMB (const LBBox3fa& bounds, unsigned int activeTimeSegments, BBox1f time_range, unsigned int totalTimeSegments, unsigned int geomID, unsigned int primID) + : bbox(bounds.interpolate(0.5f)), _activeTimeSegments(activeTimeSegments), _totalTimeSegments(totalTimeSegments), time_range(time_range) + { + assert(activeTimeSegments > 0); + bbox.lower.a = geomID; + bbox.upper.a = primID; + } + + __forceinline PrimRefMB (EmptyTy empty, const LBBox3fa& bounds, unsigned int activeTimeSegments, BBox1f time_range, unsigned int totalTimeSegments, size_t id) + : bbox(bounds.interpolate(0.5f)), _activeTimeSegments(activeTimeSegments), _totalTimeSegments(totalTimeSegments), time_range(time_range) + { + assert(activeTimeSegments > 0); +#if defined(__X86_64__) + bbox.lower.u = id & 0xFFFFFFFF; + bbox.upper.u = (id >> 32) & 0xFFFFFFFF; +#else + bbox.lower.u = id; + bbox.upper.u = 0; +#endif + } + + /*! returns bounds for binning */ + __forceinline BBox3fa bounds() const { + return bbox; + } + + /*! returns the number of time segments of this primref */ + __forceinline unsigned int size() const { + return _activeTimeSegments; + } + + __forceinline unsigned int totalTimeSegments() const { + return _totalTimeSegments; + } + + /* calculate overlapping time segment range */ + __forceinline range timeSegmentRange(const BBox1f& range) const { + return getTimeSegmentRange(range,time_range,float(_totalTimeSegments)); + } + + /* returns time that corresponds to time step */ + __forceinline float timeStep(const int i) const { + assert(i>=0 && i<=(int)_totalTimeSegments); + return time_range.lower + time_range.size()*float(i)/float(_totalTimeSegments); + } + + /*! checks if time range overlaps */ + __forceinline bool time_range_overlap(const BBox1f& range) const + { + if (0.9999f*time_range.upper <= range.lower) return false; + if (1.0001f*time_range.lower >= range.upper) return false; + return true; + } + + /*! returns center for binning */ + __forceinline Vec3fa binCenter() const { + return center2(bounds()); + } + + /*! returns bounds and centroid used for binning */ + __forceinline void binBoundsAndCenter(BBox3fa& bounds_o, Vec3fa& center_o) const + { + bounds_o = bounds(); + center_o = center2(bounds()); + } + + /*! returns the geometry ID */ + __forceinline unsigned int geomID() const { + return bbox.lower.a; + } + + /*! returns the primitive ID */ + __forceinline unsigned int primID() const { + return bbox.upper.a; + } + + /*! returns an size_t sized ID */ + __forceinline size_t ID() const { +#if defined(__X86_64__) + return size_t(bbox.lower.u) + (size_t(bbox.upper.u) << 32); +#else + return size_t(bbox.lower.u); +#endif + } + + /*! special function for operator< */ + __forceinline uint64_t ID64() const { + return (((uint64_t)primID()) << 32) + (uint64_t)geomID(); + } + + /*! allows sorting the primrefs by ID */ + friend __forceinline bool operator<(const PrimRefMB& p0, const PrimRefMB& p1) { + return p0.ID64() < p1.ID64(); + } + + /*! Outputs primitive reference to a stream. */ + friend __forceinline embree_ostream operator<<(embree_ostream cout, const PrimRefMB& ref) { + return cout << "{ bounds = " << ref.bounds() << ", geomID = " << ref.geomID() << ", primID = " << ref.primID() << ", active_segments = " << ref.size() << ", total_segments = " << ref.totalTimeSegments() << " }"; + } + + public: + BBox3fa bbox; // bounds, geomID, primID + unsigned int _activeTimeSegments; + unsigned int _totalTimeSegments; + BBox1f time_range; // entire geometry time range + }; + +#endif +} diff --git a/thirdparty/embree/kernels/common/profile.h b/thirdparty/embree/kernels/common/profile.h new file mode 100644 index 000000000000..a7de36414de2 --- /dev/null +++ b/thirdparty/embree/kernels/common/profile.h @@ -0,0 +1,159 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" + +namespace embree +{ + /*! helper structure for the implementation of the profile functions below */ + struct ProfileTimer + { + static const size_t N = 20; + + ProfileTimer () {} + + ProfileTimer (const size_t numSkip) : i(0), j(0), maxJ(0), numSkip(numSkip), t0(0) + { + for (size_t i=0; i=numSkip) { + dt_min[j] = min(dt_min[j],dt); + dt_avg[j] = dt_avg[j] + dt; + dt_max[j] = max(dt_max[j],dt); + } + j++; + maxJ = max(maxJ,j); + } + + __forceinline void relative (const char* name) + { + const double t1 = getSeconds(); + const double dt = t1-tj; + tj = t1; + assert(names[j] == nullptr || names[j] == name); + names[j] = name; + if (i == 0) dt_fst[j] = dt; + if (i>=numSkip) { + dt_min[j] = min(dt_min[j],dt); + dt_avg[j] = dt_avg[j] + dt; + dt_max[j] = max(dt_max[j],dt); + } + j++; + maxJ = max(maxJ,j); + } + + void print(size_t numElements) + { + for (size_t k=0; k + void profile(const size_t numSkip, const size_t numIter, const size_t numElements, const Closure& closure) + { + ProfileTimer timer(numSkip); + + for (size_t i=0; i + void profile(ProfileTimer& timer, const size_t numSkip, const size_t numIter, const size_t numElements, const Closure& closure) + { + timer = ProfileTimer(numSkip); + + for (size_t i=0; i + struct RayK + { + /* Default construction does nothing */ + __forceinline RayK() {} + + /* Constructs a ray from origin, direction, and ray segment. Near + * has to be smaller than far */ + __forceinline RayK(const Vec3vf& org, const Vec3vf& dir, + const vfloat& tnear = zero, const vfloat& tfar = inf, + const vfloat& time = zero, const vint& mask = -1, const vint& id = 0, const vint& flags = 0) + : org(org), dir(dir), _tnear(tnear), tfar(tfar), _time(time), mask(mask), id(id), flags(flags) {} + + /* Returns the size of the ray */ + static __forceinline size_t size() { return K; } + + /* Calculates if this is a valid ray that does not cause issues during traversal */ + __forceinline vbool valid() const + { + const vbool vx = (abs(org.x) <= vfloat(FLT_LARGE)) & (abs(dir.x) <= vfloat(FLT_LARGE)); + const vbool vy = (abs(org.y) <= vfloat(FLT_LARGE)) & (abs(dir.y) <= vfloat(FLT_LARGE)); + const vbool vz = (abs(org.z) <= vfloat(FLT_LARGE)) & (abs(dir.z) <= vfloat(FLT_LARGE)); + const vbool vn = abs(tnear()) <= vfloat(inf); + const vbool vf = abs(tfar) <= vfloat(inf); + return vx & vy & vz & vn & vf; + } + + __forceinline void get(RayK<1>* ray) const; + __forceinline void get(size_t i, RayK<1>& ray) const; + __forceinline void set(const RayK<1>* ray); + __forceinline void set(size_t i, const RayK<1>& ray); + + __forceinline void copy(size_t dest, size_t source); + + __forceinline vint octant() const + { + return select(dir.x < 0.0f, vint(1), vint(zero)) | + select(dir.y < 0.0f, vint(2), vint(zero)) | + select(dir.z < 0.0f, vint(4), vint(zero)); + } + + /* Ray data */ + Vec3vf org; // ray origin + vfloat _tnear; // start of ray segment + Vec3vf dir; // ray direction + vfloat _time; // time of this ray for motion blur + vfloat tfar; // end of ray segment + vint mask; // used to mask out objects during traversal + vint id; + vint flags; + + __forceinline vfloat& tnear() { return _tnear; } + __forceinline vfloat& time() { return _time; } + __forceinline const vfloat& tnear() const { return _tnear; } + __forceinline const vfloat& time() const { return _time; } + }; + + /* Ray+hit structure for K rays */ + template + struct RayHitK : RayK + { + using RayK::org; + using RayK::_tnear; + using RayK::dir; + using RayK::_time; + using RayK::tfar; + using RayK::mask; + using RayK::id; + using RayK::flags; + + using RayK::tnear; + using RayK::time; + + /* Default construction does nothing */ + __forceinline RayHitK() {} + + /* Constructs a ray from origin, direction, and ray segment. Near + * has to be smaller than far */ + __forceinline RayHitK(const Vec3vf& org, const Vec3vf& dir, + const vfloat& tnear = zero, const vfloat& tfar = inf, + const vfloat& time = zero, const vint& mask = -1, const vint& id = 0, const vint& flags = 0) + : RayK(org, dir, tnear, tfar, time, mask, id, flags), + geomID(RTC_INVALID_GEOMETRY_ID) + { + for (unsigned l = 0; l < RTC_MAX_INSTANCE_LEVEL_COUNT; ++l) + instID[l] = RTC_INVALID_GEOMETRY_ID; + } + + __forceinline RayHitK(const RayK& ray) + : RayK(ray), + geomID(RTC_INVALID_GEOMETRY_ID) + { + for (unsigned l = 0; l < RTC_MAX_INSTANCE_LEVEL_COUNT; ++l) + instID[l] = RTC_INVALID_GEOMETRY_ID; + } + + __forceinline RayHitK& operator =(const RayK& ray) + { + org = ray.org; + _tnear = ray._tnear; + dir = ray.dir; + _time = ray._time; + tfar = ray.tfar; + mask = ray.mask; + id = ray.id; + flags = ray.flags; + + geomID = RTC_INVALID_GEOMETRY_ID; + for (unsigned l = 0; l < RTC_MAX_INSTANCE_LEVEL_COUNT; ++l) + instID[l] = RTC_INVALID_GEOMETRY_ID; + + return *this; + } + + /* Calculates if the hit is valid */ + __forceinline void verifyHit(const vbool& valid0) const + { + vbool valid = valid0 & geomID != vuint(RTC_INVALID_GEOMETRY_ID); + const vbool vt = (abs(tfar) <= vfloat(FLT_LARGE)) | (tfar == vfloat(neg_inf)); + const vbool vu = (abs(u) <= vfloat(FLT_LARGE)); + const vbool vv = (abs(u) <= vfloat(FLT_LARGE)); + const vbool vnx = abs(Ng.x) <= vfloat(FLT_LARGE); + const vbool vny = abs(Ng.y) <= vfloat(FLT_LARGE); + const vbool vnz = abs(Ng.z) <= vfloat(FLT_LARGE); + if (any(valid & !vt)) throw_RTCError(RTC_ERROR_UNKNOWN,"invalid t"); + if (any(valid & !vu)) throw_RTCError(RTC_ERROR_UNKNOWN,"invalid u"); + if (any(valid & !vv)) throw_RTCError(RTC_ERROR_UNKNOWN,"invalid v"); + if (any(valid & !vnx)) throw_RTCError(RTC_ERROR_UNKNOWN,"invalid Ng.x"); + if (any(valid & !vny)) throw_RTCError(RTC_ERROR_UNKNOWN,"invalid Ng.y"); + if (any(valid & !vnz)) throw_RTCError(RTC_ERROR_UNKNOWN,"invalid Ng.z"); + } + + __forceinline void get(RayHitK<1>* ray) const; + __forceinline void get(size_t i, RayHitK<1>& ray) const; + __forceinline void set(const RayHitK<1>* ray); + __forceinline void set(size_t i, const RayHitK<1>& ray); + + __forceinline void copy(size_t dest, size_t source); + + /* Hit data */ + Vec3vf Ng; // geometry normal + vfloat u; // barycentric u coordinate of hit + vfloat v; // barycentric v coordinate of hit + vuint primID; // primitive ID + vuint geomID; // geometry ID + vuint instID[RTC_MAX_INSTANCE_LEVEL_COUNT]; // instance ID + }; + + /* Specialization for a single ray */ + template<> + struct RayK<1> + { + /* Default construction does nothing */ + __forceinline RayK() {} + + /* Constructs a ray from origin, direction, and ray segment. Near + * has to be smaller than far */ + __forceinline RayK(const Vec3fa& org, const Vec3fa& dir, float tnear = zero, float tfar = inf, float time = zero, int mask = -1, int id = 0, int flags = 0) + : org(org,tnear), dir(dir,time), tfar(tfar), mask(mask), id(id), flags(flags) {} + + /* Calculates if this is a valid ray that does not cause issues during traversal */ + __forceinline bool valid() const { + return all(le_mask(abs(Vec3fa(org)), Vec3fa(FLT_LARGE)) & le_mask(abs(Vec3fa(dir)), Vec3fa(FLT_LARGE))) && abs(tnear()) <= float(inf) && abs(tfar) <= float(inf); + } + + /* Ray data */ + Vec3ff org; // 3 floats for ray origin, 1 float for tnear + //float tnear; // start of ray segment + Vec3ff dir; // 3 floats for ray direction, 1 float for time + // float time; + float tfar; // end of ray segment + int mask; // used to mask out objects during traversal + int id; // ray ID + int flags; // ray flags + + __forceinline float& tnear() { return org.w; }; + __forceinline const float& tnear() const { return org.w; }; + + __forceinline float& time() { return dir.w; }; + __forceinline const float& time() const { return dir.w; }; + + }; + + template<> + struct RayHitK<1> : RayK<1> + { + /* Default construction does nothing */ + __forceinline RayHitK() {} + + /* Constructs a ray from origin, direction, and ray segment. Near + * has to be smaller than far */ + __forceinline RayHitK(const Vec3fa& org, const Vec3fa& dir, float tnear = zero, float tfar = inf, float time = zero, int mask = -1, int id = 0, int flags = 0) + : RayK<1>(org, dir, tnear, tfar, time, mask, id, flags), + geomID(RTC_INVALID_GEOMETRY_ID) {} + + __forceinline RayHitK(const RayK<1>& ray) + : RayK<1>(ray), + geomID(RTC_INVALID_GEOMETRY_ID) {} + + __forceinline RayHitK<1>& operator =(const RayK<1>& ray) + { + org = ray.org; + dir = ray.dir; + tfar = ray.tfar; + mask = ray.mask; + id = ray.id; + flags = ray.flags; + + geomID = RTC_INVALID_GEOMETRY_ID; + + return *this; + } + + /* Calculates if the hit is valid */ + __forceinline void verifyHit() const + { + if (geomID == RTC_INVALID_GEOMETRY_ID) return; + const bool vt = (abs(tfar) <= FLT_LARGE) || (tfar == float(neg_inf)); + const bool vu = (abs(u) <= FLT_LARGE); + const bool vv = (abs(u) <= FLT_LARGE); + const bool vnx = abs(Ng.x) <= FLT_LARGE; + const bool vny = abs(Ng.y) <= FLT_LARGE; + const bool vnz = abs(Ng.z) <= FLT_LARGE; + if (!vt) throw_RTCError(RTC_ERROR_UNKNOWN, "invalid t"); + if (!vu) throw_RTCError(RTC_ERROR_UNKNOWN, "invalid u"); + if (!vv) throw_RTCError(RTC_ERROR_UNKNOWN, "invalid v"); + if (!vnx) throw_RTCError(RTC_ERROR_UNKNOWN, "invalid Ng.x"); + if (!vny) throw_RTCError(RTC_ERROR_UNKNOWN, "invalid Ng.y"); + if (!vnz) throw_RTCError(RTC_ERROR_UNKNOWN, "invalid Ng.z"); + } + + /* Hit data */ + Vec3f Ng; // not normalized geometry normal + float u; // barycentric u coordinate of hit + float v; // barycentric v coordinate of hit + unsigned int primID; // primitive ID + unsigned int geomID; // geometry ID + unsigned int instID[RTC_MAX_INSTANCE_LEVEL_COUNT]; // instance ID + }; + + /* Converts ray packet to single rays */ + template + __forceinline void RayK::get(RayK<1>* ray) const + { + for (size_t i = 0; i < K; i++) // FIXME: use SIMD transpose + { + ray[i].org.x = org.x[i]; ray[i].org.y = org.y[i]; ray[i].org.z = org.z[i]; ray[i].tnear() = tnear()[i]; + ray[i].dir.x = dir.x[i]; ray[i].dir.y = dir.y[i]; ray[i].dir.z = dir.z[i]; ray[i].time() = time()[i]; + ray[i].tfar = tfar[i]; ray[i].mask = mask[i]; ray[i].id = id[i]; ray[i].flags = flags[i]; + } + } + + template + __forceinline void RayHitK::get(RayHitK<1>* ray) const + { + // FIXME: use SIMD transpose + for (size_t i = 0; i < K; i++) + get(i, ray[i]); + } + + /* Extracts a single ray out of a ray packet*/ + template + __forceinline void RayK::get(size_t i, RayK<1>& ray) const + { + ray.org.x = org.x[i]; ray.org.y = org.y[i]; ray.org.z = org.z[i]; ray.tnear() = tnear()[i]; + ray.dir.x = dir.x[i]; ray.dir.y = dir.y[i]; ray.dir.z = dir.z[i]; ray.time() = time()[i]; + ray.tfar = tfar[i]; ray.mask = mask[i]; ray.id = id[i]; ray.flags = flags[i]; + } + + template + __forceinline void RayHitK::get(size_t i, RayHitK<1>& ray) const + { + ray.org.x = org.x[i]; ray.org.y = org.y[i]; ray.org.z = org.z[i]; ray.tnear() = tnear()[i]; + ray.dir.x = dir.x[i]; ray.dir.y = dir.y[i]; ray.dir.z = dir.z[i]; ray.tfar = tfar[i]; ray.time() = time()[i]; + ray.mask = mask[i]; ray.id = id[i]; ray.flags = flags[i]; + ray.Ng.x = Ng.x[i]; ray.Ng.y = Ng.y[i]; ray.Ng.z = Ng.z[i]; + ray.u = u[i]; ray.v = v[i]; + ray.primID = primID[i]; ray.geomID = geomID[i]; + + instance_id_stack::copy(instID, ray.instID, i); + } + + /* Converts single rays to ray packet */ + template + __forceinline void RayK::set(const RayK<1>* ray) + { + // FIXME: use SIMD transpose + for (size_t i = 0; i < K; i++) + set(i, ray[i]); + } + + template + __forceinline void RayHitK::set(const RayHitK<1>* ray) + { + // FIXME: use SIMD transpose + for (size_t i = 0; i < K; i++) + set(i, ray[i]); + } + + /* inserts a single ray into a ray packet element */ + template + __forceinline void RayK::set(size_t i, const RayK<1>& ray) + { + org.x[i] = ray.org.x; org.y[i] = ray.org.y; org.z[i] = ray.org.z; tnear()[i] = ray.tnear(); + dir.x[i] = ray.dir.x; dir.y[i] = ray.dir.y; dir.z[i] = ray.dir.z; time()[i] = ray.time(); + tfar[i] = ray.tfar; mask[i] = ray.mask; id[i] = ray.id; flags[i] = ray.flags; + } + + template + __forceinline void RayHitK::set(size_t i, const RayHitK<1>& ray) + { + org.x[i] = ray.org.x; org.y[i] = ray.org.y; org.z[i] = ray.org.z; tnear()[i] = ray.tnear(); + dir.x[i] = ray.dir.x; dir.y[i] = ray.dir.y; dir.z[i] = ray.dir.z; time()[i] = ray.time(); + tfar[i] = ray.tfar; mask[i] = ray.mask; id[i] = ray.id; flags[i] = ray.flags; + Ng.x[i] = ray.Ng.x; Ng.y[i] = ray.Ng.y; Ng.z[i] = ray.Ng.z; + u[i] = ray.u; v[i] = ray.v; + primID[i] = ray.primID; geomID[i] = ray.geomID; + + instance_id_stack::copy(ray.instID, instID, i); + } + + /* copies a ray packet element into another element*/ + template + __forceinline void RayK::copy(size_t dest, size_t source) + { + org.x[dest] = org.x[source]; org.y[dest] = org.y[source]; org.z[dest] = org.z[source]; tnear()[dest] = tnear()[source]; + dir.x[dest] = dir.x[source]; dir.y[dest] = dir.y[source]; dir.z[dest] = dir.z[source]; time()[dest] = time()[source]; + tfar [dest] = tfar[source]; mask[dest] = mask[source]; id[dest] = id[source]; flags[dest] = flags[source]; + } + + template + __forceinline void RayHitK::copy(size_t dest, size_t source) + { + org.x[dest] = org.x[source]; org.y[dest] = org.y[source]; org.z[dest] = org.z[source]; tnear()[dest] = tnear()[source]; + dir.x[dest] = dir.x[source]; dir.y[dest] = dir.y[source]; dir.z[dest] = dir.z[source]; time()[dest] = time()[source]; + tfar [dest] = tfar[source]; mask[dest] = mask[source]; id[dest] = id[source]; flags[dest] = flags[source]; + Ng.x[dest] = Ng.x[source]; Ng.y[dest] = Ng.y[source]; Ng.z[dest] = Ng.z[source]; + u[dest] = u[source]; v[dest] = v[source]; + primID[dest] = primID[source]; geomID[dest] = geomID[source]; + + instance_id_stack::copy(instID, instID, source, dest); + } + + /* Shortcuts */ + typedef RayK<1> Ray; + typedef RayK<4> Ray4; + typedef RayK<8> Ray8; + typedef RayK<16> Ray16; + struct RayN; + + typedef RayHitK<1> RayHit; + typedef RayHitK<4> RayHit4; + typedef RayHitK<8> RayHit8; + typedef RayHitK<16> RayHit16; + struct RayHitN; + + template + struct RayTypeHelper; + + template + struct RayTypeHelper + { + typedef RayHitK Ty; + }; + + template + struct RayTypeHelper + { + typedef RayK Ty; + }; + + template + using RayType = typename RayTypeHelper<1, intersect>::Ty; + + template + using RayTypeK = typename RayTypeHelper::Ty; + + /* Outputs ray to stream */ + template + __forceinline embree_ostream operator <<(embree_ostream cout, const RayK& ray) + { + return cout << "{ " << embree_endl + << " org = " << ray.org << embree_endl + << " dir = " << ray.dir << embree_endl + << " near = " << ray.tnear() << embree_endl + << " far = " << ray.tfar << embree_endl + << " time = " << ray.time() << embree_endl + << " mask = " << ray.mask << embree_endl + << " id = " << ray.id << embree_endl + << " flags = " << ray.flags << embree_endl + << "}"; + } + + template + __forceinline embree_ostream operator <<(embree_ostream cout, const RayHitK& ray) + { + cout << "{ " << embree_endl + << " org = " << ray.org << embree_endl + << " dir = " << ray.dir << embree_endl + << " near = " << ray.tnear() << embree_endl + << " far = " << ray.tfar << embree_endl + << " time = " << ray.time() << embree_endl + << " mask = " << ray.mask << embree_endl + << " id = " << ray.id << embree_endl + << " flags = " << ray.flags << embree_endl + << " Ng = " << ray.Ng + << " u = " << ray.u << embree_endl + << " v = " << ray.v << embree_endl + << " primID = " << ray.primID << embree_endl + << " geomID = " << ray.geomID << embree_endl + << " instID ="; + for (unsigned l = 0; l < RTC_MAX_INSTANCE_LEVEL_COUNT; ++l) + { + cout << " " << ray.instID[l]; + } + cout << embree_endl; + return cout << "}"; + } + + struct RayStreamSOA + { + __forceinline RayStreamSOA(void* rays, size_t N) + : ptr((char*)rays), N(N) {} + + /* ray data access functions */ + __forceinline float* org_x(size_t offset = 0) { return (float*)&ptr[0*4*N+offset]; } // x coordinate of ray origin + __forceinline float* org_y(size_t offset = 0) { return (float*)&ptr[1*4*N+offset]; } // y coordinate of ray origin + __forceinline float* org_z(size_t offset = 0) { return (float*)&ptr[2*4*N+offset]; }; // z coordinate of ray origin + __forceinline float* tnear(size_t offset = 0) { return (float*)&ptr[3*4*N+offset]; }; // start of ray segment + + __forceinline float* dir_x(size_t offset = 0) { return (float*)&ptr[4*4*N+offset]; }; // x coordinate of ray direction + __forceinline float* dir_y(size_t offset = 0) { return (float*)&ptr[5*4*N+offset]; }; // y coordinate of ray direction + __forceinline float* dir_z(size_t offset = 0) { return (float*)&ptr[6*4*N+offset]; }; // z coordinate of ray direction + __forceinline float* time (size_t offset = 0) { return (float*)&ptr[7*4*N+offset]; }; // time of this ray for motion blur + + __forceinline float* tfar (size_t offset = 0) { return (float*)&ptr[8*4*N+offset]; }; // end of ray segment (set to hit distance) + __forceinline int* mask (size_t offset = 0) { return (int*)&ptr[9*4*N+offset]; }; // used to mask out objects during traversal (optional) + __forceinline int* id (size_t offset = 0) { return (int*)&ptr[10*4*N+offset]; }; // id + __forceinline int* flags(size_t offset = 0) { return (int*)&ptr[11*4*N+offset]; }; // flags + + /* hit data access functions */ + __forceinline float* Ng_x(size_t offset = 0) { return (float*)&ptr[12*4*N+offset]; }; // x coordinate of geometry normal + __forceinline float* Ng_y(size_t offset = 0) { return (float*)&ptr[13*4*N+offset]; }; // y coordinate of geometry normal + __forceinline float* Ng_z(size_t offset = 0) { return (float*)&ptr[14*4*N+offset]; }; // z coordinate of geometry normal + + __forceinline float* u(size_t offset = 0) { return (float*)&ptr[15*4*N+offset]; }; // barycentric u coordinate of hit + __forceinline float* v(size_t offset = 0) { return (float*)&ptr[16*4*N+offset]; }; // barycentric v coordinate of hit + + __forceinline unsigned int* primID(size_t offset = 0) { return (unsigned int*)&ptr[17*4*N+offset]; }; // primitive ID + __forceinline unsigned int* geomID(size_t offset = 0) { return (unsigned int*)&ptr[18*4*N+offset]; }; // geometry ID + __forceinline unsigned int* instID(size_t level, size_t offset = 0) { return (unsigned int*)&ptr[19*4*N+level*4*N+offset]; }; // instance ID + + __forceinline Ray getRayByOffset(size_t offset) + { + Ray ray; + ray.org.x = org_x(offset)[0]; + ray.org.y = org_y(offset)[0]; + ray.org.z = org_z(offset)[0]; + ray.tnear() = tnear(offset)[0]; + ray.dir.x = dir_x(offset)[0]; + ray.dir.y = dir_y(offset)[0]; + ray.dir.z = dir_z(offset)[0]; + ray.time() = time(offset)[0]; + ray.tfar = tfar(offset)[0]; + ray.mask = mask(offset)[0]; + ray.id = id(offset)[0]; + ray.flags = flags(offset)[0]; + return ray; + } + + template + __forceinline RayK getRayByOffset(size_t offset) + { + RayK ray; + ray.org.x = vfloat::loadu(org_x(offset)); + ray.org.y = vfloat::loadu(org_y(offset)); + ray.org.z = vfloat::loadu(org_z(offset)); + ray.tnear = vfloat::loadu(tnear(offset)); + ray.dir.x = vfloat::loadu(dir_x(offset)); + ray.dir.y = vfloat::loadu(dir_y(offset)); + ray.dir.z = vfloat::loadu(dir_z(offset)); + ray.time = vfloat::loadu(time(offset)); + ray.tfar = vfloat::loadu(tfar(offset)); + ray.mask = vint::loadu(mask(offset)); + ray.id = vint::loadu(id(offset)); + ray.flags = vint::loadu(flags(offset)); + return ray; + } + + template + __forceinline RayK getRayByOffset(const vbool& valid, size_t offset) + { + RayK ray; + ray.org.x = vfloat::loadu(valid, org_x(offset)); + ray.org.y = vfloat::loadu(valid, org_y(offset)); + ray.org.z = vfloat::loadu(valid, org_z(offset)); + ray.tnear() = vfloat::loadu(valid, tnear(offset)); + ray.dir.x = vfloat::loadu(valid, dir_x(offset)); + ray.dir.y = vfloat::loadu(valid, dir_y(offset)); + ray.dir.z = vfloat::loadu(valid, dir_z(offset)); + ray.time() = vfloat::loadu(valid, time(offset)); + ray.tfar = vfloat::loadu(valid, tfar(offset)); + +#if !defined(__AVX__) + /* SSE: some ray members must be loaded with scalar instructions to ensure that we don't cause memory faults, + because the SSE masked loads always access the entire vector */ + if (unlikely(!all(valid))) + { + ray.mask = zero; + ray.id = zero; + ray.flags = zero; + + for (size_t k = 0; k < K; k++) + { + if (likely(valid[k])) + { + ray.mask[k] = mask(offset)[k]; + ray.id[k] = id(offset)[k]; + ray.flags[k] = flags(offset)[k]; + } + } + } + else +#endif + { + ray.mask = vint::loadu(valid, mask(offset)); + ray.id = vint::loadu(valid, id(offset)); + ray.flags = vint::loadu(valid, flags(offset)); + } + + return ray; + } + + template + __forceinline void setHitByOffset(const vbool& valid_i, size_t offset, const RayHitK& ray) + { + /* + * valid_i: stores which of the input rays exist (do not access nonexistent rays!) + * valid: stores which of the rays actually hit something. + */ + vbool valid = valid_i; + valid &= (ray.geomID != RTC_INVALID_GEOMETRY_ID); + + if (likely(any(valid))) + { + vfloat::storeu(valid, tfar(offset), ray.tfar); + vfloat::storeu(valid, Ng_x(offset), ray.Ng.x); + vfloat::storeu(valid, Ng_y(offset), ray.Ng.y); + vfloat::storeu(valid, Ng_z(offset), ray.Ng.z); + vfloat::storeu(valid, u(offset), ray.u); + vfloat::storeu(valid, v(offset), ray.v); + +#if !defined(__AVX__) + /* SSE: some ray members must be stored with scalar instructions to ensure that we don't cause memory faults, + because the SSE masked stores always access the entire vector */ + if (unlikely(!all(valid_i))) + { + for (size_t k = 0; k < K; k++) + { + if (likely(valid[k])) + { + primID(offset)[k] = ray.primID[k]; + geomID(offset)[k] = ray.geomID[k]; + + instID(0, offset)[k] = ray.instID[0][k]; +#if (RTC_MAX_INSTANCE_LEVEL_COUNT > 1) + for (unsigned l = 1; l < RTC_MAX_INSTANCE_LEVEL_COUNT && ray.instID[l-1][k] != RTC_INVALID_GEOMETRY_ID; ++l) + instID(l, offset)[k] = ray.instID[l][k]; +#endif + } + } + } + else +#endif + { + vuint::storeu(valid, primID(offset), ray.primID); + vuint::storeu(valid, geomID(offset), ray.geomID); + + vuint::storeu(valid, instID(0, offset), ray.instID[0]); +#if (RTC_MAX_INSTANCE_LEVEL_COUNT > 1) + for (unsigned l = 1; l < RTC_MAX_INSTANCE_LEVEL_COUNT && any(valid & (ray.instID[l-1] != RTC_INVALID_GEOMETRY_ID)); ++l) + vuint::storeu(valid, instID(l, offset), ray.instID[l]); +#endif + } + } + } + + template + __forceinline void setHitByOffset(const vbool& valid_i, size_t offset, const RayK& ray) + { + vbool valid = valid_i; + valid &= (ray.tfar < 0.0f); + + if (likely(any(valid))) + vfloat::storeu(valid, tfar(offset), ray.tfar); + } + + __forceinline size_t getOctantByOffset(size_t offset) + { + const float dx = dir_x(offset)[0]; + const float dy = dir_y(offset)[0]; + const float dz = dir_z(offset)[0]; + const size_t octantID = (dx < 0.0f ? 1 : 0) + (dy < 0.0f ? 2 : 0) + (dz < 0.0f ? 4 : 0); + return octantID; + } + + __forceinline bool isValidByOffset(size_t offset) + { + const float nnear = tnear(offset)[0]; + const float ffar = tfar(offset)[0]; + return nnear <= ffar; + } + + template + __forceinline RayK getRayByOffset(const vbool& valid, const vint& offset) + { + RayK ray; + +#if defined(__AVX2__) + ray.org.x = vfloat::template gather<1>(valid, org_x(), offset); + ray.org.y = vfloat::template gather<1>(valid, org_y(), offset); + ray.org.z = vfloat::template gather<1>(valid, org_z(), offset); + ray.tnear() = vfloat::template gather<1>(valid, tnear(), offset); + ray.dir.x = vfloat::template gather<1>(valid, dir_x(), offset); + ray.dir.y = vfloat::template gather<1>(valid, dir_y(), offset); + ray.dir.z = vfloat::template gather<1>(valid, dir_z(), offset); + ray.time() = vfloat::template gather<1>(valid, time(), offset); + ray.tfar = vfloat::template gather<1>(valid, tfar(), offset); + ray.mask = vint::template gather<1>(valid, mask(), offset); + ray.id = vint::template gather<1>(valid, id(), offset); + ray.flags = vint::template gather<1>(valid, flags(), offset); +#else + ray.org = zero; + ray.tnear() = zero; + ray.dir = zero; + ray.time() = zero; + ray.tfar = zero; + ray.mask = zero; + ray.id = zero; + ray.flags = zero; + + for (size_t k = 0; k < K; k++) + { + if (likely(valid[k])) + { + const size_t ofs = offset[k]; + + ray.org.x[k] = *org_x(ofs); + ray.org.y[k] = *org_y(ofs); + ray.org.z[k] = *org_z(ofs); + ray.tnear()[k] = *tnear(ofs); + ray.dir.x[k] = *dir_x(ofs); + ray.dir.y[k] = *dir_y(ofs); + ray.dir.z[k] = *dir_z(ofs); + ray.time()[k] = *time(ofs); + ray.tfar[k] = *tfar(ofs); + ray.mask[k] = *mask(ofs); + ray.id[k] = *id(ofs); + ray.flags[k] = *flags(ofs); + } + } +#endif + + return ray; + } + + template + __forceinline void setHitByOffset(const vbool& valid_i, const vint& offset, const RayHitK& ray) + { + vbool valid = valid_i; + valid &= (ray.geomID != RTC_INVALID_GEOMETRY_ID); + + if (likely(any(valid))) + { +#if defined(__AVX512F__) + vfloat::template scatter<1>(valid, tfar(), offset, ray.tfar); + vfloat::template scatter<1>(valid, Ng_x(), offset, ray.Ng.x); + vfloat::template scatter<1>(valid, Ng_y(), offset, ray.Ng.y); + vfloat::template scatter<1>(valid, Ng_z(), offset, ray.Ng.z); + vfloat::template scatter<1>(valid, u(), offset, ray.u); + vfloat::template scatter<1>(valid, v(), offset, ray.v); + vuint::template scatter<1>(valid, primID(), offset, ray.primID); + vuint::template scatter<1>(valid, geomID(), offset, ray.geomID); + + vuint::template scatter<1>(valid, instID(0), offset, ray.instID[0]); +#if (RTC_MAX_INSTANCE_LEVEL_COUNT > 1) + for (unsigned l = 1; l < RTC_MAX_INSTANCE_LEVEL_COUNT && any(valid & (ray.instID[l-1] != RTC_INVALID_GEOMETRY_ID)); ++l) + vuint::template scatter<1>(valid, instID(l), offset, ray.instID[l]); +#endif +#else + size_t valid_bits = movemask(valid); + while (valid_bits != 0) + { + const size_t k = bscf(valid_bits); + const size_t ofs = offset[k]; + + *tfar(ofs) = ray.tfar[k]; + + *Ng_x(ofs) = ray.Ng.x[k]; + *Ng_y(ofs) = ray.Ng.y[k]; + *Ng_z(ofs) = ray.Ng.z[k]; + *u(ofs) = ray.u[k]; + *v(ofs) = ray.v[k]; + *primID(ofs) = ray.primID[k]; + *geomID(ofs) = ray.geomID[k]; + + *instID(0, ofs) = ray.instID[0][k]; +#if (RTC_MAX_INSTANCE_LEVEL_COUNT > 1) + for (unsigned l = 1; l < RTC_MAX_INSTANCE_LEVEL_COUNT && ray.instID[l-1][k] != RTC_INVALID_GEOMETRY_ID; ++l) + *instID(l, ofs) = ray.instID[l][k]; +#endif + } +#endif + } + } + + template + __forceinline void setHitByOffset(const vbool& valid_i, const vint& offset, const RayK& ray) + { + vbool valid = valid_i; + valid &= (ray.tfar < 0.0f); + + if (likely(any(valid))) + { +#if defined(__AVX512F__) + vfloat::template scatter<1>(valid, tfar(), offset, ray.tfar); +#else + size_t valid_bits = movemask(valid); + while (valid_bits != 0) + { + const size_t k = bscf(valid_bits); + const size_t ofs = offset[k]; + + *tfar(ofs) = ray.tfar[k]; + } +#endif + } + } + + char* __restrict__ ptr; + size_t N; + }; + + template + struct StackRayStreamSOA : public RayStreamSOA + { + __forceinline StackRayStreamSOA(size_t K) + : RayStreamSOA(data, K) { assert(K <= MAX_K); } + + char data[MAX_K / 4 * sizeof(RayHit4)]; + }; + + + struct RayStreamSOP + { + template + __forceinline void init(T& t) + { + org_x = (float*)&t.org.x; + org_y = (float*)&t.org.y; + org_z = (float*)&t.org.z; + tnear = (float*)&t.tnear; + dir_x = (float*)&t.dir.x; + dir_y = (float*)&t.dir.y; + dir_z = (float*)&t.dir.z; + time = (float*)&t.time; + tfar = (float*)&t.tfar; + mask = (unsigned int*)&t.mask; + id = (unsigned int*)&t.id; + flags = (unsigned int*)&t.flags; + + Ng_x = (float*)&t.Ng.x; + Ng_y = (float*)&t.Ng.y; + Ng_z = (float*)&t.Ng.z; + u = (float*)&t.u; + v = (float*)&t.v; + primID = (unsigned int*)&t.primID; + geomID = (unsigned int*)&t.geomID; + + for (unsigned l = 0; l < RTC_MAX_INSTANCE_LEVEL_COUNT; ++l) + instID[l] = (unsigned int*)&t.instID[l]; + } + + __forceinline Ray getRayByOffset(size_t offset) + { + Ray ray; + ray.org.x = *(float* __restrict__)((char*)org_x + offset); + ray.org.y = *(float* __restrict__)((char*)org_y + offset); + ray.org.z = *(float* __restrict__)((char*)org_z + offset); + ray.dir.x = *(float* __restrict__)((char*)dir_x + offset); + ray.dir.y = *(float* __restrict__)((char*)dir_y + offset); + ray.dir.z = *(float* __restrict__)((char*)dir_z + offset); + ray.tfar = *(float* __restrict__)((char*)tfar + offset); + ray.tnear() = tnear ? *(float* __restrict__)((char*)tnear + offset) : 0.0f; + ray.time() = time ? *(float* __restrict__)((char*)time + offset) : 0.0f; + ray.mask = mask ? *(unsigned int* __restrict__)((char*)mask + offset) : -1; + ray.id = id ? *(unsigned int* __restrict__)((char*)id + offset) : -1; + ray.flags = flags ? *(unsigned int* __restrict__)((char*)flags + offset) : -1; + return ray; + } + + template + __forceinline RayK getRayByOffset(const vbool& valid, size_t offset) + { + RayK ray; + ray.org.x = vfloat::loadu(valid, (float* __restrict__)((char*)org_x + offset)); + ray.org.y = vfloat::loadu(valid, (float* __restrict__)((char*)org_y + offset)); + ray.org.z = vfloat::loadu(valid, (float* __restrict__)((char*)org_z + offset)); + ray.dir.x = vfloat::loadu(valid, (float* __restrict__)((char*)dir_x + offset)); + ray.dir.y = vfloat::loadu(valid, (float* __restrict__)((char*)dir_y + offset)); + ray.dir.z = vfloat::loadu(valid, (float* __restrict__)((char*)dir_z + offset)); + ray.tfar = vfloat::loadu(valid, (float* __restrict__)((char*)tfar + offset)); + ray.tnear() = tnear ? vfloat::loadu(valid, (float* __restrict__)((char*)tnear + offset)) : 0.0f; + ray.time() = time ? vfloat::loadu(valid, (float* __restrict__)((char*)time + offset)) : 0.0f; + ray.mask = mask ? vint::loadu(valid, (const void* __restrict__)((char*)mask + offset)) : -1; + ray.id = id ? vint::loadu(valid, (const void* __restrict__)((char*)id + offset)) : -1; + ray.flags = flags ? vint::loadu(valid, (const void* __restrict__)((char*)flags + offset)) : -1; + return ray; + } + + template + __forceinline Vec3vf getDirByOffset(const vbool& valid, size_t offset) + { + Vec3vf dir; + dir.x = vfloat::loadu(valid, (float* __restrict__)((char*)dir_x + offset)); + dir.y = vfloat::loadu(valid, (float* __restrict__)((char*)dir_y + offset)); + dir.z = vfloat::loadu(valid, (float* __restrict__)((char*)dir_z + offset)); + return dir; + } + + __forceinline void setHitByOffset(size_t offset, const RayHit& ray) + { + if (ray.geomID != RTC_INVALID_GEOMETRY_ID) + { + *(float* __restrict__)((char*)tfar + offset) = ray.tfar; + + if (likely(Ng_x)) *(float* __restrict__)((char*)Ng_x + offset) = ray.Ng.x; + if (likely(Ng_y)) *(float* __restrict__)((char*)Ng_y + offset) = ray.Ng.y; + if (likely(Ng_z)) *(float* __restrict__)((char*)Ng_z + offset) = ray.Ng.z; + *(float* __restrict__)((char*)u + offset) = ray.u; + *(float* __restrict__)((char*)v + offset) = ray.v; + *(unsigned int* __restrict__)((char*)geomID + offset) = ray.geomID; + *(unsigned int* __restrict__)((char*)primID + offset) = ray.primID; + + if (likely(instID[0])) { + *(unsigned int* __restrict__)((char*)instID[0] + offset) = ray.instID[0]; +#if (RTC_MAX_INSTANCE_LEVEL_COUNT > 1) + for (unsigned l = 1; l < RTC_MAX_INSTANCE_LEVEL_COUNT && ray.instID[l-1] != RTC_INVALID_GEOMETRY_ID; ++l) + *(unsigned int* __restrict__)((char*)instID[l] + offset) = ray.instID[l]; +#endif + } + } + } + + __forceinline void setHitByOffset(size_t offset, const Ray& ray) + { + *(float* __restrict__)((char*)tfar + offset) = ray.tfar; + } + + template + __forceinline void setHitByOffset(const vbool& valid_i, size_t offset, const RayHitK& ray) + { + vbool valid = valid_i; + valid &= (ray.geomID != RTC_INVALID_GEOMETRY_ID); + + if (likely(any(valid))) + { + vfloat::storeu(valid, (float* __restrict__)((char*)tfar + offset), ray.tfar); + + if (likely(Ng_x)) vfloat::storeu(valid, (float* __restrict__)((char*)Ng_x + offset), ray.Ng.x); + if (likely(Ng_y)) vfloat::storeu(valid, (float* __restrict__)((char*)Ng_y + offset), ray.Ng.y); + if (likely(Ng_z)) vfloat::storeu(valid, (float* __restrict__)((char*)Ng_z + offset), ray.Ng.z); + vfloat::storeu(valid, (float* __restrict__)((char*)u + offset), ray.u); + vfloat::storeu(valid, (float* __restrict__)((char*)v + offset), ray.v); + vuint::storeu(valid, (unsigned int* __restrict__)((char*)primID + offset), ray.primID); + vuint::storeu(valid, (unsigned int* __restrict__)((char*)geomID + offset), ray.geomID); + + if (likely(instID[0])) { + vuint::storeu(valid, (unsigned int* __restrict__)((char*)instID[0] + offset), ray.instID[0]); +#if (RTC_MAX_INSTANCE_LEVEL_COUNT > 1) + for (unsigned l = 1; l < RTC_MAX_INSTANCE_LEVEL_COUNT && any(valid & (ray.instID[l-1] != RTC_INVALID_GEOMETRY_ID)); ++l) + vuint::storeu(valid, (unsigned int* __restrict__)((char*)instID[l] + offset), ray.instID[l]); +#endif + } + } + } + + template + __forceinline void setHitByOffset(const vbool& valid_i, size_t offset, const RayK& ray) + { + vbool valid = valid_i; + valid &= (ray.tfar < 0.0f); + + if (likely(any(valid))) + vfloat::storeu(valid, (float* __restrict__)((char*)tfar + offset), ray.tfar); + } + + __forceinline size_t getOctantByOffset(size_t offset) + { + const float dx = *(float* __restrict__)((char*)dir_x + offset); + const float dy = *(float* __restrict__)((char*)dir_y + offset); + const float dz = *(float* __restrict__)((char*)dir_z + offset); + const size_t octantID = (dx < 0.0f ? 1 : 0) + (dy < 0.0f ? 2 : 0) + (dz < 0.0f ? 4 : 0); + return octantID; + } + + __forceinline bool isValidByOffset(size_t offset) + { + const float nnear = tnear ? *(float* __restrict__)((char*)tnear + offset) : 0.0f; + const float ffar = *(float* __restrict__)((char*)tfar + offset); + return nnear <= ffar; + } + + template + __forceinline vbool isValidByOffset(const vbool& valid, size_t offset) + { + const vfloat nnear = tnear ? vfloat::loadu(valid, (float* __restrict__)((char*)tnear + offset)) : 0.0f; + const vfloat ffar = vfloat::loadu(valid, (float* __restrict__)((char*)tfar + offset)); + return nnear <= ffar; + } + + template + __forceinline RayK getRayByOffset(const vbool& valid, const vint& offset) + { + RayK ray; + +#if defined(__AVX2__) + ray.org.x = vfloat::template gather<1>(valid, org_x, offset); + ray.org.y = vfloat::template gather<1>(valid, org_y, offset); + ray.org.z = vfloat::template gather<1>(valid, org_z, offset); + ray.dir.x = vfloat::template gather<1>(valid, dir_x, offset); + ray.dir.y = vfloat::template gather<1>(valid, dir_y, offset); + ray.dir.z = vfloat::template gather<1>(valid, dir_z, offset); + ray.tfar = vfloat::template gather<1>(valid, tfar, offset); + ray.tnear() = tnear ? vfloat::template gather<1>(valid, tnear, offset) : vfloat(zero); + ray.time() = time ? vfloat::template gather<1>(valid, time, offset) : vfloat(zero); + ray.mask = mask ? vint::template gather<1>(valid, (int*)mask, offset) : vint(-1); + ray.id = id ? vint::template gather<1>(valid, (int*)id, offset) : vint(-1); + ray.flags = flags ? vint::template gather<1>(valid, (int*)flags, offset) : vint(-1); +#else + ray.org = zero; + ray.tnear() = zero; + ray.dir = zero; + ray.tfar = zero; + ray.time() = zero; + ray.mask = zero; + ray.id = zero; + ray.flags = zero; + + for (size_t k = 0; k < K; k++) + { + if (likely(valid[k])) + { + const size_t ofs = offset[k]; + + ray.org.x[k] = *(float* __restrict__)((char*)org_x + ofs); + ray.org.y[k] = *(float* __restrict__)((char*)org_y + ofs); + ray.org.z[k] = *(float* __restrict__)((char*)org_z + ofs); + ray.dir.x[k] = *(float* __restrict__)((char*)dir_x + ofs); + ray.dir.y[k] = *(float* __restrict__)((char*)dir_y + ofs); + ray.dir.z[k] = *(float* __restrict__)((char*)dir_z + ofs); + ray.tfar[k] = *(float* __restrict__)((char*)tfar + ofs); + ray.tnear()[k] = tnear ? *(float* __restrict__)((char*)tnear + ofs) : 0.0f; + ray.time()[k] = time ? *(float* __restrict__)((char*)time + ofs) : 0.0f; + ray.mask[k] = mask ? *(int* __restrict__)((char*)mask + ofs) : -1; + ray.id[k] = id ? *(int* __restrict__)((char*)id + ofs) : -1; + ray.flags[k] = flags ? *(int* __restrict__)((char*)flags + ofs) : -1; + } + } +#endif + + return ray; + } + + template + __forceinline void setHitByOffset(const vbool& valid_i, const vint& offset, const RayHitK& ray) + { + vbool valid = valid_i; + valid &= (ray.geomID != RTC_INVALID_GEOMETRY_ID); + + if (likely(any(valid))) + { +#if defined(__AVX512F__) + vfloat::template scatter<1>(valid, tfar, offset, ray.tfar); + + if (likely(Ng_x)) vfloat::template scatter<1>(valid, Ng_x, offset, ray.Ng.x); + if (likely(Ng_y)) vfloat::template scatter<1>(valid, Ng_y, offset, ray.Ng.y); + if (likely(Ng_z)) vfloat::template scatter<1>(valid, Ng_z, offset, ray.Ng.z); + vfloat::template scatter<1>(valid, u, offset, ray.u); + vfloat::template scatter<1>(valid, v, offset, ray.v); + vuint::template scatter<1>(valid, (unsigned int*)geomID, offset, ray.geomID); + vuint::template scatter<1>(valid, (unsigned int*)primID, offset, ray.primID); + + if (likely(instID[0])) { + vuint::template scatter<1>(valid, (unsigned int*)instID[0], offset, ray.instID[0]); +#if (RTC_MAX_INSTANCE_LEVEL_COUNT > 1) + for (unsigned l = 1; l < RTC_MAX_INSTANCE_LEVEL_COUNT && any(valid & (ray.instID[l-1] != RTC_INVALID_GEOMETRY_ID)); ++l) + vuint::template scatter<1>(valid, (unsigned int*)instID[l], offset, ray.instID[l]); +#endif + } +#else + size_t valid_bits = movemask(valid); + while (valid_bits != 0) + { + const size_t k = bscf(valid_bits); + const size_t ofs = offset[k]; + + *(float* __restrict__)((char*)tfar + ofs) = ray.tfar[k]; + + if (likely(Ng_x)) *(float* __restrict__)((char*)Ng_x + ofs) = ray.Ng.x[k]; + if (likely(Ng_y)) *(float* __restrict__)((char*)Ng_y + ofs) = ray.Ng.y[k]; + if (likely(Ng_z)) *(float* __restrict__)((char*)Ng_z + ofs) = ray.Ng.z[k]; + *(float* __restrict__)((char*)u + ofs) = ray.u[k]; + *(float* __restrict__)((char*)v + ofs) = ray.v[k]; + *(unsigned int* __restrict__)((char*)primID + ofs) = ray.primID[k]; + *(unsigned int* __restrict__)((char*)geomID + ofs) = ray.geomID[k]; + + if (likely(instID[0])) { + *(unsigned int* __restrict__)((char*)instID[0] + ofs) = ray.instID[0][k]; +#if (RTC_MAX_INSTANCE_LEVEL_COUNT > 1) + for (unsigned l = 1; l < RTC_MAX_INSTANCE_LEVEL_COUNT && ray.instID[l-1][k] != RTC_INVALID_GEOMETRY_ID; ++l) + *(unsigned int* __restrict__)((char*)instID[l] + ofs) = ray.instID[l][k]; +#endif + } + } +#endif + } + } + + template + __forceinline void setHitByOffset(const vbool& valid_i, const vint& offset, const RayK& ray) + { + vbool valid = valid_i; + valid &= (ray.tfar < 0.0f); + + if (likely(any(valid))) + { +#if defined(__AVX512F__) + vfloat::template scatter<1>(valid, tfar, offset, ray.tfar); +#else + size_t valid_bits = movemask(valid); + while (valid_bits != 0) + { + const size_t k = bscf(valid_bits); + const size_t ofs = offset[k]; + + *(float* __restrict__)((char*)tfar + ofs) = ray.tfar[k]; + } +#endif + } + } + + /* ray data */ + float* __restrict__ org_x; // x coordinate of ray origin + float* __restrict__ org_y; // y coordinate of ray origin + float* __restrict__ org_z; // z coordinate of ray origin + float* __restrict__ tnear; // start of ray segment (optional) + + float* __restrict__ dir_x; // x coordinate of ray direction + float* __restrict__ dir_y; // y coordinate of ray direction + float* __restrict__ dir_z; // z coordinate of ray direction + float* __restrict__ time; // time of this ray for motion blur (optional) + + float* __restrict__ tfar; // end of ray segment (set to hit distance) + unsigned int* __restrict__ mask; // used to mask out objects during traversal (optional) + unsigned int* __restrict__ id; // ray ID + unsigned int* __restrict__ flags; // ray flags + + /* hit data */ + float* __restrict__ Ng_x; // x coordinate of geometry normal (optional) + float* __restrict__ Ng_y; // y coordinate of geometry normal (optional) + float* __restrict__ Ng_z; // z coordinate of geometry normal (optional) + + float* __restrict__ u; // barycentric u coordinate of hit + float* __restrict__ v; // barycentric v coordinate of hit + + unsigned int* __restrict__ primID; // primitive ID + unsigned int* __restrict__ geomID; // geometry ID + unsigned int* __restrict__ instID[RTC_MAX_INSTANCE_LEVEL_COUNT]; // instance ID (optional) + }; + + + struct RayStreamAOS + { + __forceinline RayStreamAOS(void* rays) + : ptr((Ray*)rays) {} + + __forceinline Ray& getRayByOffset(size_t offset) + { + return *(Ray*)((char*)ptr + offset); + } + + template + __forceinline RayK getRayByOffset(const vint& offset); + + template + __forceinline RayK getRayByOffset(const vbool& valid, const vint& offset) + { + const vint valid_offset = select(valid, offset, vintx(zero)); + return getRayByOffset(valid_offset); + } + + template + __forceinline void setHitByOffset(const vbool& valid_i, const vint& offset, const RayHitK& ray) + { + vbool valid = valid_i; + valid &= (ray.geomID != RTC_INVALID_GEOMETRY_ID); + + if (likely(any(valid))) + { +#if defined(__AVX512F__) + vfloat::template scatter<1>(valid, &ptr->tfar, offset, ray.tfar); + vfloat::template scatter<1>(valid, &((RayHit*)ptr)->Ng.x, offset, ray.Ng.x); + vfloat::template scatter<1>(valid, &((RayHit*)ptr)->Ng.y, offset, ray.Ng.y); + vfloat::template scatter<1>(valid, &((RayHit*)ptr)->Ng.z, offset, ray.Ng.z); + vfloat::template scatter<1>(valid, &((RayHit*)ptr)->u, offset, ray.u); + vfloat::template scatter<1>(valid, &((RayHit*)ptr)->v, offset, ray.v); + vuint::template scatter<1>(valid, (unsigned int*)&((RayHit*)ptr)->primID, offset, ray.primID); + vuint::template scatter<1>(valid, (unsigned int*)&((RayHit*)ptr)->geomID, offset, ray.geomID); + + vuint::template scatter<1>(valid, (unsigned int*)&((RayHit*)ptr)->instID[0], offset, ray.instID[0]); +#if (RTC_MAX_INSTANCE_LEVEL_COUNT > 1) + for (unsigned l = 1; l < RTC_MAX_INSTANCE_LEVEL_COUNT && any(valid & (ray.instID[l-1] != RTC_INVALID_GEOMETRY_ID)); ++l) + vuint::template scatter<1>(valid, (unsigned int*)&((RayHit*)ptr)->instID[l], offset, ray.instID[l]); +#endif +#else + size_t valid_bits = movemask(valid); + while (valid_bits != 0) + { + const size_t k = bscf(valid_bits); + RayHit* __restrict__ ray_k = (RayHit*)((char*)ptr + offset[k]); + ray_k->tfar = ray.tfar[k]; + ray_k->Ng.x = ray.Ng.x[k]; + ray_k->Ng.y = ray.Ng.y[k]; + ray_k->Ng.z = ray.Ng.z[k]; + ray_k->u = ray.u[k]; + ray_k->v = ray.v[k]; + ray_k->primID = ray.primID[k]; + ray_k->geomID = ray.geomID[k]; + + instance_id_stack::copy(ray.instID, ray_k->instID, k); + } +#endif + } + } + + template + __forceinline void setHitByOffset(const vbool& valid_i, const vint& offset, const RayK& ray) + { + vbool valid = valid_i; + valid &= (ray.tfar < 0.0f); + + if (likely(any(valid))) + { +#if defined(__AVX512F__) + vfloat::template scatter<1>(valid, &ptr->tfar, offset, ray.tfar); +#else + size_t valid_bits = movemask(valid); + while (valid_bits != 0) + { + const size_t k = bscf(valid_bits); + Ray* __restrict__ ray_k = (Ray*)((char*)ptr + offset[k]); + ray_k->tfar = ray.tfar[k]; + } +#endif + } + } + + Ray* __restrict__ ptr; + }; + + template<> + __forceinline Ray4 RayStreamAOS::getRayByOffset(const vint4& offset) + { + Ray4 ray; + + /* load and transpose: org.x, org.y, org.z, tnear */ + const vfloat4 a0 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[0]))->org); + const vfloat4 a1 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[1]))->org); + const vfloat4 a2 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[2]))->org); + const vfloat4 a3 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[3]))->org); + + transpose(a0,a1,a2,a3, ray.org.x, ray.org.y, ray.org.z, ray.tnear()); + + /* load and transpose: dir.x, dir.y, dir.z, time */ + const vfloat4 b0 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[0]))->dir); + const vfloat4 b1 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[1]))->dir); + const vfloat4 b2 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[2]))->dir); + const vfloat4 b3 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[3]))->dir); + + transpose(b0,b1,b2,b3, ray.dir.x, ray.dir.y, ray.dir.z, ray.time()); + + /* load and transpose: tfar, mask, id, flags */ + const vfloat4 c0 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[0]))->tfar); + const vfloat4 c1 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[1]))->tfar); + const vfloat4 c2 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[2]))->tfar); + const vfloat4 c3 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[3]))->tfar); + + vfloat4 maskf, idf, flagsf; + transpose(c0,c1,c2,c3, ray.tfar, maskf, idf, flagsf); + ray.mask = asInt(maskf); + ray.id = asInt(idf); + ray.flags = asInt(flagsf); + + return ray; + } + +#if defined(__AVX__) + template<> + __forceinline Ray8 RayStreamAOS::getRayByOffset(const vint8& offset) + { + Ray8 ray; + + /* load and transpose: org.x, org.y, org.z, tnear, dir.x, dir.y, dir.z, time */ + const vfloat8 ab0 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[0]))->org); + const vfloat8 ab1 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[1]))->org); + const vfloat8 ab2 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[2]))->org); + const vfloat8 ab3 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[3]))->org); + const vfloat8 ab4 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[4]))->org); + const vfloat8 ab5 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[5]))->org); + const vfloat8 ab6 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[6]))->org); + const vfloat8 ab7 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[7]))->org); + + transpose(ab0,ab1,ab2,ab3,ab4,ab5,ab6,ab7, ray.org.x, ray.org.y, ray.org.z, ray.tnear(), ray.dir.x, ray.dir.y, ray.dir.z, ray.time()); + + /* load and transpose: tfar, mask, id, flags */ + const vfloat4 c0 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[0]))->tfar); + const vfloat4 c1 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[1]))->tfar); + const vfloat4 c2 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[2]))->tfar); + const vfloat4 c3 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[3]))->tfar); + const vfloat4 c4 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[4]))->tfar); + const vfloat4 c5 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[5]))->tfar); + const vfloat4 c6 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[6]))->tfar); + const vfloat4 c7 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[7]))->tfar); + + vfloat8 maskf, idf, flagsf; + transpose(c0,c1,c2,c3,c4,c5,c6,c7, ray.tfar, maskf, idf, flagsf); + ray.mask = asInt(maskf); + ray.id = asInt(idf); + ray.flags = asInt(flagsf); + + return ray; + } +#endif + +#if defined(__AVX512F__) + template<> + __forceinline Ray16 RayStreamAOS::getRayByOffset(const vint16& offset) + { + Ray16 ray; + + /* load and transpose: org.x, org.y, org.z, tnear, dir.x, dir.y, dir.z, time */ + const vfloat8 ab0 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[ 0]))->org); + const vfloat8 ab1 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[ 1]))->org); + const vfloat8 ab2 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[ 2]))->org); + const vfloat8 ab3 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[ 3]))->org); + const vfloat8 ab4 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[ 4]))->org); + const vfloat8 ab5 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[ 5]))->org); + const vfloat8 ab6 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[ 6]))->org); + const vfloat8 ab7 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[ 7]))->org); + const vfloat8 ab8 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[ 8]))->org); + const vfloat8 ab9 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[ 9]))->org); + const vfloat8 ab10 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[10]))->org); + const vfloat8 ab11 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[11]))->org); + const vfloat8 ab12 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[12]))->org); + const vfloat8 ab13 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[13]))->org); + const vfloat8 ab14 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[14]))->org); + const vfloat8 ab15 = vfloat8::loadu(&((Ray*)((char*)ptr + offset[15]))->org); + + transpose(ab0,ab1,ab2,ab3,ab4,ab5,ab6,ab7,ab8,ab9,ab10,ab11,ab12,ab13,ab14,ab15, + ray.org.x, ray.org.y, ray.org.z, ray.tnear(), ray.dir.x, ray.dir.y, ray.dir.z, ray.time()); + + /* load and transpose: tfar, mask, id, flags */ + const vfloat4 c0 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[ 0]))->tfar); + const vfloat4 c1 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[ 1]))->tfar); + const vfloat4 c2 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[ 2]))->tfar); + const vfloat4 c3 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[ 3]))->tfar); + const vfloat4 c4 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[ 4]))->tfar); + const vfloat4 c5 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[ 5]))->tfar); + const vfloat4 c6 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[ 6]))->tfar); + const vfloat4 c7 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[ 7]))->tfar); + const vfloat4 c8 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[ 8]))->tfar); + const vfloat4 c9 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[ 9]))->tfar); + const vfloat4 c10 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[10]))->tfar); + const vfloat4 c11 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[11]))->tfar); + const vfloat4 c12 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[12]))->tfar); + const vfloat4 c13 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[13]))->tfar); + const vfloat4 c14 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[14]))->tfar); + const vfloat4 c15 = vfloat4::loadu(&((Ray*)((char*)ptr + offset[15]))->tfar); + + vfloat16 maskf, idf, flagsf; + transpose(c0,c1,c2,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13,c14,c15, + ray.tfar, maskf, idf, flagsf); + ray.mask = asInt(maskf); + ray.id = asInt(idf); + ray.flags = asInt(flagsf); + + return ray; + } +#endif + + + struct RayStreamAOP + { + __forceinline RayStreamAOP(void* rays) + : ptr((Ray**)rays) {} + + __forceinline Ray& getRayByIndex(size_t index) + { + return *ptr[index]; + } + + template + __forceinline RayK getRayByIndex(const vint& index); + + template + __forceinline RayK getRayByIndex(const vbool& valid, const vint& index) + { + const vint valid_index = select(valid, index, vintx(zero)); + return getRayByIndex(valid_index); + } + + template + __forceinline void setHitByIndex(const vbool& valid_i, const vint& index, const RayHitK& ray) + { + vbool valid = valid_i; + valid &= (ray.geomID != RTC_INVALID_GEOMETRY_ID); + + if (likely(any(valid))) + { + size_t valid_bits = movemask(valid); + while (valid_bits != 0) + { + const size_t k = bscf(valid_bits); + RayHit* __restrict__ ray_k = (RayHit*)ptr[index[k]]; + + ray_k->tfar = ray.tfar[k]; + ray_k->Ng.x = ray.Ng.x[k]; + ray_k->Ng.y = ray.Ng.y[k]; + ray_k->Ng.z = ray.Ng.z[k]; + ray_k->u = ray.u[k]; + ray_k->v = ray.v[k]; + ray_k->primID = ray.primID[k]; + ray_k->geomID = ray.geomID[k]; + instance_id_stack::copy(ray.instID, ray_k->instID, k); + } + } + } + + template + __forceinline void setHitByIndex(const vbool& valid_i, const vint& index, const RayK& ray) + { + vbool valid = valid_i; + valid &= (ray.tfar < 0.0f); + + if (likely(any(valid))) + { + size_t valid_bits = movemask(valid); + while (valid_bits != 0) + { + const size_t k = bscf(valid_bits); + Ray* __restrict__ ray_k = ptr[index[k]]; + + ray_k->tfar = ray.tfar[k]; + } + } + } + + Ray** __restrict__ ptr; + }; + + template<> + __forceinline Ray4 RayStreamAOP::getRayByIndex(const vint4& index) + { + Ray4 ray; + + /* load and transpose: org.x, org.y, org.z, tnear */ + const vfloat4 a0 = vfloat4::loadu(&ptr[index[0]]->org); + const vfloat4 a1 = vfloat4::loadu(&ptr[index[1]]->org); + const vfloat4 a2 = vfloat4::loadu(&ptr[index[2]]->org); + const vfloat4 a3 = vfloat4::loadu(&ptr[index[3]]->org); + + transpose(a0,a1,a2,a3, ray.org.x, ray.org.y, ray.org.z, ray.tnear()); + + /* load and transpose: dir.x, dir.y, dir.z, time */ + const vfloat4 b0 = vfloat4::loadu(&ptr[index[0]]->dir); + const vfloat4 b1 = vfloat4::loadu(&ptr[index[1]]->dir); + const vfloat4 b2 = vfloat4::loadu(&ptr[index[2]]->dir); + const vfloat4 b3 = vfloat4::loadu(&ptr[index[3]]->dir); + + transpose(b0,b1,b2,b3, ray.dir.x, ray.dir.y, ray.dir.z, ray.time()); + + /* load and transpose: tfar, mask, id, flags */ + const vfloat4 c0 = vfloat4::loadu(&ptr[index[0]]->tfar); + const vfloat4 c1 = vfloat4::loadu(&ptr[index[1]]->tfar); + const vfloat4 c2 = vfloat4::loadu(&ptr[index[2]]->tfar); + const vfloat4 c3 = vfloat4::loadu(&ptr[index[3]]->tfar); + + vfloat4 maskf, idf, flagsf; + transpose(c0,c1,c2,c3, ray.tfar, maskf, idf, flagsf); + ray.mask = asInt(maskf); + ray.id = asInt(idf); + ray.flags = asInt(flagsf); + + return ray; + } + +#if defined(__AVX__) + template<> + __forceinline Ray8 RayStreamAOP::getRayByIndex(const vint8& index) + { + Ray8 ray; + + /* load and transpose: org.x, org.y, org.z, tnear, dir.x, dir.y, dir.z, time */ + const vfloat8 ab0 = vfloat8::loadu(&ptr[index[0]]->org); + const vfloat8 ab1 = vfloat8::loadu(&ptr[index[1]]->org); + const vfloat8 ab2 = vfloat8::loadu(&ptr[index[2]]->org); + const vfloat8 ab3 = vfloat8::loadu(&ptr[index[3]]->org); + const vfloat8 ab4 = vfloat8::loadu(&ptr[index[4]]->org); + const vfloat8 ab5 = vfloat8::loadu(&ptr[index[5]]->org); + const vfloat8 ab6 = vfloat8::loadu(&ptr[index[6]]->org); + const vfloat8 ab7 = vfloat8::loadu(&ptr[index[7]]->org); + + transpose(ab0,ab1,ab2,ab3,ab4,ab5,ab6,ab7, ray.org.x, ray.org.y, ray.org.z, ray.tnear(), ray.dir.x, ray.dir.y, ray.dir.z, ray.time()); + + /* load and transpose: tfar, mask, id, flags */ + const vfloat4 c0 = vfloat4::loadu(&ptr[index[0]]->tfar); + const vfloat4 c1 = vfloat4::loadu(&ptr[index[1]]->tfar); + const vfloat4 c2 = vfloat4::loadu(&ptr[index[2]]->tfar); + const vfloat4 c3 = vfloat4::loadu(&ptr[index[3]]->tfar); + const vfloat4 c4 = vfloat4::loadu(&ptr[index[4]]->tfar); + const vfloat4 c5 = vfloat4::loadu(&ptr[index[5]]->tfar); + const vfloat4 c6 = vfloat4::loadu(&ptr[index[6]]->tfar); + const vfloat4 c7 = vfloat4::loadu(&ptr[index[7]]->tfar); + + vfloat8 maskf, idf, flagsf; + transpose(c0,c1,c2,c3,c4,c5,c6,c7, ray.tfar, maskf, idf, flagsf); + ray.mask = asInt(maskf); + ray.id = asInt(idf); + ray.flags = asInt(flagsf); + + return ray; + } +#endif + +#if defined(__AVX512F__) + template<> + __forceinline Ray16 RayStreamAOP::getRayByIndex(const vint16& index) + { + Ray16 ray; + + /* load and transpose: org.x, org.y, org.z, tnear, dir.x, dir.y, dir.z, time */ + const vfloat8 ab0 = vfloat8::loadu(&ptr[index[0]]->org); + const vfloat8 ab1 = vfloat8::loadu(&ptr[index[1]]->org); + const vfloat8 ab2 = vfloat8::loadu(&ptr[index[2]]->org); + const vfloat8 ab3 = vfloat8::loadu(&ptr[index[3]]->org); + const vfloat8 ab4 = vfloat8::loadu(&ptr[index[4]]->org); + const vfloat8 ab5 = vfloat8::loadu(&ptr[index[5]]->org); + const vfloat8 ab6 = vfloat8::loadu(&ptr[index[6]]->org); + const vfloat8 ab7 = vfloat8::loadu(&ptr[index[7]]->org); + const vfloat8 ab8 = vfloat8::loadu(&ptr[index[8]]->org); + const vfloat8 ab9 = vfloat8::loadu(&ptr[index[9]]->org); + const vfloat8 ab10 = vfloat8::loadu(&ptr[index[10]]->org); + const vfloat8 ab11 = vfloat8::loadu(&ptr[index[11]]->org); + const vfloat8 ab12 = vfloat8::loadu(&ptr[index[12]]->org); + const vfloat8 ab13 = vfloat8::loadu(&ptr[index[13]]->org); + const vfloat8 ab14 = vfloat8::loadu(&ptr[index[14]]->org); + const vfloat8 ab15 = vfloat8::loadu(&ptr[index[15]]->org); + + transpose(ab0,ab1,ab2,ab3,ab4,ab5,ab6,ab7,ab8,ab9,ab10,ab11,ab12,ab13,ab14,ab15, + ray.org.x, ray.org.y, ray.org.z, ray.tnear(), ray.dir.x, ray.dir.y, ray.dir.z, ray.time()); + + /* load and transpose: tfar, mask, id, flags */ + const vfloat4 c0 = vfloat4::loadu(&ptr[index[0]]->tfar); + const vfloat4 c1 = vfloat4::loadu(&ptr[index[1]]->tfar); + const vfloat4 c2 = vfloat4::loadu(&ptr[index[2]]->tfar); + const vfloat4 c3 = vfloat4::loadu(&ptr[index[3]]->tfar); + const vfloat4 c4 = vfloat4::loadu(&ptr[index[4]]->tfar); + const vfloat4 c5 = vfloat4::loadu(&ptr[index[5]]->tfar); + const vfloat4 c6 = vfloat4::loadu(&ptr[index[6]]->tfar); + const vfloat4 c7 = vfloat4::loadu(&ptr[index[7]]->tfar); + const vfloat4 c8 = vfloat4::loadu(&ptr[index[8]]->tfar); + const vfloat4 c9 = vfloat4::loadu(&ptr[index[9]]->tfar); + const vfloat4 c10 = vfloat4::loadu(&ptr[index[10]]->tfar); + const vfloat4 c11 = vfloat4::loadu(&ptr[index[11]]->tfar); + const vfloat4 c12 = vfloat4::loadu(&ptr[index[12]]->tfar); + const vfloat4 c13 = vfloat4::loadu(&ptr[index[13]]->tfar); + const vfloat4 c14 = vfloat4::loadu(&ptr[index[14]]->tfar); + const vfloat4 c15 = vfloat4::loadu(&ptr[index[15]]->tfar); + + vfloat16 maskf, idf, flagsf; + transpose(c0,c1,c2,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13,c14,c15, + ray.tfar, maskf, idf, flagsf); + + ray.mask = asInt(maskf); + ray.id = asInt(idf); + ray.flags = asInt(flagsf); + + return ray; + } +#endif +} diff --git a/thirdparty/embree/kernels/common/rtcore.cpp b/thirdparty/embree/kernels/common/rtcore.cpp new file mode 100644 index 000000000000..2cfb466a60c5 --- /dev/null +++ b/thirdparty/embree/kernels/common/rtcore.cpp @@ -0,0 +1,1760 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#define RTC_EXPORT_API + +#include "default.h" +#include "device.h" +#include "scene.h" +#include "context.h" +#include "../../include/embree3/rtcore_ray.h" +using namespace embree; + +RTC_NAMESPACE_BEGIN; + + /* mutex to make API thread safe */ + static MutexSys g_mutex; + + RTC_API RTCDevice rtcNewDevice(const char* config) + { + RTC_CATCH_BEGIN; + RTC_TRACE(rtcNewDevice); + Lock lock(g_mutex); + Device* device = new Device(config); + return (RTCDevice) device->refInc(); + RTC_CATCH_END(nullptr); + return (RTCDevice) nullptr; + } + + RTC_API void rtcRetainDevice(RTCDevice hdevice) + { + Device* device = (Device*) hdevice; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcRetainDevice); + RTC_VERIFY_HANDLE(hdevice); + Lock lock(g_mutex); + device->refInc(); + RTC_CATCH_END(nullptr); + } + + RTC_API void rtcReleaseDevice(RTCDevice hdevice) + { + Device* device = (Device*) hdevice; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcReleaseDevice); + RTC_VERIFY_HANDLE(hdevice); + Lock lock(g_mutex); + device->refDec(); + RTC_CATCH_END(nullptr); + } + + RTC_API ssize_t rtcGetDeviceProperty(RTCDevice hdevice, RTCDeviceProperty prop) + { + Device* device = (Device*) hdevice; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetDeviceProperty); + RTC_VERIFY_HANDLE(hdevice); + Lock lock(g_mutex); + return device->getProperty(prop); + RTC_CATCH_END(device); + return 0; + } + + RTC_API void rtcSetDeviceProperty(RTCDevice hdevice, const RTCDeviceProperty prop, ssize_t val) + { + Device* device = (Device*) hdevice; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetDeviceProperty); + const bool internal_prop = (size_t)prop >= 1000000 && (size_t)prop < 1000004; + if (!internal_prop) RTC_VERIFY_HANDLE(hdevice); // allow NULL device for special internal settings + Lock lock(g_mutex); + device->setProperty(prop,val); + RTC_CATCH_END(device); + } + + RTC_API RTCError rtcGetDeviceError(RTCDevice hdevice) + { + Device* device = (Device*) hdevice; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetDeviceError); + if (device == nullptr) return Device::getThreadErrorCode(); + else return device->getDeviceErrorCode(); + RTC_CATCH_END(device); + return RTC_ERROR_UNKNOWN; + } + + RTC_API void rtcSetDeviceErrorFunction(RTCDevice hdevice, RTCErrorFunction error, void* userPtr) + { + Device* device = (Device*) hdevice; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetDeviceErrorFunction); + RTC_VERIFY_HANDLE(hdevice); + device->setErrorFunction(error, userPtr); + RTC_CATCH_END(device); + } + + RTC_API void rtcSetDeviceMemoryMonitorFunction(RTCDevice hdevice, RTCMemoryMonitorFunction memoryMonitor, void* userPtr) + { + Device* device = (Device*) hdevice; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetDeviceMemoryMonitorFunction); + device->setMemoryMonitorFunction(memoryMonitor, userPtr); + RTC_CATCH_END(device); + } + + RTC_API RTCBuffer rtcNewBuffer(RTCDevice hdevice, size_t byteSize) + { + RTC_CATCH_BEGIN; + RTC_TRACE(rtcNewBuffer); + RTC_VERIFY_HANDLE(hdevice); + Buffer* buffer = new Buffer((Device*)hdevice, byteSize); + return (RTCBuffer)buffer->refInc(); + RTC_CATCH_END((Device*)hdevice); + return nullptr; + } + + RTC_API RTCBuffer rtcNewSharedBuffer(RTCDevice hdevice, void* ptr, size_t byteSize) + { + RTC_CATCH_BEGIN; + RTC_TRACE(rtcNewSharedBuffer); + RTC_VERIFY_HANDLE(hdevice); + Buffer* buffer = new Buffer((Device*)hdevice, byteSize, ptr); + return (RTCBuffer)buffer->refInc(); + RTC_CATCH_END((Device*)hdevice); + return nullptr; + } + + RTC_API void* rtcGetBufferData(RTCBuffer hbuffer) + { + Buffer* buffer = (Buffer*)hbuffer; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetBufferData); + RTC_VERIFY_HANDLE(hbuffer); + return buffer->data(); + RTC_CATCH_END2(buffer); + return nullptr; + } + + RTC_API void rtcRetainBuffer(RTCBuffer hbuffer) + { + Buffer* buffer = (Buffer*)hbuffer; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcRetainBuffer); + RTC_VERIFY_HANDLE(hbuffer); + buffer->refInc(); + RTC_CATCH_END2(buffer); + } + + RTC_API void rtcReleaseBuffer(RTCBuffer hbuffer) + { + Buffer* buffer = (Buffer*)hbuffer; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcReleaseBuffer); + RTC_VERIFY_HANDLE(hbuffer); + buffer->refDec(); + RTC_CATCH_END2(buffer); + } + + RTC_API RTCScene rtcNewScene (RTCDevice hdevice) + { + RTC_CATCH_BEGIN; + RTC_TRACE(rtcNewScene); + RTC_VERIFY_HANDLE(hdevice); + Scene* scene = new Scene((Device*)hdevice); + return (RTCScene) scene->refInc(); + RTC_CATCH_END((Device*)hdevice); + return nullptr; + } + + RTC_API RTCDevice rtcGetSceneDevice(RTCScene hscene) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetSceneDevice); + RTC_VERIFY_HANDLE(hscene); + return (RTCDevice)scene->device->refInc(); // user will own one additional device reference + RTC_CATCH_END2(scene); + return (RTCDevice)nullptr; + } + + RTC_API void rtcSetSceneProgressMonitorFunction(RTCScene hscene, RTCProgressMonitorFunction progress, void* ptr) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetSceneProgressMonitorFunction); + RTC_VERIFY_HANDLE(hscene); + Lock lock(g_mutex); + scene->setProgressMonitorFunction(progress,ptr); + RTC_CATCH_END2(scene); + } + + RTC_API void rtcSetSceneBuildQuality (RTCScene hscene, RTCBuildQuality quality) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetSceneBuildQuality); + RTC_VERIFY_HANDLE(hscene); + if (quality != RTC_BUILD_QUALITY_LOW && + quality != RTC_BUILD_QUALITY_MEDIUM && + quality != RTC_BUILD_QUALITY_HIGH) + throw std::runtime_error("invalid build quality"); + scene->setBuildQuality(quality); + RTC_CATCH_END2(scene); + } + + RTC_API void rtcSetSceneFlags (RTCScene hscene, RTCSceneFlags flags) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetSceneFlags); + RTC_VERIFY_HANDLE(hscene); + scene->setSceneFlags(flags); + RTC_CATCH_END2(scene); + } + + RTC_API RTCSceneFlags rtcGetSceneFlags(RTCScene hscene) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetSceneFlags); + RTC_VERIFY_HANDLE(hscene); + return scene->getSceneFlags(); + RTC_CATCH_END2(scene); + return RTC_SCENE_FLAG_NONE; + } + + RTC_API void rtcCommitScene (RTCScene hscene) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcCommitScene); + RTC_VERIFY_HANDLE(hscene); + scene->commit(false); + RTC_CATCH_END2(scene); + } + + RTC_API void rtcJoinCommitScene (RTCScene hscene) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcJoinCommitScene); + RTC_VERIFY_HANDLE(hscene); + scene->commit(true); + RTC_CATCH_END2(scene); + } + + RTC_API void rtcGetSceneBounds(RTCScene hscene, RTCBounds* bounds_o) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetSceneBounds); + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + BBox3fa bounds = scene->bounds.bounds(); + bounds_o->lower_x = bounds.lower.x; + bounds_o->lower_y = bounds.lower.y; + bounds_o->lower_z = bounds.lower.z; + bounds_o->align0 = 0; + bounds_o->upper_x = bounds.upper.x; + bounds_o->upper_y = bounds.upper.y; + bounds_o->upper_z = bounds.upper.z; + bounds_o->align1 = 0; + RTC_CATCH_END2(scene); + } + + RTC_API void rtcGetSceneLinearBounds(RTCScene hscene, RTCLinearBounds* bounds_o) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetSceneBounds); + RTC_VERIFY_HANDLE(hscene); + if (bounds_o == nullptr) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"invalid destination pointer"); + if (scene->isModified()) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + + bounds_o->bounds0.lower_x = scene->bounds.bounds0.lower.x; + bounds_o->bounds0.lower_y = scene->bounds.bounds0.lower.y; + bounds_o->bounds0.lower_z = scene->bounds.bounds0.lower.z; + bounds_o->bounds0.align0 = 0; + bounds_o->bounds0.upper_x = scene->bounds.bounds0.upper.x; + bounds_o->bounds0.upper_y = scene->bounds.bounds0.upper.y; + bounds_o->bounds0.upper_z = scene->bounds.bounds0.upper.z; + bounds_o->bounds0.align1 = 0; + bounds_o->bounds1.lower_x = scene->bounds.bounds1.lower.x; + bounds_o->bounds1.lower_y = scene->bounds.bounds1.lower.y; + bounds_o->bounds1.lower_z = scene->bounds.bounds1.lower.z; + bounds_o->bounds1.align0 = 0; + bounds_o->bounds1.upper_x = scene->bounds.bounds1.upper.x; + bounds_o->bounds1.upper_y = scene->bounds.bounds1.upper.y; + bounds_o->bounds1.upper_z = scene->bounds.bounds1.upper.z; + bounds_o->bounds1.align1 = 0; + RTC_CATCH_END2(scene); + } + + RTC_API void rtcCollide (RTCScene hscene0, RTCScene hscene1, RTCCollideFunc callback, void* userPtr) + { + Scene* scene0 = (Scene*) hscene0; + Scene* scene1 = (Scene*) hscene1; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcCollide); +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene0); + RTC_VERIFY_HANDLE(hscene1); + if (scene0->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene got not committed"); + if (scene1->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene got not committed"); + if (scene0->device != scene1->device) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scenes are from different devices"); + auto nUserPrims0 = scene0->getNumPrimitives (Geometry::MTY_USER_GEOMETRY, false); + auto nUserPrims1 = scene1->getNumPrimitives (Geometry::MTY_USER_GEOMETRY, false); + if (scene0->numPrimitives() != nUserPrims0 && scene1->numPrimitives() != nUserPrims1) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scenes must only contain user geometries with a single timestep"); +#endif + scene0->intersectors.collide(scene0,scene1,callback,userPtr); + RTC_CATCH_END(scene0->device); + } + + inline bool pointQuery(Scene* scene, RTCPointQuery* query, RTCPointQueryContext* userContext, RTCPointQueryFunction queryFunc, void* userPtr) + { + bool changed = false; + if (userContext->instStackSize > 0) + { + const AffineSpace3fa transform = AffineSpace3fa_load_unaligned((AffineSpace3fa*)userContext->world2inst[userContext->instStackSize-1]); + + float similarityScale = 0.f; + const bool similtude = similarityTransform(transform, &similarityScale); + assert((similtude && similarityScale > 0) || (!similtude && similarityScale == 0.f)); + + PointQuery query_inst; + query_inst.p = xfmPoint(transform, Vec3fa(query->x, query->y, query->z)); + query_inst.radius = query->radius * similarityScale; + query_inst.time = query->time; + + PointQueryContext context_inst(scene, (PointQuery*)query, + similtude ? POINT_QUERY_TYPE_SPHERE : POINT_QUERY_TYPE_AABB, + queryFunc, userContext, similarityScale, userPtr); + changed = scene->intersectors.pointQuery((PointQuery*)&query_inst, &context_inst); + } + else + { + PointQueryContext context(scene, (PointQuery*)query, + POINT_QUERY_TYPE_SPHERE, queryFunc, userContext, 1.f, userPtr); + changed = scene->intersectors.pointQuery((PointQuery*)query, &context); + } + return changed; + } + + RTC_API bool rtcPointQuery(RTCScene hscene, RTCPointQuery* query, RTCPointQueryContext* userContext, RTCPointQueryFunction queryFunc, void* userPtr) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcPointQuery); +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + RTC_VERIFY_HANDLE(userContext); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene got not committed"); + if (((size_t)query) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "query not aligned to 16 bytes"); + if (((size_t)userContext) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "context not aligned to 16 bytes"); +#endif + + return pointQuery(scene, query, userContext, queryFunc, userPtr); + RTC_CATCH_END2_FALSE(scene); + } + + RTC_API bool rtcPointQuery4 (const int* valid, RTCScene hscene, RTCPointQuery4* query, struct RTCPointQueryContext* userContext, RTCPointQueryFunction queryFunc, void** userPtrN) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcPointQuery4); + +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene got not committed"); + if (((size_t)valid) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "mask not aligned to 16 bytes"); + if (((size_t)query) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "query not aligned to 16 bytes"); +#endif + STAT(size_t cnt=0; for (size_t i=0; i<4; i++) cnt += ((int*)valid)[i] == -1;); + STAT3(point_query.travs,cnt,cnt,cnt); + + bool changed = false; + PointQuery4* query4 = (PointQuery4*)query; + PointQuery query1; + for (size_t i=0; i<4; i++) { + if (!valid[i]) continue; + query4->get(i,query1); + changed |= pointQuery(scene, (RTCPointQuery*)&query1, userContext, queryFunc, userPtrN?userPtrN[i]:NULL); + query4->set(i,query1); + } + return changed; + RTC_CATCH_END2_FALSE(scene); + } + + RTC_API bool rtcPointQuery8 (const int* valid, RTCScene hscene, RTCPointQuery8* query, struct RTCPointQueryContext* userContext, RTCPointQueryFunction queryFunc, void** userPtrN) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcPointQuery8); + +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene got not committed"); + if (((size_t)valid) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "mask not aligned to 16 bytes"); + if (((size_t)query) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "query not aligned to 16 bytes"); +#endif + STAT(size_t cnt=0; for (size_t i=0; i<4; i++) cnt += ((int*)valid)[i] == -1;); + STAT3(point_query.travs,cnt,cnt,cnt); + + bool changed = false; + PointQuery8* query8 = (PointQuery8*)query; + PointQuery query1; + for (size_t i=0; i<8; i++) { + if (!valid[i]) continue; + query8->get(i,query1); + changed |= pointQuery(scene, (RTCPointQuery*)&query1, userContext, queryFunc, userPtrN?userPtrN[i]:NULL); + query8->set(i,query1); + } + return changed; + RTC_CATCH_END2_FALSE(scene); + } + + RTC_API bool rtcPointQuery16 (const int* valid, RTCScene hscene, RTCPointQuery16* query, struct RTCPointQueryContext* userContext, RTCPointQueryFunction queryFunc, void** userPtrN) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcPointQuery16); + +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene got not committed"); + if (((size_t)valid) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "mask not aligned to 16 bytes"); + if (((size_t)query) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "query not aligned to 16 bytes"); +#endif + STAT(size_t cnt=0; for (size_t i=0; i<4; i++) cnt += ((int*)valid)[i] == -1;); + STAT3(point_query.travs,cnt,cnt,cnt); + + bool changed = false; + PointQuery16* query16 = (PointQuery16*)query; + PointQuery query1; + for (size_t i=0; i<16; i++) { + if (!valid[i]) continue; + PointQuery query1; query16->get(i,query1); + changed |= pointQuery(scene, (RTCPointQuery*)&query1, userContext, queryFunc, userPtrN?userPtrN[i]:NULL); + query16->set(i,query1); + } + return changed; + RTC_CATCH_END2_FALSE(scene); + } + + RTC_API void rtcIntersect1 (RTCScene hscene, RTCIntersectContext* user_context, RTCRayHit* rayhit) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcIntersect1); +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)rayhit) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "ray not aligned to 16 bytes"); +#endif + STAT3(normal.travs,1,1,1); + IntersectContext context(scene,user_context); + scene->intersectors.intersect(*rayhit,&context); +#if defined(DEBUG) + ((RayHit*)rayhit)->verifyHit(); +#endif + RTC_CATCH_END2(scene); + } + + RTC_API void rtcIntersect4 (const int* valid, RTCScene hscene, RTCIntersectContext* user_context, RTCRayHit4* rayhit) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcIntersect4); + +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)valid) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "mask not aligned to 16 bytes"); + if (((size_t)rayhit) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit not aligned to 16 bytes"); +#endif + STAT(size_t cnt=0; for (size_t i=0; i<4; i++) cnt += ((int*)valid)[i] == -1;); + STAT3(normal.travs,cnt,cnt,cnt); + + IntersectContext context(scene,user_context); +#if !defined(EMBREE_RAY_PACKETS) + Ray4* ray4 = (Ray4*) rayhit; + for (size_t i=0; i<4; i++) { + if (!valid[i]) continue; + RayHit ray1; ray4->get(i,ray1); + scene->intersectors.intersect((RTCRayHit&)ray1,&context); + ray4->set(i,ray1); + } +#else + scene->intersectors.intersect4(valid,*rayhit,&context); +#endif + + RTC_CATCH_END2(scene); + } + + RTC_API void rtcIntersect8 (const int* valid, RTCScene hscene, RTCIntersectContext* user_context, RTCRayHit8* rayhit) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcIntersect8); + +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)valid) & 0x1F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "mask not aligned to 32 bytes"); + if (((size_t)rayhit) & 0x1F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit not aligned to 32 bytes"); +#endif + STAT(size_t cnt=0; for (size_t i=0; i<8; i++) cnt += ((int*)valid)[i] == -1;); + STAT3(normal.travs,cnt,cnt,cnt); + + IntersectContext context(scene,user_context); +#if !defined(EMBREE_RAY_PACKETS) + Ray8* ray8 = (Ray8*) rayhit; + for (size_t i=0; i<8; i++) { + if (!valid[i]) continue; + RayHit ray1; ray8->get(i,ray1); + scene->intersectors.intersect((RTCRayHit&)ray1,&context); + ray8->set(i,ray1); + } +#else + if (likely(scene->intersectors.intersector8)) + scene->intersectors.intersect8(valid,*rayhit,&context); + else + scene->device->rayStreamFilters.intersectSOA(scene,(char*)rayhit,8,1,sizeof(RTCRayHit8),&context); +#endif + RTC_CATCH_END2(scene); + } + + RTC_API void rtcIntersect16 (const int* valid, RTCScene hscene, RTCIntersectContext* user_context, RTCRayHit16* rayhit) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcIntersect16); + +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)valid) & 0x3F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "mask not aligned to 64 bytes"); + if (((size_t)rayhit) & 0x3F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit not aligned to 64 bytes"); +#endif + STAT(size_t cnt=0; for (size_t i=0; i<16; i++) cnt += ((int*)valid)[i] == -1;); + STAT3(normal.travs,cnt,cnt,cnt); + + IntersectContext context(scene,user_context); +#if !defined(EMBREE_RAY_PACKETS) + Ray16* ray16 = (Ray16*) rayhit; + for (size_t i=0; i<16; i++) { + if (!valid[i]) continue; + RayHit ray1; ray16->get(i,ray1); + scene->intersectors.intersect((RTCRayHit&)ray1,&context); + ray16->set(i,ray1); + } +#else + if (likely(scene->intersectors.intersector16)) + scene->intersectors.intersect16(valid,*rayhit,&context); + else + scene->device->rayStreamFilters.intersectSOA(scene,(char*)rayhit,16,1,sizeof(RTCRayHit16),&context); +#endif + RTC_CATCH_END2(scene); + } + + RTC_API void rtcIntersect1M (RTCScene hscene, RTCIntersectContext* user_context, RTCRayHit* rayhit, unsigned int M, size_t byteStride) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcIntersect1M); + +#if defined (EMBREE_RAY_PACKETS) +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)rayhit ) & 0x03) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "ray not aligned to 4 bytes"); +#endif + STAT3(normal.travs,M,M,M); + IntersectContext context(scene,user_context); + + /* fast codepath for single rays */ + if (likely(M == 1)) { + if (likely(rayhit->ray.tnear <= rayhit->ray.tfar)) + scene->intersectors.intersect(*rayhit,&context); + } + + /* codepath for streams */ + else { + scene->device->rayStreamFilters.intersectAOS(scene,rayhit,M,byteStride,&context); + } +#else + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcIntersect1M not supported"); +#endif + RTC_CATCH_END2(scene); + } + + RTC_API void rtcIntersect1Mp (RTCScene hscene, RTCIntersectContext* user_context, RTCRayHit** rn, unsigned int M) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcIntersect1Mp); + +#if defined (EMBREE_RAY_PACKETS) +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)rn) & 0x03) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "ray not aligned to 4 bytes"); +#endif + STAT3(normal.travs,M,M,M); + IntersectContext context(scene,user_context); + + /* fast codepath for single rays */ + if (likely(M == 1)) { + if (likely(rn[0]->ray.tnear <= rn[0]->ray.tfar)) + scene->intersectors.intersect(*rn[0],&context); + } + + /* codepath for streams */ + else { + scene->device->rayStreamFilters.intersectAOP(scene,rn,M,&context); + } +#else + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcIntersect1Mp not supported"); +#endif + RTC_CATCH_END2(scene); + } + + RTC_API void rtcIntersectNM (RTCScene hscene, RTCIntersectContext* user_context, struct RTCRayHitN* rayhit, unsigned int N, unsigned int M, size_t byteStride) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcIntersectNM); + +#if defined (EMBREE_RAY_PACKETS) +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)rayhit) & 0x03) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "ray not aligned to 4 bytes"); +#endif + STAT3(normal.travs,N*M,N*M,N*M); + IntersectContext context(scene,user_context); + + /* code path for single ray streams */ + if (likely(N == 1)) + { + /* fast code path for streams of size 1 */ + if (likely(M == 1)) { + if (likely(((RTCRayHit*)rayhit)->ray.tnear <= ((RTCRayHit*)rayhit)->ray.tfar)) + scene->intersectors.intersect(*(RTCRayHit*)rayhit,&context); + } + /* normal codepath for single ray streams */ + else { + scene->device->rayStreamFilters.intersectAOS(scene,(RTCRayHit*)rayhit,M,byteStride,&context); + } + } + /* code path for ray packet streams */ + else { + scene->device->rayStreamFilters.intersectSOA(scene,(char*)rayhit,N,M,byteStride,&context); + } +#else + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcIntersectNM not supported"); +#endif + RTC_CATCH_END2(scene); + } + + RTC_API void rtcIntersectNp (RTCScene hscene, RTCIntersectContext* user_context, const RTCRayHitNp* rayhit, unsigned int N) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcIntersectNp); + +#if defined (EMBREE_RAY_PACKETS) +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)rayhit->ray.org_x ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->ray.org_x not aligned to 4 bytes"); + if (((size_t)rayhit->ray.org_y ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->ray.org_y not aligned to 4 bytes"); + if (((size_t)rayhit->ray.org_z ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->ray.org_z not aligned to 4 bytes"); + if (((size_t)rayhit->ray.dir_x ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->ray.dir_x not aligned to 4 bytes"); + if (((size_t)rayhit->ray.dir_y ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->ray.dir_y not aligned to 4 bytes"); + if (((size_t)rayhit->ray.dir_z ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->ray.dir_z not aligned to 4 bytes"); + if (((size_t)rayhit->ray.tnear ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->ray.dir_x not aligned to 4 bytes"); + if (((size_t)rayhit->ray.tfar ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->ray.tnear not aligned to 4 bytes"); + if (((size_t)rayhit->ray.time ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->ray.time not aligned to 4 bytes"); + if (((size_t)rayhit->ray.mask ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->ray.mask not aligned to 4 bytes"); + if (((size_t)rayhit->hit.Ng_x ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->hit.Ng_x not aligned to 4 bytes"); + if (((size_t)rayhit->hit.Ng_y ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->hit.Ng_y not aligned to 4 bytes"); + if (((size_t)rayhit->hit.Ng_z ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->hit.Ng_z not aligned to 4 bytes"); + if (((size_t)rayhit->hit.u ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->hit.u not aligned to 4 bytes"); + if (((size_t)rayhit->hit.v ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->hit.v not aligned to 4 bytes"); + if (((size_t)rayhit->hit.geomID) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->hit.geomID not aligned to 4 bytes"); + if (((size_t)rayhit->hit.primID) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->hit.primID not aligned to 4 bytes"); + if (((size_t)rayhit->hit.instID) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "rayhit->hit.instID not aligned to 4 bytes"); +#endif + STAT3(normal.travs,N,N,N); + IntersectContext context(scene,user_context); + scene->device->rayStreamFilters.intersectSOP(scene,rayhit,N,&context); +#else + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcIntersectNp not supported"); +#endif + RTC_CATCH_END2(scene); + } + + RTC_API void rtcOccluded1 (RTCScene hscene, RTCIntersectContext* user_context, RTCRay* ray) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcOccluded1); + STAT3(shadow.travs,1,1,1); +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)ray) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "ray not aligned to 16 bytes"); +#endif + IntersectContext context(scene,user_context); + scene->intersectors.occluded(*ray,&context); + RTC_CATCH_END2(scene); + } + + RTC_API void rtcOccluded4 (const int* valid, RTCScene hscene, RTCIntersectContext* user_context, RTCRay4* ray) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcOccluded4); + +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)valid) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "mask not aligned to 16 bytes"); + if (((size_t)ray) & 0x0F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "ray not aligned to 16 bytes"); +#endif + STAT(size_t cnt=0; for (size_t i=0; i<4; i++) cnt += ((int*)valid)[i] == -1;); + STAT3(shadow.travs,cnt,cnt,cnt); + + IntersectContext context(scene,user_context); +#if !defined(EMBREE_RAY_PACKETS) + RayHit4* ray4 = (RayHit4*) ray; + for (size_t i=0; i<4; i++) { + if (!valid[i]) continue; + RayHit ray1; ray4->get(i,ray1); + scene->intersectors.occluded((RTCRay&)ray1,&context); + ray4->geomID[i] = ray1.geomID; + } +#else + scene->intersectors.occluded4(valid,*ray,&context); +#endif + + RTC_CATCH_END2(scene); + } + + RTC_API void rtcOccluded8 (const int* valid, RTCScene hscene, RTCIntersectContext* user_context, RTCRay8* ray) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcOccluded8); + +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)valid) & 0x1F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "mask not aligned to 32 bytes"); + if (((size_t)ray) & 0x1F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "ray not aligned to 32 bytes"); +#endif + STAT(size_t cnt=0; for (size_t i=0; i<8; i++) cnt += ((int*)valid)[i] == -1;); + STAT3(shadow.travs,cnt,cnt,cnt); + + IntersectContext context(scene,user_context); +#if !defined(EMBREE_RAY_PACKETS) + RayHit8* ray8 = (RayHit8*) ray; + for (size_t i=0; i<8; i++) { + if (!valid[i]) continue; + RayHit ray1; ray8->get(i,ray1); + scene->intersectors.occluded((RTCRay&)ray1,&context); + ray8->set(i,ray1); + } +#else + if (likely(scene->intersectors.intersector8)) + scene->intersectors.occluded8(valid,*ray,&context); + else + scene->device->rayStreamFilters.occludedSOA(scene,(char*)ray,8,1,sizeof(RTCRay8),&context); +#endif + + RTC_CATCH_END2(scene); + } + + RTC_API void rtcOccluded16 (const int* valid, RTCScene hscene, RTCIntersectContext* user_context, RTCRay16* ray) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcOccluded16); + +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)valid) & 0x3F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "mask not aligned to 64 bytes"); + if (((size_t)ray) & 0x3F) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "ray not aligned to 64 bytes"); +#endif + STAT(size_t cnt=0; for (size_t i=0; i<16; i++) cnt += ((int*)valid)[i] == -1;); + STAT3(shadow.travs,cnt,cnt,cnt); + + IntersectContext context(scene,user_context); +#if !defined(EMBREE_RAY_PACKETS) + RayHit16* ray16 = (RayHit16*) ray; + for (size_t i=0; i<16; i++) { + if (!valid[i]) continue; + RayHit ray1; ray16->get(i,ray1); + scene->intersectors.occluded((RTCRay&)ray1,&context); + ray16->set(i,ray1); + } +#else + if (likely(scene->intersectors.intersector16)) + scene->intersectors.occluded16(valid,*ray,&context); + else + scene->device->rayStreamFilters.occludedSOA(scene,(char*)ray,16,1,sizeof(RTCRay16),&context); +#endif + + RTC_CATCH_END2(scene); + } + + RTC_API void rtcOccluded1M(RTCScene hscene, RTCIntersectContext* user_context, RTCRay* ray, unsigned int M, size_t byteStride) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcOccluded1M); + +#if defined (EMBREE_RAY_PACKETS) +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)ray) & 0x03) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "ray not aligned to 4 bytes"); +#endif + STAT3(shadow.travs,M,M,M); + IntersectContext context(scene,user_context); + /* fast codepath for streams of size 1 */ + if (likely(M == 1)) { + if (likely(ray->tnear <= ray->tfar)) + scene->intersectors.occluded (*ray,&context); + } + /* codepath for normal streams */ + else { + scene->device->rayStreamFilters.occludedAOS(scene,ray,M,byteStride,&context); + } +#else + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcOccluded1M not supported"); +#endif + RTC_CATCH_END2(scene); + } + + RTC_API void rtcOccluded1Mp(RTCScene hscene, RTCIntersectContext* user_context, RTCRay** ray, unsigned int M) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcOccluded1Mp); + +#if defined (EMBREE_RAY_PACKETS) +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)ray) & 0x03) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "ray not aligned to 4 bytes"); +#endif + STAT3(shadow.travs,M,M,M); + IntersectContext context(scene,user_context); + + /* fast codepath for streams of size 1 */ + if (likely(M == 1)) { + if (likely(ray[0]->tnear <= ray[0]->tfar)) + scene->intersectors.occluded (*ray[0],&context); + } + /* codepath for normal streams */ + else { + scene->device->rayStreamFilters.occludedAOP(scene,ray,M,&context); + } +#else + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcOccluded1Mp not supported"); +#endif + RTC_CATCH_END2(scene); + } + + RTC_API void rtcOccludedNM(RTCScene hscene, RTCIntersectContext* user_context, RTCRayN* ray, unsigned int N, unsigned int M, size_t byteStride) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcOccludedNM); + +#if defined (EMBREE_RAY_PACKETS) +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (byteStride < sizeof(RTCRayHit)) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"byteStride too small"); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)ray) & 0x03) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "ray not aligned to 4 bytes"); +#endif + STAT3(shadow.travs,N*M,N*N,N*N); + IntersectContext context(scene,user_context); + + /* codepath for single rays */ + if (likely(N == 1)) + { + /* fast path for streams of size 1 */ + if (likely(M == 1)) { + if (likely(((RTCRay*)ray)->tnear <= ((RTCRay*)ray)->tfar)) + scene->intersectors.occluded (*(RTCRay*)ray,&context); + } + /* codepath for normal ray streams */ + else { + scene->device->rayStreamFilters.occludedAOS(scene,(RTCRay*)ray,M,byteStride,&context); + } + } + /* code path for ray packet streams */ + else { + scene->device->rayStreamFilters.occludedSOA(scene,(char*)ray,N,M,byteStride,&context); + } +#else + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcOccludedNM not supported"); +#endif + RTC_CATCH_END2(scene); + } + + RTC_API void rtcOccludedNp(RTCScene hscene, RTCIntersectContext* user_context, const RTCRayNp* ray, unsigned int N) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcOccludedNp); + +#if defined (EMBREE_RAY_PACKETS) +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + if (scene->isModified()) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); + if (((size_t)ray->org_x ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "org_x not aligned to 4 bytes"); + if (((size_t)ray->org_y ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "org_y not aligned to 4 bytes"); + if (((size_t)ray->org_z ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "org_z not aligned to 4 bytes"); + if (((size_t)ray->dir_x ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "dir_x not aligned to 4 bytes"); + if (((size_t)ray->dir_y ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "dir_y not aligned to 4 bytes"); + if (((size_t)ray->dir_z ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "dir_z not aligned to 4 bytes"); + if (((size_t)ray->tnear ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "dir_x not aligned to 4 bytes"); + if (((size_t)ray->tfar ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "tnear not aligned to 4 bytes"); + if (((size_t)ray->time ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "time not aligned to 4 bytes"); + if (((size_t)ray->mask ) & 0x03 ) throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "mask not aligned to 4 bytes"); +#endif + STAT3(shadow.travs,N,N,N); + IntersectContext context(scene,user_context); + scene->device->rayStreamFilters.occludedSOP(scene,ray,N,&context); +#else + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcOccludedNp not supported"); +#endif + RTC_CATCH_END2(scene); + } + + RTC_API void rtcRetainScene (RTCScene hscene) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcRetainScene); + RTC_VERIFY_HANDLE(hscene); + scene->refInc(); + RTC_CATCH_END2(scene); + } + + RTC_API void rtcReleaseScene (RTCScene hscene) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcReleaseScene); + RTC_VERIFY_HANDLE(hscene); + scene->refDec(); + RTC_CATCH_END2(scene); + } + + RTC_API void rtcSetGeometryInstancedScene(RTCGeometry hgeometry, RTCScene hscene) + { + Geometry* geometry = (Geometry*) hgeometry; + Ref scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryInstancedScene); + RTC_VERIFY_HANDLE(hgeometry); + RTC_VERIFY_HANDLE(hscene); + geometry->setInstancedScene(scene); + RTC_CATCH_END2(geometry); + } + + AffineSpace3fa loadTransform(RTCFormat format, const float* xfm) + { + AffineSpace3fa space = one; + switch (format) + { + case RTC_FORMAT_FLOAT3X4_ROW_MAJOR: + space = AffineSpace3fa(Vec3fa(xfm[ 0], xfm[ 4], xfm[ 8]), + Vec3fa(xfm[ 1], xfm[ 5], xfm[ 9]), + Vec3fa(xfm[ 2], xfm[ 6], xfm[10]), + Vec3fa(xfm[ 3], xfm[ 7], xfm[11])); + break; + + case RTC_FORMAT_FLOAT3X4_COLUMN_MAJOR: + space = AffineSpace3fa(Vec3fa(xfm[ 0], xfm[ 1], xfm[ 2]), + Vec3fa(xfm[ 3], xfm[ 4], xfm[ 5]), + Vec3fa(xfm[ 6], xfm[ 7], xfm[ 8]), + Vec3fa(xfm[ 9], xfm[10], xfm[11])); + break; + + case RTC_FORMAT_FLOAT4X4_COLUMN_MAJOR: + space = AffineSpace3fa(Vec3fa(xfm[ 0], xfm[ 1], xfm[ 2]), + Vec3fa(xfm[ 4], xfm[ 5], xfm[ 6]), + Vec3fa(xfm[ 8], xfm[ 9], xfm[10]), + Vec3fa(xfm[12], xfm[13], xfm[14])); + break; + + default: + throw_RTCError(RTC_ERROR_INVALID_OPERATION, "invalid matrix format"); + break; + } + return space; + } + + void storeTransform(const AffineSpace3fa& space, RTCFormat format, float* xfm) + { + switch (format) + { + case RTC_FORMAT_FLOAT3X4_ROW_MAJOR: + xfm[ 0] = space.l.vx.x; xfm[ 1] = space.l.vy.x; xfm[ 2] = space.l.vz.x; xfm[ 3] = space.p.x; + xfm[ 4] = space.l.vx.y; xfm[ 5] = space.l.vy.y; xfm[ 6] = space.l.vz.y; xfm[ 7] = space.p.y; + xfm[ 8] = space.l.vx.z; xfm[ 9] = space.l.vy.z; xfm[10] = space.l.vz.z; xfm[11] = space.p.z; + break; + + case RTC_FORMAT_FLOAT3X4_COLUMN_MAJOR: + xfm[ 0] = space.l.vx.x; xfm[ 1] = space.l.vx.y; xfm[ 2] = space.l.vx.z; + xfm[ 3] = space.l.vy.x; xfm[ 4] = space.l.vy.y; xfm[ 5] = space.l.vy.z; + xfm[ 6] = space.l.vz.x; xfm[ 7] = space.l.vz.y; xfm[ 8] = space.l.vz.z; + xfm[ 9] = space.p.x; xfm[10] = space.p.y; xfm[11] = space.p.z; + break; + + case RTC_FORMAT_FLOAT4X4_COLUMN_MAJOR: + xfm[ 0] = space.l.vx.x; xfm[ 1] = space.l.vx.y; xfm[ 2] = space.l.vx.z; xfm[ 3] = 0.f; + xfm[ 4] = space.l.vy.x; xfm[ 5] = space.l.vy.y; xfm[ 6] = space.l.vy.z; xfm[ 7] = 0.f; + xfm[ 8] = space.l.vz.x; xfm[ 9] = space.l.vz.y; xfm[10] = space.l.vz.z; xfm[11] = 0.f; + xfm[12] = space.p.x; xfm[13] = space.p.y; xfm[14] = space.p.z; xfm[15] = 1.f; + break; + + default: + throw_RTCError(RTC_ERROR_INVALID_OPERATION, "invalid matrix format"); + break; + } + } + + RTC_API void rtcSetGeometryTransform(RTCGeometry hgeometry, unsigned int timeStep, RTCFormat format, const void* xfm) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryTransform); + RTC_VERIFY_HANDLE(hgeometry); + RTC_VERIFY_HANDLE(xfm); + const AffineSpace3fa transform = loadTransform(format, (const float*)xfm); + geometry->setTransform(transform, timeStep); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryTransformQuaternion(RTCGeometry hgeometry, unsigned int timeStep, const RTCQuaternionDecomposition* qd) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryTransformQuaternion); + RTC_VERIFY_HANDLE(hgeometry); + RTC_VERIFY_HANDLE(qd); + + AffineSpace3fx transform; + transform.l.vx.x = qd->scale_x; + transform.l.vy.y = qd->scale_y; + transform.l.vz.z = qd->scale_z; + transform.l.vy.x = qd->skew_xy; + transform.l.vz.x = qd->skew_xz; + transform.l.vz.y = qd->skew_yz; + transform.l.vx.y = qd->translation_x; + transform.l.vx.z = qd->translation_y; + transform.l.vy.z = qd->translation_z; + transform.p.x = qd->shift_x; + transform.p.y = qd->shift_y; + transform.p.z = qd->shift_z; + + // normalize quaternion + Quaternion3f q(qd->quaternion_r, qd->quaternion_i, qd->quaternion_j, qd->quaternion_k); + q = normalize(q); + transform.l.vx.w = q.i; + transform.l.vy.w = q.j; + transform.l.vz.w = q.k; + transform.p.w = q.r; + + geometry->setQuaternionDecomposition(transform, timeStep); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcGetGeometryTransform(RTCGeometry hgeometry, float time, RTCFormat format, void* xfm) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetGeometryTransform); + const AffineSpace3fa transform = geometry->getTransform(time); + storeTransform(transform, format, (float*)xfm); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcFilterIntersection(const struct RTCIntersectFunctionNArguments* const args_i, const struct RTCFilterFunctionNArguments* filter_args) + { + IntersectFunctionNArguments* args = (IntersectFunctionNArguments*) args_i; + args->report(args,filter_args); + } + + RTC_API void rtcFilterOcclusion(const struct RTCOccludedFunctionNArguments* const args_i, const struct RTCFilterFunctionNArguments* filter_args) + { + OccludedFunctionNArguments* args = (OccludedFunctionNArguments*) args_i; + args->report(args,filter_args); + } + + RTC_API RTCGeometry rtcNewGeometry (RTCDevice hdevice, RTCGeometryType type) + { + Device* device = (Device*) hdevice; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcNewGeometry); + RTC_VERIFY_HANDLE(hdevice); + + switch (type) + { + case RTC_GEOMETRY_TYPE_TRIANGLE: + { +#if defined(EMBREE_GEOMETRY_TRIANGLE) + createTriangleMeshTy createTriangleMesh = nullptr; + SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(device->enabled_cpu_features,createTriangleMesh); + Geometry* geom = createTriangleMesh(device); + return (RTCGeometry) geom->refInc(); +#else + throw_RTCError(RTC_ERROR_UNKNOWN,"RTC_GEOMETRY_TYPE_TRIANGLE is not supported"); +#endif + } + + case RTC_GEOMETRY_TYPE_QUAD: + { +#if defined(EMBREE_GEOMETRY_QUAD) + createQuadMeshTy createQuadMesh = nullptr; + SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(device->enabled_cpu_features,createQuadMesh); + Geometry* geom = createQuadMesh(device); + return (RTCGeometry) geom->refInc(); +#else + throw_RTCError(RTC_ERROR_UNKNOWN,"RTC_GEOMETRY_TYPE_QUAD is not supported"); +#endif + } + + case RTC_GEOMETRY_TYPE_SPHERE_POINT: + case RTC_GEOMETRY_TYPE_DISC_POINT: + case RTC_GEOMETRY_TYPE_ORIENTED_DISC_POINT: + { +#if defined(EMBREE_GEOMETRY_POINT) + createPointsTy createPoints = nullptr; + SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(device->enabled_builder_cpu_features, createPoints); + + Geometry *geom; + switch(type) { + case RTC_GEOMETRY_TYPE_SPHERE_POINT: + geom = createPoints(device, Geometry::GTY_SPHERE_POINT); + break; + case RTC_GEOMETRY_TYPE_DISC_POINT: + geom = createPoints(device, Geometry::GTY_DISC_POINT); + break; + case RTC_GEOMETRY_TYPE_ORIENTED_DISC_POINT: + geom = createPoints(device, Geometry::GTY_ORIENTED_DISC_POINT); + break; + default: + geom = nullptr; + break; + } + return (RTCGeometry) geom->refInc(); +#else + throw_RTCError(RTC_ERROR_UNKNOWN,"RTC_GEOMETRY_TYPE_POINT is not supported"); +#endif + } + + case RTC_GEOMETRY_TYPE_CONE_LINEAR_CURVE: + case RTC_GEOMETRY_TYPE_ROUND_LINEAR_CURVE: + case RTC_GEOMETRY_TYPE_FLAT_LINEAR_CURVE: + + case RTC_GEOMETRY_TYPE_ROUND_BEZIER_CURVE: + case RTC_GEOMETRY_TYPE_FLAT_BEZIER_CURVE: + case RTC_GEOMETRY_TYPE_NORMAL_ORIENTED_BEZIER_CURVE: + + case RTC_GEOMETRY_TYPE_ROUND_BSPLINE_CURVE: + case RTC_GEOMETRY_TYPE_FLAT_BSPLINE_CURVE: + case RTC_GEOMETRY_TYPE_NORMAL_ORIENTED_BSPLINE_CURVE: + + case RTC_GEOMETRY_TYPE_ROUND_HERMITE_CURVE: + case RTC_GEOMETRY_TYPE_FLAT_HERMITE_CURVE: + case RTC_GEOMETRY_TYPE_NORMAL_ORIENTED_HERMITE_CURVE: + + case RTC_GEOMETRY_TYPE_ROUND_CATMULL_ROM_CURVE: + case RTC_GEOMETRY_TYPE_FLAT_CATMULL_ROM_CURVE: + case RTC_GEOMETRY_TYPE_NORMAL_ORIENTED_CATMULL_ROM_CURVE: + { +#if defined(EMBREE_GEOMETRY_CURVE) + createLineSegmentsTy createLineSegments = nullptr; + SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(device->enabled_cpu_features,createLineSegments); + createCurvesTy createCurves = nullptr; + SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(device->enabled_cpu_features,createCurves); + + Geometry* geom; + switch (type) { + case RTC_GEOMETRY_TYPE_CONE_LINEAR_CURVE : geom = createLineSegments (device,Geometry::GTY_CONE_LINEAR_CURVE); break; + case RTC_GEOMETRY_TYPE_ROUND_LINEAR_CURVE : geom = createLineSegments (device,Geometry::GTY_ROUND_LINEAR_CURVE); break; + case RTC_GEOMETRY_TYPE_FLAT_LINEAR_CURVE : geom = createLineSegments (device,Geometry::GTY_FLAT_LINEAR_CURVE); break; + //case RTC_GEOMETRY_TYPE_NORMAL_ORIENTED_LINEAR_CURVE : geom = createLineSegments (device,Geometry::GTY_ORIENTED_LINEAR_CURVE); break; + + case RTC_GEOMETRY_TYPE_ROUND_BEZIER_CURVE : geom = createCurves(device,Geometry::GTY_ROUND_BEZIER_CURVE); break; + case RTC_GEOMETRY_TYPE_FLAT_BEZIER_CURVE : geom = createCurves(device,Geometry::GTY_FLAT_BEZIER_CURVE); break; + case RTC_GEOMETRY_TYPE_NORMAL_ORIENTED_BEZIER_CURVE : geom = createCurves(device,Geometry::GTY_ORIENTED_BEZIER_CURVE); break; + + case RTC_GEOMETRY_TYPE_ROUND_BSPLINE_CURVE : geom = createCurves(device,Geometry::GTY_ROUND_BSPLINE_CURVE); break; + case RTC_GEOMETRY_TYPE_FLAT_BSPLINE_CURVE : geom = createCurves(device,Geometry::GTY_FLAT_BSPLINE_CURVE); break; + case RTC_GEOMETRY_TYPE_NORMAL_ORIENTED_BSPLINE_CURVE : geom = createCurves(device,Geometry::GTY_ORIENTED_BSPLINE_CURVE); break; + + case RTC_GEOMETRY_TYPE_ROUND_HERMITE_CURVE : geom = createCurves(device,Geometry::GTY_ROUND_HERMITE_CURVE); break; + case RTC_GEOMETRY_TYPE_FLAT_HERMITE_CURVE : geom = createCurves(device,Geometry::GTY_FLAT_HERMITE_CURVE); break; + case RTC_GEOMETRY_TYPE_NORMAL_ORIENTED_HERMITE_CURVE : geom = createCurves(device,Geometry::GTY_ORIENTED_HERMITE_CURVE); break; + + case RTC_GEOMETRY_TYPE_ROUND_CATMULL_ROM_CURVE : geom = createCurves(device,Geometry::GTY_ROUND_CATMULL_ROM_CURVE); break; + case RTC_GEOMETRY_TYPE_FLAT_CATMULL_ROM_CURVE : geom = createCurves(device,Geometry::GTY_FLAT_CATMULL_ROM_CURVE); break; + case RTC_GEOMETRY_TYPE_NORMAL_ORIENTED_CATMULL_ROM_CURVE : geom = createCurves(device,Geometry::GTY_ORIENTED_CATMULL_ROM_CURVE); break; + default: geom = nullptr; break; + } + return (RTCGeometry) geom->refInc(); +#else + throw_RTCError(RTC_ERROR_UNKNOWN,"RTC_GEOMETRY_TYPE_CURVE is not supported"); +#endif + } + + case RTC_GEOMETRY_TYPE_SUBDIVISION: + { +#if defined(EMBREE_GEOMETRY_SUBDIVISION) + createSubdivMeshTy createSubdivMesh = nullptr; + SELECT_SYMBOL_DEFAULT_AVX(device->enabled_cpu_features,createSubdivMesh); + //SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(device->enabled_cpu_features,createSubdivMesh); // FIXME: this does not work for some reason? + Geometry* geom = createSubdivMesh(device); + return (RTCGeometry) geom->refInc(); +#else + throw_RTCError(RTC_ERROR_UNKNOWN,"RTC_GEOMETRY_TYPE_SUBDIVISION is not supported"); +#endif + } + + case RTC_GEOMETRY_TYPE_USER: + { +#if defined(EMBREE_GEOMETRY_USER) + createUserGeometryTy createUserGeometry = nullptr; + SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(device->enabled_cpu_features,createUserGeometry); + Geometry* geom = createUserGeometry(device); + return (RTCGeometry) geom->refInc(); +#else + throw_RTCError(RTC_ERROR_UNKNOWN,"RTC_GEOMETRY_TYPE_USER is not supported"); +#endif + } + + case RTC_GEOMETRY_TYPE_INSTANCE: + { +#if defined(EMBREE_GEOMETRY_INSTANCE) + createInstanceTy createInstance = nullptr; + SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(device->enabled_cpu_features,createInstance); + Geometry* geom = createInstance(device); + return (RTCGeometry) geom->refInc(); +#else + throw_RTCError(RTC_ERROR_UNKNOWN,"RTC_GEOMETRY_TYPE_INSTANCE is not supported"); +#endif + } + + case RTC_GEOMETRY_TYPE_GRID: + { +#if defined(EMBREE_GEOMETRY_GRID) + createGridMeshTy createGridMesh = nullptr; + SELECT_SYMBOL_DEFAULT_AVX_AVX2_AVX512KNL_AVX512SKX(device->enabled_cpu_features,createGridMesh); + Geometry* geom = createGridMesh(device); + return (RTCGeometry) geom->refInc(); +#else + throw_RTCError(RTC_ERROR_UNKNOWN,"RTC_GEOMETRY_TYPE_GRID is not supported"); +#endif + } + + default: + throw_RTCError(RTC_ERROR_UNKNOWN,"invalid geometry type"); + } + + RTC_CATCH_END(device); + return nullptr; + } + + RTC_API void rtcSetGeometryUserPrimitiveCount(RTCGeometry hgeometry, unsigned int userPrimitiveCount) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryUserPrimitiveCount); + RTC_VERIFY_HANDLE(hgeometry); + + if (unlikely(geometry->getType() != Geometry::GTY_USER_GEOMETRY)) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"operation only allowed for user geometries"); + + geometry->setNumPrimitives(userPrimitiveCount); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryTimeStepCount(RTCGeometry hgeometry, unsigned int timeStepCount) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryTimeStepCount); + RTC_VERIFY_HANDLE(hgeometry); + + if (timeStepCount > RTC_MAX_TIME_STEP_COUNT) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"number of time steps is out of range"); + + geometry->setNumTimeSteps(timeStepCount); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryTimeRange(RTCGeometry hgeometry, float startTime, float endTime) + { + Ref geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryTimeRange); + RTC_VERIFY_HANDLE(hgeometry); + + if (startTime > endTime) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"startTime has to be smaller or equal to the endTime"); + + geometry->setTimeRange(BBox1f(startTime,endTime)); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryVertexAttributeCount(RTCGeometry hgeometry, unsigned int N) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryVertexAttributeCount); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setVertexAttributeCount(N); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryTopologyCount(RTCGeometry hgeometry, unsigned int N) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryTopologyCount); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setTopologyCount(N); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryBuildQuality (RTCGeometry hgeometry, RTCBuildQuality quality) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryBuildQuality); + RTC_VERIFY_HANDLE(hgeometry); + if (quality != RTC_BUILD_QUALITY_LOW && + quality != RTC_BUILD_QUALITY_MEDIUM && + quality != RTC_BUILD_QUALITY_HIGH && + quality != RTC_BUILD_QUALITY_REFIT) + throw std::runtime_error("invalid build quality"); + geometry->setBuildQuality(quality); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryMaxRadiusScale(RTCGeometry hgeometry, float maxRadiusScale) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryMaxRadiusScale); + RTC_VERIFY_HANDLE(hgeometry); +#if RTC_MIN_WIDTH + if (maxRadiusScale < 1.0f) throw_RTCError(RTC_ERROR_INVALID_OPERATION,"maximal radius scale has to be larger or equal to 1"); + geometry->setMaxRadiusScale(maxRadiusScale); +#else + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"min-width feature is not enabled"); +#endif + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryMask (RTCGeometry hgeometry, unsigned int mask) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryMask); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setMask(mask); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometrySubdivisionMode (RTCGeometry hgeometry, unsigned topologyID, RTCSubdivisionMode mode) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometrySubdivisionMode); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setSubdivisionMode(topologyID,mode); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryVertexAttributeTopology(RTCGeometry hgeometry, unsigned int vertexAttributeID, unsigned int topologyID) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryVertexAttributeTopology); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setVertexAttributeTopology(vertexAttributeID, topologyID); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryBuffer(RTCGeometry hgeometry, RTCBufferType type, unsigned int slot, RTCFormat format, RTCBuffer hbuffer, size_t byteOffset, size_t byteStride, size_t itemCount) + { + Geometry* geometry = (Geometry*) hgeometry; + Ref buffer = (Buffer*)hbuffer; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryBuffer); + RTC_VERIFY_HANDLE(hgeometry); + RTC_VERIFY_HANDLE(hbuffer); + + if (geometry->device != buffer->device) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"inputs are from different devices"); + + if (itemCount > 0xFFFFFFFFu) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"buffer too large"); + + geometry->setBuffer(type, slot, format, buffer, byteOffset, byteStride, (unsigned int)itemCount); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetSharedGeometryBuffer(RTCGeometry hgeometry, RTCBufferType type, unsigned int slot, RTCFormat format, const void* ptr, size_t byteOffset, size_t byteStride, size_t itemCount) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetSharedGeometryBuffer); + RTC_VERIFY_HANDLE(hgeometry); + + if (itemCount > 0xFFFFFFFFu) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"buffer too large"); + + Ref buffer = new Buffer(geometry->device, itemCount*byteStride, (char*)ptr + byteOffset); + geometry->setBuffer(type, slot, format, buffer, 0, byteStride, (unsigned int)itemCount); + RTC_CATCH_END2(geometry); + } + + RTC_API void* rtcSetNewGeometryBuffer(RTCGeometry hgeometry, RTCBufferType type, unsigned int slot, RTCFormat format, size_t byteStride, size_t itemCount) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetNewGeometryBuffer); + RTC_VERIFY_HANDLE(hgeometry); + + if (itemCount > 0xFFFFFFFFu) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"buffer too large"); + + /* vertex buffers need to get overallocated slightly as elements are accessed using SSE loads */ + size_t bytes = itemCount*byteStride; + if (type == RTC_BUFFER_TYPE_VERTEX || type == RTC_BUFFER_TYPE_VERTEX_ATTRIBUTE) + bytes += (16 - (byteStride%16))%16; + + Ref buffer = new Buffer(geometry->device, bytes); + geometry->setBuffer(type, slot, format, buffer, 0, byteStride, (unsigned int)itemCount); + return buffer->data(); + RTC_CATCH_END2(geometry); + return nullptr; + } + + RTC_API void* rtcGetGeometryBufferData(RTCGeometry hgeometry, RTCBufferType type, unsigned int slot) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetGeometryBufferData); + RTC_VERIFY_HANDLE(hgeometry); + return geometry->getBuffer(type, slot); + RTC_CATCH_END2(geometry); + return nullptr; + } + + RTC_API void rtcEnableGeometry (RTCGeometry hgeometry) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcEnableGeometry); + RTC_VERIFY_HANDLE(hgeometry); + geometry->enable(); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcUpdateGeometryBuffer (RTCGeometry hgeometry, RTCBufferType type, unsigned int slot) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcUpdateGeometryBuffer); + RTC_VERIFY_HANDLE(hgeometry); + geometry->updateBuffer(type, slot); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcDisableGeometry (RTCGeometry hgeometry) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcDisableGeometry); + RTC_VERIFY_HANDLE(hgeometry); + geometry->disable(); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryTessellationRate (RTCGeometry hgeometry, float tessellationRate) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryTessellationRate); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setTessellationRate(tessellationRate); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryUserData (RTCGeometry hgeometry, void* ptr) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryUserData); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setUserData(ptr); + RTC_CATCH_END2(geometry); + } + + RTC_API void* rtcGetGeometryUserData (RTCGeometry hgeometry) + { + Geometry* geometry = (Geometry*) hgeometry; // no ref counting here! + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetGeometryUserData); + RTC_VERIFY_HANDLE(hgeometry); + return geometry->getUserData(); + RTC_CATCH_END2(geometry); + return nullptr; + } + + RTC_API void rtcSetGeometryBoundsFunction (RTCGeometry hgeometry, RTCBoundsFunction bounds, void* userPtr) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryBoundsFunction); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setBoundsFunction(bounds,userPtr); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryDisplacementFunction (RTCGeometry hgeometry, RTCDisplacementFunctionN displacement) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryDisplacementFunction); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setDisplacementFunction(displacement); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryIntersectFunction (RTCGeometry hgeometry, RTCIntersectFunctionN intersect) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryIntersectFunction); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setIntersectFunctionN(intersect); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryPointQueryFunction(RTCGeometry hgeometry, RTCPointQueryFunction pointQuery) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryPointQueryFunction); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setPointQueryFunction(pointQuery); + RTC_CATCH_END2(geometry); + } + + RTC_API unsigned int rtcGetGeometryFirstHalfEdge(RTCGeometry hgeometry, unsigned int faceID) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetGeometryFirstHalfEdge); + return geometry->getFirstHalfEdge(faceID); + RTC_CATCH_END2(geometry); + return -1; + } + + RTC_API unsigned int rtcGetGeometryFace(RTCGeometry hgeometry, unsigned int edgeID) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetGeometryFace); + return geometry->getFace(edgeID); + RTC_CATCH_END2(geometry); + return -1; + } + + RTC_API unsigned int rtcGetGeometryNextHalfEdge(RTCGeometry hgeometry, unsigned int edgeID) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetGeometryNextHalfEdge); + return geometry->getNextHalfEdge(edgeID); + RTC_CATCH_END2(geometry); + return -1; + } + + RTC_API unsigned int rtcGetGeometryPreviousHalfEdge(RTCGeometry hgeometry, unsigned int edgeID) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetGeometryPreviousHalfEdge); + return geometry->getPreviousHalfEdge(edgeID); + RTC_CATCH_END2(geometry); + return -1; + } + + RTC_API unsigned int rtcGetGeometryOppositeHalfEdge(RTCGeometry hgeometry, unsigned int topologyID, unsigned int edgeID) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetGeometryOppositeHalfEdge); + return geometry->getOppositeHalfEdge(topologyID,edgeID); + RTC_CATCH_END2(geometry); + return -1; + } + + RTC_API void rtcSetGeometryOccludedFunction (RTCGeometry hgeometry, RTCOccludedFunctionN occluded) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetOccludedFunctionN); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setOccludedFunctionN(occluded); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryIntersectFilterFunction (RTCGeometry hgeometry, RTCFilterFunctionN filter) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryIntersectFilterFunction); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setIntersectionFilterFunctionN(filter); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcSetGeometryOccludedFilterFunction (RTCGeometry hgeometry, RTCFilterFunctionN filter) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcSetGeometryOccludedFilterFunction); + RTC_VERIFY_HANDLE(hgeometry); + geometry->setOcclusionFilterFunctionN(filter); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcInterpolate(const RTCInterpolateArguments* const args) + { + Geometry* geometry = (Geometry*) args->geometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcInterpolate); +#if defined(DEBUG) + RTC_VERIFY_HANDLE(args->geometry); +#endif + geometry->interpolate(args); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcInterpolateN(const RTCInterpolateNArguments* const args) + { + Geometry* geometry = (Geometry*) args->geometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcInterpolateN); +#if defined(DEBUG) + RTC_VERIFY_HANDLE(args->geometry); +#endif + geometry->interpolateN(args); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcCommitGeometry (RTCGeometry hgeometry) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcCommitGeometry); + RTC_VERIFY_HANDLE(hgeometry); + return geometry->commit(); + RTC_CATCH_END2(geometry); + } + + RTC_API unsigned int rtcAttachGeometry (RTCScene hscene, RTCGeometry hgeometry) + { + Scene* scene = (Scene*) hscene; + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcAttachGeometry); + RTC_VERIFY_HANDLE(hscene); + RTC_VERIFY_HANDLE(hgeometry); + if (scene->device != geometry->device) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"inputs are from different devices"); + return scene->bind(RTC_INVALID_GEOMETRY_ID,geometry); + RTC_CATCH_END2(scene); + return -1; + } + + RTC_API void rtcAttachGeometryByID (RTCScene hscene, RTCGeometry hgeometry, unsigned int geomID) + { + Scene* scene = (Scene*) hscene; + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcAttachGeometryByID); + RTC_VERIFY_HANDLE(hscene); + RTC_VERIFY_HANDLE(hgeometry); + RTC_VERIFY_GEOMID(geomID); + if (scene->device != geometry->device) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"inputs are from different devices"); + scene->bind(geomID,geometry); + RTC_CATCH_END2(scene); + } + + RTC_API void rtcDetachGeometry (RTCScene hscene, unsigned int geomID) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcDetachGeometry); + RTC_VERIFY_HANDLE(hscene); + RTC_VERIFY_GEOMID(geomID); + scene->detachGeometry(geomID); + RTC_CATCH_END2(scene); + } + + RTC_API void rtcRetainGeometry (RTCGeometry hgeometry) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcRetainGeometry); + RTC_VERIFY_HANDLE(hgeometry); + geometry->refInc(); + RTC_CATCH_END2(geometry); + } + + RTC_API void rtcReleaseGeometry (RTCGeometry hgeometry) + { + Geometry* geometry = (Geometry*) hgeometry; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcReleaseGeometry); + RTC_VERIFY_HANDLE(hgeometry); + geometry->refDec(); + RTC_CATCH_END2(geometry); + } + + RTC_API RTCGeometry rtcGetGeometry (RTCScene hscene, unsigned int geomID) + { + Scene* scene = (Scene*) hscene; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcGetGeometry); +#if defined(DEBUG) + RTC_VERIFY_HANDLE(hscene); + RTC_VERIFY_GEOMID(geomID); +#endif + return (RTCGeometry) scene->get(geomID); + RTC_CATCH_END2(scene); + return nullptr; + } + +RTC_NAMESPACE_END diff --git a/thirdparty/embree/kernels/common/rtcore.h b/thirdparty/embree/kernels/common/rtcore.h new file mode 100644 index 000000000000..6583d12d57cc --- /dev/null +++ b/thirdparty/embree/kernels/common/rtcore.h @@ -0,0 +1,126 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../../include/embree3/rtcore.h" +RTC_NAMESPACE_USE + +namespace embree +{ + /*! decoding of intersection flags */ + __forceinline bool isCoherent (RTCIntersectContextFlags flags) { return (flags & RTC_INTERSECT_CONTEXT_FLAG_COHERENT) == RTC_INTERSECT_CONTEXT_FLAG_COHERENT; } + __forceinline bool isIncoherent(RTCIntersectContextFlags flags) { return (flags & RTC_INTERSECT_CONTEXT_FLAG_COHERENT) == RTC_INTERSECT_CONTEXT_FLAG_INCOHERENT; } + +#if defined(TASKING_TBB) && (TBB_INTERFACE_VERSION_MAJOR >= 8) +# define USE_TASK_ARENA 1 +#else +# define USE_TASK_ARENA 0 +#endif + +#if defined(TASKING_TBB) && (TBB_INTERFACE_VERSION >= 11009) // TBB 2019 Update 9 +# define TASKING_TBB_USE_TASK_ISOLATION 1 +#else +# define TASKING_TBB_USE_TASK_ISOLATION 0 +#endif + +/*! Macros used in the rtcore API implementation */ +#define RTC_CATCH_BEGIN try { + +#define RTC_CATCH_END(device) \ + } catch (std::bad_alloc&) { \ + Device::process_error(device,RTC_ERROR_OUT_OF_MEMORY,"out of memory"); \ + } catch (rtcore_error& e) { \ + Device::process_error(device,e.error,e.what()); \ + } catch (std::exception& e) { \ + Device::process_error(device,RTC_ERROR_UNKNOWN,e.what()); \ + } catch (...) { \ + Device::process_error(device,RTC_ERROR_UNKNOWN,"unknown exception caught"); \ + } + +#define RTC_CATCH_END2(scene) \ + } catch (std::bad_alloc&) { \ + Device* device = scene ? scene->device : nullptr; \ + Device::process_error(device,RTC_ERROR_OUT_OF_MEMORY,"out of memory"); \ + } catch (rtcore_error& e) { \ + Device* device = scene ? scene->device : nullptr; \ + Device::process_error(device,e.error,e.what()); \ + } catch (std::exception& e) { \ + Device* device = scene ? scene->device : nullptr; \ + Device::process_error(device,RTC_ERROR_UNKNOWN,e.what()); \ + } catch (...) { \ + Device* device = scene ? scene->device : nullptr; \ + Device::process_error(device,RTC_ERROR_UNKNOWN,"unknown exception caught"); \ + } + +#define RTC_CATCH_END2_FALSE(scene) \ + } catch (std::bad_alloc&) { \ + Device* device = scene ? scene->device : nullptr; \ + Device::process_error(device,RTC_ERROR_OUT_OF_MEMORY,"out of memory"); \ + return false; \ + } catch (rtcore_error& e) { \ + Device* device = scene ? scene->device : nullptr; \ + Device::process_error(device,e.error,e.what()); \ + return false; \ + } catch (std::exception& e) { \ + Device* device = scene ? scene->device : nullptr; \ + Device::process_error(device,RTC_ERROR_UNKNOWN,e.what()); \ + return false; \ + } catch (...) { \ + Device* device = scene ? scene->device : nullptr; \ + Device::process_error(device,RTC_ERROR_UNKNOWN,"unknown exception caught"); \ + return false; \ + } + +#define RTC_VERIFY_HANDLE(handle) \ + if (handle == nullptr) { \ + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"invalid argument"); \ + } + +#define RTC_VERIFY_GEOMID(id) \ + if (id == RTC_INVALID_GEOMETRY_ID) { \ + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"invalid argument"); \ + } + +#define RTC_VERIFY_UPPER(id,upper) \ + if (id > upper) { \ + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"invalid argument"); \ + } + +#define RTC_VERIFY_RANGE(id,lower,upper) \ + if (id < lower || id > upper) \ + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"argument out of bounds"); + +#if 0 // enable to debug print all API calls +#define RTC_TRACE(x) std::cout << #x << std::endl; +#else +#define RTC_TRACE(x) +#endif + + /*! used to throw embree API errors */ + struct rtcore_error : public std::exception + { + __forceinline rtcore_error(RTCError error, const std::string& str) + : error(error), str(str) {} + + ~rtcore_error() throw() {} + + const char* what () const throw () { + return str.c_str(); + } + + RTCError error; + std::string str; + }; + +#if defined(DEBUG) // only report file and line in debug mode + #define throw_RTCError(error,str) \ + throw rtcore_error(error,std::string(__FILE__) + " (" + toString(__LINE__) + "): " + std::string(str)); +#else + #define throw_RTCError(error,str) \ + throw rtcore_error(error,str); +#endif + +#define RTC_BUILD_ARGUMENTS_HAS(settings,member) \ + (settings.byteSize > (offsetof(RTCBuildArguments,member)+sizeof(settings.member))) +} diff --git a/thirdparty/embree/kernels/common/rtcore_builder.cpp b/thirdparty/embree/kernels/common/rtcore_builder.cpp new file mode 100644 index 000000000000..6bb96bba077f --- /dev/null +++ b/thirdparty/embree/kernels/common/rtcore_builder.cpp @@ -0,0 +1,442 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#define RTC_EXPORT_API + +#include "default.h" +#include "device.h" +#include "scene.h" +#include "context.h" +#include "alloc.h" + +#include "../builders/bvh_builder_sah.h" +#include "../builders/bvh_builder_morton.h" + +namespace embree +{ + namespace isa // FIXME: support more ISAs for builders + { + struct BVH : public RefCount + { + BVH (Device* device) + : device(device), allocator(device,true), morton_src(device,0), morton_tmp(device,0) + { + device->refInc(); + } + + ~BVH() { + device->refDec(); + } + + public: + Device* device; + FastAllocator allocator; + mvector morton_src; + mvector morton_tmp; + }; + + void* rtcBuildBVHMorton(const RTCBuildArguments* arguments) + { + BVH* bvh = (BVH*) arguments->bvh; + RTCBuildPrimitive* prims_i = arguments->primitives; + size_t primitiveCount = arguments->primitiveCount; + RTCCreateNodeFunction createNode = arguments->createNode; + RTCSetNodeChildrenFunction setNodeChildren = arguments->setNodeChildren; + RTCSetNodeBoundsFunction setNodeBounds = arguments->setNodeBounds; + RTCCreateLeafFunction createLeaf = arguments->createLeaf; + RTCProgressMonitorFunction buildProgress = arguments->buildProgress; + void* userPtr = arguments->userPtr; + + std::atomic progress(0); + + /* initialize temporary arrays for morton builder */ + PrimRef* prims = (PrimRef*) prims_i; + mvector& morton_src = bvh->morton_src; + mvector& morton_tmp = bvh->morton_tmp; + morton_src.resize(primitiveCount); + morton_tmp.resize(primitiveCount); + + /* compute centroid bounds */ + const BBox3fa centBounds = parallel_reduce ( size_t(0), primitiveCount, BBox3fa(empty), [&](const range& r) -> BBox3fa { + + BBox3fa bounds(empty); + for (size_t i=r.begin(); i& r) { + BVHBuilderMorton::MortonCodeGenerator generator(mapping,&morton_src[r.begin()]); + for (size_t i=r.begin(); i root = BVHBuilderMorton::build>( + + /* thread local allocator for fast allocations */ + [&] () -> FastAllocator::CachedAllocator { + return bvh->allocator.getCachedAllocator(); + }, + + /* lambda function that allocates BVH nodes */ + [&] ( const FastAllocator::CachedAllocator& alloc, size_t N ) -> void* { + return createNode((RTCThreadLocalAllocator)&alloc, (unsigned int)N,userPtr); + }, + + /* lambda function that sets bounds */ + [&] (void* node, const std::pair* children, size_t N) -> std::pair + { + BBox3fa bounds = empty; + void* childptrs[BVHBuilderMorton::MAX_BRANCHING_FACTOR]; + const RTCBounds* cbounds[BVHBuilderMorton::MAX_BRANCHING_FACTOR]; + for (size_t i=0; i& current, const FastAllocator::CachedAllocator& alloc) -> std::pair + { + RTCBuildPrimitive localBuildPrims[RTC_BUILD_MAX_PRIMITIVES_PER_LEAF]; + BBox3fa bounds = empty; + for (size_t i=0;i BBox3fa { + return prims[morton.index].bounds(); + }, + + /* progress monitor function */ + [&] (size_t dn) { + if (!buildProgress) return true; + const size_t n = progress.fetch_add(dn)+dn; + const double f = std::min(1.0,double(n)/double(primitiveCount)); + return buildProgress(userPtr,f); + }, + + morton_src.data(),morton_tmp.data(),primitiveCount, + *arguments); + + bvh->allocator.cleanup(); + return root.first; + } + + void* rtcBuildBVHBinnedSAH(const RTCBuildArguments* arguments) + { + BVH* bvh = (BVH*) arguments->bvh; + RTCBuildPrimitive* prims = arguments->primitives; + size_t primitiveCount = arguments->primitiveCount; + RTCCreateNodeFunction createNode = arguments->createNode; + RTCSetNodeChildrenFunction setNodeChildren = arguments->setNodeChildren; + RTCSetNodeBoundsFunction setNodeBounds = arguments->setNodeBounds; + RTCCreateLeafFunction createLeaf = arguments->createLeaf; + RTCProgressMonitorFunction buildProgress = arguments->buildProgress; + void* userPtr = arguments->userPtr; + + std::atomic progress(0); + + /* calculate priminfo */ + auto computeBounds = [&](const range& r) -> CentGeomBBox3fa + { + CentGeomBBox3fa bounds(empty); + for (size_t j=r.begin(); j( + + /* thread local allocator for fast allocations */ + [&] () -> FastAllocator::CachedAllocator { + return bvh->allocator.getCachedAllocator(); + }, + + /* lambda function that creates BVH nodes */ + [&](BVHBuilderBinnedSAH::BuildRecord* children, const size_t N, const FastAllocator::CachedAllocator& alloc) -> void* + { + void* node = createNode((RTCThreadLocalAllocator)&alloc, (unsigned int)N,userPtr); + const RTCBounds* cbounds[GeneralBVHBuilder::MAX_BRANCHING_FACTOR]; + for (size_t i=0; i void* { + setNodeChildren(node,children, (unsigned int)N,userPtr); + return node; + }, + + /* lambda function that creates BVH leaves */ + [&](const PrimRef* prims, const range& range, const FastAllocator::CachedAllocator& alloc) -> void* { + return createLeaf((RTCThreadLocalAllocator)&alloc,(RTCBuildPrimitive*)(prims+range.begin()),range.size(),userPtr); + }, + + /* progress monitor function */ + [&] (size_t dn) { + if (!buildProgress) return true; + const size_t n = progress.fetch_add(dn)+dn; + const double f = std::min(1.0,double(n)/double(primitiveCount)); + return buildProgress(userPtr,f); + }, + + (PrimRef*)prims,pinfo,*arguments); + + bvh->allocator.cleanup(); + return root; + } + + static __forceinline const std::pair mergePair(const std::pair& a, const std::pair& b) { + CentGeomBBox3fa centBounds = CentGeomBBox3fa::merge2(a.first,b.first); + unsigned int maxGeomID = max(a.second,b.second); + return std::pair(centBounds,maxGeomID); + } + + void* rtcBuildBVHSpatialSAH(const RTCBuildArguments* arguments) + { + BVH* bvh = (BVH*) arguments->bvh; + RTCBuildPrimitive* prims = arguments->primitives; + size_t primitiveCount = arguments->primitiveCount; + RTCCreateNodeFunction createNode = arguments->createNode; + RTCSetNodeChildrenFunction setNodeChildren = arguments->setNodeChildren; + RTCSetNodeBoundsFunction setNodeBounds = arguments->setNodeBounds; + RTCCreateLeafFunction createLeaf = arguments->createLeaf; + RTCSplitPrimitiveFunction splitPrimitive = arguments->splitPrimitive; + RTCProgressMonitorFunction buildProgress = arguments->buildProgress; + void* userPtr = arguments->userPtr; + + std::atomic progress(0); + + /* calculate priminfo */ + + auto computeBounds = [&](const range& r) -> std::pair + { + CentGeomBBox3fa bounds(empty); + unsigned maxGeomID = 0; + for (size_t j=r.begin(); j(bounds,maxGeomID); + }; + + + const std::pair pair = + parallel_reduce(size_t(0),primitiveCount,size_t(1024),size_t(1024),std::pair(CentGeomBBox3fa(empty),0), computeBounds, mergePair); + + CentGeomBBox3fa bounds = pair.first; + const unsigned int maxGeomID = pair.second; + + if (unlikely(maxGeomID >= ((unsigned int)1 << (32-RESERVED_NUM_SPATIAL_SPLITS_GEOMID_BITS)))) + { + /* fallback code for max geomID larger than threshold */ + return rtcBuildBVHBinnedSAH(arguments); + } + + const PrimInfo pinfo(0,primitiveCount,bounds); + + /* function that splits a build primitive */ + struct Splitter + { + Splitter (RTCSplitPrimitiveFunction splitPrimitive, unsigned geomID, unsigned primID, void* userPtr) + : splitPrimitive(splitPrimitive), geomID(geomID), primID(primID), userPtr(userPtr) {} + + __forceinline void operator() (PrimRef& prim, const size_t dim, const float pos, PrimRef& left_o, PrimRef& right_o) const + { + prim.geomIDref() &= BVHBuilderBinnedFastSpatialSAH::GEOMID_MASK; + splitPrimitive((RTCBuildPrimitive*)&prim,(unsigned)dim,pos,(RTCBounds*)&left_o,(RTCBounds*)&right_o,userPtr); + left_o.geomIDref() = geomID; left_o.primIDref() = primID; + right_o.geomIDref() = geomID; right_o.primIDref() = primID; + } + + __forceinline void operator() (const BBox3fa& box, const size_t dim, const float pos, BBox3fa& left_o, BBox3fa& right_o) const + { + PrimRef prim(box,geomID & BVHBuilderBinnedFastSpatialSAH::GEOMID_MASK,primID); + splitPrimitive((RTCBuildPrimitive*)&prim,(unsigned)dim,pos,(RTCBounds*)&left_o,(RTCBounds*)&right_o,userPtr); + } + + RTCSplitPrimitiveFunction splitPrimitive; + unsigned geomID; + unsigned primID; + void* userPtr; + }; + + /* build BVH */ + void* root = BVHBuilderBinnedFastSpatialSAH::build( + + /* thread local allocator for fast allocations */ + [&] () -> FastAllocator::CachedAllocator { + return bvh->allocator.getCachedAllocator(); + }, + + /* lambda function that creates BVH nodes */ + [&] (BVHBuilderBinnedFastSpatialSAH::BuildRecord* children, const size_t N, const FastAllocator::CachedAllocator& alloc) -> void* + { + void* node = createNode((RTCThreadLocalAllocator)&alloc, (unsigned int)N,userPtr); + const RTCBounds* cbounds[GeneralBVHBuilder::MAX_BRANCHING_FACTOR]; + for (size_t i=0; i void* { + setNodeChildren(node,children, (unsigned int)N,userPtr); + return node; + }, + + /* lambda function that creates BVH leaves */ + [&] (const PrimRef* prims, const range& range, const FastAllocator::CachedAllocator& alloc) -> void* { + return createLeaf((RTCThreadLocalAllocator)&alloc,(RTCBuildPrimitive*)(prims+range.begin()),range.size(),userPtr); + }, + + /* returns the splitter */ + [&] ( const PrimRef& prim ) -> Splitter { + return Splitter(splitPrimitive,prim.geomID(),prim.primID(),userPtr); + }, + + /* progress monitor function */ + [&] (size_t dn) { + if (!buildProgress) return true; + const size_t n = progress.fetch_add(dn)+dn; + const double f = std::min(1.0,double(n)/double(primitiveCount)); + return buildProgress(userPtr,f); + }, + + (PrimRef*)prims, + arguments->primitiveArrayCapacity, + pinfo,*arguments); + + bvh->allocator.cleanup(); + return root; + } + } +} + +using namespace embree; +using namespace embree::isa; + +RTC_NAMESPACE_BEGIN + + RTC_API RTCBVH rtcNewBVH(RTCDevice device) + { + RTC_CATCH_BEGIN; + RTC_TRACE(rtcNewAllocator); + RTC_VERIFY_HANDLE(device); + BVH* bvh = new BVH((Device*)device); + return (RTCBVH) bvh->refInc(); + RTC_CATCH_END((Device*)device); + return nullptr; + } + + RTC_API void* rtcBuildBVH(const RTCBuildArguments* arguments) + { + BVH* bvh = (BVH*) arguments->bvh; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcBuildBVH); + RTC_VERIFY_HANDLE(bvh); + RTC_VERIFY_HANDLE(arguments); + RTC_VERIFY_HANDLE(arguments->createNode); + RTC_VERIFY_HANDLE(arguments->setNodeChildren); + RTC_VERIFY_HANDLE(arguments->setNodeBounds); + RTC_VERIFY_HANDLE(arguments->createLeaf); + + if (arguments->primitiveArrayCapacity < arguments->primitiveCount) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"primitiveArrayCapacity must be greater or equal to primitiveCount") + + /* initialize the allocator */ + bvh->allocator.init_estimate(arguments->primitiveCount*sizeof(BBox3fa)); + bvh->allocator.reset(); + + /* switch between differnet builders based on quality level */ + if (arguments->buildQuality == RTC_BUILD_QUALITY_LOW) + return rtcBuildBVHMorton(arguments); + else if (arguments->buildQuality == RTC_BUILD_QUALITY_MEDIUM) + return rtcBuildBVHBinnedSAH(arguments); + else if (arguments->buildQuality == RTC_BUILD_QUALITY_HIGH) { + if (arguments->splitPrimitive == nullptr || arguments->primitiveArrayCapacity <= arguments->primitiveCount) + return rtcBuildBVHBinnedSAH(arguments); + else + return rtcBuildBVHSpatialSAH(arguments); + } + else + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"invalid build quality"); + + /* if we are in dynamic mode, then do not clear temporary data */ + if (!(arguments->buildFlags & RTC_BUILD_FLAG_DYNAMIC)) + { + bvh->morton_src.clear(); + bvh->morton_tmp.clear(); + } + + RTC_CATCH_END(bvh->device); + return nullptr; + } + + RTC_API void* rtcThreadLocalAlloc(RTCThreadLocalAllocator localAllocator, size_t bytes, size_t align) + { + FastAllocator::CachedAllocator* alloc = (FastAllocator::CachedAllocator*) localAllocator; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcThreadLocalAlloc); + return alloc->malloc0(bytes,align); + RTC_CATCH_END(alloc->alloc->getDevice()); + return nullptr; + } + + RTC_API void rtcMakeStaticBVH(RTCBVH hbvh) + { + BVH* bvh = (BVH*) hbvh; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcStaticBVH); + RTC_VERIFY_HANDLE(hbvh); + bvh->morton_src.clear(); + bvh->morton_tmp.clear(); + RTC_CATCH_END(bvh->device); + } + + RTC_API void rtcRetainBVH(RTCBVH hbvh) + { + BVH* bvh = (BVH*) hbvh; + Device* device = bvh ? bvh->device : nullptr; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcRetainBVH); + RTC_VERIFY_HANDLE(hbvh); + bvh->refInc(); + RTC_CATCH_END(device); + } + + RTC_API void rtcReleaseBVH(RTCBVH hbvh) + { + BVH* bvh = (BVH*) hbvh; + Device* device = bvh ? bvh->device : nullptr; + RTC_CATCH_BEGIN; + RTC_TRACE(rtcReleaseBVH); + RTC_VERIFY_HANDLE(hbvh); + bvh->refDec(); + RTC_CATCH_END(device); + } + +RTC_NAMESPACE_END diff --git a/thirdparty/embree/kernels/common/scene.cpp b/thirdparty/embree/kernels/common/scene.cpp new file mode 100644 index 000000000000..b0d1a723d12e --- /dev/null +++ b/thirdparty/embree/kernels/common/scene.cpp @@ -0,0 +1,953 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "scene.h" + +#include "../bvh/bvh4_factory.h" +#include "../bvh/bvh8_factory.h" +#include "../../common/algorithms/parallel_reduce.h" + +namespace embree +{ + /* error raising rtcIntersect and rtcOccluded functions */ + void missing_rtcCommit() { throw_RTCError(RTC_ERROR_INVALID_OPERATION,"scene not committed"); } + void invalid_rtcIntersect1() { throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcIntersect and rtcOccluded not enabled"); } + void invalid_rtcIntersect4() { throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcIntersect4 and rtcOccluded4 not enabled"); } + void invalid_rtcIntersect8() { throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcIntersect8 and rtcOccluded8 not enabled"); } + void invalid_rtcIntersect16() { throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcIntersect16 and rtcOccluded16 not enabled"); } + void invalid_rtcIntersectN() { throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcIntersectN and rtcOccludedN not enabled"); } + + Scene::Scene (Device* device) + : device(device), + flags_modified(true), enabled_geometry_types(0), + scene_flags(RTC_SCENE_FLAG_NONE), + quality_flags(RTC_BUILD_QUALITY_MEDIUM), + is_build(false), modified(true), + progressInterface(this), progress_monitor_function(nullptr), progress_monitor_ptr(nullptr), progress_monitor_counter(0) + { + device->refInc(); + + intersectors = Accel::Intersectors(missing_rtcCommit); + + /* one can overwrite flags through device for debugging */ + if (device->quality_flags != -1) + quality_flags = (RTCBuildQuality) device->quality_flags; + if (device->scene_flags != -1) + scene_flags = (RTCSceneFlags) device->scene_flags; + } + + Scene::~Scene() noexcept + { + device->refDec(); + } + + void Scene::printStatistics() + { + /* calculate maximum number of time segments */ + unsigned max_time_steps = 0; + for (size_t i=0; inumTimeSteps); + } + + /* initialize vectors*/ + std::vector statistics[Geometry::GTY_END]; + for (size_t i=0; igetType(); + assert(tynumTimeSegments(); + assert((unsigned int)timesegments < max_time_steps); + statistics[ty][timesegments] += get(i)->size(); + } + + /* print statistics */ + std::cout << std::setw(23) << "segments" << ": "; + for (size_t t=0; ttri_accel == "default") + { + if (quality_flags != RTC_BUILD_QUALITY_LOW) + { + int mode = 2*(int)isCompactAccel() + 1*(int)isRobustAccel(); + switch (mode) { + case /*0b00*/ 0: +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX()) + { + if (quality_flags == RTC_BUILD_QUALITY_HIGH) + accels_add(device->bvh8_factory->BVH8Triangle4(this,BVHFactory::BuildVariant::HIGH_QUALITY,BVHFactory::IntersectVariant::FAST)); + else + accels_add(device->bvh8_factory->BVH8Triangle4(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::FAST)); + } + else +#endif + { + if (quality_flags == RTC_BUILD_QUALITY_HIGH) + accels_add(device->bvh4_factory->BVH4Triangle4(this,BVHFactory::BuildVariant::HIGH_QUALITY,BVHFactory::IntersectVariant::FAST)); + else + accels_add(device->bvh4_factory->BVH4Triangle4(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::FAST)); + } + break; + + case /*0b01*/ 1: +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX()) + accels_add(device->bvh8_factory->BVH8Triangle4v(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::ROBUST)); + else +#endif + accels_add(device->bvh4_factory->BVH4Triangle4v(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::ROBUST)); + + break; + case /*0b10*/ 2: accels_add(device->bvh4_factory->BVH4Triangle4i(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::FAST )); break; + case /*0b11*/ 3: accels_add(device->bvh4_factory->BVH4Triangle4i(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::ROBUST)); break; + } + } + else /* dynamic */ + { +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX()) + { + int mode = 2*(int)isCompactAccel() + 1*(int)isRobustAccel(); + switch (mode) { + case /*0b00*/ 0: accels_add(device->bvh8_factory->BVH8Triangle4 (this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::FAST )); break; + case /*0b01*/ 1: accels_add(device->bvh8_factory->BVH8Triangle4v(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::ROBUST)); break; + case /*0b10*/ 2: accels_add(device->bvh4_factory->BVH4Triangle4i(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::FAST )); break; + case /*0b11*/ 3: accels_add(device->bvh4_factory->BVH4Triangle4i(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::ROBUST)); break; + } + } + else +#endif + { + int mode = 2*(int)isCompactAccel() + 1*(int)isRobustAccel(); + switch (mode) { + case /*0b00*/ 0: accels_add(device->bvh4_factory->BVH4Triangle4 (this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::FAST )); break; + case /*0b01*/ 1: accels_add(device->bvh4_factory->BVH4Triangle4v(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::ROBUST)); break; + case /*0b10*/ 2: accels_add(device->bvh4_factory->BVH4Triangle4i(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::FAST )); break; + case /*0b11*/ 3: accels_add(device->bvh4_factory->BVH4Triangle4i(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::ROBUST)); break; + } + } + } + } + else if (device->tri_accel == "bvh4.triangle4") accels_add(device->bvh4_factory->BVH4Triangle4 (this)); + else if (device->tri_accel == "bvh4.triangle4v") accels_add(device->bvh4_factory->BVH4Triangle4v(this)); + else if (device->tri_accel == "bvh4.triangle4i") accels_add(device->bvh4_factory->BVH4Triangle4i(this)); + else if (device->tri_accel == "qbvh4.triangle4i") accels_add(device->bvh4_factory->BVH4QuantizedTriangle4i(this)); + +#if defined (EMBREE_TARGET_SIMD8) + else if (device->tri_accel == "bvh8.triangle4") accels_add(device->bvh8_factory->BVH8Triangle4 (this)); + else if (device->tri_accel == "bvh8.triangle4v") accels_add(device->bvh8_factory->BVH8Triangle4v(this)); + else if (device->tri_accel == "bvh8.triangle4i") accels_add(device->bvh8_factory->BVH8Triangle4i(this)); + else if (device->tri_accel == "qbvh8.triangle4i") accels_add(device->bvh8_factory->BVH8QuantizedTriangle4i(this)); + else if (device->tri_accel == "qbvh8.triangle4") accels_add(device->bvh8_factory->BVH8QuantizedTriangle4(this)); +#endif + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown triangle acceleration structure "+device->tri_accel); +#endif + } + + void Scene::createTriangleMBAccel() + { +#if defined(EMBREE_GEOMETRY_TRIANGLE) + if (device->tri_accel_mb == "default") + { + int mode = 2*(int)isCompactAccel() + 1*(int)isRobustAccel(); + +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX2()) // BVH8 reduces performance on AVX only-machines + { + switch (mode) { + case /*0b00*/ 0: accels_add(device->bvh8_factory->BVH8Triangle4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::FAST )); break; + case /*0b01*/ 1: accels_add(device->bvh8_factory->BVH8Triangle4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::ROBUST)); break; + case /*0b10*/ 2: accels_add(device->bvh4_factory->BVH4Triangle4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::FAST )); break; + case /*0b11*/ 3: accels_add(device->bvh4_factory->BVH4Triangle4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::ROBUST)); break; + } + } + else +#endif + { + switch (mode) { + case /*0b00*/ 0: accels_add(device->bvh4_factory->BVH4Triangle4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::FAST )); break; + case /*0b01*/ 1: accels_add(device->bvh4_factory->BVH4Triangle4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::ROBUST)); break; + case /*0b10*/ 2: accels_add(device->bvh4_factory->BVH4Triangle4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::FAST )); break; + case /*0b11*/ 3: accels_add(device->bvh4_factory->BVH4Triangle4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::ROBUST)); break; + } + } + } + else if (device->tri_accel_mb == "bvh4.triangle4imb") accels_add(device->bvh4_factory->BVH4Triangle4iMB(this)); + else if (device->tri_accel_mb == "bvh4.triangle4vmb") accels_add(device->bvh4_factory->BVH4Triangle4vMB(this)); +#if defined (EMBREE_TARGET_SIMD8) + else if (device->tri_accel_mb == "bvh8.triangle4imb") accels_add(device->bvh8_factory->BVH8Triangle4iMB(this)); + else if (device->tri_accel_mb == "bvh8.triangle4vmb") accels_add(device->bvh8_factory->BVH8Triangle4vMB(this)); +#endif + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown motion blur triangle acceleration structure "+device->tri_accel_mb); +#endif + } + + void Scene::createQuadAccel() + { +#if defined(EMBREE_GEOMETRY_QUAD) + if (device->quad_accel == "default") + { + if (quality_flags != RTC_BUILD_QUALITY_LOW) + { + /* static */ + int mode = 2*(int)isCompactAccel() + 1*(int)isRobustAccel(); + switch (mode) { + case /*0b00*/ 0: +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX()) + { + if (quality_flags == RTC_BUILD_QUALITY_HIGH) + accels_add(device->bvh8_factory->BVH8Quad4v(this,BVHFactory::BuildVariant::HIGH_QUALITY,BVHFactory::IntersectVariant::FAST)); + else + accels_add(device->bvh8_factory->BVH8Quad4v(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::FAST)); + } + else +#endif + { + if (quality_flags == RTC_BUILD_QUALITY_HIGH) + accels_add(device->bvh4_factory->BVH4Quad4v(this,BVHFactory::BuildVariant::HIGH_QUALITY,BVHFactory::IntersectVariant::FAST)); + else + accels_add(device->bvh4_factory->BVH4Quad4v(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::FAST)); + } + break; + + case /*0b01*/ 1: +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX()) + accels_add(device->bvh8_factory->BVH8Quad4v(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::ROBUST)); + else +#endif + accels_add(device->bvh4_factory->BVH4Quad4v(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::ROBUST)); + break; + + case /*0b10*/ 2: accels_add(device->bvh4_factory->BVH4Quad4i(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::FAST)); break; + case /*0b11*/ 3: accels_add(device->bvh4_factory->BVH4Quad4i(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::ROBUST)); break; + } + } + else /* dynamic */ + { +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX()) + { + int mode = 2*(int)isCompactAccel() + 1*(int)isRobustAccel(); + switch (mode) { + case /*0b00*/ 0: accels_add(device->bvh8_factory->BVH8Quad4v(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::FAST)); break; + case /*0b01*/ 1: accels_add(device->bvh8_factory->BVH8Quad4v(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::ROBUST)); break; + case /*0b10*/ 2: accels_add(device->bvh4_factory->BVH4Quad4v(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::FAST)); break; + case /*0b11*/ 3: accels_add(device->bvh4_factory->BVH4Quad4v(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::ROBUST)); break; + } + } + else +#endif + { + int mode = 2*(int)isCompactAccel() + 1*(int)isRobustAccel(); + switch (mode) { + case /*0b00*/ 0: accels_add(device->bvh4_factory->BVH4Quad4v(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::FAST)); break; + case /*0b01*/ 1: accels_add(device->bvh4_factory->BVH4Quad4v(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::ROBUST)); break; + case /*0b10*/ 2: accels_add(device->bvh4_factory->BVH4Quad4v(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::FAST)); break; + case /*0b11*/ 3: accels_add(device->bvh4_factory->BVH4Quad4v(this,BVHFactory::BuildVariant::DYNAMIC,BVHFactory::IntersectVariant::ROBUST)); break; + } + } + } + } + else if (device->quad_accel == "bvh4.quad4v") accels_add(device->bvh4_factory->BVH4Quad4v(this)); + else if (device->quad_accel == "bvh4.quad4i") accels_add(device->bvh4_factory->BVH4Quad4i(this)); + else if (device->quad_accel == "qbvh4.quad4i") accels_add(device->bvh4_factory->BVH4QuantizedQuad4i(this)); + +#if defined (EMBREE_TARGET_SIMD8) + else if (device->quad_accel == "bvh8.quad4v") accels_add(device->bvh8_factory->BVH8Quad4v(this)); + else if (device->quad_accel == "bvh8.quad4i") accels_add(device->bvh8_factory->BVH8Quad4i(this)); + else if (device->quad_accel == "qbvh8.quad4i") accels_add(device->bvh8_factory->BVH8QuantizedQuad4i(this)); +#endif + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown quad acceleration structure "+device->quad_accel); +#endif + } + + void Scene::createQuadMBAccel() + { +#if defined(EMBREE_GEOMETRY_QUAD) + if (device->quad_accel_mb == "default") + { + int mode = 2*(int)isCompactAccel() + 1*(int)isRobustAccel(); + switch (mode) { + case /*0b00*/ 0: +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX()) + accels_add(device->bvh8_factory->BVH8Quad4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::FAST)); + else +#endif + accels_add(device->bvh4_factory->BVH4Quad4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::FAST)); + break; + + case /*0b01*/ 1: +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX()) + accels_add(device->bvh8_factory->BVH8Quad4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::ROBUST)); + else +#endif + accels_add(device->bvh4_factory->BVH4Quad4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::ROBUST)); + break; + + case /*0b10*/ 2: accels_add(device->bvh4_factory->BVH4Quad4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::FAST )); break; + case /*0b11*/ 3: accels_add(device->bvh4_factory->BVH4Quad4iMB(this,BVHFactory::BuildVariant::STATIC,BVHFactory::IntersectVariant::ROBUST)); break; + } + } + else if (device->quad_accel_mb == "bvh4.quad4imb") accels_add(device->bvh4_factory->BVH4Quad4iMB(this)); +#if defined (EMBREE_TARGET_SIMD8) + else if (device->quad_accel_mb == "bvh8.quad4imb") accels_add(device->bvh8_factory->BVH8Quad4iMB(this)); +#endif + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown quad motion blur acceleration structure "+device->quad_accel_mb); +#endif + } + + void Scene::createHairAccel() + { +#if defined(EMBREE_GEOMETRY_CURVE) || defined(EMBREE_GEOMETRY_POINT) + if (device->hair_accel == "default") + { + int mode = 2*(int)isCompactAccel() + 1*(int)isRobustAccel(); +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX2()) // only enable on HSW machines, for SNB this codepath is slower + { + switch (mode) { + case /*0b00*/ 0: accels_add(device->bvh8_factory->BVH8OBBVirtualCurve8v(this,BVHFactory::IntersectVariant::FAST)); break; + case /*0b01*/ 1: accels_add(device->bvh8_factory->BVH8OBBVirtualCurve8v(this,BVHFactory::IntersectVariant::ROBUST)); break; + case /*0b10*/ 2: accels_add(device->bvh4_factory->BVH4OBBVirtualCurve8i(this,BVHFactory::IntersectVariant::FAST)); break; + case /*0b11*/ 3: accels_add(device->bvh4_factory->BVH4OBBVirtualCurve8i(this,BVHFactory::IntersectVariant::ROBUST)); break; + } + } + else +#endif + { + switch (mode) { + case /*0b00*/ 0: accels_add(device->bvh4_factory->BVH4OBBVirtualCurve4v(this,BVHFactory::IntersectVariant::FAST)); break; + case /*0b01*/ 1: accels_add(device->bvh4_factory->BVH4OBBVirtualCurve4v(this,BVHFactory::IntersectVariant::ROBUST)); break; + case /*0b10*/ 2: accels_add(device->bvh4_factory->BVH4OBBVirtualCurve4i(this,BVHFactory::IntersectVariant::FAST)); break; + case /*0b11*/ 3: accels_add(device->bvh4_factory->BVH4OBBVirtualCurve4i(this,BVHFactory::IntersectVariant::ROBUST)); break; + } + } + } + else if (device->hair_accel == "bvh4obb.virtualcurve4v" ) accels_add(device->bvh4_factory->BVH4OBBVirtualCurve4v(this,BVHFactory::IntersectVariant::FAST)); + else if (device->hair_accel == "bvh4obb.virtualcurve4i" ) accels_add(device->bvh4_factory->BVH4OBBVirtualCurve4i(this,BVHFactory::IntersectVariant::FAST)); +#if defined (EMBREE_TARGET_SIMD8) + else if (device->hair_accel == "bvh8obb.virtualcurve8v" ) accels_add(device->bvh8_factory->BVH8OBBVirtualCurve8v(this,BVHFactory::IntersectVariant::FAST)); + else if (device->hair_accel == "bvh4obb.virtualcurve8i" ) accels_add(device->bvh4_factory->BVH4OBBVirtualCurve8i(this,BVHFactory::IntersectVariant::FAST)); +#endif + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown hair acceleration structure "+device->hair_accel); +#endif + } + + void Scene::createHairMBAccel() + { +#if defined(EMBREE_GEOMETRY_CURVE) || defined(EMBREE_GEOMETRY_POINT) + if (device->hair_accel_mb == "default") + { +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX2()) // only enable on HSW machines, on SNB this codepath is slower + { + if (isRobustAccel()) accels_add(device->bvh8_factory->BVH8OBBVirtualCurve8iMB(this,BVHFactory::IntersectVariant::ROBUST)); + else accels_add(device->bvh8_factory->BVH8OBBVirtualCurve8iMB(this,BVHFactory::IntersectVariant::FAST)); + } + else +#endif + { + if (isRobustAccel()) accels_add(device->bvh4_factory->BVH4OBBVirtualCurve4iMB(this,BVHFactory::IntersectVariant::ROBUST)); + else accels_add(device->bvh4_factory->BVH4OBBVirtualCurve4iMB(this,BVHFactory::IntersectVariant::FAST)); + } + } + else if (device->hair_accel_mb == "bvh4.virtualcurve4imb") accels_add(device->bvh4_factory->BVH4OBBVirtualCurve4iMB(this,BVHFactory::IntersectVariant::FAST)); + +#if defined (EMBREE_TARGET_SIMD8) + else if (device->hair_accel_mb == "bvh4.virtualcurve8imb") accels_add(device->bvh4_factory->BVH4OBBVirtualCurve8iMB(this,BVHFactory::IntersectVariant::FAST)); + else if (device->hair_accel_mb == "bvh8.virtualcurve8imb") accels_add(device->bvh8_factory->BVH8OBBVirtualCurve8iMB(this,BVHFactory::IntersectVariant::FAST)); +#endif + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown motion blur hair acceleration structure "+device->hair_accel_mb); +#endif + } + + void Scene::createSubdivAccel() + { +#if defined(EMBREE_GEOMETRY_SUBDIVISION) + if (device->subdiv_accel == "default") { + accels_add(device->bvh4_factory->BVH4SubdivPatch1(this)); + } + else if (device->subdiv_accel == "bvh4.grid.eager" ) accels_add(device->bvh4_factory->BVH4SubdivPatch1(this)); + else if (device->subdiv_accel == "bvh4.subdivpatch1eager" ) accels_add(device->bvh4_factory->BVH4SubdivPatch1(this)); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown subdiv accel "+device->subdiv_accel); +#endif + } + + void Scene::createSubdivMBAccel() + { +#if defined(EMBREE_GEOMETRY_SUBDIVISION) + if (device->subdiv_accel_mb == "default") { + accels_add(device->bvh4_factory->BVH4SubdivPatch1MB(this)); + } + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown subdiv mblur accel "+device->subdiv_accel_mb); +#endif + } + + void Scene::createUserGeometryAccel() + { +#if defined(EMBREE_GEOMETRY_USER) + if (device->object_accel == "default") + { +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX() && !isCompactAccel()) + { + if (quality_flags != RTC_BUILD_QUALITY_LOW) { + accels_add(device->bvh8_factory->BVH8UserGeometry(this,BVHFactory::BuildVariant::STATIC)); + } else { + accels_add(device->bvh8_factory->BVH8UserGeometry(this,BVHFactory::BuildVariant::DYNAMIC)); + } + } + else +#endif + { + if (quality_flags != RTC_BUILD_QUALITY_LOW) { + accels_add(device->bvh4_factory->BVH4UserGeometry(this,BVHFactory::BuildVariant::STATIC)); + } else { + accels_add(device->bvh4_factory->BVH4UserGeometry(this,BVHFactory::BuildVariant::DYNAMIC)); + } + } + } + else if (device->object_accel == "bvh4.object") accels_add(device->bvh4_factory->BVH4UserGeometry(this)); +#if defined (EMBREE_TARGET_SIMD8) + else if (device->object_accel == "bvh8.object") accels_add(device->bvh8_factory->BVH8UserGeometry(this)); +#endif + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown user geometry accel "+device->object_accel); +#endif + } + + void Scene::createUserGeometryMBAccel() + { +#if defined(EMBREE_GEOMETRY_USER) + if (device->object_accel_mb == "default" ) { +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX() && !isCompactAccel()) + accels_add(device->bvh8_factory->BVH8UserGeometryMB(this)); + else +#endif + accels_add(device->bvh4_factory->BVH4UserGeometryMB(this)); + } + else if (device->object_accel_mb == "bvh4.object") accels_add(device->bvh4_factory->BVH4UserGeometryMB(this)); +#if defined (EMBREE_TARGET_SIMD8) + else if (device->object_accel_mb == "bvh8.object") accels_add(device->bvh8_factory->BVH8UserGeometryMB(this)); +#endif + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown user geometry mblur accel "+device->object_accel_mb); +#endif + } + + void Scene::createInstanceAccel() + { +#if defined(EMBREE_GEOMETRY_INSTANCE) + // if (device->object_accel == "default") + { +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX() && !isCompactAccel()) { + if (quality_flags != RTC_BUILD_QUALITY_LOW) { + accels_add(device->bvh8_factory->BVH8Instance(this, false, BVHFactory::BuildVariant::STATIC)); + } else { + accels_add(device->bvh8_factory->BVH8Instance(this, false, BVHFactory::BuildVariant::DYNAMIC)); + } + } + else +#endif + { + if (quality_flags != RTC_BUILD_QUALITY_LOW) { + accels_add(device->bvh4_factory->BVH4Instance(this, false, BVHFactory::BuildVariant::STATIC)); + } else { + accels_add(device->bvh4_factory->BVH4Instance(this, false, BVHFactory::BuildVariant::DYNAMIC)); + } + } + } + // else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown instance accel "+device->instance_accel); +#endif + } + + void Scene::createInstanceMBAccel() + { +#if defined(EMBREE_GEOMETRY_INSTANCE) + //if (device->instance_accel_mb == "default") + { +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX() && !isCompactAccel()) + accels_add(device->bvh8_factory->BVH8InstanceMB(this, false)); + else +#endif + accels_add(device->bvh4_factory->BVH4InstanceMB(this, false)); + } + //else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown instance mblur accel "+device->instance_accel_mb); +#endif + } + + void Scene::createInstanceExpensiveAccel() + { +#if defined(EMBREE_GEOMETRY_INSTANCE) + // if (device->object_accel == "default") + { +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX() && !isCompactAccel()) { + if (quality_flags != RTC_BUILD_QUALITY_LOW) { + accels_add(device->bvh8_factory->BVH8Instance(this, true, BVHFactory::BuildVariant::STATIC)); + } else { + accels_add(device->bvh8_factory->BVH8Instance(this, true, BVHFactory::BuildVariant::DYNAMIC)); + } + } + else +#endif + { + if (quality_flags != RTC_BUILD_QUALITY_LOW) { + accels_add(device->bvh4_factory->BVH4Instance(this, true, BVHFactory::BuildVariant::STATIC)); + } else { + accels_add(device->bvh4_factory->BVH4Instance(this, true, BVHFactory::BuildVariant::DYNAMIC)); + } + } + } + // else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown instance accel "+device->instance_accel); +#endif + } + + void Scene::createInstanceExpensiveMBAccel() + { +#if defined(EMBREE_GEOMETRY_INSTANCE) + //if (device->instance_accel_mb == "default") + { +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX() && !isCompactAccel()) + accels_add(device->bvh8_factory->BVH8InstanceMB(this, true)); + else +#endif + accels_add(device->bvh4_factory->BVH4InstanceMB(this, true)); + } + //else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown instance mblur accel "+device->instance_accel_mb); +#endif + } + + void Scene::createGridAccel() + { + BVHFactory::IntersectVariant ivariant = isRobustAccel() ? BVHFactory::IntersectVariant::ROBUST : BVHFactory::IntersectVariant::FAST; +#if defined(EMBREE_GEOMETRY_GRID) + if (device->grid_accel == "default") + { +#if defined (EMBREE_TARGET_SIMD8) + if (device->canUseAVX() && !isCompactAccel()) + { + accels_add(device->bvh8_factory->BVH8Grid(this,BVHFactory::BuildVariant::STATIC,ivariant)); + } + else +#endif + { + accels_add(device->bvh4_factory->BVH4Grid(this,BVHFactory::BuildVariant::STATIC,ivariant)); + } + } + else if (device->grid_accel == "bvh4.grid") accels_add(device->bvh4_factory->BVH4Grid(this,BVHFactory::BuildVariant::STATIC,ivariant)); +#if defined (EMBREE_TARGET_SIMD8) + else if (device->grid_accel == "bvh8.grid") accels_add(device->bvh8_factory->BVH8Grid(this,BVHFactory::BuildVariant::STATIC,ivariant)); +#endif + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown grid accel "+device->grid_accel); +#endif + + } + + void Scene::createGridMBAccel() + { +#if defined(EMBREE_GEOMETRY_GRID) + if (device->grid_accel_mb == "default") + { + accels_add(device->bvh4_factory->BVH4GridMB(this,BVHFactory::BuildVariant::STATIC)); + } + else if (device->grid_accel_mb == "bvh4mb.grid") accels_add(device->bvh4_factory->BVH4GridMB(this)); + else throw_RTCError(RTC_ERROR_INVALID_ARGUMENT,"unknown grid mb accel "+device->grid_accel); +#endif + + } + + void Scene::clear() { + } + + unsigned Scene::bind(unsigned geomID, Ref geometry) + { + Lock lock(geometriesMutex); + if (geomID == RTC_INVALID_GEOMETRY_ID) { + geomID = id_pool.allocate(); + if (geomID == RTC_INVALID_GEOMETRY_ID) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"too many geometries inside scene"); + } + else + { + if (!id_pool.add(geomID)) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"invalid geometry ID provided"); + } + if (geomID >= geometries.size()) { + geometries.resize(geomID+1); + vertices.resize(geomID+1); + geometryModCounters_.resize(geomID+1); + } + geometries[geomID] = geometry; + geometryModCounters_[geomID] = 0; + if (geometry->isEnabled()) { + setModified (); + } + return geomID; + } + + void Scene::detachGeometry(size_t geomID) + { + Lock lock(geometriesMutex); + + if (geomID >= geometries.size()) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"invalid geometry ID"); + + Ref& geometry = geometries[geomID]; + if (geometry == null) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"invalid geometry"); + + if (geometry->isEnabled()) { + setModified (); + } + accels_deleteGeometry(unsigned(geomID)); + id_pool.deallocate((unsigned)geomID); + geometries[geomID] = null; + vertices[geomID] = nullptr; + geometryModCounters_[geomID] = 0; + } + + void Scene::updateInterface() + { + is_build = true; + } + + void Scene::commit_task () + { + checkIfModifiedAndSet (); + if (!isModified()) { + return; + } + + /* print scene statistics */ + if (device->verbosity(2)) + printStatistics(); + + progress_monitor_counter = 0; + + /* gather scene stats and call preCommit function of each geometry */ + this->world = parallel_reduce (size_t(0), geometries.size(), GeometryCounts (), + [this](const range& r)->GeometryCounts + { + GeometryCounts c; + for (auto i=r.begin(); iisEnabled()) + { + geometries[i]->preCommit(); + geometries[i]->addElementsToCount (c); + c.numFilterFunctions += (int) geometries[i]->hasFilterFunctions(); + } + } + return c; + }, + std::plus() + ); + + /* select acceleration structures to build */ + unsigned int new_enabled_geometry_types = world.enabledGeometryTypesMask(); + if (flags_modified || new_enabled_geometry_types != enabled_geometry_types) + { + accels_init(); + + /* we need to make all geometries modified, otherwise two level builder will + not rebuild currently not modified geometries */ + parallel_for(geometryModCounters_.size(), [&] ( const size_t i ) { + geometryModCounters_[i] = 0; + }); + + if (getNumPrimitives(TriangleMesh::geom_type,false)) createTriangleAccel(); + if (getNumPrimitives(TriangleMesh::geom_type,true)) createTriangleMBAccel(); + if (getNumPrimitives(QuadMesh::geom_type,false)) createQuadAccel(); + if (getNumPrimitives(QuadMesh::geom_type,true)) createQuadMBAccel(); + if (getNumPrimitives(GridMesh::geom_type,false)) createGridAccel(); + if (getNumPrimitives(GridMesh::geom_type,true)) createGridMBAccel(); + if (getNumPrimitives(SubdivMesh::geom_type,false)) createSubdivAccel(); + if (getNumPrimitives(SubdivMesh::geom_type,true)) createSubdivMBAccel(); + if (getNumPrimitives(Geometry::MTY_CURVES,false)) createHairAccel(); + if (getNumPrimitives(Geometry::MTY_CURVES,true)) createHairMBAccel(); + if (getNumPrimitives(UserGeometry::geom_type,false)) createUserGeometryAccel(); + if (getNumPrimitives(UserGeometry::geom_type,true)) createUserGeometryMBAccel(); + if (getNumPrimitives(Geometry::MTY_INSTANCE_CHEAP,false)) createInstanceAccel(); + if (getNumPrimitives(Geometry::MTY_INSTANCE_CHEAP,true)) createInstanceMBAccel(); + if (getNumPrimitives(Geometry::MTY_INSTANCE_EXPENSIVE,false)) createInstanceExpensiveAccel(); + if (getNumPrimitives(Geometry::MTY_INSTANCE_EXPENSIVE,true)) createInstanceExpensiveMBAccel(); + + flags_modified = false; + enabled_geometry_types = new_enabled_geometry_types; + } + + /* select fast code path if no filter function is present */ + accels_select(hasFilterFunction()); + + /* build all hierarchies of this scene */ + accels_build(); + + /* make static geometry immutable */ + if (!isDynamicAccel()) { + accels_immutable(); + flags_modified = true; // in non-dynamic mode we have to re-create accels + } + + /* call postCommit function of each geometry */ + parallel_for(geometries.size(), [&] ( const size_t i ) { + if (geometries[i] && geometries[i]->isEnabled()) { + geometries[i]->postCommit(); + vertices[i] = geometries[i]->getCompactVertexArray(); + geometryModCounters_[i] = geometries[i]->getModCounter(); + } + }); + + updateInterface(); + + if (device->verbosity(2)) { + std::cout << "created scene intersector" << std::endl; + accels_print(2); + std::cout << "selected scene intersector" << std::endl; + intersectors.print(2); + } + + setModified(false); + } + + void Scene::setBuildQuality(RTCBuildQuality quality_flags_i) + { + if (quality_flags == quality_flags_i) return; + quality_flags = quality_flags_i; + flags_modified = true; + } + + RTCBuildQuality Scene::getBuildQuality() const { + return quality_flags; + } + + void Scene::setSceneFlags(RTCSceneFlags scene_flags_i) + { + if (scene_flags == scene_flags_i) return; + scene_flags = scene_flags_i; + flags_modified = true; + } + + RTCSceneFlags Scene::getSceneFlags() const { + return scene_flags; + } + +#if defined(TASKING_INTERNAL) + + void Scene::commit (bool join) + { + Lock buildLock(buildMutex,false); + + /* allocates own taskscheduler for each build */ + Ref scheduler = nullptr; + { + Lock lock(schedulerMutex); + scheduler = this->scheduler; + if (scheduler == null) { + buildLock.lock(); + this->scheduler = scheduler = new TaskScheduler; + } + } + + /* worker threads join build */ + if (!buildLock.isLocked()) + { + if (!join) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"use rtcJoinCommitScene to join a build operation"); + + scheduler->join(); + return; + } + + /* initiate build */ + try { + scheduler->spawn_root([&]() { commit_task(); Lock lock(schedulerMutex); this->scheduler = nullptr; }, 1, !join); + } + catch (...) { + accels_clear(); + updateInterface(); + Lock lock(schedulerMutex); + this->scheduler = nullptr; + throw; + } + } + +#endif + +#if defined(TASKING_TBB) + + void Scene::commit (bool join) + { +#if defined(TASKING_TBB) && (TBB_INTERFACE_VERSION_MAJOR < 8) + if (join) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcJoinCommitScene not supported with this TBB version"); +#endif + + /* try to obtain build lock */ + Lock lock(buildMutex,buildMutex.try_lock()); + + /* join hierarchy build */ + if (!lock.isLocked()) + { +#if !TASKING_TBB_USE_TASK_ISOLATION + if (!join) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"invoking rtcCommitScene from multiple threads is not supported with this TBB version"); +#endif + + do { + +#if USE_TASK_ARENA + if (join) { + device->arena->execute([&]{ group.wait(); }); + } + else +#endif + { + group.wait(); + } + + pause_cpu(); + yield(); + } while (!buildMutex.try_lock()); + + buildMutex.unlock(); + return; + } + + /* for best performance set FTZ and DAZ flags in the MXCSR control and status register */ + const unsigned int mxcsr = _mm_getcsr(); + _mm_setcsr(mxcsr | /* FTZ */ (1<<15) | /* DAZ */ (1<<6)); + + try { +#if TBB_INTERFACE_VERSION_MAJOR < 8 + tbb::task_group_context ctx( tbb::task_group_context::isolated, tbb::task_group_context::default_traits); +#else + tbb::task_group_context ctx( tbb::task_group_context::isolated, tbb::task_group_context::default_traits | tbb::task_group_context::fp_settings ); +#endif + //ctx.set_priority(tbb::priority_high); + +#if USE_TASK_ARENA + if (join) + { + device->arena->execute([&]{ + group.run([&]{ + tbb::parallel_for (size_t(0), size_t(1), size_t(1), [&] (size_t) { commit_task(); }, ctx); + }); + group.wait(); + }); + } + else +#endif + { + group.run([&]{ + tbb::parallel_for (size_t(0), size_t(1), size_t(1), [&] (size_t) { commit_task(); }, ctx); + }); + group.wait(); + } + + /* reset MXCSR register again */ + _mm_setcsr(mxcsr); + } + catch (...) + { + /* reset MXCSR register again */ + _mm_setcsr(mxcsr); + + accels_clear(); + updateInterface(); + throw; + } + } +#endif + +#if defined(TASKING_PPL) + + void Scene::commit (bool join) + { +#if defined(TASKING_PPL) + if (join) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"rtcJoinCommitScene not supported with PPL"); +#endif + + /* try to obtain build lock */ + Lock lock(buildMutex); + + checkIfModifiedAndSet (); + if (!isModified()) { + return; + } + + /* for best performance set FTZ and DAZ flags in the MXCSR control and status register */ + const unsigned int mxcsr = _mm_getcsr(); + _mm_setcsr(mxcsr | /* FTZ */ (1<<15) | /* DAZ */ (1<<6)); + + try { + + group.run([&]{ + concurrency::parallel_for(size_t(0), size_t(1), size_t(1), [&](size_t) { commit_task(); }); + }); + group.wait(); + + /* reset MXCSR register again */ + _mm_setcsr(mxcsr); + } + catch (...) + { + /* reset MXCSR register again */ + _mm_setcsr(mxcsr); + + accels_clear(); + updateInterface(); + throw; + } + } +#endif + + void Scene::setProgressMonitorFunction(RTCProgressMonitorFunction func, void* ptr) + { + progress_monitor_function = func; + progress_monitor_ptr = ptr; + } + + void Scene::progressMonitor(double dn) + { + if (progress_monitor_function) { + size_t n = size_t(dn) + progress_monitor_counter.fetch_add(size_t(dn)); + if (!progress_monitor_function(progress_monitor_ptr, n / (double(numPrimitives())))) { + throw_RTCError(RTC_ERROR_CANCELLED,"progress monitor forced termination"); + } + } + } +} diff --git a/thirdparty/embree/kernels/common/scene.h b/thirdparty/embree/kernels/common/scene.h new file mode 100644 index 000000000000..b41c6cde9134 --- /dev/null +++ b/thirdparty/embree/kernels/common/scene.h @@ -0,0 +1,390 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" +#include "device.h" +#include "builder.h" +#include "../../common/algorithms/parallel_any_of.h" +#include "scene_triangle_mesh.h" +#include "scene_quad_mesh.h" +#include "scene_user_geometry.h" +#include "scene_instance.h" +#include "scene_curves.h" +#include "scene_line_segments.h" +#include "scene_subdiv_mesh.h" +#include "scene_grid_mesh.h" +#include "scene_points.h" +#include "../subdiv/tessellation_cache.h" + +#include "acceln.h" +#include "geometry.h" + +namespace embree +{ + /*! Base class all scenes are derived from */ + class Scene : public AccelN + { + ALIGNED_CLASS_(std::alignment_of::value); + + public: + template + class Iterator + { + public: + Iterator () {} + + Iterator (Scene* scene, bool all = false) + : scene(scene), all(all) {} + + __forceinline Ty* at(const size_t i) + { + Geometry* geom = scene->geometries[i].ptr; + if (geom == nullptr) return nullptr; + if (!all && !geom->isEnabled()) return nullptr; + const size_t mask = geom->getTypeMask() & Ty::geom_type; + if (!(mask)) return nullptr; + if ((geom->numTimeSteps != 1) != mblur) return nullptr; + return (Ty*) geom; + } + + __forceinline Ty* operator[] (const size_t i) { + return at(i); + } + + __forceinline size_t size() const { + return scene->size(); + } + + __forceinline size_t numPrimitives() const { + return scene->getNumPrimitives(Ty::geom_type,mblur); + } + + __forceinline size_t maxPrimitivesPerGeometry() + { + size_t ret = 0; + for (size_t i=0; isize(); i++) { + Ty* mesh = at(i); + if (mesh == nullptr) continue; + ret = max(ret,mesh->size()); + } + return ret; + } + + __forceinline unsigned int maxGeomID() + { + unsigned int ret = 0; + for (size_t i=0; isize(); i++) { + Ty* mesh = at(i); + if (mesh == nullptr) continue; + ret = max(ret,(unsigned int)i); + } + return ret; + } + + __forceinline unsigned maxTimeStepsPerGeometry() + { + unsigned ret = 0; + for (size_t i=0; isize(); i++) { + Ty* mesh = at(i); + if (mesh == nullptr) continue; + ret = max(ret,mesh->numTimeSteps); + } + return ret; + } + + private: + Scene* scene; + bool all; + }; + + class Iterator2 + { + public: + Iterator2 () {} + + Iterator2 (Scene* scene, Geometry::GTypeMask typemask, bool mblur) + : scene(scene), typemask(typemask), mblur(mblur) {} + + __forceinline Geometry* at(const size_t i) + { + Geometry* geom = scene->geometries[i].ptr; + if (geom == nullptr) return nullptr; + if (!geom->isEnabled()) return nullptr; + if (!(geom->getTypeMask() & typemask)) return nullptr; + if ((geom->numTimeSteps != 1) != mblur) return nullptr; + return geom; + } + + __forceinline Geometry* operator[] (const size_t i) { + return at(i); + } + + __forceinline size_t size() const { + return scene->size(); + } + + private: + Scene* scene; + Geometry::GTypeMask typemask; + bool mblur; + }; + + public: + + /*! Scene construction */ + Scene (Device* device); + + /*! Scene destruction */ + ~Scene () noexcept; + + private: + /*! class is non-copyable */ + Scene (const Scene& other) DELETED; // do not implement + Scene& operator= (const Scene& other) DELETED; // do not implement + + public: + void createTriangleAccel(); + void createTriangleMBAccel(); + void createQuadAccel(); + void createQuadMBAccel(); + void createHairAccel(); + void createHairMBAccel(); + void createSubdivAccel(); + void createSubdivMBAccel(); + void createUserGeometryAccel(); + void createUserGeometryMBAccel(); + void createInstanceAccel(); + void createInstanceMBAccel(); + void createInstanceExpensiveAccel(); + void createInstanceExpensiveMBAccel(); + void createGridAccel(); + void createGridMBAccel(); + + /*! prints statistics about the scene */ + void printStatistics(); + + /*! clears the scene */ + void clear(); + + /*! detaches some geometry */ + void detachGeometry(size_t geomID); + + void setBuildQuality(RTCBuildQuality quality_flags); + RTCBuildQuality getBuildQuality() const; + + void setSceneFlags(RTCSceneFlags scene_flags); + RTCSceneFlags getSceneFlags() const; + + void commit (bool join); + void commit_task (); + void build () {} + + void updateInterface(); + + /* return number of geometries */ + __forceinline size_t size() const { return geometries.size(); } + + /* bind geometry to the scene */ + unsigned int bind (unsigned geomID, Ref geometry); + + /* determines if scene is modified */ + __forceinline bool isModified() const { return modified; } + + /* sets modified flag */ + __forceinline void setModified(bool f = true) { + modified = f; + } + + __forceinline bool isGeometryModified(size_t geomID) + { + Ref& g = geometries[geomID]; + if (!g) return false; + return g->getModCounter() > geometryModCounters_[geomID]; + } + + protected: + + __forceinline void checkIfModifiedAndSet () + { + if (isModified ()) return; + + auto geometryIsModified = [this](size_t geomID)->bool { + return isGeometryModified(geomID); + }; + + if (parallel_any_of (size_t(0), geometries.size (), geometryIsModified)) { + setModified (); + } + } + + public: + + /* get mesh by ID */ + __forceinline Geometry* get(size_t i) { assert(i < geometries.size()); return geometries[i].ptr; } + __forceinline const Geometry* get(size_t i) const { assert(i < geometries.size()); return geometries[i].ptr; } + + template + __forceinline Mesh* get(size_t i) { + assert(i < geometries.size()); + assert(geometries[i]->getTypeMask() & Mesh::geom_type); + return (Mesh*)geometries[i].ptr; + } + template + __forceinline const Mesh* get(size_t i) const { + assert(i < geometries.size()); + assert(geometries[i]->getTypeMask() & Mesh::geom_type); + return (Mesh*)geometries[i].ptr; + } + + template + __forceinline Mesh* getSafe(size_t i) { + assert(i < geometries.size()); + if (geometries[i] == null) return nullptr; + if (!(geometries[i]->getTypeMask() & Mesh::geom_type)) return nullptr; + else return (Mesh*) geometries[i].ptr; + } + + __forceinline Ref get_locked(size_t i) { + Lock lock(geometriesMutex); + assert(i < geometries.size()); + return geometries[i]; + } + + /* flag decoding */ + __forceinline bool isFastAccel() const { return !isCompactAccel() && !isRobustAccel(); } + __forceinline bool isCompactAccel() const { return scene_flags & RTC_SCENE_FLAG_COMPACT; } + __forceinline bool isRobustAccel() const { return scene_flags & RTC_SCENE_FLAG_ROBUST; } + __forceinline bool isStaticAccel() const { return !(scene_flags & RTC_SCENE_FLAG_DYNAMIC); } + __forceinline bool isDynamicAccel() const { return scene_flags & RTC_SCENE_FLAG_DYNAMIC; } + + __forceinline bool hasContextFilterFunction() const { + return scene_flags & RTC_SCENE_FLAG_CONTEXT_FILTER_FUNCTION; + } + + __forceinline bool hasGeometryFilterFunction() { + return world.numFilterFunctions != 0; + } + + __forceinline bool hasFilterFunction() { + return hasContextFilterFunction() || hasGeometryFilterFunction(); + } + + /* test if scene got already build */ + __forceinline bool isBuild() const { return is_build; } + + public: + IDPool id_pool; + vector> geometries; //!< list of all user geometries + vector geometryModCounters_; + vector vertices; + + public: + Device* device; + + /* these are to detect if we need to recreate the acceleration structures */ + bool flags_modified; + unsigned int enabled_geometry_types; + + RTCSceneFlags scene_flags; + RTCBuildQuality quality_flags; + MutexSys buildMutex; + SpinLock geometriesMutex; + bool is_build; + private: + bool modified; //!< true if scene got modified + + public: + + /*! global lock step task scheduler */ +#if defined(TASKING_INTERNAL) + MutexSys schedulerMutex; + Ref scheduler; +#elif defined(TASKING_TBB) && TASKING_TBB_USE_TASK_ISOLATION + tbb::isolated_task_group group; +#elif defined(TASKING_TBB) + tbb::task_group group; +#elif defined(TASKING_PPL) + concurrency::task_group group; +#endif + + public: + struct BuildProgressMonitorInterface : public BuildProgressMonitor { + BuildProgressMonitorInterface(Scene* scene) + : scene(scene) {} + void operator() (size_t dn) const { scene->progressMonitor(double(dn)); } + private: + Scene* scene; + }; + BuildProgressMonitorInterface progressInterface; + RTCProgressMonitorFunction progress_monitor_function; + void* progress_monitor_ptr; + std::atomic progress_monitor_counter; + void progressMonitor(double nprims); + void setProgressMonitorFunction(RTCProgressMonitorFunction func, void* ptr); + + private: + GeometryCounts world; //!< counts for geometry + + public: + + __forceinline size_t numPrimitives() const { + return world.size(); + } + + __forceinline size_t getNumPrimitives(Geometry::GTypeMask mask, bool mblur) const + { + size_t count = 0; + + if (mask & Geometry::MTY_TRIANGLE_MESH) + count += mblur ? world.numMBTriangles : world.numTriangles; + + if (mask & Geometry::MTY_QUAD_MESH) + count += mblur ? world.numMBQuads : world.numQuads; + + if (mask & Geometry::MTY_CURVE2) + count += mblur ? world.numMBLineSegments : world.numLineSegments; + + if (mask & Geometry::MTY_CURVE4) + count += mblur ? world.numMBBezierCurves : world.numBezierCurves; + + if (mask & Geometry::MTY_POINTS) + count += mblur ? world.numMBPoints : world.numPoints; + + if (mask & Geometry::MTY_SUBDIV_MESH) + count += mblur ? world.numMBSubdivPatches : world.numSubdivPatches; + + if (mask & Geometry::MTY_USER_GEOMETRY) + count += mblur ? world.numMBUserGeometries : world.numUserGeometries; + + if (mask & Geometry::MTY_INSTANCE_CHEAP) + count += mblur ? world.numMBInstancesCheap : world.numInstancesCheap; + + if (mask & Geometry::MTY_INSTANCE_EXPENSIVE) + count += mblur ? world.numMBInstancesExpensive : world.numInstancesExpensive; + + if (mask & Geometry::MTY_GRID_MESH) + count += mblur ? world.numMBGrids : world.numGrids; + + return count; + } + + template + __forceinline unsigned getNumTimeSteps() + { + if (!mblur) + return 1; + + Scene::Iterator iter(this); + return iter.maxTimeStepsPerGeometry(); + } + + template + __forceinline unsigned int getMaxGeomID() + { + Scene::Iterator iter(this); + return iter.maxGeomID(); + } + }; +} diff --git a/thirdparty/embree/kernels/common/scene_curves.h b/thirdparty/embree/kernels/common/scene_curves.h new file mode 100644 index 000000000000..2649ab0e3e87 --- /dev/null +++ b/thirdparty/embree/kernels/common/scene_curves.h @@ -0,0 +1,341 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" +#include "geometry.h" +#include "buffer.h" + +namespace embree +{ + /*! represents an array of bicubic bezier curves */ + struct CurveGeometry : public Geometry + { + /*! type of this geometry */ + static const Geometry::GTypeMask geom_type = Geometry::MTY_CURVE4; + + public: + + /*! bezier curve construction */ + CurveGeometry (Device* device, Geometry::GType gtype); + + public: + void setMask(unsigned mask); + void setNumTimeSteps (unsigned int numTimeSteps); + void setVertexAttributeCount (unsigned int N); + void setBuffer(RTCBufferType type, unsigned int slot, RTCFormat format, const Ref& buffer, size_t offset, size_t stride, unsigned int num); + void* getBuffer(RTCBufferType type, unsigned int slot); + void updateBuffer(RTCBufferType type, unsigned int slot); + void commit(); + bool verify(); + void setTessellationRate(float N); + void setMaxRadiusScale(float s); + void addElementsToCount (GeometryCounts & counts) const; + + public: + + /*! returns the number of vertices */ + __forceinline size_t numVertices() const { + return vertices[0].size(); + } + + /*! returns the i'th curve */ + __forceinline const unsigned int& curve(size_t i) const { + return curves[i]; + } + + /*! returns i'th vertex of the first time step */ + __forceinline Vec3ff vertex(size_t i) const { + return vertices0[i]; + } + + /*! returns i'th normal of the first time step */ + __forceinline Vec3fa normal(size_t i) const { + return normals0[i]; + } + + /*! returns i'th tangent of the first time step */ + __forceinline Vec3ff tangent(size_t i) const { + return tangents0[i]; + } + + /*! returns i'th normal derivative of the first time step */ + __forceinline Vec3fa dnormal(size_t i) const { + return dnormals0[i]; + } + + /*! returns i'th radius of the first time step */ + __forceinline float radius(size_t i) const { + return vertices0[i].w; + } + + /*! returns i'th vertex of itime'th timestep */ + __forceinline Vec3ff vertex(size_t i, size_t itime) const { + return vertices[itime][i]; + } + + /*! returns i'th normal of itime'th timestep */ + __forceinline Vec3fa normal(size_t i, size_t itime) const { + return normals[itime][i]; + } + + /*! returns i'th tangent of itime'th timestep */ + __forceinline Vec3ff tangent(size_t i, size_t itime) const { + return tangents[itime][i]; + } + + /*! returns i'th normal derivative of itime'th timestep */ + __forceinline Vec3fa dnormal(size_t i, size_t itime) const { + return dnormals[itime][i]; + } + + /*! returns i'th radius of itime'th timestep */ + __forceinline float radius(size_t i, size_t itime) const { + return vertices[itime][i].w; + } + + /*! gathers the curve starting with i'th vertex */ + __forceinline void gather(Vec3ff& p0, Vec3ff& p1, Vec3ff& p2, Vec3ff& p3, size_t i) const + { + p0 = vertex(i+0); + p1 = vertex(i+1); + p2 = vertex(i+2); + p3 = vertex(i+3); + } + + /*! gathers the curve starting with i'th vertex of itime'th timestep */ + __forceinline void gather(Vec3ff& p0, Vec3ff& p1, Vec3ff& p2, Vec3ff& p3, size_t i, size_t itime) const + { + p0 = vertex(i+0,itime); + p1 = vertex(i+1,itime); + p2 = vertex(i+2,itime); + p3 = vertex(i+3,itime); + } + + /*! gathers the curve starting with i'th vertex */ + __forceinline void gather(Vec3ff& p0, Vec3ff& p1, Vec3ff& p2, Vec3ff& p3, Vec3fa& n0, Vec3fa& n1, Vec3fa& n2, Vec3fa& n3, size_t i) const + { + p0 = vertex(i+0); + p1 = vertex(i+1); + p2 = vertex(i+2); + p3 = vertex(i+3); + n0 = normal(i+0); + n1 = normal(i+1); + n2 = normal(i+2); + n3 = normal(i+3); + } + + /*! gathers the curve starting with i'th vertex of itime'th timestep */ + __forceinline void gather(Vec3ff& p0, Vec3ff& p1, Vec3ff& p2, Vec3ff& p3, Vec3fa& n0, Vec3fa& n1, Vec3fa& n2, Vec3fa& n3, size_t i, size_t itime) const + { + p0 = vertex(i+0,itime); + p1 = vertex(i+1,itime); + p2 = vertex(i+2,itime); + p3 = vertex(i+3,itime); + n0 = normal(i+0,itime); + n1 = normal(i+1,itime); + n2 = normal(i+2,itime); + n3 = normal(i+3,itime); + } + + /*! prefetches the curve starting with i'th vertex of itime'th timestep */ + __forceinline void prefetchL1_vertices(size_t i) const + { + prefetchL1(vertices0.getPtr(i)+0); + prefetchL1(vertices0.getPtr(i)+64); + } + + /*! prefetches the curve starting with i'th vertex of itime'th timestep */ + __forceinline void prefetchL2_vertices(size_t i) const + { + prefetchL2(vertices0.getPtr(i)+0); + prefetchL2(vertices0.getPtr(i)+64); + } + + /*! loads curve vertices for specified time */ + __forceinline void gather(Vec3ff& p0, Vec3ff& p1, Vec3ff& p2, Vec3ff& p3, size_t i, float time) const + { + float ftime; + const size_t itime = timeSegment(time, ftime); + + const float t0 = 1.0f - ftime; + const float t1 = ftime; + Vec3ff a0,a1,a2,a3; + gather(a0,a1,a2,a3,i,itime); + Vec3ff b0,b1,b2,b3; + gather(b0,b1,b2,b3,i,itime+1); + p0 = madd(Vec3ff(t0),a0,t1*b0); + p1 = madd(Vec3ff(t0),a1,t1*b1); + p2 = madd(Vec3ff(t0),a2,t1*b2); + p3 = madd(Vec3ff(t0),a3,t1*b3); + } + + /*! loads curve vertices for specified time */ + __forceinline void gather(Vec3ff& p0, Vec3ff& p1, Vec3ff& p2, Vec3ff& p3, Vec3fa& n0, Vec3fa& n1, Vec3fa& n2, Vec3fa& n3, size_t i, float time) const + { + float ftime; + const size_t itime = timeSegment(time, ftime); + + const float t0 = 1.0f - ftime; + const float t1 = ftime; + Vec3ff a0,a1,a2,a3; Vec3fa an0,an1,an2,an3; + gather(a0,a1,a2,a3,an0,an1,an2,an3,i,itime); + Vec3ff b0,b1,b2,b3; Vec3fa bn0,bn1,bn2,bn3; + gather(b0,b1,b2,b3,bn0,bn1,bn2,bn3,i,itime+1); + p0 = madd(Vec3ff(t0),a0,t1*b0); + p1 = madd(Vec3ff(t0),a1,t1*b1); + p2 = madd(Vec3ff(t0),a2,t1*b2); + p3 = madd(Vec3ff(t0),a3,t1*b3); + n0 = madd(Vec3ff(t0),an0,t1*bn0); + n1 = madd(Vec3ff(t0),an1,t1*bn1); + n2 = madd(Vec3ff(t0),an2,t1*bn2); + n3 = madd(Vec3ff(t0),an3,t1*bn3); + } + + template + __forceinline TensorLinearCubicBezierSurface3fa getNormalOrientedCurve(IntersectContext* context, const Vec3fa& ray_org, const unsigned int primID, const size_t itime) const + { + Vec3ff v0,v1,v2,v3; Vec3fa n0,n1,n2,n3; + unsigned int vertexID = curve(primID); + gather(v0,v1,v2,v3,n0,n1,n2,n3,vertexID,itime); + SourceCurve3ff ccurve(v0,v1,v2,v3); + SourceCurve3fa ncurve(n0,n1,n2,n3); + ccurve = enlargeRadiusToMinWidth(context,this,ray_org,ccurve); + return TensorLinearCubicBezierSurface3fa::fromCenterAndNormalCurve(ccurve,ncurve); + } + + template + __forceinline TensorLinearCubicBezierSurface3fa getNormalOrientedCurve(IntersectContext* context, const Vec3fa& ray_org, const unsigned int primID, const float time) const + { + float ftime; + const size_t itime = timeSegment(time, ftime); + const TensorLinearCubicBezierSurface3fa curve0 = getNormalOrientedCurve(context,ray_org,primID,itime+0); + const TensorLinearCubicBezierSurface3fa curve1 = getNormalOrientedCurve(context,ray_org,primID,itime+1); + return clerp(curve0,curve1,ftime); + } + + /*! gathers the hermite curve starting with i'th vertex */ + __forceinline void gather_hermite(Vec3ff& p0, Vec3ff& t0, Vec3ff& p1, Vec3ff& t1, size_t i) const + { + p0 = vertex (i+0); + p1 = vertex (i+1); + t0 = tangent(i+0); + t1 = tangent(i+1); + } + + /*! gathers the hermite curve starting with i'th vertex of itime'th timestep */ + __forceinline void gather_hermite(Vec3ff& p0, Vec3ff& t0, Vec3ff& p1, Vec3ff& t1, size_t i, size_t itime) const + { + p0 = vertex (i+0,itime); + p1 = vertex (i+1,itime); + t0 = tangent(i+0,itime); + t1 = tangent(i+1,itime); + } + + /*! loads curve vertices for specified time */ + __forceinline void gather_hermite(Vec3ff& p0, Vec3ff& t0, Vec3ff& p1, Vec3ff& t1, size_t i, float time) const + { + float ftime; + const size_t itime = timeSegment(time, ftime); + const float f0 = 1.0f - ftime, f1 = ftime; + Vec3ff ap0,at0,ap1,at1; + gather_hermite(ap0,at0,ap1,at1,i,itime); + Vec3ff bp0,bt0,bp1,bt1; + gather_hermite(bp0,bt0,bp1,bt1,i,itime+1); + p0 = madd(Vec3ff(f0),ap0,f1*bp0); + t0 = madd(Vec3ff(f0),at0,f1*bt0); + p1 = madd(Vec3ff(f0),ap1,f1*bp1); + t1 = madd(Vec3ff(f0),at1,f1*bt1); + } + + /*! gathers the hermite curve starting with i'th vertex */ + __forceinline void gather_hermite(Vec3ff& p0, Vec3ff& t0, Vec3fa& n0, Vec3fa& dn0, Vec3ff& p1, Vec3ff& t1, Vec3fa& n1, Vec3fa& dn1, size_t i) const + { + p0 = vertex (i+0); + p1 = vertex (i+1); + t0 = tangent(i+0); + t1 = tangent(i+1); + n0 = normal(i+0); + n1 = normal(i+1); + dn0 = dnormal(i+0); + dn1 = dnormal(i+1); + } + + /*! gathers the hermite curve starting with i'th vertex of itime'th timestep */ + __forceinline void gather_hermite(Vec3ff& p0, Vec3ff& t0, Vec3fa& n0, Vec3fa& dn0, Vec3ff& p1, Vec3ff& t1, Vec3fa& n1, Vec3fa& dn1, size_t i, size_t itime) const + { + p0 = vertex (i+0,itime); + p1 = vertex (i+1,itime); + t0 = tangent(i+0,itime); + t1 = tangent(i+1,itime); + n0 = normal(i+0,itime); + n1 = normal(i+1,itime); + dn0 = dnormal(i+0,itime); + dn1 = dnormal(i+1,itime); + } + + /*! loads curve vertices for specified time */ + __forceinline void gather_hermite(Vec3ff& p0, Vec3fa& t0, Vec3fa& n0, Vec3fa& dn0, Vec3ff& p1, Vec3fa& t1, Vec3fa& n1, Vec3fa& dn1, size_t i, float time) const + { + float ftime; + const size_t itime = timeSegment(time, ftime); + const float f0 = 1.0f - ftime, f1 = ftime; + Vec3ff ap0,at0,ap1,at1; Vec3fa an0,adn0,an1,adn1; + gather_hermite(ap0,at0,an0,adn0,ap1,at1,an1,adn1,i,itime); + Vec3ff bp0,bt0,bp1,bt1; Vec3fa bn0,bdn0,bn1,bdn1; + gather_hermite(bp0,bt0,bn0,bdn0,bp1,bt1,bn1,bdn1,i,itime+1); + p0 = madd(Vec3ff(f0),ap0,f1*bp0); + t0 = madd(Vec3ff(f0),at0,f1*bt0); + n0 = madd(Vec3ff(f0),an0,f1*bn0); + dn0= madd(Vec3ff(f0),adn0,f1*bdn0); + p1 = madd(Vec3ff(f0),ap1,f1*bp1); + t1 = madd(Vec3ff(f0),at1,f1*bt1); + n1 = madd(Vec3ff(f0),an1,f1*bn1); + dn1= madd(Vec3ff(f0),adn1,f1*bdn1); + } + + template + __forceinline TensorLinearCubicBezierSurface3fa getNormalOrientedHermiteCurve(IntersectContext* context, const Vec3fa& ray_org, const unsigned int primID, const size_t itime) const + { + Vec3ff v0,t0,v1,t1; Vec3fa n0,dn0,n1,dn1; + unsigned int vertexID = curve(primID); + gather_hermite(v0,t0,n0,dn0,v1,t1,n1,dn1,vertexID,itime); + + SourceCurve3ff ccurve(v0,t0,v1,t1); + SourceCurve3fa ncurve(n0,dn0,n1,dn1); + ccurve = enlargeRadiusToMinWidth(context,this,ray_org,ccurve); + return TensorLinearCubicBezierSurface3fa::fromCenterAndNormalCurve(ccurve,ncurve); + } + + template + __forceinline TensorLinearCubicBezierSurface3fa getNormalOrientedHermiteCurve(IntersectContext* context, const Vec3fa& ray_org, const unsigned int primID, const float time) const + { + float ftime; + const size_t itime = timeSegment(time, ftime); + const TensorLinearCubicBezierSurface3fa curve0 = getNormalOrientedHermiteCurve(context, ray_org, primID,itime+0); + const TensorLinearCubicBezierSurface3fa curve1 = getNormalOrientedHermiteCurve(context, ray_org, primID,itime+1); + return clerp(curve0,curve1,ftime); + } + + private: + void resizeBuffers(unsigned int numSteps); + + public: + BufferView curves; //!< array of curve indices + BufferView vertices0; //!< fast access to first vertex buffer + BufferView normals0; //!< fast access to first normal buffer + BufferView tangents0; //!< fast access to first tangent buffer + BufferView dnormals0; //!< fast access to first normal derivative buffer + vector> vertices; //!< vertex array for each timestep + vector> normals; //!< normal array for each timestep + vector> tangents; //!< tangent array for each timestep + vector> dnormals; //!< normal derivative array for each timestep + BufferView flags; //!< start, end flag per segment + vector> vertexAttribs; //!< user buffers + int tessellationRate; //!< tessellation rate for flat curve + float maxRadiusScale = 1.0; //!< maximal min-width scaling of curve radii + }; + + DECLARE_ISA_FUNCTION(CurveGeometry*, createCurves, Device* COMMA Geometry::GType); +} diff --git a/thirdparty/embree/kernels/common/scene_grid_mesh.h b/thirdparty/embree/kernels/common/scene_grid_mesh.h new file mode 100644 index 000000000000..c08658466a26 --- /dev/null +++ b/thirdparty/embree/kernels/common/scene_grid_mesh.h @@ -0,0 +1,215 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "geometry.h" +#include "buffer.h" + +namespace embree +{ + /*! Grid Mesh */ + struct GridMesh : public Geometry + { + /*! type of this geometry */ + static const Geometry::GTypeMask geom_type = Geometry::MTY_GRID_MESH; + + /*! grid */ + struct Grid + { + unsigned int startVtxID; + unsigned int lineVtxOffset; + unsigned short resX,resY; + + /* border flags due to 3x3 vertex pattern */ + __forceinline unsigned int get3x3FlagsX(const unsigned int x) const + { + return (x + 2 >= (unsigned int)resX) ? (1<<15) : 0; + } + + /* border flags due to 3x3 vertex pattern */ + __forceinline unsigned int get3x3FlagsY(const unsigned int y) const + { + return (y + 2 >= (unsigned int)resY) ? (1<<15) : 0; + } + + /*! outputs grid structure */ + __forceinline friend embree_ostream operator<<(embree_ostream cout, const Grid& t) { + return cout << "Grid { startVtxID " << t.startVtxID << ", lineVtxOffset " << t.lineVtxOffset << ", resX " << t.resX << ", resY " << t.resY << " }"; + } + }; + + public: + + /*! grid mesh construction */ + GridMesh (Device* device); + + /* geometry interface */ + public: + void setMask(unsigned mask); + void setNumTimeSteps (unsigned int numTimeSteps); + void setVertexAttributeCount (unsigned int N); + void setBuffer(RTCBufferType type, unsigned int slot, RTCFormat format, const Ref& buffer, size_t offset, size_t stride, unsigned int num); + void* getBuffer(RTCBufferType type, unsigned int slot); + void updateBuffer(RTCBufferType type, unsigned int slot); + void commit(); + bool verify(); + void interpolate(const RTCInterpolateArguments* const args); + void addElementsToCount (GeometryCounts & counts) const; + + __forceinline unsigned int getNumSubGrids(const size_t gridID) + { + const Grid &g = grid(gridID); + return max((unsigned int)1,((unsigned int)g.resX >> 1) * ((unsigned int)g.resY >> 1)); + } + + /*! get fast access to first vertex buffer */ + __forceinline float * getCompactVertexArray () const { + return (float*) vertices0.getPtr(); + } + + public: + + /*! returns number of vertices */ + __forceinline size_t numVertices() const { + return vertices[0].size(); + } + + /*! returns i'th grid*/ + __forceinline const Grid& grid(size_t i) const { + return grids[i]; + } + + /*! returns i'th vertex of the first time step */ + __forceinline const Vec3fa vertex(size_t i) const { // FIXME: check if this does a unaligned load + return vertices0[i]; + } + + /*! returns i'th vertex of the first time step */ + __forceinline const char* vertexPtr(size_t i) const { + return vertices0.getPtr(i); + } + + /*! returns i'th vertex of itime'th timestep */ + __forceinline const Vec3fa vertex(size_t i, size_t itime) const { + return vertices[itime][i]; + } + + /*! returns i'th vertex of itime'th timestep */ + __forceinline const char* vertexPtr(size_t i, size_t itime) const { + return vertices[itime].getPtr(i); + } + + /*! returns i'th vertex of the first timestep */ + __forceinline size_t grid_vertex_index(const Grid& g, size_t x, size_t y) const { + assert(x < (size_t)g.resX); + assert(y < (size_t)g.resY); + return g.startVtxID + x + y * g.lineVtxOffset; + } + + /*! returns i'th vertex of the first timestep */ + __forceinline const Vec3fa grid_vertex(const Grid& g, size_t x, size_t y) const { + const size_t index = grid_vertex_index(g,x,y); + return vertex(index); + } + + /*! returns i'th vertex of the itime'th timestep */ + __forceinline const Vec3fa grid_vertex(const Grid& g, size_t x, size_t y, size_t itime) const { + const size_t index = grid_vertex_index(g,x,y); + return vertex(index,itime); + } + + /*! calculates the build bounds of the i'th primitive, if it's valid */ + __forceinline bool buildBounds(const Grid& g, size_t sx, size_t sy, BBox3fa& bbox) const + { + BBox3fa b(empty); + for (size_t t=0; t& itime_range) const + { + if (unlikely(gridID >= grids.size())) return false; + const Grid &g = grid(gridID); + if (unlikely(g.startVtxID + 0 >= vertices0.size())) return false; + if (unlikely(g.startVtxID + (g.resY-1)*g.lineVtxOffset + g.resX-1 >= vertices0.size())) return false; + + for (size_t y=0;y grids; //!< array of triangles + BufferView vertices0; //!< fast access to first vertex buffer + vector> vertices; //!< vertex array for each timestep + vector vertexAttribs; //!< vertex attributes + }; + + namespace isa + { + struct GridMeshISA : public GridMesh + { + GridMeshISA (Device* device) + : GridMesh(device) {} + }; + } + + DECLARE_ISA_FUNCTION(GridMesh*, createGridMesh, Device*); +} diff --git a/thirdparty/embree/kernels/common/scene_instance.h b/thirdparty/embree/kernels/common/scene_instance.h new file mode 100644 index 000000000000..7ff82a4fb8dc --- /dev/null +++ b/thirdparty/embree/kernels/common/scene_instance.h @@ -0,0 +1,272 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "geometry.h" +#include "accel.h" + +namespace embree +{ + struct MotionDerivativeCoefficients; + + /*! Instanced acceleration structure */ + struct Instance : public Geometry + { + ALIGNED_STRUCT_(16); + static const Geometry::GTypeMask geom_type = Geometry::MTY_INSTANCE; + + public: + Instance (Device* device, Accel* object = nullptr, unsigned int numTimeSteps = 1); + ~Instance(); + + private: + Instance (const Instance& other) DELETED; // do not implement + Instance& operator= (const Instance& other) DELETED; // do not implement + + private: + LBBox3fa nonlinearBounds(const BBox1f& time_range_in, + const BBox1f& geom_time_range, + float geom_time_segments) const; + + BBox3fa boundSegment(size_t itime, + BBox3fa const& obbox0, BBox3fa const& obbox1, + BBox3fa const& bbox0, BBox3fa const& bbox1, + float t_min, float t_max) const; + + /* calculates the (correct) interpolated bounds */ + __forceinline BBox3fa bounds(size_t itime0, size_t itime1, float f) const + { + if (unlikely(gsubtype == GTY_SUBTYPE_INSTANCE_QUATERNION)) + return xfmBounds(slerp(local2world[itime0], local2world[itime1], f), + lerp(getObjectBounds(itime0), getObjectBounds(itime1), f)); + return xfmBounds(lerp(local2world[itime0], local2world[itime1], f), + lerp(getObjectBounds(itime0), getObjectBounds(itime1), f)); + } + + public: + virtual void setNumTimeSteps (unsigned int numTimeSteps) override; + virtual void setInstancedScene(const Ref& scene) override; + virtual void setTransform(const AffineSpace3fa& local2world, unsigned int timeStep) override; + virtual void setQuaternionDecomposition(const AffineSpace3ff& qd, unsigned int timeStep) override; + virtual AffineSpace3fa getTransform(float time) override; + virtual void setMask (unsigned mask) override; + virtual void build() {} + virtual void addElementsToCount (GeometryCounts & counts) const override; + virtual void commit() override; + + public: + + /*! calculates the bounds of instance */ + __forceinline BBox3fa bounds(size_t i) const { + assert(i == 0); + if (unlikely(gsubtype == GTY_SUBTYPE_INSTANCE_QUATERNION)) + return xfmBounds(quaternionDecompositionToAffineSpace(local2world[0]),object->bounds.bounds()); + return xfmBounds(local2world[0],object->bounds.bounds()); + } + + /*! gets the bounds of the instanced scene */ + __forceinline BBox3fa getObjectBounds(size_t itime) const { + return object->getBounds(timeStep(itime)); + } + + /*! calculates the bounds of instance */ + __forceinline BBox3fa bounds(size_t i, size_t itime) const { + assert(i == 0); + if (unlikely(gsubtype == GTY_SUBTYPE_INSTANCE_QUATERNION)) + return xfmBounds(quaternionDecompositionToAffineSpace(local2world[itime]),getObjectBounds(itime)); + return xfmBounds(local2world[itime],getObjectBounds(itime)); + } + + /*! calculates the linear bounds of the i'th primitive for the specified time range */ + __forceinline LBBox3fa linearBounds(size_t i, const BBox1f& dt) const { + assert(i == 0); + LBBox3fa lbbox = nonlinearBounds(dt, time_range, fnumTimeSegments); + return lbbox; + } + + /*! calculates the build bounds of the i'th item, if it's valid */ + __forceinline bool buildBounds(size_t i, BBox3fa* bbox = nullptr) const + { + assert(i==0); + const BBox3fa b = bounds(i); + if (bbox) *bbox = b; + return isvalid(b); + } + + /*! calculates the build bounds of the i'th item at the itime'th time segment, if it's valid */ + __forceinline bool buildBounds(size_t i, size_t itime, BBox3fa& bbox) const + { + assert(i==0); + const LBBox3fa bounds = linearBounds(i,itime); + bbox = bounds.bounds (); + return isvalid(bounds); + } + + /* gets version info of topology */ + unsigned int getTopologyVersion() const { + return numPrimitives; + } + + /* returns true if topology changed */ + bool topologyChanged(unsigned int otherVersion) const { + return numPrimitives != otherVersion; + } + + /*! check if the i'th primitive is valid between the specified time range */ + __forceinline bool valid(size_t i, const range& itime_range) const + { + assert(i == 0); + for (size_t itime = itime_range.begin(); itime <= itime_range.end(); itime++) + if (!isvalid(bounds(i,itime))) return false; + + return true; + } + + __forceinline AffineSpace3fa getLocal2World() const + { + if (unlikely(gsubtype == GTY_SUBTYPE_INSTANCE_QUATERNION)) + return quaternionDecompositionToAffineSpace(local2world[0]); + return local2world[0]; + } + + __forceinline AffineSpace3fa getLocal2World(float t) const + { + float ftime; const unsigned int itime = timeSegment(t, ftime); + if (unlikely(gsubtype == GTY_SUBTYPE_INSTANCE_QUATERNION)) + return slerp(local2world[itime+0],local2world[itime+1],ftime); + return lerp(local2world[itime+0],local2world[itime+1],ftime); + } + + __forceinline AffineSpace3fa getWorld2Local() const { + return world2local0; + } + + __forceinline AffineSpace3fa getWorld2Local(float t) const { + return rcp(getLocal2World(t)); + } + + template + __forceinline AffineSpace3vf getWorld2Local(const vbool& valid, const vfloat& t) const + { + if (unlikely(gsubtype == GTY_SUBTYPE_INSTANCE_QUATERNION)) + return getWorld2LocalSlerp(valid, t); + return getWorld2LocalLerp(valid, t); + } + + private: + + template + __forceinline AffineSpace3vf getWorld2LocalSlerp(const vbool& valid, const vfloat& t) const + { + vfloat ftime; + const vint itime_k = timeSegment(t, ftime); + assert(any(valid)); + const size_t index = bsf(movemask(valid)); + const int itime = itime_k[index]; + if (likely(all(valid, itime_k == vint(itime)))) { + return rcp(slerp(AffineSpace3vff(local2world[itime+0]), + AffineSpace3vff(local2world[itime+1]), + ftime)); + } + else { + AffineSpace3vff space0,space1; + vbool valid1 = valid; + while (any(valid1)) { + vbool valid2; + const int itime = next_unique(valid1, itime_k, valid2); + space0 = select(valid2, AffineSpace3vff(local2world[itime+0]), space0); + space1 = select(valid2, AffineSpace3vff(local2world[itime+1]), space1); + } + return rcp(slerp(space0, space1, ftime)); + } + } + + template + __forceinline AffineSpace3vf getWorld2LocalLerp(const vbool& valid, const vfloat& t) const + { + vfloat ftime; + const vint itime_k = timeSegment(t, ftime); + assert(any(valid)); + const size_t index = bsf(movemask(valid)); + const int itime = itime_k[index]; + if (likely(all(valid, itime_k == vint(itime)))) { + return rcp(lerp(AffineSpace3vf((AffineSpace3fa)local2world[itime+0]), + AffineSpace3vf((AffineSpace3fa)local2world[itime+1]), + ftime)); + } else { + AffineSpace3vf space0,space1; + vbool valid1 = valid; + while (any(valid1)) { + vbool valid2; + const int itime = next_unique(valid1, itime_k, valid2); + space0 = select(valid2, AffineSpace3vf((AffineSpace3fa)local2world[itime+0]), space0); + space1 = select(valid2, AffineSpace3vf((AffineSpace3fa)local2world[itime+1]), space1); + } + return rcp(lerp(space0, space1, ftime)); + } + } + + public: + Accel* object; //!< pointer to instanced acceleration structure + AffineSpace3ff* local2world; //!< transformation from local space to world space for each timestep (either normal matrix or quaternion decomposition) + AffineSpace3fa world2local0; //!< transformation from world space to local space for timestep 0 + }; + + namespace isa + { + struct InstanceISA : public Instance + { + InstanceISA (Device* device) + : Instance(device) {} + + PrimInfo createPrimRefArray(mvector& prims, const range& r, size_t k, unsigned int geomID) const + { + assert(r.begin() == 0); + assert(r.end() == 1); + + PrimInfo pinfo(empty); + BBox3fa b = empty; + if (!buildBounds(0,&b)) return pinfo; + // const BBox3fa b = bounds(0); + // if (!isvalid(b)) return pinfo; + + const PrimRef prim(b,geomID,unsigned(0)); + pinfo.add_center2(prim); + prims[k++] = prim; + return pinfo; + } + + PrimInfo createPrimRefArrayMB(mvector& prims, size_t itime, const range& r, size_t k, unsigned int geomID) const + { + assert(r.begin() == 0); + assert(r.end() == 1); + + PrimInfo pinfo(empty); + BBox3fa b = empty; + if (!buildBounds(0,&b)) return pinfo; + // if (!valid(0,range(itime))) return pinfo; + // const PrimRef prim(linearBounds(0,itime).bounds(),geomID,unsigned(0)); + const PrimRef prim(b,geomID,unsigned(0)); + pinfo.add_center2(prim); + prims[k++] = prim; + return pinfo; + } + + PrimInfoMB createPrimRefMBArray(mvector& prims, const BBox1f& t0t1, const range& r, size_t k, unsigned int geomID) const + { + assert(r.begin() == 0); + assert(r.end() == 1); + + PrimInfoMB pinfo(empty); + if (!valid(0, timeSegmentRange(t0t1))) return pinfo; + const PrimRefMB prim(linearBounds(0,t0t1),this->numTimeSegments(),this->time_range,this->numTimeSegments(),geomID,unsigned(0)); + pinfo.add_primref(prim); + prims[k++] = prim; + return pinfo; + } + }; + } + + DECLARE_ISA_FUNCTION(Instance*, createInstance, Device*); +} diff --git a/thirdparty/embree/kernels/common/scene_line_segments.h b/thirdparty/embree/kernels/common/scene_line_segments.h new file mode 100644 index 000000000000..c0f9ee8f7796 --- /dev/null +++ b/thirdparty/embree/kernels/common/scene_line_segments.h @@ -0,0 +1,307 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" +#include "geometry.h" +#include "buffer.h" + +namespace embree +{ + /*! represents an array of line segments */ + struct LineSegments : public Geometry + { + /*! type of this geometry */ + static const Geometry::GTypeMask geom_type = Geometry::MTY_CURVE2; + + public: + + /*! line segments construction */ + LineSegments (Device* device, Geometry::GType gtype); + + public: + void setMask (unsigned mask); + void setNumTimeSteps (unsigned int numTimeSteps); + void setVertexAttributeCount (unsigned int N); + void setBuffer(RTCBufferType type, unsigned int slot, RTCFormat format, const Ref& buffer, size_t offset, size_t stride, unsigned int num); + void* getBuffer(RTCBufferType type, unsigned int slot); + void updateBuffer(RTCBufferType type, unsigned int slot); + void commit(); + bool verify (); + void interpolate(const RTCInterpolateArguments* const args); + void setTessellationRate(float N); + void setMaxRadiusScale(float s); + void addElementsToCount (GeometryCounts & counts) const; + + public: + + /*! returns the number of vertices */ + __forceinline size_t numVertices() const { + return vertices[0].size(); + } + + /*! returns the i'th segment */ + __forceinline const unsigned int& segment(size_t i) const { + return segments[i]; + } + + /*! returns the segment to the left of the i'th segment */ + __forceinline bool segmentLeftExists(size_t i) const { + assert (flags); + return (flags[i] & RTC_CURVE_FLAG_NEIGHBOR_LEFT) != 0; + } + + /*! returns the segment to the right of the i'th segment */ + __forceinline bool segmentRightExists(size_t i) const { + assert (flags); + return (flags[i] & RTC_CURVE_FLAG_NEIGHBOR_RIGHT) != 0; + } + + /*! returns i'th vertex of the first time step */ + __forceinline Vec3ff vertex(size_t i) const { + return vertices0[i]; + } + + /*! returns i'th vertex of the first time step */ + __forceinline const char* vertexPtr(size_t i) const { + return vertices0.getPtr(i); + } + + /*! returns i'th normal of the first time step */ + __forceinline Vec3fa normal(size_t i) const { + return normals0[i]; + } + + /*! returns i'th radius of the first time step */ + __forceinline float radius(size_t i) const { + return vertices0[i].w; + } + + /*! returns i'th vertex of itime'th timestep */ + __forceinline Vec3ff vertex(size_t i, size_t itime) const { + return vertices[itime][i]; + } + + /*! returns i'th vertex of itime'th timestep */ + __forceinline const char* vertexPtr(size_t i, size_t itime) const { + return vertices[itime].getPtr(i); + } + + /*! returns i'th normal of itime'th timestep */ + __forceinline Vec3fa normal(size_t i, size_t itime) const { + return normals[itime][i]; + } + + /*! returns i'th radius of itime'th timestep */ + __forceinline float radius(size_t i, size_t itime) const { + return vertices[itime][i].w; + } + + /*! calculates bounding box of i'th line segment */ + __forceinline BBox3fa bounds(const Vec3ff& v0, const Vec3ff& v1) const + { + const BBox3ff b = merge(BBox3ff(v0),BBox3ff(v1)); + return enlarge((BBox3fa)b,maxRadiusScale*Vec3fa(max(v0.w,v1.w))); + } + + /*! calculates bounding box of i'th line segment */ + __forceinline BBox3fa bounds(size_t i) const + { + const unsigned int index = segment(i); + const Vec3ff v0 = vertex(index+0); + const Vec3ff v1 = vertex(index+1); + return bounds(v0,v1); + } + + /*! calculates bounding box of i'th line segment for the itime'th time step */ + __forceinline BBox3fa bounds(size_t i, size_t itime) const + { + const unsigned int index = segment(i); + const Vec3ff v0 = vertex(index+0,itime); + const Vec3ff v1 = vertex(index+1,itime); + return bounds(v0,v1); + } + + /*! calculates bounding box of i'th line segment */ + __forceinline BBox3fa bounds(const LinearSpace3fa& space, size_t i) const + { + const unsigned int index = segment(i); + const Vec3ff v0 = vertex(index+0); + const Vec3ff v1 = vertex(index+1); + const Vec3ff w0(xfmVector(space,(Vec3fa)v0),v0.w); + const Vec3ff w1(xfmVector(space,(Vec3fa)v1),v1.w); + return bounds(w0,w1); + } + + /*! calculates bounding box of i'th line segment for the itime'th time step */ + __forceinline BBox3fa bounds(const LinearSpace3fa& space, size_t i, size_t itime) const + { + const unsigned int index = segment(i); + const Vec3ff v0 = vertex(index+0,itime); + const Vec3ff v1 = vertex(index+1,itime); + const Vec3ff w0(xfmVector(space,(Vec3fa)v0),v0.w); + const Vec3ff w1(xfmVector(space,(Vec3fa)v1),v1.w); + return bounds(w0,w1); + } + + /*! check if the i'th primitive is valid at the itime'th timestep */ + __forceinline bool valid(size_t i, size_t itime) const { + return valid(i, make_range(itime, itime)); + } + + /*! check if the i'th primitive is valid between the specified time range */ + __forceinline bool valid(size_t i, const range& itime_range) const + { + const unsigned int index = segment(i); + if (index+1 >= numVertices()) return false; + + for (size_t itime = itime_range.begin(); itime <= itime_range.end(); itime++) + { + const Vec3ff v0 = vertex(index+0,itime); if (unlikely(!isvalid4(v0))) return false; + const Vec3ff v1 = vertex(index+1,itime); if (unlikely(!isvalid4(v1))) return false; + if (min(v0.w,v1.w) < 0.0f) return false; + } + return true; + } + + /*! calculates the linear bounds of the i'th primitive at the itimeGlobal'th time segment */ + __forceinline LBBox3fa linearBounds(size_t i, size_t itime) const { + return LBBox3fa(bounds(i,itime+0),bounds(i,itime+1)); + } + + /*! calculates the build bounds of the i'th primitive, if it's valid */ + __forceinline bool buildBounds(size_t i, BBox3fa* bbox) const + { + if (!valid(i,0)) return false; + *bbox = bounds(i); + return true; + } + + /*! calculates the build bounds of the i'th primitive at the itime'th time segment, if it's valid */ + __forceinline bool buildBounds(size_t i, size_t itime, BBox3fa& bbox) const + { + if (!valid(i,itime+0) || !valid(i,itime+1)) return false; + bbox = bounds(i,itime); // use bounds of first time step in builder + return true; + } + + /*! calculates the linear bounds of the i'th primitive for the specified time range */ + __forceinline LBBox3fa linearBounds(size_t primID, const BBox1f& dt) const { + return LBBox3fa([&] (size_t itime) { return bounds(primID, itime); }, dt, time_range, fnumTimeSegments); + } + + /*! calculates the linear bounds of the i'th primitive for the specified time range */ + __forceinline LBBox3fa linearBounds(const LinearSpace3fa& space, size_t primID, const BBox1f& dt) const { + return LBBox3fa([&] (size_t itime) { return bounds(space, primID, itime); }, dt, time_range, fnumTimeSegments); + } + + /*! calculates the linear bounds of the i'th primitive for the specified time range */ + __forceinline bool linearBounds(size_t i, const BBox1f& time_range, LBBox3fa& bbox) const + { + if (!valid(i, timeSegmentRange(time_range))) return false; + bbox = linearBounds(i, time_range); + return true; + } + + /*! get fast access to first vertex buffer */ + __forceinline float * getCompactVertexArray () const { + return (float*) vertices0.getPtr(); + } + + public: + BufferView segments; //!< array of line segment indices + BufferView vertices0; //!< fast access to first vertex buffer + BufferView normals0; //!< fast access to first normal buffer + BufferView flags; //!< start, end flag per segment + vector> vertices; //!< vertex array for each timestep + vector> normals; //!< normal array for each timestep + vector> vertexAttribs; //!< user buffers + int tessellationRate; //!< tessellation rate for bezier curve + float maxRadiusScale = 1.0; //!< maximal min-width scaling of curve radii + }; + + namespace isa + { + struct LineSegmentsISA : public LineSegments + { + LineSegmentsISA (Device* device, Geometry::GType gtype) + : LineSegments(device,gtype) {} + + Vec3fa computeDirection(unsigned int primID) const + { + const unsigned vtxID = segment(primID); + const Vec3fa v0 = vertex(vtxID+0); + const Vec3fa v1 = vertex(vtxID+1); + return v1-v0; + } + + Vec3fa computeDirection(unsigned int primID, size_t time) const + { + const unsigned vtxID = segment(primID); + const Vec3fa v0 = vertex(vtxID+0,time); + const Vec3fa v1 = vertex(vtxID+1,time); + return v1-v0; + } + + PrimInfo createPrimRefArray(mvector& prims, const range& r, size_t k, unsigned int geomID) const + { + PrimInfo pinfo(empty); + for (size_t j=r.begin(); j& prims, size_t itime, const range& r, size_t k, unsigned int geomID) const + { + PrimInfo pinfo(empty); + for (size_t j=r.begin(); j& prims, const BBox1f& t0t1, const range& r, size_t k, unsigned int geomID) const + { + PrimInfoMB pinfo(empty); + for (size_t j=r.begin(); jnumTimeSegments(),this->time_range,this->numTimeSegments(),geomID,unsigned(j)); + pinfo.add_primref(prim); + prims[k++] = prim; + } + return pinfo; + } + + BBox3fa vbounds(size_t i) const { + return bounds(i); + } + + BBox3fa vbounds(const LinearSpace3fa& space, size_t i) const { + return bounds(space,i); + } + + LBBox3fa vlinearBounds(size_t primID, const BBox1f& time_range) const { + return linearBounds(primID,time_range); + } + + LBBox3fa vlinearBounds(const LinearSpace3fa& space, size_t primID, const BBox1f& time_range) const { + return linearBounds(space,primID,time_range); + } + }; + } + + DECLARE_ISA_FUNCTION(LineSegments*, createLineSegments, Device* COMMA Geometry::GType); +} diff --git a/thirdparty/embree/kernels/common/scene_points.h b/thirdparty/embree/kernels/common/scene_points.h new file mode 100644 index 000000000000..1d39ed07baa2 --- /dev/null +++ b/thirdparty/embree/kernels/common/scene_points.h @@ -0,0 +1,282 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "buffer.h" +#include "default.h" +#include "geometry.h" + +namespace embree +{ + /*! represents an array of points */ + struct Points : public Geometry + { + /*! type of this geometry */ + static const Geometry::GTypeMask geom_type = Geometry::MTY_POINTS; + + public: + /*! line segments construction */ + Points(Device* device, Geometry::GType gtype); + + public: + void setMask(unsigned mask); + void setNumTimeSteps(unsigned int numTimeSteps); + void setVertexAttributeCount(unsigned int N); + void setBuffer(RTCBufferType type, + unsigned int slot, + RTCFormat format, + const Ref& buffer, + size_t offset, + size_t stride, + unsigned int num); + void* getBuffer(RTCBufferType type, unsigned int slot); + void updateBuffer(RTCBufferType type, unsigned int slot); + void commit(); + bool verify(); + void setMaxRadiusScale(float s); + void addElementsToCount (GeometryCounts & counts) const; + + public: + /*! returns the number of vertices */ + __forceinline size_t numVertices() const { + return vertices[0].size(); + } + + /*! returns i'th vertex of the first time step */ + __forceinline Vec3ff vertex(size_t i) const { + return vertices0[i]; + } + + /*! returns i'th vertex of the first time step */ + __forceinline const char* vertexPtr(size_t i) const { + return vertices0.getPtr(i); + } + + /*! returns i'th normal of the first time step */ + __forceinline Vec3fa normal(size_t i) const { + return normals0[i]; + } + + /*! returns i'th radius of the first time step */ + __forceinline float radius(size_t i) const { + return vertices0[i].w; + } + + /*! returns i'th vertex of itime'th timestep */ + __forceinline Vec3ff vertex(size_t i, size_t itime) const { + return vertices[itime][i]; + } + + /*! returns i'th vertex of itime'th timestep */ + __forceinline const char* vertexPtr(size_t i, size_t itime) const { + return vertices[itime].getPtr(i); + } + + /*! returns i'th normal of itime'th timestep */ + __forceinline Vec3fa normal(size_t i, size_t itime) const { + return normals[itime][i]; + } + + /*! returns i'th radius of itime'th timestep */ + __forceinline float radius(size_t i, size_t itime) const { + return vertices[itime][i].w; + } + + /*! calculates bounding box of i'th line segment */ + __forceinline BBox3fa bounds(const Vec3ff& v0) const { + return enlarge(BBox3fa(v0), maxRadiusScale*Vec3fa(v0.w)); + } + + /*! calculates bounding box of i'th line segment */ + __forceinline BBox3fa bounds(size_t i) const + { + const Vec3ff v0 = vertex(i); + return bounds(v0); + } + + /*! calculates bounding box of i'th line segment for the itime'th time step */ + __forceinline BBox3fa bounds(size_t i, size_t itime) const + { + const Vec3ff v0 = vertex(i, itime); + return bounds(v0); + } + + /*! calculates bounding box of i'th line segment */ + __forceinline BBox3fa bounds(const LinearSpace3fa& space, size_t i) const + { + const Vec3ff v0 = vertex(i); + const Vec3ff w0(xfmVector(space, (Vec3fa)v0), v0.w); + return bounds(w0); + } + + /*! calculates bounding box of i'th line segment for the itime'th time step */ + __forceinline BBox3fa bounds(const LinearSpace3fa& space, size_t i, size_t itime) const + { + const Vec3ff v0 = vertex(i, itime); + const Vec3ff w0(xfmVector(space, (Vec3fa)v0), v0.w); + return bounds(w0); + } + + /*! check if the i'th primitive is valid at the itime'th timestep */ + __forceinline bool valid(size_t i, size_t itime) const { + return valid(i, make_range(itime, itime)); + } + + /*! check if the i'th primitive is valid between the specified time range */ + __forceinline bool valid(size_t i, const range& itime_range) const + { + const unsigned int index = (unsigned int)i; + if (index >= numVertices()) + return false; + + for (size_t itime = itime_range.begin(); itime <= itime_range.end(); itime++) { + const Vec3ff v0 = vertex(index + 0, itime); + if (unlikely(!isvalid4(v0))) + return false; + if (v0.w < 0.0f) + return false; + } + return true; + } + + /*! calculates the linear bounds of the i'th primitive at the itimeGlobal'th time segment */ + __forceinline LBBox3fa linearBounds(size_t i, size_t itime) const { + return LBBox3fa(bounds(i, itime + 0), bounds(i, itime + 1)); + } + + /*! calculates the build bounds of the i'th primitive, if it's valid */ + __forceinline bool buildBounds(size_t i, BBox3fa* bbox) const + { + if (!valid(i, 0)) + return false; + *bbox = bounds(i); + return true; + } + + /*! calculates the build bounds of the i'th primitive at the itime'th time segment, if it's valid */ + __forceinline bool buildBounds(size_t i, size_t itime, BBox3fa& bbox) const + { + if (!valid(i, itime + 0) || !valid(i, itime + 1)) + return false; + bbox = bounds(i, itime); // use bounds of first time step in builder + return true; + } + + /*! calculates the linear bounds of the i'th primitive for the specified time range */ + __forceinline LBBox3fa linearBounds(size_t primID, const BBox1f& dt) const { + return LBBox3fa([&](size_t itime) { return bounds(primID, itime); }, dt, time_range, fnumTimeSegments); + } + + /*! calculates the linear bounds of the i'th primitive for the specified time range */ + __forceinline LBBox3fa linearBounds(const LinearSpace3fa& space, size_t primID, const BBox1f& dt) const { + return LBBox3fa([&](size_t itime) { return bounds(space, primID, itime); }, dt, time_range, fnumTimeSegments); + } + + /*! calculates the linear bounds of the i'th primitive for the specified time range */ + __forceinline bool linearBounds(size_t i, const BBox1f& time_range, LBBox3fa& bbox) const + { + if (!valid(i, timeSegmentRange(time_range))) return false; + bbox = linearBounds(i, time_range); + return true; + } + + /*! get fast access to first vertex buffer */ + __forceinline float * getCompactVertexArray () const { + return (float*) vertices0.getPtr(); + } + + public: + BufferView vertices0; //!< fast access to first vertex buffer + BufferView normals0; //!< fast access to first normal buffer + vector> vertices; //!< vertex array for each timestep + vector> normals; //!< normal array for each timestep + vector> vertexAttribs; //!< user buffers + float maxRadiusScale = 1.0; //!< maximal min-width scaling of curve radii + }; + + namespace isa + { + struct PointsISA : public Points + { + PointsISA(Device* device, Geometry::GType gtype) : Points(device, gtype) {} + + Vec3fa computeDirection(unsigned int primID) const + { + return Vec3fa(1, 0, 0); + } + + Vec3fa computeDirection(unsigned int primID, size_t time) const + { + return Vec3fa(1, 0, 0); + } + + PrimInfo createPrimRefArray(mvector& prims, const range& r, size_t k, unsigned int geomID) const + { + PrimInfo pinfo(empty); + for (size_t j = r.begin(); j < r.end(); j++) { + BBox3fa bounds = empty; + if (!buildBounds(j, &bounds)) + continue; + const PrimRef prim(bounds, geomID, unsigned(j)); + pinfo.add_center2(prim); + prims[k++] = prim; + } + return pinfo; + } + + PrimInfo createPrimRefArrayMB(mvector& prims, size_t itime, const range& r, size_t k, unsigned int geomID) const + { + PrimInfo pinfo(empty); + for (size_t j = r.begin(); j < r.end(); j++) { + BBox3fa bounds = empty; + if (!buildBounds(j, itime, bounds)) + continue; + const PrimRef prim(bounds, geomID, unsigned(j)); + pinfo.add_center2(prim); + prims[k++] = prim; + } + return pinfo; + } + + PrimInfoMB createPrimRefMBArray(mvector& prims, + const BBox1f& t0t1, + const range& r, + size_t k, + unsigned int geomID) const + { + PrimInfoMB pinfo(empty); + for (size_t j = r.begin(); j < r.end(); j++) { + if (!valid(j, timeSegmentRange(t0t1))) + continue; + const PrimRefMB prim(linearBounds(j, t0t1), this->numTimeSegments(), this->time_range, this->numTimeSegments(), geomID, unsigned(j)); + pinfo.add_primref(prim); + prims[k++] = prim; + } + return pinfo; + } + + BBox3fa vbounds(size_t i) const + { + return bounds(i); + } + + BBox3fa vbounds(const LinearSpace3fa& space, size_t i) const + { + return bounds(space, i); + } + + LBBox3fa vlinearBounds(size_t primID, const BBox1f& time_range) const + { + return linearBounds(primID, time_range); + } + + LBBox3fa vlinearBounds(const LinearSpace3fa& space, size_t primID, const BBox1f& time_range) const + { + return linearBounds(space, primID, time_range); + } + }; + } // namespace isa + + DECLARE_ISA_FUNCTION(Points*, createPoints, Device* COMMA Geometry::GType); +} // namespace embree diff --git a/thirdparty/embree/kernels/common/scene_quad_mesh.h b/thirdparty/embree/kernels/common/scene_quad_mesh.h new file mode 100644 index 000000000000..d5bb054b14a1 --- /dev/null +++ b/thirdparty/embree/kernels/common/scene_quad_mesh.h @@ -0,0 +1,277 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "geometry.h" +#include "buffer.h" + +namespace embree +{ + /*! Quad Mesh */ + struct QuadMesh : public Geometry + { + /*! type of this geometry */ + static const Geometry::GTypeMask geom_type = Geometry::MTY_QUAD_MESH; + + /*! triangle indices */ + struct Quad + { + uint32_t v[4]; + + /*! outputs triangle indices */ + __forceinline friend embree_ostream operator<<(embree_ostream cout, const Quad& q) { + return cout << "Quad {" << q.v[0] << ", " << q.v[1] << ", " << q.v[2] << ", " << q.v[3] << " }"; + } + }; + + public: + + /*! quad mesh construction */ + QuadMesh (Device* device); + + /* geometry interface */ + public: + void setMask(unsigned mask); + void setNumTimeSteps (unsigned int numTimeSteps); + void setVertexAttributeCount (unsigned int N); + void setBuffer(RTCBufferType type, unsigned int slot, RTCFormat format, const Ref& buffer, size_t offset, size_t stride, unsigned int num); + void* getBuffer(RTCBufferType type, unsigned int slot); + void updateBuffer(RTCBufferType type, unsigned int slot); + void commit(); + bool verify(); + void interpolate(const RTCInterpolateArguments* const args); + void addElementsToCount (GeometryCounts & counts) const; + + public: + + /*! returns number of vertices */ + __forceinline size_t numVertices() const { + return vertices[0].size(); + } + + /*! returns i'th quad */ + __forceinline const Quad& quad(size_t i) const { + return quads[i]; + } + + /*! returns i'th vertex of itime'th timestep */ + __forceinline const Vec3fa vertex(size_t i) const { + return vertices0[i]; + } + + /*! returns i'th vertex of itime'th timestep */ + __forceinline const char* vertexPtr(size_t i) const { + return vertices0.getPtr(i); + } + + /*! returns i'th vertex of itime'th timestep */ + __forceinline const Vec3fa vertex(size_t i, size_t itime) const { + return vertices[itime][i]; + } + + /*! returns i'th vertex of itime'th timestep */ + __forceinline const char* vertexPtr(size_t i, size_t itime) const { + return vertices[itime].getPtr(i); + } + + /*! calculates the bounds of the i'th quad */ + __forceinline BBox3fa bounds(size_t i) const + { + const Quad& q = quad(i); + const Vec3fa v0 = vertex(q.v[0]); + const Vec3fa v1 = vertex(q.v[1]); + const Vec3fa v2 = vertex(q.v[2]); + const Vec3fa v3 = vertex(q.v[3]); + return BBox3fa(min(v0,v1,v2,v3),max(v0,v1,v2,v3)); + } + + /*! calculates the bounds of the i'th quad at the itime'th timestep */ + __forceinline BBox3fa bounds(size_t i, size_t itime) const + { + const Quad& q = quad(i); + const Vec3fa v0 = vertex(q.v[0],itime); + const Vec3fa v1 = vertex(q.v[1],itime); + const Vec3fa v2 = vertex(q.v[2],itime); + const Vec3fa v3 = vertex(q.v[3],itime); + return BBox3fa(min(v0,v1,v2,v3),max(v0,v1,v2,v3)); + } + + /*! check if the i'th primitive is valid at the itime'th timestep */ + __forceinline bool valid(size_t i, size_t itime) const { + return valid(i, make_range(itime, itime)); + } + + /*! check if the i'th primitive is valid between the specified time range */ + __forceinline bool valid(size_t i, const range& itime_range) const + { + const Quad& q = quad(i); + if (unlikely(q.v[0] >= numVertices())) return false; + if (unlikely(q.v[1] >= numVertices())) return false; + if (unlikely(q.v[2] >= numVertices())) return false; + if (unlikely(q.v[3] >= numVertices())) return false; + + for (size_t itime = itime_range.begin(); itime <= itime_range.end(); itime++) + { + if (!isvalid(vertex(q.v[0],itime))) return false; + if (!isvalid(vertex(q.v[1],itime))) return false; + if (!isvalid(vertex(q.v[2],itime))) return false; + if (!isvalid(vertex(q.v[3],itime))) return false; + } + + return true; + } + + /*! calculates the linear bounds of the i'th quad at the itimeGlobal'th time segment */ + __forceinline LBBox3fa linearBounds(size_t i, size_t itime) const { + return LBBox3fa(bounds(i,itime+0),bounds(i,itime+1)); + } + + /*! calculates the build bounds of the i'th primitive, if it's valid */ + __forceinline bool buildBounds(size_t i, BBox3fa* bbox = nullptr) const + { + const Quad& q = quad(i); + if (q.v[0] >= numVertices()) return false; + if (q.v[1] >= numVertices()) return false; + if (q.v[2] >= numVertices()) return false; + if (q.v[3] >= numVertices()) return false; + + for (unsigned int t=0; t= numVertices())) return false; + if (unlikely(q.v[1] >= numVertices())) return false; + if (unlikely(q.v[2] >= numVertices())) return false; + if (unlikely(q.v[3] >= numVertices())) return false; + + assert(itime+1 < numTimeSteps); + const Vec3fa a0 = vertex(q.v[0],itime+0); if (unlikely(!isvalid(a0))) return false; + const Vec3fa a1 = vertex(q.v[1],itime+0); if (unlikely(!isvalid(a1))) return false; + const Vec3fa a2 = vertex(q.v[2],itime+0); if (unlikely(!isvalid(a2))) return false; + const Vec3fa a3 = vertex(q.v[3],itime+0); if (unlikely(!isvalid(a3))) return false; + const Vec3fa b0 = vertex(q.v[0],itime+1); if (unlikely(!isvalid(b0))) return false; + const Vec3fa b1 = vertex(q.v[1],itime+1); if (unlikely(!isvalid(b1))) return false; + const Vec3fa b2 = vertex(q.v[2],itime+1); if (unlikely(!isvalid(b2))) return false; + const Vec3fa b3 = vertex(q.v[3],itime+1); if (unlikely(!isvalid(b3))) return false; + + /* use bounds of first time step in builder */ + bbox = BBox3fa(min(a0,a1,a2,a3),max(a0,a1,a2,a3)); + return true; + } + + /*! calculates the linear bounds of the i'th primitive for the specified time range */ + __forceinline LBBox3fa linearBounds(size_t primID, const BBox1f& dt) const { + return LBBox3fa([&] (size_t itime) { return bounds(primID, itime); }, dt, time_range, fnumTimeSegments); + } + + /*! calculates the linear bounds of the i'th primitive for the specified time range */ + __forceinline bool linearBounds(size_t i, const BBox1f& dt, LBBox3fa& bbox) const + { + if (!valid(i, timeSegmentRange(dt))) return false; + bbox = linearBounds(i, dt); + return true; + } + + /*! get fast access to first vertex buffer */ + __forceinline float * getCompactVertexArray () const { + return (float*) vertices0.getPtr(); + } + + /* gets version info of topology */ + unsigned int getTopologyVersion() const { + return quads.modCounter; + } + + /* returns true if topology changed */ + bool topologyChanged(unsigned int otherVersion) const { + return quads.isModified(otherVersion); // || numPrimitivesChanged; + } + + /* returns the projected area */ + __forceinline float projectedPrimitiveArea(const size_t i) const { + const Quad& q = quad(i); + const Vec3fa v0 = vertex(q.v[0]); + const Vec3fa v1 = vertex(q.v[1]); + const Vec3fa v2 = vertex(q.v[2]); + const Vec3fa v3 = vertex(q.v[3]); + return areaProjectedTriangle(v0,v1,v3) + + areaProjectedTriangle(v1,v2,v3); + } + + public: + BufferView quads; //!< array of quads + BufferView vertices0; //!< fast access to first vertex buffer + vector> vertices; //!< vertex array for each timestep + vector> vertexAttribs; //!< vertex attribute buffers + }; + + namespace isa + { + struct QuadMeshISA : public QuadMesh + { + QuadMeshISA (Device* device) + : QuadMesh(device) {} + + PrimInfo createPrimRefArray(mvector& prims, const range& r, size_t k, unsigned int geomID) const + { + PrimInfo pinfo(empty); + for (size_t j=r.begin(); j& prims, size_t itime, const range& r, size_t k, unsigned int geomID) const + { + PrimInfo pinfo(empty); + for (size_t j=r.begin(); j& prims, const BBox1f& t0t1, const range& r, size_t k, unsigned int geomID) const + { + PrimInfoMB pinfo(empty); + for (size_t j=r.begin(); jnumTimeSegments(),this->time_range,this->numTimeSegments(),geomID,unsigned(j)); + pinfo.add_primref(prim); + prims[k++] = prim; + } + return pinfo; + } + }; + } + + DECLARE_ISA_FUNCTION(QuadMesh*, createQuadMesh, Device*); +} diff --git a/thirdparty/embree/kernels/common/scene_subdiv_mesh.h b/thirdparty/embree/kernels/common/scene_subdiv_mesh.h new file mode 100644 index 000000000000..25ee8e8efade --- /dev/null +++ b/thirdparty/embree/kernels/common/scene_subdiv_mesh.h @@ -0,0 +1,326 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "geometry.h" +#include "buffer.h" +#include "../subdiv/half_edge.h" +#include "../subdiv/tessellation_cache.h" +#include "../subdiv/catmullclark_coefficients.h" +#include "../subdiv/patch.h" +#include "../../common/algorithms/parallel_map.h" +#include "../../common/algorithms/parallel_set.h" + +namespace embree +{ + class SubdivMesh : public Geometry + { + ALIGNED_CLASS_(16); + public: + + typedef HalfEdge::Edge Edge; + + /*! type of this geometry */ + static const Geometry::GTypeMask geom_type = Geometry::MTY_SUBDIV_MESH; + + /*! structure used to sort half edges using radix sort by their key */ + struct KeyHalfEdge + { + KeyHalfEdge() {} + + KeyHalfEdge (uint64_t key, HalfEdge* edge) + : key(key), edge(edge) {} + + __forceinline operator uint64_t() const { + return key; + } + + friend __forceinline bool operator<(const KeyHalfEdge& e0, const KeyHalfEdge& e1) { + return e0.key < e1.key; + } + + public: + uint64_t key; + HalfEdge* edge; + }; + + public: + + /*! subdiv mesh construction */ + SubdivMesh(Device* device); + + public: + void setMask (unsigned mask); + void setSubdivisionMode (unsigned int topologyID, RTCSubdivisionMode mode); + void setVertexAttributeTopology(unsigned int vertexAttribID, unsigned int topologyID); + void setNumTimeSteps (unsigned int numTimeSteps); + void setVertexAttributeCount (unsigned int N); + void setTopologyCount (unsigned int N); + void setBuffer(RTCBufferType type, unsigned int slot, RTCFormat format, const Ref& buffer, size_t offset, size_t stride, unsigned int num); + void* getBuffer(RTCBufferType type, unsigned int slot); + void updateBuffer(RTCBufferType type, unsigned int slot); + void setTessellationRate(float N); + bool verify(); + void commit(); + void addElementsToCount (GeometryCounts & counts) const; + void setDisplacementFunction (RTCDisplacementFunctionN func); + unsigned int getFirstHalfEdge(unsigned int faceID); + unsigned int getFace(unsigned int edgeID); + unsigned int getNextHalfEdge(unsigned int edgeID); + unsigned int getPreviousHalfEdge(unsigned int edgeID); + unsigned int getOppositeHalfEdge(unsigned int topologyID, unsigned int edgeID); + + public: + + /*! return the number of faces */ + size_t numFaces() const { + return faceVertices.size(); + } + + /*! return the number of edges */ + size_t numEdges() const { + return topology[0].vertexIndices.size(); + } + + /*! return the number of vertices */ + size_t numVertices() const { + return vertices[0].size(); + } + + /*! calculates the bounds of the i'th subdivision patch at the j'th timestep */ + __forceinline BBox3fa bounds(size_t i, size_t j = 0) const { + return topology[0].getHalfEdge(i)->bounds(vertices[j]); + } + + /*! check if the i'th primitive is valid */ + __forceinline bool valid(size_t i) const { + return topology[0].valid(i) && !invalidFace(i); + } + + /*! check if the i'th primitive is valid for the j'th time range */ + __forceinline bool valid(size_t i, size_t j) const { + return topology[0].valid(i) && !invalidFace(i,j); + } + + /*! prints some statistics */ + void printStatistics(); + + /*! initializes the half edge data structure */ + void initializeHalfEdgeStructures (); + + public: + + /*! returns the vertex buffer for some time step */ + __forceinline const BufferView& getVertexBuffer( const size_t t = 0 ) const { + return vertices[t]; + } + + /* returns tessellation level of edge */ + __forceinline float getEdgeLevel(const size_t i) const + { + if (levels) return clamp(levels[i],1.0f,4096.0f); // FIXME: do we want to limit edge level? + else return clamp(tessellationRate,1.0f,4096.0f); // FIXME: do we want to limit edge level? + } + + public: + RTCDisplacementFunctionN displFunc; //!< displacement function + + /*! all buffers in this section are provided by the application */ + public: + + /*! the topology contains all data that may differ when + * interpolating different user data buffers */ + struct Topology + { + public: + + /*! Default topology construction */ + Topology () : halfEdges(nullptr,0) {} + + /*! Topology initialization */ + Topology (SubdivMesh* mesh); + + /*! make the class movable */ + public: + Topology (Topology&& other) // FIXME: this is only required to workaround compilation issues under Windows + : mesh(std::move(other.mesh)), + vertexIndices(std::move(other.vertexIndices)), + subdiv_mode(std::move(other.subdiv_mode)), + halfEdges(std::move(other.halfEdges)), + halfEdges0(std::move(other.halfEdges0)), + halfEdges1(std::move(other.halfEdges1)) {} + + Topology& operator= (Topology&& other) // FIXME: this is only required to workaround compilation issues under Windows + { + mesh = std::move(other.mesh); + vertexIndices = std::move(other.vertexIndices); + subdiv_mode = std::move(other.subdiv_mode); + halfEdges = std::move(other.halfEdges); + halfEdges0 = std::move(other.halfEdges0); + halfEdges1 = std::move(other.halfEdges1); + return *this; + } + + public: + /*! check if the i'th primitive is valid in this topology */ + __forceinline bool valid(size_t i) const + { + if (unlikely(subdiv_mode == RTC_SUBDIVISION_MODE_NO_BOUNDARY)) { + if (getHalfEdge(i)->faceHasBorder()) return false; + } + return true; + } + + /*! updates the interpolation mode for the topology */ + void setSubdivisionMode (RTCSubdivisionMode mode); + + /*! marks all buffers as modified */ + void update (); + + /*! verifies index array */ + bool verify (size_t numVertices); + + /*! initializes the half edge data structure */ + void initializeHalfEdgeStructures (); + + private: + + /*! recalculates the half edges */ + void calculateHalfEdges(); + + /*! updates half edges when recalculation is not necessary */ + void updateHalfEdges(); + + /*! user input data */ + public: + + SubdivMesh* mesh; + + /*! indices of the vertices composing each face */ + BufferView vertexIndices; + + /*! subdiv interpolation mode */ + RTCSubdivisionMode subdiv_mode; + + /*! generated data */ + public: + + /*! returns the start half edge for face f */ + __forceinline const HalfEdge* getHalfEdge ( const size_t f ) const { + return &halfEdges[mesh->faceStartEdge[f]]; + } + + /*! Half edge structure, generated by initHalfEdgeStructures */ + mvector halfEdges; + + /*! the following data is only required during construction of the + * half edge structure and can be cleared for static scenes */ + private: + + /*! two arrays used to sort the half edges */ + std::vector halfEdges0; + std::vector halfEdges1; + }; + + /*! returns the start half edge for topology t and face f */ + __forceinline const HalfEdge* getHalfEdge ( const size_t t , const size_t f ) const { + return topology[t].getHalfEdge(f); + } + + /*! buffer containing the number of vertices for each face */ + BufferView faceVertices; + + /*! array of topologies */ + vector topology; + + /*! vertex buffer (one buffer for each time step) */ + vector> vertices; + + /*! user data buffers */ + vector vertexAttribs; + + /*! edge crease buffer containing edges (pairs of vertices) that carry edge crease weights */ + BufferView edge_creases; + + /*! edge crease weights for each edge of the edge_creases buffer */ + BufferView edge_crease_weights; + + /*! vertex crease buffer containing all vertices that carry vertex crease weights */ + BufferView vertex_creases; + + /*! vertex crease weights for each vertex of the vertex_creases buffer */ + BufferView vertex_crease_weights; + + /*! subdivision level for each half edge of the vertexIndices buffer */ + BufferView levels; + float tessellationRate; // constant rate that is used when levels is not set + + /*! buffer that marks specific faces as holes */ + BufferView holes; + + /*! all data in this section is generated by initializeHalfEdgeStructures function */ + private: + + /*! number of half edges used by faces */ + size_t numHalfEdges; + + /*! fast lookup table to find the first half edge for some face */ + mvector faceStartEdge; + + /*! fast lookup table to find the face for some half edge */ + mvector halfEdgeFace; + + /*! set with all holes */ + parallel_set holeSet; + + /*! fast lookup table to detect invalid faces */ + mvector invalid_face; + + /*! test if face i is invalid in timestep j */ + __forceinline char& invalidFace(size_t i, size_t j = 0) { return invalid_face[i*numTimeSteps+j]; } + __forceinline const char& invalidFace(size_t i, size_t j = 0) const { return invalid_face[i*numTimeSteps+j]; } + + /*! interpolation cache */ + public: + static __forceinline size_t numInterpolationSlots4(size_t stride) { return (stride+15)/16; } + static __forceinline size_t numInterpolationSlots8(size_t stride) { return (stride+31)/32; } + static __forceinline size_t interpolationSlot(size_t prim, size_t slot, size_t stride) { + const size_t slots = numInterpolationSlots4(stride); + assert(slot < slots); + return slots*prim+slot; + } + std::vector> vertex_buffer_tags; + std::vector> vertex_attrib_buffer_tags; + std::vector patch_eval_trees; + + /*! the following data is only required during construction of the + * half edge structure and can be cleared for static scenes */ + private: + + /*! map with all vertex creases */ + parallel_map vertexCreaseMap; + + /*! map with all edge creases */ + parallel_map edgeCreaseMap; + + protected: + + /*! counts number of geometry commits */ + size_t commitCounter; + }; + + namespace isa + { + struct SubdivMeshISA : public SubdivMesh + { + SubdivMeshISA (Device* device) + : SubdivMesh(device) {} + + void interpolate(const RTCInterpolateArguments* const args); + void interpolateN(const RTCInterpolateNArguments* const args); + }; + } + + DECLARE_ISA_FUNCTION(SubdivMesh*, createSubdivMesh, Device*); +}; diff --git a/thirdparty/embree/kernels/common/scene_triangle_mesh.cpp b/thirdparty/embree/kernels/common/scene_triangle_mesh.cpp new file mode 100644 index 000000000000..d1c2750f14a1 --- /dev/null +++ b/thirdparty/embree/kernels/common/scene_triangle_mesh.cpp @@ -0,0 +1,243 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "scene_triangle_mesh.h" +#include "scene.h" + +namespace embree +{ +#if defined(EMBREE_LOWEST_ISA) + + TriangleMesh::TriangleMesh (Device* device) + : Geometry(device,GTY_TRIANGLE_MESH,0,1) + { + vertices.resize(numTimeSteps); + } + + void TriangleMesh::setMask (unsigned mask) + { + this->mask = mask; + Geometry::update(); + } + + void TriangleMesh::setNumTimeSteps (unsigned int numTimeSteps) + { + vertices.resize(numTimeSteps); + Geometry::setNumTimeSteps(numTimeSteps); + } + + void TriangleMesh::setVertexAttributeCount (unsigned int N) + { + vertexAttribs.resize(N); + Geometry::update(); + } + + void TriangleMesh::setBuffer(RTCBufferType type, unsigned int slot, RTCFormat format, const Ref& buffer, size_t offset, size_t stride, unsigned int num) + { + /* verify that all accesses are 4 bytes aligned */ + if (((size_t(buffer->getPtr()) + offset) & 0x3) || (stride & 0x3)) + throw_RTCError(RTC_ERROR_INVALID_OPERATION, "data must be 4 bytes aligned"); + + if (type == RTC_BUFFER_TYPE_VERTEX) + { + if (format != RTC_FORMAT_FLOAT3) + throw_RTCError(RTC_ERROR_INVALID_OPERATION, "invalid vertex buffer format"); + + /* if buffer is larger than 16GB the premultiplied index optimization does not work */ + if (stride*num > 16ll*1024ll*1024ll*1024ll) + throw_RTCError(RTC_ERROR_INVALID_OPERATION, "vertex buffer can be at most 16GB large"); + + if (slot >= vertices.size()) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "invalid vertex buffer slot"); + + vertices[slot].set(buffer, offset, stride, num, format); + vertices[slot].checkPadding16(); + vertices0 = vertices[0]; + } + else if (type == RTC_BUFFER_TYPE_VERTEX_ATTRIBUTE) + { + if (format < RTC_FORMAT_FLOAT || format > RTC_FORMAT_FLOAT16) + throw_RTCError(RTC_ERROR_INVALID_OPERATION, "invalid vertex attribute buffer format"); + + if (slot >= vertexAttribs.size()) + throw_RTCError(RTC_ERROR_INVALID_OPERATION, "invalid vertex attribute buffer slot"); + + vertexAttribs[slot].set(buffer, offset, stride, num, format); + vertexAttribs[slot].checkPadding16(); + } + else if (type == RTC_BUFFER_TYPE_INDEX) + { + if (slot != 0) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "invalid buffer slot"); + if (format != RTC_FORMAT_UINT3) + throw_RTCError(RTC_ERROR_INVALID_OPERATION, "invalid index buffer format"); + + triangles.set(buffer, offset, stride, num, format); + setNumPrimitives(num); + } + else + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "unknown buffer type"); + } + + void* TriangleMesh::getBuffer(RTCBufferType type, unsigned int slot) + { + if (type == RTC_BUFFER_TYPE_INDEX) + { + if (slot != 0) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "invalid buffer slot"); + return triangles.getPtr(); + } + else if (type == RTC_BUFFER_TYPE_VERTEX) + { + if (slot >= vertices.size()) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "invalid buffer slot"); + return vertices[slot].getPtr(); + } + else if (type == RTC_BUFFER_TYPE_VERTEX_ATTRIBUTE) + { + if (slot >= vertexAttribs.size()) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "invalid buffer slot"); + return vertexAttribs[slot].getPtr(); + } + else + { + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "unknown buffer type"); + return nullptr; + } + } + + void TriangleMesh::updateBuffer(RTCBufferType type, unsigned int slot) + { + if (type == RTC_BUFFER_TYPE_INDEX) + { + if (slot != 0) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "invalid buffer slot"); + triangles.setModified(); + } + else if (type == RTC_BUFFER_TYPE_VERTEX) + { + if (slot >= vertices.size()) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "invalid buffer slot"); + vertices[slot].setModified(); + } + else if (type == RTC_BUFFER_TYPE_VERTEX_ATTRIBUTE) + { + if (slot >= vertexAttribs.size()) + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "invalid buffer slot"); + vertexAttribs[slot].setModified(); + } + else + { + throw_RTCError(RTC_ERROR_INVALID_ARGUMENT, "unknown buffer type"); + } + + Geometry::update(); + } + + void TriangleMesh::commit() + { + /* verify that stride of all time steps are identical */ + for (unsigned int t=0; t= numVertices()) return false; + if (triangles[i].v[1] >= numVertices()) return false; + if (triangles[i].v[2] >= numVertices()) return false; + } + + /*! verify vertices */ + for (const auto& buffer : vertices) + for (size_t i=0; iprimID; + float u = args->u; + float v = args->v; + RTCBufferType bufferType = args->bufferType; + unsigned int bufferSlot = args->bufferSlot; + float* P = args->P; + float* dPdu = args->dPdu; + float* dPdv = args->dPdv; + float* ddPdudu = args->ddPdudu; + float* ddPdvdv = args->ddPdvdv; + float* ddPdudv = args->ddPdudv; + unsigned int valueCount = args->valueCount; + + /* calculate base pointer and stride */ + assert((bufferType == RTC_BUFFER_TYPE_VERTEX && bufferSlot < numTimeSteps) || + (bufferType == RTC_BUFFER_TYPE_VERTEX_ATTRIBUTE && bufferSlot <= vertexAttribs.size())); + const char* src = nullptr; + size_t stride = 0; + if (bufferType == RTC_BUFFER_TYPE_VERTEX_ATTRIBUTE) { + src = vertexAttribs[bufferSlot].getPtr(); + stride = vertexAttribs[bufferSlot].getStride(); + } else { + src = vertices[bufferSlot].getPtr(); + stride = vertices[bufferSlot].getStride(); + } + + for (unsigned int i=0; i& buffer, size_t offset, size_t stride, unsigned int num); + void* getBuffer(RTCBufferType type, unsigned int slot); + void updateBuffer(RTCBufferType type, unsigned int slot); + void commit(); + bool verify(); + void interpolate(const RTCInterpolateArguments* const args); + void addElementsToCount (GeometryCounts & counts) const; + + public: + + /*! returns number of vertices */ + __forceinline size_t numVertices() const { + return vertices[0].size(); + } + + /*! returns i'th triangle*/ + __forceinline const Triangle& triangle(size_t i) const { + return triangles[i]; + } + + /*! returns i'th vertex of the first time step */ + __forceinline const Vec3fa vertex(size_t i) const { + return vertices0[i]; + } + + /*! returns i'th vertex of the first time step */ + __forceinline const char* vertexPtr(size_t i) const { + return vertices0.getPtr(i); + } + + /*! returns i'th vertex of itime'th timestep */ + __forceinline const Vec3fa vertex(size_t i, size_t itime) const { + return vertices[itime][i]; + } + + /*! returns i'th vertex of itime'th timestep */ + __forceinline const char* vertexPtr(size_t i, size_t itime) const { + return vertices[itime].getPtr(i); + } + + /*! calculates the bounds of the i'th triangle */ + __forceinline BBox3fa bounds(size_t i) const + { + const Triangle& tri = triangle(i); + const Vec3fa v0 = vertex(tri.v[0]); + const Vec3fa v1 = vertex(tri.v[1]); + const Vec3fa v2 = vertex(tri.v[2]); + return BBox3fa(min(v0,v1,v2),max(v0,v1,v2)); + } + + /*! calculates the bounds of the i'th triangle at the itime'th timestep */ + __forceinline BBox3fa bounds(size_t i, size_t itime) const + { + const Triangle& tri = triangle(i); + const Vec3fa v0 = vertex(tri.v[0],itime); + const Vec3fa v1 = vertex(tri.v[1],itime); + const Vec3fa v2 = vertex(tri.v[2],itime); + return BBox3fa(min(v0,v1,v2),max(v0,v1,v2)); + } + + /*! check if the i'th primitive is valid at the itime'th timestep */ + __forceinline bool valid(size_t i, size_t itime) const { + return valid(i, make_range(itime, itime)); + } + + /*! check if the i'th primitive is valid between the specified time range */ + __forceinline bool valid(size_t i, const range& itime_range) const + { + const Triangle& tri = triangle(i); + if (unlikely(tri.v[0] >= numVertices())) return false; + if (unlikely(tri.v[1] >= numVertices())) return false; + if (unlikely(tri.v[2] >= numVertices())) return false; + + for (size_t itime = itime_range.begin(); itime <= itime_range.end(); itime++) + { + if (!isvalid(vertex(tri.v[0],itime))) return false; + if (!isvalid(vertex(tri.v[1],itime))) return false; + if (!isvalid(vertex(tri.v[2],itime))) return false; + } + + return true; + } + + /*! calculates the linear bounds of the i'th primitive at the itimeGlobal'th time segment */ + __forceinline LBBox3fa linearBounds(size_t i, size_t itime) const { + return LBBox3fa(bounds(i,itime+0),bounds(i,itime+1)); + } + + /*! calculates the build bounds of the i'th primitive, if it's valid */ + __forceinline bool buildBounds(size_t i, BBox3fa* bbox = nullptr) const + { + const Triangle& tri = triangle(i); + if (unlikely(tri.v[0] >= numVertices())) return false; + if (unlikely(tri.v[1] >= numVertices())) return false; + if (unlikely(tri.v[2] >= numVertices())) return false; + + for (size_t t=0; t= numVertices())) return false; + if (unlikely(tri.v[1] >= numVertices())) return false; + if (unlikely(tri.v[2] >= numVertices())) return false; + + assert(itime+1 < numTimeSteps); + const Vec3fa a0 = vertex(tri.v[0],itime+0); if (unlikely(!isvalid(a0))) return false; + const Vec3fa a1 = vertex(tri.v[1],itime+0); if (unlikely(!isvalid(a1))) return false; + const Vec3fa a2 = vertex(tri.v[2],itime+0); if (unlikely(!isvalid(a2))) return false; + const Vec3fa b0 = vertex(tri.v[0],itime+1); if (unlikely(!isvalid(b0))) return false; + const Vec3fa b1 = vertex(tri.v[1],itime+1); if (unlikely(!isvalid(b1))) return false; + const Vec3fa b2 = vertex(tri.v[2],itime+1); if (unlikely(!isvalid(b2))) return false; + + /* use bounds of first time step in builder */ + bbox = BBox3fa(min(a0,a1,a2),max(a0,a1,a2)); + return true; + } + + /*! calculates the linear bounds of the i'th primitive for the specified time range */ + __forceinline LBBox3fa linearBounds(size_t primID, const BBox1f& dt) const { + return LBBox3fa([&] (size_t itime) { return bounds(primID, itime); }, dt, time_range, fnumTimeSegments); + } + + /*! calculates the linear bounds of the i'th primitive for the specified time range */ + __forceinline bool linearBounds(size_t i, const BBox1f& dt, LBBox3fa& bbox) const { + if (!valid(i, timeSegmentRange(dt))) return false; + bbox = linearBounds(i, dt); + return true; + } + + /*! get fast access to first vertex buffer */ + __forceinline float * getCompactVertexArray () const { + return (float*) vertices0.getPtr(); + } + + /* gets version info of topology */ + unsigned int getTopologyVersion() const { + return triangles.modCounter; + } + + /* returns true if topology changed */ + bool topologyChanged(unsigned int otherVersion) const { + return triangles.isModified(otherVersion); // || numPrimitivesChanged; + } + + /* returns the projected area */ + __forceinline float projectedPrimitiveArea(const size_t i) const { + const Triangle& tri = triangle(i); + const Vec3fa v0 = vertex(tri.v[0]); + const Vec3fa v1 = vertex(tri.v[1]); + const Vec3fa v2 = vertex(tri.v[2]); + return areaProjectedTriangle(v0,v1,v2); + } + + public: + BufferView triangles; //!< array of triangles + BufferView vertices0; //!< fast access to first vertex buffer + vector> vertices; //!< vertex array for each timestep + vector vertexAttribs; //!< vertex attributes + }; + + namespace isa + { + struct TriangleMeshISA : public TriangleMesh + { + TriangleMeshISA (Device* device) + : TriangleMesh(device) {} + + PrimInfo createPrimRefArray(mvector& prims, const range& r, size_t k, unsigned int geomID) const + { + PrimInfo pinfo(empty); + for (size_t j=r.begin(); j& prims, size_t itime, const range& r, size_t k, unsigned int geomID) const + { + PrimInfo pinfo(empty); + for (size_t j=r.begin(); j& prims, const BBox1f& t0t1, const range& r, size_t k, unsigned int geomID) const + { + PrimInfoMB pinfo(empty); + for (size_t j=r.begin(); jnumTimeSegments(),this->time_range,this->numTimeSegments(),geomID,unsigned(j)); + pinfo.add_primref(prim); + prims[k++] = prim; + } + return pinfo; + } + }; + } + + DECLARE_ISA_FUNCTION(TriangleMesh*, createTriangleMesh, Device*); +} diff --git a/thirdparty/embree/kernels/common/scene_user_geometry.h b/thirdparty/embree/kernels/common/scene_user_geometry.h new file mode 100644 index 000000000000..8d11ed698605 --- /dev/null +++ b/thirdparty/embree/kernels/common/scene_user_geometry.h @@ -0,0 +1,77 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "accelset.h" + +namespace embree +{ + /*! User geometry with user defined intersection functions */ + struct UserGeometry : public AccelSet + { + /*! type of this geometry */ + static const Geometry::GTypeMask geom_type = Geometry::MTY_USER_GEOMETRY; + + public: + UserGeometry (Device* device, unsigned int items = 0, unsigned int numTimeSteps = 1); + virtual void setMask (unsigned mask); + virtual void setBoundsFunction (RTCBoundsFunction bounds, void* userPtr); + virtual void setIntersectFunctionN (RTCIntersectFunctionN intersect); + virtual void setOccludedFunctionN (RTCOccludedFunctionN occluded); + virtual void build() {} + virtual void addElementsToCount (GeometryCounts & counts) const; + }; + + namespace isa + { + struct UserGeometryISA : public UserGeometry + { + UserGeometryISA (Device* device) + : UserGeometry(device) {} + + PrimInfo createPrimRefArray(mvector& prims, const range& r, size_t k, unsigned int geomID) const + { + PrimInfo pinfo(empty); + for (size_t j=r.begin(); j& prims, size_t itime, const range& r, size_t k, unsigned int geomID) const + { + PrimInfo pinfo(empty); + for (size_t j=r.begin(); j& prims, const BBox1f& t0t1, const range& r, size_t k, unsigned int geomID) const + { + PrimInfoMB pinfo(empty); + for (size_t j=r.begin(); jnumTimeSegments(),this->time_range,this->numTimeSegments(),geomID,unsigned(j)); + pinfo.add_primref(prim); + prims[k++] = prim; + } + return pinfo; + } + }; + } + + DECLARE_ISA_FUNCTION(UserGeometry*, createUserGeometry, Device*); +} diff --git a/thirdparty/embree/kernels/common/stack_item.h b/thirdparty/embree/kernels/common/stack_item.h new file mode 100644 index 000000000000..533c3853652e --- /dev/null +++ b/thirdparty/embree/kernels/common/stack_item.h @@ -0,0 +1,125 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" + +namespace embree +{ + /*! An item on the stack holds the node ID and distance of that node. */ + template + struct __aligned(16) StackItemT + { + /*! assert that the xchg function works */ + static_assert(sizeof(T) <= 12, "sizeof(T) <= 12 failed"); + + __forceinline StackItemT() {} + + __forceinline StackItemT(T &ptr, unsigned &dist) : ptr(ptr), dist(dist) {} + + /*! use SSE instructions to swap stack items */ + __forceinline static void xchg(StackItemT& a, StackItemT& b) + { + const vfloat4 sse_a = vfloat4::load((float*)&a); + const vfloat4 sse_b = vfloat4::load((float*)&b); + vfloat4::store(&a,sse_b); + vfloat4::store(&b,sse_a); + } + + /*! Sort 2 stack items. */ + __forceinline friend void sort(StackItemT& s1, StackItemT& s2) { + if (s2.dist < s1.dist) xchg(s2,s1); + } + + /*! Sort 3 stack items. */ + __forceinline friend void sort(StackItemT& s1, StackItemT& s2, StackItemT& s3) + { + if (s2.dist < s1.dist) xchg(s2,s1); + if (s3.dist < s2.dist) xchg(s3,s2); + if (s2.dist < s1.dist) xchg(s2,s1); + } + + /*! Sort 4 stack items. */ + __forceinline friend void sort(StackItemT& s1, StackItemT& s2, StackItemT& s3, StackItemT& s4) + { + if (s2.dist < s1.dist) xchg(s2,s1); + if (s4.dist < s3.dist) xchg(s4,s3); + if (s3.dist < s1.dist) xchg(s3,s1); + if (s4.dist < s2.dist) xchg(s4,s2); + if (s3.dist < s2.dist) xchg(s3,s2); + } + + /*! use SSE instructions to swap stack items */ + __forceinline static void cmp_xchg(vint4& a, vint4& b) + { +#if defined(__AVX512VL__) + const vboolf4 mask(shuffle<2,2,2,2>(b) < shuffle<2,2,2,2>(a)); +#else + const vboolf4 mask0(b < a); + const vboolf4 mask(shuffle<2,2,2,2>(mask0)); +#endif + const vint4 c = select(mask,b,a); + const vint4 d = select(mask,a,b); + a = c; + b = d; + } + + /*! Sort 3 stack items. */ + __forceinline static void sort3(vint4& s1, vint4& s2, vint4& s3) + { + cmp_xchg(s2,s1); + cmp_xchg(s3,s2); + cmp_xchg(s2,s1); + } + + /*! Sort 4 stack items. */ + __forceinline static void sort4(vint4& s1, vint4& s2, vint4& s3, vint4& s4) + { + cmp_xchg(s2,s1); + cmp_xchg(s4,s3); + cmp_xchg(s3,s1); + cmp_xchg(s4,s2); + cmp_xchg(s3,s2); + } + + + /*! Sort N stack items. */ + __forceinline friend void sort(StackItemT* begin, StackItemT* end) + { + for (StackItemT* i = begin+1; i != end; ++i) + { + const vfloat4 item = vfloat4::load((float*)i); + const unsigned dist = i->dist; + StackItemT* j = i; + + while ((j != begin) && ((j-1)->dist < dist)) + { + vfloat4::store(j, vfloat4::load((float*)(j-1))); + --j; + } + + vfloat4::store(j, item); + } + } + + public: + T ptr; + unsigned dist; + }; + + /*! An item on the stack holds the node ID and active ray mask. */ + template + struct __aligned(8) StackItemMaskT + { + T ptr; + size_t mask; + }; + + struct __aligned(8) StackItemMaskCoherent + { + size_t mask; + size_t parent; + size_t child; + }; +} diff --git a/thirdparty/embree/kernels/common/stat.cpp b/thirdparty/embree/kernels/common/stat.cpp new file mode 100644 index 000000000000..b73c3a8c7663 --- /dev/null +++ b/thirdparty/embree/kernels/common/stat.cpp @@ -0,0 +1,128 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "stat.h" + +namespace embree +{ + Stat Stat::instance; + + Stat::Stat () { + } + + Stat::~Stat () + { +#ifdef EMBREE_STAT_COUNTERS + Stat::print(std::cout); +#endif + } + + void Stat::print(std::ostream& cout) + { + Counters& cntrs = instance.cntrs; + Counters::Data& data = instance.cntrs.code; + //Counters::Data& data = instance.cntrs.active; + + /* print absolute numbers */ + cout << "--------- ABSOLUTE ---------" << std::endl; + cout << " #normal_travs = " << float(data.normal.travs )*1E-6 << "M" << std::endl; + cout << " #nodes = " << float(data.normal.trav_nodes )*1E-6 << "M" << std::endl; + cout << " #nodes_xfm = " << float(data.normal.trav_xfm_nodes )*1E-6 << "M" << std::endl; + cout << " #leaves = " << float(data.normal.trav_leaves )*1E-6 << "M" << std::endl; + cout << " #prims = " << float(data.normal.trav_prims )*1E-6 << "M" << std::endl; + cout << " #prim_hits = " << float(data.normal.trav_prim_hits )*1E-6 << "M" << std::endl; + + cout << " #stack nodes = " << float(data.normal.trav_stack_nodes )*1E-6 << "M" << std::endl; + cout << " #stack pop = " << float(data.normal.trav_stack_pop )*1E-6 << "M" << std::endl; + + size_t normal_box_hits = 0; + size_t weighted_box_hits = 0; + for (size_t i=0;i travs; + std::atomic trav_nodes; + std::atomic trav_leaves; + std::atomic trav_prims; + std::atomic trav_prim_hits; + std::atomic trav_hit_boxes[SIZE_HISTOGRAM+1]; + std::atomic trav_stack_pop; + std::atomic trav_stack_nodes; + std::atomic trav_xfm_nodes; + + } normal, shadow, point_query; + } all, active, code; + + std::atomic user[10]; + }; + + public: + + static __forceinline Counters& get() { + return instance.cntrs; + } + + static void clear() { + instance.cntrs.clear(); + } + + static void print(embree_ostream cout); + + private: + Counters cntrs; + + private: + static Stat instance; + }; +} diff --git a/thirdparty/embree/kernels/common/state.cpp b/thirdparty/embree/kernels/common/state.cpp new file mode 100644 index 000000000000..e9a912d7667d --- /dev/null +++ b/thirdparty/embree/kernels/common/state.cpp @@ -0,0 +1,528 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "state.h" +#include "../../common/lexers/streamfilters.h" + +namespace embree +{ + MutexSys g_printMutex; + + State::ErrorHandler State::g_errorHandler; + + State::ErrorHandler::ErrorHandler() + : thread_error(createTls()) {} + + State::ErrorHandler::~ErrorHandler() + { + Lock lock(errors_mutex); + for (size_t i=0; i lock(errors_mutex); + stored_error = new RTCError(RTC_ERROR_NONE); + thread_errors.push_back(stored_error); + setTls(thread_error,stored_error); + return stored_error; + } + + State::State () + : enabled_cpu_features(getCPUFeatures()), + enabled_builder_cpu_features(enabled_cpu_features), + frequency_level(FREQUENCY_SIMD256) + { + tri_accel = "default"; + tri_builder = "default"; + tri_traverser = "default"; + + tri_accel_mb = "default"; + tri_builder_mb = "default"; + tri_traverser_mb = "default"; + + quad_accel = "default"; + quad_builder = "default"; + quad_traverser = "default"; + + quad_accel_mb = "default"; + quad_builder_mb = "default"; + quad_traverser_mb = "default"; + + line_accel = "default"; + line_builder = "default"; + line_traverser = "default"; + + line_accel_mb = "default"; + line_builder_mb = "default"; + line_traverser_mb = "default"; + + hair_accel = "default"; + hair_builder = "default"; + hair_traverser = "default"; + + hair_accel_mb = "default"; + hair_builder_mb = "default"; + hair_traverser_mb = "default"; + + object_accel = "default"; + object_builder = "default"; + object_accel_min_leaf_size = 1; + object_accel_max_leaf_size = 1; + + object_accel_mb = "default"; + object_builder_mb = "default"; + object_accel_mb_min_leaf_size = 1; + object_accel_mb_max_leaf_size = 1; + + max_spatial_split_replications = 1.2f; + useSpatialPreSplits = false; + + tessellation_cache_size = 128*1024*1024; + + subdiv_accel = "default"; + subdiv_accel_mb = "default"; + + grid_accel = "default"; + grid_builder = "default"; + grid_accel_mb = "default"; + grid_builder_mb = "default"; + + instancing_open_min = 0; + instancing_block_size = 0; + instancing_open_factor = 8.0f; + instancing_open_max_depth = 32; + instancing_open_max = 50000000; + + ignore_config_files = false; + float_exceptions = false; + quality_flags = -1; + scene_flags = -1; + verbose = 0; + benchmark = 0; + + numThreads = 0; + numUserThreads = 0; + +#if TASKING_INTERNAL + set_affinity = true; +#else + set_affinity = false; +#endif + /* per default enable affinity on KNL */ + if (hasISA(AVX512KNL)) set_affinity = true; + + start_threads = false; + enable_selockmemoryprivilege = false; +#if defined(__LINUX__) + hugepages = true; +#else + hugepages = false; +#endif + hugepages_success = true; + + alloc_main_block_size = 0; + alloc_num_main_slots = 0; + alloc_thread_block_size = 0; + alloc_single_thread_alloc = -1; + + error_function = nullptr; + error_function_userptr = nullptr; + + memory_monitor_function = nullptr; + memory_monitor_userptr = nullptr; + } + + State::~State() { + } + + bool State::hasISA(const int isa) { + return (enabled_cpu_features & isa) == isa; + } + + bool State::checkISASupport() { + return (getCPUFeatures() & enabled_cpu_features) == enabled_cpu_features; + } + + void State::verify() + { + /* verify that calculations stay in range */ + assert(rcp(min_rcp_input)*FLT_LARGE+FLT_LARGE < 0.01f*FLT_MAX); + + /* here we verify that CPP files compiled for a specific ISA only + * call that same or lower ISA version of non-inlined class member + * functions */ +#if defined(DEBUG) +#if defined(EMBREE_TARGET_SSE2) + assert(sse2::getISA() <= SSE2); +#endif +#if defined(EMBREE_TARGET_SSE42) + assert(sse42::getISA() <= SSE42); +#endif +#if defined(EMBREE_TARGET_AVX) + assert(avx::getISA() <= AVX); +#endif +#if defined(EMBREE_TARGET_AVX2) + assert(avx2::getISA() <= AVX2); +#endif +#if defined (EMBREE_TARGET_AVX512KNL) + assert(avx512knl::getISA() <= AVX512KNL); +#endif +#if defined (EMBREE_TARGET_AVX512SKX) + assert(avx512skx::getISA() <= AVX512SKX); +#endif +#endif + } + + const char* symbols[3] = { "=", ",", "|" }; + + bool State::parseFile(const FileName& fileName) + { + FILE* f = fopen(fileName.c_str(),"r"); + if (!f) return false; + Ref > file = new FileStream(f,fileName); + + std::vector syms; + for (size_t i=0; i cin = new TokenStream(new LineCommentFilter(file,"#"), + TokenStream::alpha+TokenStream::ALPHA+TokenStream::numbers+"_.", + TokenStream::separators,syms); + parse(cin); + return true; + } + + void State::parseString(const char* cfg) + { + if (cfg == nullptr) return; + + std::vector syms; + for (size_t i=0; i cin = new TokenStream(new StrStream(cfg), + TokenStream::alpha+TokenStream::ALPHA+TokenStream::numbers+"_.", + TokenStream::separators,syms); + parse(cin); + } + + int string_to_cpufeatures(const std::string& isa) + { + if (isa == "sse" ) return SSE; + else if (isa == "sse2") return SSE2; + else if (isa == "sse3") return SSE3; + else if (isa == "ssse3") return SSSE3; + else if (isa == "sse41") return SSE41; + else if (isa == "sse4.1") return SSE41; + else if (isa == "sse42") return SSE42; + else if (isa == "sse4.2") return SSE42; + else if (isa == "avx") return AVX; + else if (isa == "avxi") return AVXI; + else if (isa == "avx2") return AVX2; + else if (isa == "avx512knl") return AVX512KNL; + else if (isa == "avx512skx") return AVX512SKX; + else return SSE2; + } + + void State::parse(Ref cin) + { + /* parse until end of stream */ + while (cin->peek() != Token::Eof()) + { + const Token tok = cin->get(); + + if (tok == Token::Id("threads") && cin->trySymbol("=")) + numThreads = cin->get().Int(); + + else if (tok == Token::Id("user_threads")&& cin->trySymbol("=")) + numUserThreads = cin->get().Int(); + + else if (tok == Token::Id("set_affinity")&& cin->trySymbol("=")) + set_affinity = cin->get().Int(); + + else if (tok == Token::Id("affinity")&& cin->trySymbol("=")) + set_affinity = cin->get().Int(); + + else if (tok == Token::Id("start_threads")&& cin->trySymbol("=")) + start_threads = cin->get().Int(); + + else if (tok == Token::Id("isa") && cin->trySymbol("=")) { + std::string isa = toLowerCase(cin->get().Identifier()); + enabled_cpu_features = string_to_cpufeatures(isa); + enabled_builder_cpu_features = enabled_cpu_features; + } + + else if (tok == Token::Id("max_isa") && cin->trySymbol("=")) { + std::string isa = toLowerCase(cin->get().Identifier()); + enabled_cpu_features &= string_to_cpufeatures(isa); + enabled_builder_cpu_features &= enabled_cpu_features; + } + + else if (tok == Token::Id("max_builder_isa") && cin->trySymbol("=")) { + std::string isa = toLowerCase(cin->get().Identifier()); + enabled_builder_cpu_features &= string_to_cpufeatures(isa); + } + + else if (tok == Token::Id("frequency_level") && cin->trySymbol("=")) { + std::string freq = cin->get().Identifier(); + if (freq == "simd128") frequency_level = FREQUENCY_SIMD128; + else if (freq == "simd256") frequency_level = FREQUENCY_SIMD256; + else if (freq == "simd512") frequency_level = FREQUENCY_SIMD512; + } + + else if (tok == Token::Id("enable_selockmemoryprivilege") && cin->trySymbol("=")) { + enable_selockmemoryprivilege = cin->get().Int(); + } + else if (tok == Token::Id("hugepages") && cin->trySymbol("=")) { + hugepages = cin->get().Int(); + } + + else if (tok == Token::Id("ignore_config_files") && cin->trySymbol("=")) + ignore_config_files = cin->get().Int(); + else if (tok == Token::Id("float_exceptions") && cin->trySymbol("=")) + float_exceptions = cin->get().Int(); + + else if ((tok == Token::Id("tri_accel") || tok == Token::Id("accel")) && cin->trySymbol("=")) + tri_accel = cin->get().Identifier(); + else if ((tok == Token::Id("tri_builder") || tok == Token::Id("builder")) && cin->trySymbol("=")) + tri_builder = cin->get().Identifier(); + else if ((tok == Token::Id("tri_traverser") || tok == Token::Id("traverser")) && cin->trySymbol("=")) + tri_traverser = cin->get().Identifier(); + + else if ((tok == Token::Id("tri_accel_mb") || tok == Token::Id("accel_mb")) && cin->trySymbol("=")) + tri_accel_mb = cin->get().Identifier(); + else if ((tok == Token::Id("tri_builder_mb") || tok == Token::Id("builder_mb")) && cin->trySymbol("=")) + tri_builder_mb = cin->get().Identifier(); + else if ((tok == Token::Id("tri_traverser_mb") || tok == Token::Id("traverser_mb")) && cin->trySymbol("=")) + tri_traverser_mb = cin->get().Identifier(); + + else if ((tok == Token::Id("quad_accel")) && cin->trySymbol("=")) + quad_accel = cin->get().Identifier(); + else if ((tok == Token::Id("quad_builder")) && cin->trySymbol("=")) + quad_builder = cin->get().Identifier(); + else if ((tok == Token::Id("quad_traverser")) && cin->trySymbol("=")) + quad_traverser = cin->get().Identifier(); + + else if ((tok == Token::Id("quad_accel_mb")) && cin->trySymbol("=")) + quad_accel_mb = cin->get().Identifier(); + else if ((tok == Token::Id("quad_builder_mb")) && cin->trySymbol("=")) + quad_builder_mb = cin->get().Identifier(); + else if ((tok == Token::Id("quad_traverser_mb")) && cin->trySymbol("=")) + quad_traverser_mb = cin->get().Identifier(); + + else if ((tok == Token::Id("line_accel")) && cin->trySymbol("=")) + line_accel = cin->get().Identifier(); + else if ((tok == Token::Id("line_builder")) && cin->trySymbol("=")) + line_builder = cin->get().Identifier(); + else if ((tok == Token::Id("line_traverser")) && cin->trySymbol("=")) + line_traverser = cin->get().Identifier(); + + else if ((tok == Token::Id("line_accel_mb")) && cin->trySymbol("=")) + line_accel_mb = cin->get().Identifier(); + else if ((tok == Token::Id("line_builder_mb")) && cin->trySymbol("=")) + line_builder_mb = cin->get().Identifier(); + else if ((tok == Token::Id("line_traverser_mb")) && cin->trySymbol("=")) + line_traverser_mb = cin->get().Identifier(); + + else if (tok == Token::Id("hair_accel") && cin->trySymbol("=")) + hair_accel = cin->get().Identifier(); + else if (tok == Token::Id("hair_builder") && cin->trySymbol("=")) + hair_builder = cin->get().Identifier(); + else if (tok == Token::Id("hair_traverser") && cin->trySymbol("=")) + hair_traverser = cin->get().Identifier(); + + else if (tok == Token::Id("hair_accel_mb") && cin->trySymbol("=")) + hair_accel_mb = cin->get().Identifier(); + else if (tok == Token::Id("hair_builder_mb") && cin->trySymbol("=")) + hair_builder_mb = cin->get().Identifier(); + else if (tok == Token::Id("hair_traverser_mb") && cin->trySymbol("=")) + hair_traverser_mb = cin->get().Identifier(); + + else if (tok == Token::Id("object_accel") && cin->trySymbol("=")) + object_accel = cin->get().Identifier(); + else if (tok == Token::Id("object_builder") && cin->trySymbol("=")) + object_builder = cin->get().Identifier(); + else if (tok == Token::Id("object_accel_min_leaf_size") && cin->trySymbol("=")) + object_accel_min_leaf_size = cin->get().Int(); + else if (tok == Token::Id("object_accel_max_leaf_size") && cin->trySymbol("=")) + object_accel_max_leaf_size = cin->get().Int(); + + else if (tok == Token::Id("object_accel_mb") && cin->trySymbol("=")) + object_accel_mb = cin->get().Identifier(); + else if (tok == Token::Id("object_builder_mb") && cin->trySymbol("=")) + object_builder_mb = cin->get().Identifier(); + else if (tok == Token::Id("object_accel_mb_min_leaf_size") && cin->trySymbol("=")) + object_accel_mb_min_leaf_size = cin->get().Int(); + else if (tok == Token::Id("object_accel_mb_max_leaf_size") && cin->trySymbol("=")) + object_accel_mb_max_leaf_size = cin->get().Int(); + + else if (tok == Token::Id("instancing_open_min") && cin->trySymbol("=")) + instancing_open_min = cin->get().Int(); + else if (tok == Token::Id("instancing_block_size") && cin->trySymbol("=")) { + instancing_block_size = cin->get().Int(); + instancing_open_factor = 0.0f; + } + else if (tok == Token::Id("instancing_open_max_depth") && cin->trySymbol("=")) + instancing_open_max_depth = cin->get().Int(); + else if (tok == Token::Id("instancing_open_factor") && cin->trySymbol("=")) { + instancing_block_size = 0; + instancing_open_factor = cin->get().Float(); + } + else if (tok == Token::Id("instancing_open_max") && cin->trySymbol("=")) + instancing_open_max = cin->get().Int(); + + else if (tok == Token::Id("subdiv_accel") && cin->trySymbol("=")) + subdiv_accel = cin->get().Identifier(); + else if (tok == Token::Id("subdiv_accel_mb") && cin->trySymbol("=")) + subdiv_accel_mb = cin->get().Identifier(); + + else if (tok == Token::Id("grid_accel") && cin->trySymbol("=")) + grid_accel = cin->get().Identifier(); + else if (tok == Token::Id("grid_accel_mb") && cin->trySymbol("=")) + grid_accel_mb = cin->get().Identifier(); + + else if (tok == Token::Id("verbose") && cin->trySymbol("=")) + verbose = cin->get().Int(); + else if (tok == Token::Id("benchmark") && cin->trySymbol("=")) + benchmark = cin->get().Int(); + + else if (tok == Token::Id("quality")) { + if (cin->trySymbol("=")) { + Token flag = cin->get(); + if (flag == Token::Id("low")) quality_flags = RTC_BUILD_QUALITY_LOW; + else if (flag == Token::Id("medium")) quality_flags = RTC_BUILD_QUALITY_MEDIUM; + else if (flag == Token::Id("high")) quality_flags = RTC_BUILD_QUALITY_HIGH; + } + } + + else if (tok == Token::Id("scene_flags")) { + scene_flags = 0; + if (cin->trySymbol("=")) { + do { + Token flag = cin->get(); + if (flag == Token::Id("dynamic") ) scene_flags |= RTC_SCENE_FLAG_DYNAMIC; + else if (flag == Token::Id("compact")) scene_flags |= RTC_SCENE_FLAG_COMPACT; + else if (flag == Token::Id("robust")) scene_flags |= RTC_SCENE_FLAG_ROBUST; + } while (cin->trySymbol("|")); + } + } + + else if (tok == Token::Id("max_spatial_split_replications") && cin->trySymbol("=")) + max_spatial_split_replications = cin->get().Float(); + + else if (tok == Token::Id("presplits") && cin->trySymbol("=")) + useSpatialPreSplits = cin->get().Int() != 0 ? true : false; + + else if (tok == Token::Id("tessellation_cache_size") && cin->trySymbol("=")) + tessellation_cache_size = size_t(cin->get().Float()*1024.0f*1024.0f); + else if (tok == Token::Id("cache_size") && cin->trySymbol("=")) + tessellation_cache_size = size_t(cin->get().Float()*1024.0f*1024.0f); + + else if (tok == Token::Id("alloc_main_block_size") && cin->trySymbol("=")) + alloc_main_block_size = cin->get().Int(); + else if (tok == Token::Id("alloc_num_main_slots") && cin->trySymbol("=")) + alloc_num_main_slots = cin->get().Int(); + else if (tok == Token::Id("alloc_thread_block_size") && cin->trySymbol("=")) + alloc_thread_block_size = cin->get().Int(); + else if (tok == Token::Id("alloc_single_thread_alloc") && cin->trySymbol("=")) + alloc_single_thread_alloc = cin->get().Int(); + + cin->trySymbol(","); // optional , separator + } + } + + bool State::verbosity(size_t N) { + return N <= verbose; + } + + void State::print() + { + std::cout << "general:" << std::endl; + std::cout << " build threads = " << numThreads << std::endl; + std::cout << " build user threads = " << numUserThreads << std::endl; + std::cout << " start_threads = " << start_threads << std::endl; + std::cout << " affinity = " << set_affinity << std::endl; + std::cout << " frequency_level = "; + switch (frequency_level) { + case FREQUENCY_SIMD128: std::cout << "simd128" << std::endl; break; + case FREQUENCY_SIMD256: std::cout << "simd256" << std::endl; break; + case FREQUENCY_SIMD512: std::cout << "simd512" << std::endl; break; + default: std::cout << "error" << std::endl; break; + } + + std::cout << " hugepages = "; + if (!hugepages) std::cout << "disabled" << std::endl; + else if (hugepages_success) std::cout << "enabled" << std::endl; + else std::cout << "failed" << std::endl; + + std::cout << " verbosity = " << verbose << std::endl; + std::cout << " cache_size = " << float(tessellation_cache_size)*1E-6 << " MB" << std::endl; + std::cout << " max_spatial_split_replications = " << max_spatial_split_replications << std::endl; + + std::cout << "triangles:" << std::endl; + std::cout << " accel = " << tri_accel << std::endl; + std::cout << " builder = " << tri_builder << std::endl; + std::cout << " traverser = " << tri_traverser << std::endl; + + std::cout << "motion blur triangles:" << std::endl; + std::cout << " accel = " << tri_accel_mb << std::endl; + std::cout << " builder = " << tri_builder_mb << std::endl; + std::cout << " traverser = " << tri_traverser_mb << std::endl; + + std::cout << "quads:" << std::endl; + std::cout << " accel = " << quad_accel << std::endl; + std::cout << " builder = " << quad_builder << std::endl; + std::cout << " traverser = " << quad_traverser << std::endl; + + std::cout << "motion blur quads:" << std::endl; + std::cout << " accel = " << quad_accel_mb << std::endl; + std::cout << " builder = " << quad_builder_mb << std::endl; + std::cout << " traverser = " << quad_traverser_mb << std::endl; + + std::cout << "line segments:" << std::endl; + std::cout << " accel = " << line_accel << std::endl; + std::cout << " builder = " << line_builder << std::endl; + std::cout << " traverser = " << line_traverser << std::endl; + + std::cout << "motion blur line segments:" << std::endl; + std::cout << " accel = " << line_accel_mb << std::endl; + std::cout << " builder = " << line_builder_mb << std::endl; + std::cout << " traverser = " << line_traverser_mb << std::endl; + + std::cout << "hair:" << std::endl; + std::cout << " accel = " << hair_accel << std::endl; + std::cout << " builder = " << hair_builder << std::endl; + std::cout << " traverser = " << hair_traverser << std::endl; + + std::cout << "motion blur hair:" << std::endl; + std::cout << " accel = " << hair_accel_mb << std::endl; + std::cout << " builder = " << hair_builder_mb << std::endl; + std::cout << " traverser = " << hair_traverser_mb << std::endl; + + std::cout << "subdivision surfaces:" << std::endl; + std::cout << " accel = " << subdiv_accel << std::endl; + + std::cout << "grids:" << std::endl; + std::cout << " accel = " << grid_accel << std::endl; + std::cout << " builder = " << grid_builder << std::endl; + + std::cout << "motion blur grids:" << std::endl; + std::cout << " accel = " << grid_accel_mb << std::endl; + std::cout << " builder = " << grid_builder_mb << std::endl; + + std::cout << "object_accel:" << std::endl; + std::cout << " min_leaf_size = " << object_accel_min_leaf_size << std::endl; + std::cout << " max_leaf_size = " << object_accel_max_leaf_size << std::endl; + + std::cout << "object_accel_mb:" << std::endl; + std::cout << " min_leaf_size = " << object_accel_mb_min_leaf_size << std::endl; + std::cout << " max_leaf_size = " << object_accel_mb_max_leaf_size << std::endl; + } +} diff --git a/thirdparty/embree/kernels/common/state.h b/thirdparty/embree/kernels/common/state.h new file mode 100644 index 000000000000..d0fccc023f8a --- /dev/null +++ b/thirdparty/embree/kernels/common/state.h @@ -0,0 +1,197 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "default.h" + +namespace embree +{ + /* mutex to make printing to cout thread safe */ + extern MutexSys g_printMutex; + + struct State : public RefCount + { + public: + /*! state construction */ + State (); + + /*! state destruction */ + ~State(); + + /*! verifies that state is correct */ + void verify(); + + /*! parses state from a configuration file */ + bool parseFile(const FileName& fileName); + + /*! parses the state from a string */ + void parseString(const char* cfg); + + /*! parses the state from a stream */ + void parse(Ref cin); + + /*! prints the state */ + void print(); + + /*! checks if verbosity level is at least N */ + bool verbosity(size_t N); + + /*! checks if some particular ISA is enabled */ + bool hasISA(const int isa); + + /*! check whether selected ISA is supported by the HW */ + bool checkISASupport(); + + public: + std::string tri_accel; //!< acceleration structure to use for triangles + std::string tri_builder; //!< builder to use for triangles + std::string tri_traverser; //!< traverser to use for triangles + + public: + std::string tri_accel_mb; //!< acceleration structure to use for motion blur triangles + std::string tri_builder_mb; //!< builder to use for motion blur triangles + std::string tri_traverser_mb; //!< traverser to use for triangles + + public: + std::string quad_accel; //!< acceleration structure to use for quads + std::string quad_builder; //!< builder to use for quads + std::string quad_traverser; //!< traverser to use for quads + + public: + std::string quad_accel_mb; //!< acceleration structure to use for motion blur quads + std::string quad_builder_mb; //!< builder to use for motion blur quads + std::string quad_traverser_mb; //!< traverser to use for motion blur quads + + public: + std::string line_accel; //!< acceleration structure to use for line segments + std::string line_builder; //!< builder to use for line segments + std::string line_traverser; //!< traverser to use for line segments + + public: + std::string line_accel_mb; //!< acceleration structure to use for motion blur line segments + std::string line_builder_mb; //!< builder to use for motion blur line segments + std::string line_traverser_mb; //!< traverser to use for motion blur line segments + + public: + std::string hair_accel; //!< hair acceleration structure to use + std::string hair_builder; //!< builder to use for hair + std::string hair_traverser; //!< traverser to use for hair + + public: + std::string hair_accel_mb; //!< acceleration structure to use for motion blur hair + std::string hair_builder_mb; //!< builder to use for motion blur hair + std::string hair_traverser_mb; //!< traverser to use for motion blur hair + + public: + std::string object_accel; //!< acceleration structure for user geometries + std::string object_builder; //!< builder for user geometries + int object_accel_min_leaf_size; //!< minimum leaf size for object acceleration structure + int object_accel_max_leaf_size; //!< maximum leaf size for object acceleration structure + + public: + std::string object_accel_mb; //!< acceleration structure for user geometries + std::string object_builder_mb; //!< builder for user geometries + int object_accel_mb_min_leaf_size; //!< minimum leaf size for mblur object acceleration structure + int object_accel_mb_max_leaf_size; //!< maximum leaf size for mblur object acceleration structure + + public: + std::string subdiv_accel; //!< acceleration structure to use for subdivision surfaces + std::string subdiv_accel_mb; //!< acceleration structure to use for subdivision surfaces + + public: + std::string grid_accel; //!< acceleration structure to use for grids + std::string grid_builder; //!< builder for grids + std::string grid_accel_mb; //!< acceleration structure to use for motion blur grids + std::string grid_builder_mb; //!< builder for motion blur grids + + public: + float max_spatial_split_replications; //!< maximally replications*N many primitives in accel for spatial splits + bool useSpatialPreSplits; //!< use spatial pre-splits instead of the full spatial split builder + size_t tessellation_cache_size; //!< size of the shared tessellation cache + + public: + size_t instancing_open_min; //!< instancing opens tree to minimally that number of subtrees + size_t instancing_block_size; //!< instancing opens tree up to average block size of primitives + float instancing_open_factor; //!< instancing opens tree up to x times the number of instances + size_t instancing_open_max_depth; //!< maximum open depth for geometries + size_t instancing_open_max; //!< instancing opens tree to maximally that number of subtrees + + public: + bool ignore_config_files; //!< if true no more config files get parse + bool float_exceptions; //!< enable floating point exceptions + int quality_flags; + int scene_flags; + size_t verbose; //!< verbosity of output + size_t benchmark; //!< true + + public: + size_t numThreads; //!< number of threads to use in builders + size_t numUserThreads; //!< number of user provided threads to use in builders + bool set_affinity; //!< sets affinity for worker threads + bool start_threads; //!< true when threads should be started at device creation time + int enabled_cpu_features; //!< CPU ISA features to use + int enabled_builder_cpu_features; //!< CPU ISA features to use for builders only + enum FREQUENCY_LEVEL { + FREQUENCY_SIMD128, + FREQUENCY_SIMD256, + FREQUENCY_SIMD512 + } frequency_level; //!< frequency level the app wants to run on (default is SIMD256) + bool enable_selockmemoryprivilege; //!< configures the SeLockMemoryPrivilege under Windows to enable huge pages + bool hugepages; //!< true if huge pages should get used + bool hugepages_success; //!< status for enabling huge pages + + public: + size_t alloc_main_block_size; //!< main allocation block size (shared between threads) + int alloc_num_main_slots; //!< number of such shared blocks to be used to allocate + size_t alloc_thread_block_size; //!< size of thread local allocator block size + int alloc_single_thread_alloc; //!< in single mode nodes and leaves use same thread local allocator + + public: + + /*! checks if we can use AVX */ + bool canUseAVX() { + return hasISA(AVX) && frequency_level != FREQUENCY_SIMD128; + } + + /*! checks if we can use AVX2 */ + bool canUseAVX2() { + return hasISA(AVX2) && frequency_level != FREQUENCY_SIMD128; + } + + struct ErrorHandler + { + public: + ErrorHandler(); + ~ErrorHandler(); + RTCError* error(); + + public: + tls_t thread_error; + std::vector thread_errors; + MutexSys errors_mutex; + }; + ErrorHandler errorHandler; + static ErrorHandler g_errorHandler; + + public: + void setErrorFunction(RTCErrorFunction fptr, void* uptr) + { + error_function = fptr; + error_function_userptr = uptr; + } + + RTCErrorFunction error_function; + void* error_function_userptr; + + public: + void setMemoryMonitorFunction(RTCMemoryMonitorFunction fptr, void* uptr) + { + memory_monitor_function = fptr; + memory_monitor_userptr = uptr; + } + + RTCMemoryMonitorFunction memory_monitor_function; + void* memory_monitor_userptr; + }; +} diff --git a/thirdparty/embree/kernels/common/vector.h b/thirdparty/embree/kernels/common/vector.h new file mode 100644 index 000000000000..b47876224034 --- /dev/null +++ b/thirdparty/embree/kernels/common/vector.h @@ -0,0 +1,76 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "default.h" + +namespace embree +{ + /*! invokes the memory monitor callback */ + struct MemoryMonitorInterface { + virtual void memoryMonitor(ssize_t bytes, bool post) = 0; + }; + + /*! allocator that performs aligned monitored allocations */ + template + struct aligned_monitored_allocator + { + typedef T value_type; + typedef T* pointer; + typedef const T* const_pointer; + typedef T& reference; + typedef const T& const_reference; + typedef std::size_t size_type; + typedef std::ptrdiff_t difference_type; + + __forceinline aligned_monitored_allocator(MemoryMonitorInterface* device) + : device(device), hugepages(false) {} + + __forceinline pointer allocate( size_type n ) + { + if (n) { + assert(device); + device->memoryMonitor(n*sizeof(T),false); + } + if (n*sizeof(value_type) >= 14 * PAGE_SIZE_2M) + { + pointer p = (pointer) os_malloc(n*sizeof(value_type),hugepages); + assert(p); + return p; + } + return (pointer) alignedMalloc(n*sizeof(value_type),alignment); + } + + __forceinline void deallocate( pointer p, size_type n ) + { + if (p) + { + if (n*sizeof(value_type) >= 14 * PAGE_SIZE_2M) + os_free(p,n*sizeof(value_type),hugepages); + else + alignedFree(p); + } + else assert(n == 0); + + if (n) { + assert(device); + device->memoryMonitor(-ssize_t(n)*sizeof(T),true); + } + } + + __forceinline void construct( pointer p, const_reference val ) { + new (p) T(val); + } + + __forceinline void destroy( pointer p ) { + p->~T(); + } + + private: + MemoryMonitorInterface* device; + bool hugepages; + }; + + /*! monitored vector */ + template + using mvector = vector_t::value> >; +} diff --git a/thirdparty/embree/kernels/config.h b/thirdparty/embree/kernels/config.h new file mode 100644 index 000000000000..80a8ab2a564d --- /dev/null +++ b/thirdparty/embree/kernels/config.h @@ -0,0 +1,76 @@ + +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +/* #undef EMBREE_RAY_MASK */ +/* #undef EMBREE_STAT_COUNTERS */ +/* #undef EMBREE_BACKFACE_CULLING */ +/* #undef EMBREE_BACKFACE_CULLING_CURVES */ +#define EMBREE_FILTER_FUNCTION +/* #undef EMBREE_IGNORE_INVALID_RAYS */ +#define EMBREE_GEOMETRY_TRIANGLE +/* #undef EMBREE_GEOMETRY_QUAD */ +/* #undef EMBREE_GEOMETRY_CURVE */ +/* #undef EMBREE_GEOMETRY_SUBDIVISION */ +/* #undef EMBREE_GEOMETRY_USER */ +/* #undef EMBREE_GEOMETRY_INSTANCE */ +/* #undef EMBREE_GEOMETRY_GRID */ +/* #undef EMBREE_GEOMETRY_POINT */ +/* #undef EMBREE_RAY_PACKETS */ +/* #undef EMBREE_COMPACT_POLYS */ + +#define EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR 2.0 + +#if defined(EMBREE_GEOMETRY_TRIANGLE) + #define IF_ENABLED_TRIS(x) x +#else + #define IF_ENABLED_TRIS(x) +#endif + +#if defined(EMBREE_GEOMETRY_QUAD) + #define IF_ENABLED_QUADS(x) x +#else + #define IF_ENABLED_QUADS(x) +#endif + +#if defined(EMBREE_GEOMETRY_CURVE) || defined(EMBREE_GEOMETRY_POINT) + #define IF_ENABLED_CURVES_OR_POINTS(x) x +#else + #define IF_ENABLED_CURVES_OR_POINTS(x) +#endif + +#if defined(EMBREE_GEOMETRY_CURVE) + #define IF_ENABLED_CURVES(x) x +#else + #define IF_ENABLED_CURVES(x) +#endif + +#if defined(EMBREE_GEOMETRY_POINT) + #define IF_ENABLED_POINTS(x) x +#else + #define IF_ENABLED_POINTS(x) +#endif + +#if defined(EMBREE_GEOMETRY_SUBDIVISION) + #define IF_ENABLED_SUBDIV(x) x +#else + #define IF_ENABLED_SUBDIV(x) +#endif + +#if defined(EMBREE_GEOMETRY_USER) + #define IF_ENABLED_USER(x) x +#else + #define IF_ENABLED_USER(x) +#endif + +#if defined(EMBREE_GEOMETRY_INSTANCE) + #define IF_ENABLED_INSTANCE(x) x +#else + #define IF_ENABLED_INSTANCE(x) +#endif + +#if defined(EMBREE_GEOMETRY_GRID) + #define IF_ENABLED_GRIDS(x) x +#else + #define IF_ENABLED_GRIDS(x) +#endif diff --git a/thirdparty/embree/kernels/geometry/cone.h b/thirdparty/embree/kernels/geometry/cone.h new file mode 100644 index 000000000000..961ef86160aa --- /dev/null +++ b/thirdparty/embree/kernels/geometry/cone.h @@ -0,0 +1,321 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" + +namespace embree +{ + namespace isa + { + struct Cone + { + const Vec3fa p0; //!< start position of cone + const Vec3fa p1; //!< end position of cone + const float r0; //!< start radius of cone + const float r1; //!< end radius of cone + + __forceinline Cone(const Vec3fa& p0, const float r0, const Vec3fa& p1, const float r1) + : p0(p0), p1(p1), r0(r0), r1(r1) {} + + __forceinline bool intersect(const Vec3fa& org, const Vec3fa& dir, + BBox1f& t_o, + float& u0_o, Vec3fa& Ng0_o, + float& u1_o, Vec3fa& Ng1_o) const + { + /* calculate quadratic equation to solve */ + const Vec3fa v0 = p0-org; + const Vec3fa v1 = p1-org; + + const float rl = rcp_length(v1-v0); + const Vec3fa P0 = v0, dP = (v1-v0)*rl; + const float dr = (r1-r0)*rl; + const Vec3fa O = -P0, dO = dir; + + const float dOdO = dot(dO,dO); + const float OdO = dot(dO,O); + const float OO = dot(O,O); + const float dOz = dot(dP,dO); + const float Oz = dot(dP,O); + + const float R = r0 + Oz*dr; + const float A = dOdO - sqr(dOz) * (1.0f+sqr(dr)); + const float B = 2.0f * (OdO - dOz*(Oz + R*dr)); + const float C = OO - (sqr(Oz) + sqr(R)); + + /* we miss the cone if determinant is smaller than zero */ + const float D = B*B - 4.0f*A*C; + if (D < 0.0f) return false; + + /* special case for rays that are "parallel" to the cone */ + const float eps = float(1<<8)*float(ulp)*max(abs(dOdO),abs(sqr(dOz))); + if (unlikely(abs(A) < eps)) + { + /* cylinder case */ + if (abs(dr) < 16.0f*float(ulp)) { + if (C <= 0.0f) { t_o = BBox1f(neg_inf,pos_inf); return true; } + else { t_o = BBox1f(pos_inf,neg_inf); return false; } + } + + /* cone case */ + else + { + /* if we hit the negative cone there cannot be a hit */ + const float t = -C/B; + const float z0 = Oz+t*dOz; + const float z0r = r0+z0*dr; + if (z0r < 0.0f) return false; + + /* test if we start inside or outside the cone */ + if (dOz*dr > 0.0f) t_o = BBox1f(t,pos_inf); + else t_o = BBox1f(neg_inf,t); + } + } + + /* standard case for "non-parallel" rays */ + else + { + const float Q = sqrt(D); + const float rcp_2A = rcp(2.0f*A); + t_o.lower = (-B-Q)*rcp_2A; + t_o.upper = (-B+Q)*rcp_2A; + + /* standard case where both hits are on same cone */ + if (likely(A > 0.0f)) { + const float z0 = Oz+t_o.lower*dOz; + const float z0r = r0+z0*dr; + if (z0r < 0.0f) return false; + } + + /* special case where the hits are on the positive and negative cone */ + else + { + /* depending on the ray direction and the open direction + * of the cone we have a hit from inside or outside the + * cone */ + if (dOz*dr > 0) t_o.upper = pos_inf; + else t_o.lower = neg_inf; + } + } + + /* calculates u and Ng for near hit */ + { + u0_o = (Oz+t_o.lower*dOz)*rl; + const Vec3fa Pr = t_o.lower*dir; + const Vec3fa Pl = v0 + u0_o*(v1-v0); + const Vec3fa R = normalize(Pr-Pl); + const Vec3fa U = (p1-p0)+(r1-r0)*R; + const Vec3fa V = cross(p1-p0,R); + Ng0_o = cross(V,U); + } + + /* calculates u and Ng for far hit */ + { + u1_o = (Oz+t_o.upper*dOz)*rl; + const Vec3fa Pr = t_o.upper*dir; + const Vec3fa Pl = v0 + u1_o*(v1-v0); + const Vec3fa R = normalize(Pr-Pl); + const Vec3fa U = (p1-p0)+(r1-r0)*R; + const Vec3fa V = cross(p1-p0,R); + Ng1_o = cross(V,U); + } + return true; + } + + __forceinline bool intersect(const Vec3fa& org, const Vec3fa& dir, BBox1f& t_o) const + { + float u0_o; Vec3fa Ng0_o; float u1_o; Vec3fa Ng1_o; + return intersect(org,dir,t_o,u0_o,Ng0_o,u1_o,Ng1_o); + } + + static bool verify(const size_t id, const Cone& cone, const Ray& ray, bool shouldhit, const float t0, const float t1) + { + float eps = 0.001f; + BBox1f t; bool hit; + hit = cone.intersect(ray.org,ray.dir,t); + + bool failed = hit != shouldhit; + if (shouldhit) failed |= std::isinf(t0) ? t0 != t.lower : (t0 == -1E6) ? t.lower > -1E6f : abs(t0-t.lower) > eps; + if (shouldhit) failed |= std::isinf(t1) ? t1 != t.upper : (t1 == +1E6) ? t.upper < +1E6f : abs(t1-t.upper) > eps; + if (!failed) return true; + embree_cout << "Cone test " << id << " failed: cone = " << cone << ", ray = " << ray << ", hit = " << hit << ", t = " << t << embree_endl; + return false; + } + + /* verify cone class */ + static bool verify() + { + bool passed = true; + const Cone cone0(Vec3fa(0.0f,0.0f,0.0f),0.0f,Vec3fa(1.0f,0.0f,0.0f),1.0f); + passed &= verify(0,cone0,Ray(Vec3fa(-2.0f,1.0f,0.0f),Vec3fa(+1.0f,+0.0f,+0.0f),0.0f,float(inf)),true,3.0f,pos_inf); + passed &= verify(1,cone0,Ray(Vec3fa(+2.0f,1.0f,0.0f),Vec3fa(-1.0f,+0.0f,+0.0f),0.0f,float(inf)),true,neg_inf,1.0f); + passed &= verify(2,cone0,Ray(Vec3fa(-1.0f,0.0f,2.0f),Vec3fa(+0.0f,+0.0f,-1.0f),0.0f,float(inf)),false,0.0f,0.0f); + passed &= verify(3,cone0,Ray(Vec3fa(+1.0f,0.0f,2.0f),Vec3fa(+0.0f,+0.0f,-1.0f),0.0f,float(inf)),true,1.0f,3.0f); + passed &= verify(4,cone0,Ray(Vec3fa(-1.0f,0.0f,0.0f),Vec3fa(+1.0f,+0.0f,+0.0f),0.0f,float(inf)),true,1.0f,pos_inf); + passed &= verify(5,cone0,Ray(Vec3fa(+1.0f,0.0f,0.0f),Vec3fa(-1.0f,+0.0f,+0.0f),0.0f,float(inf)),true,neg_inf,1.0f); + passed &= verify(6,cone0,Ray(Vec3fa(+0.0f,0.0f,1.0f),Vec3fa(+0.0f,+0.0f,-1.0f),0.0f,float(inf)),true,1.0f,1.0f); + passed &= verify(7,cone0,Ray(Vec3fa(+0.0f,1.0f,0.0f),Vec3fa(-1.0f,-1.0f,+0.0f),0.0f,float(inf)),false,0.0f,0.0f); + passed &= verify(8,cone0,Ray(Vec3fa(+0.0f,1.0f,0.0f),Vec3fa(+1.0f,-1.0f,+0.0f),0.0f,float(inf)),true,0.5f,+1E6); + passed &= verify(9,cone0,Ray(Vec3fa(+0.0f,1.0f,0.0f),Vec3fa(-1.0f,+1.0f,+0.0f),0.0f,float(inf)),true,-1E6,-0.5f); + const Cone cone1(Vec3fa(0.0f,0.0f,0.0f),1.0f,Vec3fa(1.0f,0.0f,0.0f),0.0f); + passed &= verify(10,cone1,Ray(Vec3fa(-2.0f,1.0f,0.0f),Vec3fa(+1.0f,+0.0f,+0.0f),0.0f,float(inf)),true,neg_inf,2.0f); + passed &= verify(11,cone1,Ray(Vec3fa(-1.0f,0.0f,2.0f),Vec3fa(+0.0f,+0.0f,-1.0f),0.0f,float(inf)),true,0.0f,4.0f); + const Cone cylinder(Vec3fa(0.0f,0.0f,0.0f),1.0f,Vec3fa(1.0f,0.0f,0.0f),1.0f); + passed &= verify(12,cylinder,Ray(Vec3fa(-2.0f,1.0f,0.0f),Vec3fa( 0.0f,-1.0f,+0.0f),0.0f,float(inf)),true,0.0f,2.0f); + passed &= verify(13,cylinder,Ray(Vec3fa(+2.0f,1.0f,0.0f),Vec3fa( 0.0f,-1.0f,+0.0f),0.0f,float(inf)),true,0.0f,2.0f); + passed &= verify(14,cylinder,Ray(Vec3fa(+2.0f,1.0f,2.0f),Vec3fa( 0.0f,-1.0f,+0.0f),0.0f,float(inf)),false,0.0f,0.0f); + passed &= verify(15,cylinder,Ray(Vec3fa(+0.0f,0.0f,0.0f),Vec3fa( 1.0f, 0.0f,+0.0f),0.0f,float(inf)),true,neg_inf,pos_inf); + passed &= verify(16,cylinder,Ray(Vec3fa(+0.0f,0.0f,0.0f),Vec3fa(-1.0f, 0.0f,+0.0f),0.0f,float(inf)),true,neg_inf,pos_inf); + passed &= verify(17,cylinder,Ray(Vec3fa(+0.0f,2.0f,0.0f),Vec3fa( 1.0f, 0.0f,+0.0f),0.0f,float(inf)),false,pos_inf,neg_inf); + passed &= verify(18,cylinder,Ray(Vec3fa(+0.0f,2.0f,0.0f),Vec3fa(-1.0f, 0.0f,+0.0f),0.0f,float(inf)),false,pos_inf,neg_inf); + return passed; + } + + /*! output operator */ + friend __forceinline embree_ostream operator<<(embree_ostream cout, const Cone& c) { + return cout << "Cone { p0 = " << c.p0 << ", r0 = " << c.r0 << ", p1 = " << c.p1 << ", r1 = " << c.r1 << "}"; + } + }; + + template + struct ConeN + { + typedef Vec3> Vec3vfN; + + const Vec3vfN p0; //!< start position of cone + const Vec3vfN p1; //!< end position of cone + const vfloat r0; //!< start radius of cone + const vfloat r1; //!< end radius of cone + + __forceinline ConeN(const Vec3vfN& p0, const vfloat& r0, const Vec3vfN& p1, const vfloat& r1) + : p0(p0), p1(p1), r0(r0), r1(r1) {} + + __forceinline Cone operator[] (const size_t i) const + { + assert(i intersect(const Vec3fa& org, const Vec3fa& dir, + BBox>& t_o, + vfloat& u0_o, Vec3vfN& Ng0_o, + vfloat& u1_o, Vec3vfN& Ng1_o) const + { + /* calculate quadratic equation to solve */ + const Vec3vfN v0 = p0-Vec3vfN(org); + const Vec3vfN v1 = p1-Vec3vfN(org); + + const vfloat rl = rcp_length(v1-v0); + const Vec3vfN P0 = v0, dP = (v1-v0)*rl; + const vfloat dr = (r1-r0)*rl; + const Vec3vfN O = -P0, dO = dir; + + const vfloat dOdO = dot(dO,dO); + const vfloat OdO = dot(dO,O); + const vfloat OO = dot(O,O); + const vfloat dOz = dot(dP,dO); + const vfloat Oz = dot(dP,O); + + const vfloat R = r0 + Oz*dr; + const vfloat A = dOdO - sqr(dOz) * (vfloat(1.0f)+sqr(dr)); + const vfloat B = 2.0f * (OdO - dOz*(Oz + R*dr)); + const vfloat C = OO - (sqr(Oz) + sqr(R)); + + /* we miss the cone if determinant is smaller than zero */ + const vfloat D = B*B - 4.0f*A*C; + vbool valid = D >= 0.0f; + if (none(valid)) return valid; + + /* special case for rays that are "parallel" to the cone */ + const vfloat eps = float(1<<8)*float(ulp)*max(abs(dOdO),abs(sqr(dOz))); + const vbool validt = valid & (abs(A) < eps); + const vbool validf = valid & !(abs(A) < eps); + if (unlikely(any(validt))) + { + const vboolx validtt = validt & (abs(dr) < 16.0f*float(ulp)); + const vboolx validtf = validt & (abs(dr) >= 16.0f*float(ulp)); + + /* cylinder case */ + if (unlikely(any(validtt))) + { + t_o.lower = select(validtt, select(C <= 0.0f, vfloat(neg_inf), vfloat(pos_inf)), t_o.lower); + t_o.upper = select(validtt, select(C <= 0.0f, vfloat(pos_inf), vfloat(neg_inf)), t_o.upper); + valid &= !validtt | C <= 0.0f; + } + + /* cone case */ + if (any(validtf)) + { + /* if we hit the negative cone there cannot be a hit */ + const vfloat t = -C/B; + const vfloat z0 = Oz+t*dOz; + const vfloat z0r = r0+z0*dr; + valid &= !validtf | z0r >= 0.0f; + + /* test if we start inside or outside the cone */ + t_o.lower = select(validtf, select(dOz*dr > 0.0f, t, vfloat(neg_inf)), t_o.lower); + t_o.upper = select(validtf, select(dOz*dr > 0.0f, vfloat(pos_inf), t), t_o.upper); + } + } + + /* standard case for "non-parallel" rays */ + if (likely(any(validf))) + { + const vfloat Q = sqrt(D); + const vfloat rcp_2A = 0.5f*rcp(A); + t_o.lower = select(validf, (-B-Q)*rcp_2A, t_o.lower); + t_o.upper = select(validf, (-B+Q)*rcp_2A, t_o.upper); + + /* standard case where both hits are on same cone */ + const vbool validft = validf & A>0.0f; + const vbool validff = validf & !(A>0.0f); + if (any(validft)) { + const vfloat z0 = Oz+t_o.lower*dOz; + const vfloat z0r = r0+z0*dr; + valid &= !validft | z0r >= 0.0f; + } + + /* special case where the hits are on the positive and negative cone */ + if (any(validff)) { + /* depending on the ray direction and the open direction + * of the cone we have a hit from inside or outside the + * cone */ + t_o.lower = select(validff, select(dOz*dr > 0.0f, t_o.lower, float(neg_inf)), t_o.lower); + t_o.upper = select(validff, select(dOz*dr > 0.0f, float(pos_inf), t_o.upper), t_o.upper); + } + } + + /* calculates u and Ng for near hit */ + { + u0_o = (Oz+t_o.lower*dOz)*rl; + const Vec3vfN Pr = t_o.lower*Vec3vfN(dir); + const Vec3vfN Pl = v0 + u0_o*(v1-v0); + const Vec3vfN R = normalize(Pr-Pl); + const Vec3vfN U = (p1-p0)+(r1-r0)*R; + const Vec3vfN V = cross(p1-p0,R); + Ng0_o = cross(V,U); + } + + /* calculates u and Ng for far hit */ + { + u1_o = (Oz+t_o.upper*dOz)*rl; + const Vec3vfN Pr = t_o.lower*Vec3vfN(dir); + const Vec3vfN Pl = v0 + u1_o*(v1-v0); + const Vec3vfN R = normalize(Pr-Pl); + const Vec3vfN U = (p1-p0)+(r1-r0)*R; + const Vec3vfN V = cross(p1-p0,R); + Ng1_o = cross(V,U); + } + return valid; + } + + __forceinline vbool intersect(const Vec3fa& org, const Vec3fa& dir, BBox>& t_o) const + { + vfloat u0_o; Vec3vfN Ng0_o; vfloat u1_o; Vec3vfN Ng1_o; + return intersect(org,dir,t_o,u0_o,Ng0_o,u1_o,Ng1_o); + } + }; + } +} + diff --git a/thirdparty/embree/kernels/geometry/coneline_intersector.h b/thirdparty/embree/kernels/geometry/coneline_intersector.h new file mode 100644 index 000000000000..0902baff7db0 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/coneline_intersector.h @@ -0,0 +1,209 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" +#include "curve_intersector_precalculations.h" + +namespace embree +{ + namespace isa + { + namespace __coneline_internal + { + template + static __forceinline bool intersectCone(const vbool& valid_i, + const Vec3vf& ray_org_in, const Vec3vf& ray_dir, + const vfloat& ray_tnear, const ray_tfar_func& ray_tfar, + const Vec4vf& v0, const Vec4vf& v1, + const vbool& cL, const vbool& cR, + const Epilog& epilog) + { + vbool valid = valid_i; + + /* move ray origin closer to make calculations numerically stable */ + const vfloat dOdO = sqr(ray_dir); + const vfloat rcp_dOdO = rcp(dOdO); + const Vec3vf center = vfloat(0.5f)*(v0.xyz()+v1.xyz()); + const vfloat dt = dot(center-ray_org_in,ray_dir)*rcp_dOdO; + const Vec3vf ray_org = ray_org_in + dt*ray_dir; + + const Vec3vf dP = v1.xyz() - v0.xyz(); + const Vec3vf p0 = ray_org - v0.xyz(); + const Vec3vf p1 = ray_org - v1.xyz(); + + const vfloat dPdP = sqr(dP); + const vfloat dP0 = dot(p0,dP); + const vfloat dP1 = dot(p1,dP); + const vfloat dOdP = dot(ray_dir,dP); + + // intersect cone body + const vfloat dr = v0.w - v1.w; + const vfloat hy = dPdP + sqr(dr); + const vfloat dO0 = dot(ray_dir,p0); + const vfloat OO = sqr(p0); + const vfloat dPdP2 = sqr(dPdP); + const vfloat dPdPr0 = dPdP*v0.w; + + const vfloat A = dPdP2 - sqr(dOdP)*hy; + const vfloat B = dPdP2*dO0 - dP0*dOdP*hy + dPdPr0*(dr*dOdP); + const vfloat C = dPdP2*OO - sqr(dP0)*hy + dPdPr0*(2.0f*dr*dP0 - dPdPr0); + + const vfloat D = B*B - A*C; + valid &= D >= 0.0f; + if (unlikely(none(valid))) { + return false; + } + + /* standard case for "non-parallel" rays */ + const vfloat Q = sqrt(D); + const vfloat rcp_A = rcp(A); + /* special case for rays that are "parallel" to the cone - assume miss */ + const vbool isParallel = abs(A) <= min_rcp_input; + + vfloat t_cone_lower = select (isParallel, neg_inf, (-B-Q)*rcp_A); + vfloat t_cone_upper = select (isParallel, pos_inf, (-B+Q)*rcp_A); + const vfloat y_lower = dP0 + t_cone_lower*dOdP; + const vfloat y_upper = dP0 + t_cone_upper*dOdP; + t_cone_lower = select(valid & y_lower > 0.0f & y_lower < dPdP, t_cone_lower, pos_inf); + t_cone_upper = select(valid & y_upper > 0.0f & y_upper < dPdP, t_cone_upper, neg_inf); + + const vbool hitDisk0 = valid & cL; + const vbool hitDisk1 = valid & cR; + const vfloat rcp_dOdP = rcp(dOdP); + const vfloat t_disk0 = select (hitDisk0, select (sqr(p0*dOdP-ray_dir*dP0)<(sqr(v0.w)*sqr(dOdP)), -dP0*rcp_dOdP, pos_inf), pos_inf); + const vfloat t_disk1 = select (hitDisk1, select (sqr(p1*dOdP-ray_dir*dP1)<(sqr(v1.w)*sqr(dOdP)), -dP1*rcp_dOdP, pos_inf), pos_inf); + const vfloat t_disk_lower = min(t_disk0, t_disk1); + const vfloat t_disk_upper = max(t_disk0, t_disk1); + + const vfloat t_lower = min(t_cone_lower, t_disk_lower); + const vfloat t_upper = max(t_cone_upper, select(t_lower==t_disk_lower, + select(t_disk_upper==vfloat(pos_inf),neg_inf,t_disk_upper), + select(t_disk_lower==vfloat(pos_inf),neg_inf,t_disk_lower))); + + const vbool valid_lower = valid & ray_tnear <= dt+t_lower & dt+t_lower <= ray_tfar() & t_lower != vfloat(pos_inf); + const vbool valid_upper = valid & ray_tnear <= dt+t_upper & dt+t_upper <= ray_tfar() & t_upper != vfloat(neg_inf); + + const vbool valid_first = valid_lower | valid_upper; + if (unlikely(none(valid_first))) + return false; + + const vfloat t_first = select(valid_lower, t_lower, t_upper); + const vfloat y_first = select(valid_lower, y_lower, y_upper); + + const vfloat rcp_dPdP = rcp(dPdP); + const Vec3vf dP2drr0dP = dPdP*dr*v0.w*dP; + const Vec3vf dPhy = dP*hy; + const vbool cone_hit_first = valid & (t_first == t_cone_lower | t_first == t_cone_upper); + const vbool disk0_hit_first = valid & (t_first == t_disk0); + const Vec3vf Ng_first = select(cone_hit_first, dPdP2*(p0+t_first*ray_dir)+dP2drr0dP-dPhy*y_first, select(disk0_hit_first, -dP, dP)); + const vfloat u_first = select(cone_hit_first, y_first*rcp_dPdP, select(disk0_hit_first, vfloat(zero), vfloat(one))); + + /* invoke intersection filter for first hit */ + RoundLineIntersectorHitM hit(u_first,zero,dt+t_first,Ng_first); + const bool is_hit_first = epilog(valid_first, hit); + + /* check for possible second hits before potentially accepted hit */ + const vfloat t_second = t_upper; + const vfloat y_second = y_upper; + const vbool valid_second = valid_lower & valid_upper & (dt+t_upper <= ray_tfar()); + if (unlikely(none(valid_second))) + return is_hit_first; + + /* invoke intersection filter for second hit */ + const vbool cone_hit_second = t_second == t_cone_lower | t_second == t_cone_upper; + const vbool disk0_hit_second = t_second == t_disk0; + const Vec3vf Ng_second = select(cone_hit_second, dPdP2*(p0+t_second*ray_dir)+dP2drr0dP-dPhy*y_second, select(disk0_hit_second, -dP, dP)); + const vfloat u_second = select(cone_hit_second, y_second*rcp_dPdP, select(disk0_hit_first, vfloat(zero), vfloat(one))); + + hit = RoundLineIntersectorHitM(u_second,zero,dt+t_second,Ng_second); + const bool is_hit_second = epilog(valid_second, hit); + + return is_hit_first | is_hit_second; + } + } + + template + struct ConeLineIntersectorHitM + { + __forceinline ConeLineIntersectorHitM() {} + + __forceinline ConeLineIntersectorHitM(const vfloat& u, const vfloat& v, const vfloat& t, const Vec3vf& Ng) + : vu(u), vv(v), vt(t), vNg(Ng) {} + + __forceinline void finalize() {} + + __forceinline Vec2f uv (const size_t i) const { return Vec2f(vu[i],vv[i]); } + __forceinline float t (const size_t i) const { return vt[i]; } + __forceinline Vec3fa Ng(const size_t i) const { return Vec3fa(vNg.x[i],vNg.y[i],vNg.z[i]); } + + public: + vfloat vu; + vfloat vv; + vfloat vt; + Vec3vf vNg; + }; + + template + struct ConeCurveIntersector1 + { + typedef CurvePrecalculations1 Precalculations; + + struct ray_tfar { + Ray& ray; + __forceinline ray_tfar(Ray& ray) : ray(ray) {} + __forceinline vfloat operator() () const { return ray.tfar; }; + }; + + template + static __forceinline bool intersect(const vbool& valid_i, + Ray& ray, + IntersectContext* context, + const LineSegments* geom, + const Precalculations& pre, + const Vec4vf& v0i, const Vec4vf& v1i, + const vbool& cL, const vbool& cR, + const Epilog& epilog) + { + const Vec3vf ray_org(ray.org.x, ray.org.y, ray.org.z); + const Vec3vf ray_dir(ray.dir.x, ray.dir.y, ray.dir.z); + const vfloat ray_tnear(ray.tnear()); + const Vec4vf v0 = enlargeRadiusToMinWidth(context,geom,ray_org,v0i); + const Vec4vf v1 = enlargeRadiusToMinWidth(context,geom,ray_org,v1i); + return __coneline_internal::intersectCone(valid_i,ray_org,ray_dir,ray_tnear,ray_tfar(ray),v0,v1,cL,cR,epilog); + } + }; + + template + struct ConeCurveIntersectorK + { + typedef CurvePrecalculationsK Precalculations; + + struct ray_tfar { + RayK& ray; + size_t k; + __forceinline ray_tfar(RayK& ray, size_t k) : ray(ray), k(k) {} + __forceinline vfloat operator() () const { return ray.tfar[k]; }; + }; + + template + static __forceinline bool intersect(const vbool& valid_i, + RayK& ray, size_t k, + IntersectContext* context, + const LineSegments* geom, + const Precalculations& pre, + const Vec4vf& v0i, const Vec4vf& v1i, + const vbool& cL, const vbool& cR, + const Epilog& epilog) + { + const Vec3vf ray_org(ray.org.x[k], ray.org.y[k], ray.org.z[k]); + const Vec3vf ray_dir(ray.dir.x[k], ray.dir.y[k], ray.dir.z[k]); + const vfloat ray_tnear = ray.tnear()[k]; + const Vec4vf v0 = enlargeRadiusToMinWidth(context,geom,ray_org,v0i); + const Vec4vf v1 = enlargeRadiusToMinWidth(context,geom,ray_org,v1i); + return __coneline_internal::intersectCone(valid_i,ray_org,ray_dir,ray_tnear,ray_tfar(ray,k),v0,v1,cL,cR,epilog); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/conelinei_intersector.h b/thirdparty/embree/kernels/geometry/conelinei_intersector.h new file mode 100644 index 000000000000..d47218eb8b78 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/conelinei_intersector.h @@ -0,0 +1,141 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "coneline_intersector.h" +#include "intersector_epilog.h" + +namespace embree +{ + namespace isa + { + template + struct ConeCurveMiIntersector1 + { + typedef LineMi Primitive; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& line) + { + STAT3(normal.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; + vbool cL,cR; + line.gather(v0,v1,cL,cR,geom); + const vbool valid = line.template valid(); + ConeCurveIntersector1::intersect(valid,ray,context,geom,pre,v0,v1,cL,cR,Intersect1EpilogM(ray,context,line.geomID(),line.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& line) + { + STAT3(shadow.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; + vbool cL,cR; + line.gather(v0,v1,cL,cR,geom); + const vbool valid = line.template valid(); + return ConeCurveIntersector1::intersect(valid,ray,context,geom,pre,v0,v1,cL,cR,Occluded1EpilogM(ray,context,line.geomID(),line.primID())); + return false; + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& line) + { + return PrimitivePointQuery1::pointQuery(query, context, line); + } + }; + + template + struct ConeCurveMiMBIntersector1 + { + typedef LineMi Primitive; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& line) + { + STAT3(normal.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; + vbool cL,cR; + line.gather(v0,v1,cL,cR,geom,ray.time()); + const vbool valid = line.template valid(); + ConeCurveIntersector1::intersect(valid,ray,context,geom,pre,v0,v1,cL,cR,Intersect1EpilogM(ray,context,line.geomID(),line.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& line) + { + STAT3(shadow.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; + vbool cL,cR; + line.gather(v0,v1,cL,cR,geom,ray.time()); + const vbool valid = line.template valid(); + return ConeCurveIntersector1::intersect(valid,ray,context,geom,pre,v0,v1,cL,cR,Occluded1EpilogM(ray,context,line.geomID(),line.primID())); + return false; + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& line) + { + return PrimitivePointQuery1::pointQuery(query, context, line); + } + }; + + template + struct ConeCurveMiIntersectorK + { + typedef LineMi Primitive; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& line) + { + STAT3(normal.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; + vbool cL,cR; + line.gather(v0,v1,cL,cR,geom); + const vbool valid = line.template valid(); + ConeCurveIntersectorK::intersect(valid,ray,k,context,geom,pre,v0,v1,cL,cR,Intersect1KEpilogM(ray,k,context,line.geomID(),line.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& line) + { + STAT3(shadow.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; + vbool cL,cR; + line.gather(v0,v1,cL,cR,geom); + const vbool valid = line.template valid(); + return ConeCurveIntersectorK::intersect(valid,ray,k,context,geom,pre,v0,v1,cL,cR,Occluded1KEpilogM(ray,k,context,line.geomID(),line.primID())); + } + }; + + template + struct ConeCurveMiMBIntersectorK + { + typedef LineMi Primitive; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& line) + { + STAT3(normal.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; + vbool cL,cR; + line.gather(v0,v1,cL,cR,geom,ray.time()[k]); + const vbool valid = line.template valid(); + ConeCurveIntersectorK::intersect(valid,ray,k,context,geom,pre,v0,v1,cL,cR,Intersect1KEpilogM(ray,k,context,line.geomID(),line.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& line) + { + STAT3(shadow.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; + vbool cL,cR; + line.gather(v0,v1,cL,cR,geom,ray.time()[k]); + const vbool valid = line.template valid(); + return ConeCurveIntersectorK::intersect(valid,ray,k,context,geom,pre,v0,v1,cL,cR,Occluded1KEpilogM(ray,k,context,line.geomID(),line.primID())); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/curveNi.h b/thirdparty/embree/kernels/geometry/curveNi.h new file mode 100644 index 000000000000..00ca9b8b6503 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/curveNi.h @@ -0,0 +1,222 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" +#include "curve_intersector_precalculations.h" + +namespace embree +{ + template + struct CurveNi + { + struct Type : public PrimitiveType { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + /* Returns maximum number of stored primitives */ + static __forceinline size_t max_size() { return M; } + + /* Returns required number of primitive blocks for N primitives */ + static __forceinline size_t blocks(size_t N) { return (N+M-1)/M; } + + static __forceinline size_t bytes(size_t N) + { + const size_t f = N/M, r = N%M; + static_assert(sizeof(CurveNi) == 22+25*M, "internal data layout issue"); + return f*sizeof(CurveNi) + (r!=0)*(22 + 25*r); + } + + public: + + /*! Default constructor. */ + __forceinline CurveNi () {} + + /*! fill curve from curve list */ + __forceinline void fill(const PrimRef* prims, size_t& begin, size_t _end, Scene* scene) + { + size_t end = min(begin+M,_end); + N = (unsigned char)(end-begin); + const unsigned int geomID0 = prims[begin].geomID(); + this->geomID(N) = geomID0; + ty = (unsigned char) scene->get(geomID0)->getType(); + + /* encode all primitives */ + BBox3fa bounds = empty; + for (size_t i=0; iget(geomID)->vbounds(primID)); + } + + /* calculate offset and scale */ + Vec3fa loffset = bounds.lower; + float lscale = reduce_min(256.0f/(bounds.size()*sqrt(3.0f))); + if (bounds.size() == Vec3fa(zero)) lscale = 0.0f; + *this->offset(N) = loffset; + *this->scale(N) = lscale; + + /* encode all primitives */ + for (size_t i=0; iget(geomID)->computeAlignedSpace(primID); + + const LinearSpace3fa space3(trunc(126.0f*space2.vx),trunc(126.0f*space2.vy),trunc(126.0f*space2.vz)); + const BBox3fa bounds = scene->get(geomID)->vbounds(loffset,lscale,max(length(space3.vx),length(space3.vy),length(space3.vz)),space3.transposed(),primID); + + bounds_vx_x(N)[i] = (char) space3.vx.x; + bounds_vx_y(N)[i] = (char) space3.vx.y; + bounds_vx_z(N)[i] = (char) space3.vx.z; + bounds_vx_lower(N)[i] = (short) clamp(floor(bounds.lower.x),-32767.0f,32767.0f); + bounds_vx_upper(N)[i] = (short) clamp(ceil (bounds.upper.x),-32767.0f,32767.0f); + assert(-32767.0f <= floor(bounds.lower.x) && floor(bounds.lower.x) <= 32767.0f); + assert(-32767.0f <= ceil (bounds.upper.x) && ceil (bounds.upper.x) <= 32767.0f); + + bounds_vy_x(N)[i] = (char) space3.vy.x; + bounds_vy_y(N)[i] = (char) space3.vy.y; + bounds_vy_z(N)[i] = (char) space3.vy.z; + bounds_vy_lower(N)[i] = (short) clamp(floor(bounds.lower.y),-32767.0f,32767.0f); + bounds_vy_upper(N)[i] = (short) clamp(ceil (bounds.upper.y),-32767.0f,32767.0f); + assert(-32767.0f <= floor(bounds.lower.y) && floor(bounds.lower.y) <= 32767.0f); + assert(-32767.0f <= ceil (bounds.upper.y) && ceil (bounds.upper.y) <= 32767.0f); + + bounds_vz_x(N)[i] = (char) space3.vz.x; + bounds_vz_y(N)[i] = (char) space3.vz.y; + bounds_vz_z(N)[i] = (char) space3.vz.z; + bounds_vz_lower(N)[i] = (short) clamp(floor(bounds.lower.z),-32767.0f,32767.0f); + bounds_vz_upper(N)[i] = (short) clamp(ceil (bounds.upper.z),-32767.0f,32767.0f); + assert(-32767.0f <= floor(bounds.lower.z) && floor(bounds.lower.z) <= 32767.0f); + assert(-32767.0f <= ceil (bounds.upper.z) && ceil (bounds.upper.z) <= 32767.0f); + + this->primID(N)[i] = primID; + } + } + + template + __forceinline static typename BVH::NodeRef createLeaf (BVH* bvh, const PrimRef* prims, const range& set, const Allocator& alloc) + { + size_t start = set.begin(); + size_t items = CurveNi::blocks(set.size()); + size_t numbytes = CurveNi::bytes(set.size()); + CurveNi* accel = (CurveNi*) alloc.malloc1(numbytes,BVH::byteAlignment); + for (size_t i=0; iscene); + } + return bvh->encodeLeaf((char*)accel,items); + }; + + public: + + // 27.6 - 46 bytes per primitive + unsigned char ty; + unsigned char N; + unsigned char data[4+25*M+16]; + + /* + struct Layout + { + unsigned int geomID; + unsigned int primID[N]; + + char bounds_vx_x[N]; + char bounds_vx_y[N]; + char bounds_vx_z[N]; + short bounds_vx_lower[N]; + short bounds_vx_upper[N]; + + char bounds_vy_x[N]; + char bounds_vy_y[N]; + char bounds_vy_z[N]; + short bounds_vy_lower[N]; + short bounds_vy_upper[N]; + + char bounds_vz_x[N]; + char bounds_vz_y[N]; + char bounds_vz_z[N]; + short bounds_vz_lower[N]; + short bounds_vz_upper[N]; + + Vec3f offset; + float scale; + }; + */ + + __forceinline unsigned int& geomID(size_t N) { return *(unsigned int*)((char*)this+2); } + __forceinline const unsigned int& geomID(size_t N) const { return *(unsigned int*)((char*)this+2); } + + __forceinline unsigned int* primID(size_t N) { return (unsigned int*)((char*)this+6); } + __forceinline const unsigned int* primID(size_t N) const { return (unsigned int*)((char*)this+6); } + + __forceinline char* bounds_vx_x(size_t N) { return (char*)((char*)this+6+4*N); } + __forceinline const char* bounds_vx_x(size_t N) const { return (char*)((char*)this+6+4*N); } + + __forceinline char* bounds_vx_y(size_t N) { return (char*)((char*)this+6+5*N); } + __forceinline const char* bounds_vx_y(size_t N) const { return (char*)((char*)this+6+5*N); } + + __forceinline char* bounds_vx_z(size_t N) { return (char*)((char*)this+6+6*N); } + __forceinline const char* bounds_vx_z(size_t N) const { return (char*)((char*)this+6+6*N); } + + __forceinline short* bounds_vx_lower(size_t N) { return (short*)((char*)this+6+7*N); } + __forceinline const short* bounds_vx_lower(size_t N) const { return (short*)((char*)this+6+7*N); } + + __forceinline short* bounds_vx_upper(size_t N) { return (short*)((char*)this+6+9*N); } + __forceinline const short* bounds_vx_upper(size_t N) const { return (short*)((char*)this+6+9*N); } + + __forceinline char* bounds_vy_x(size_t N) { return (char*)((char*)this+6+11*N); } + __forceinline const char* bounds_vy_x(size_t N) const { return (char*)((char*)this+6+11*N); } + + __forceinline char* bounds_vy_y(size_t N) { return (char*)((char*)this+6+12*N); } + __forceinline const char* bounds_vy_y(size_t N) const { return (char*)((char*)this+6+12*N); } + + __forceinline char* bounds_vy_z(size_t N) { return (char*)((char*)this+6+13*N); } + __forceinline const char* bounds_vy_z(size_t N) const { return (char*)((char*)this+6+13*N); } + + __forceinline short* bounds_vy_lower(size_t N) { return (short*)((char*)this+6+14*N); } + __forceinline const short* bounds_vy_lower(size_t N) const { return (short*)((char*)this+6+14*N); } + + __forceinline short* bounds_vy_upper(size_t N) { return (short*)((char*)this+6+16*N); } + __forceinline const short* bounds_vy_upper(size_t N) const { return (short*)((char*)this+6+16*N); } + + __forceinline char* bounds_vz_x(size_t N) { return (char*)((char*)this+6+18*N); } + __forceinline const char* bounds_vz_x(size_t N) const { return (char*)((char*)this+6+18*N); } + + __forceinline char* bounds_vz_y(size_t N) { return (char*)((char*)this+6+19*N); } + __forceinline const char* bounds_vz_y(size_t N) const { return (char*)((char*)this+6+19*N); } + + __forceinline char* bounds_vz_z(size_t N) { return (char*)((char*)this+6+20*N); } + __forceinline const char* bounds_vz_z(size_t N) const { return (char*)((char*)this+6+20*N); } + + __forceinline short* bounds_vz_lower(size_t N) { return (short*)((char*)this+6+21*N); } + __forceinline const short* bounds_vz_lower(size_t N) const { return (short*)((char*)this+6+21*N); } + + __forceinline short* bounds_vz_upper(size_t N) { return (short*)((char*)this+6+23*N); } + __forceinline const short* bounds_vz_upper(size_t N) const { return (short*)((char*)this+6+23*N); } + + __forceinline Vec3f* offset(size_t N) { return (Vec3f*)((char*)this+6+25*N); } + __forceinline const Vec3f* offset(size_t N) const { return (Vec3f*)((char*)this+6+25*N); } + + __forceinline float* scale(size_t N) { return (float*)((char*)this+6+25*N+12); } + __forceinline const float* scale(size_t N) const { return (float*)((char*)this+6+25*N+12); } + + __forceinline char* end(size_t N) { return (char*)this+6+25*N+16; } + __forceinline const char* end(size_t N) const { return (char*)this+6+25*N+16; } + }; + + template + typename CurveNi::Type CurveNi::type; + + typedef CurveNi<4> Curve4i; + typedef CurveNi<8> Curve8i; +} diff --git a/thirdparty/embree/kernels/geometry/curveNi_intersector.h b/thirdparty/embree/kernels/geometry/curveNi_intersector.h new file mode 100644 index 000000000000..0f9038c9fc96 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/curveNi_intersector.h @@ -0,0 +1,569 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "curveNi.h" + +namespace embree +{ + namespace isa + { + template + struct CurveNiIntersector1 + { + typedef CurveNi Primitive; + typedef Vec3vf Vec3vfM; + typedef LinearSpace3LinearSpace3vfM; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline vbool intersect(Ray& ray, const Primitive& prim, vfloat& tNear_o) + { + const size_t N = prim.N; + const vfloat4 offset_scale = vfloat4::loadu(prim.offset(N)); + const Vec3fa offset = Vec3fa(offset_scale); + const Vec3fa scale = Vec3fa(shuffle<3,3,3,3>(offset_scale)); + const Vec3fa org1 = (ray.org-offset)*scale; + const Vec3fa dir1 = ray.dir*scale; + + const LinearSpace3vfM space(vfloat::load(prim.bounds_vx_x(N)), vfloat::load(prim.bounds_vx_y(N)), vfloat::load(prim.bounds_vx_z(N)), + vfloat::load(prim.bounds_vy_x(N)), vfloat::load(prim.bounds_vy_y(N)), vfloat::load(prim.bounds_vy_z(N)), + vfloat::load(prim.bounds_vz_x(N)), vfloat::load(prim.bounds_vz_y(N)), vfloat::load(prim.bounds_vz_z(N))); + + const Vec3vfM dir2 = xfmVector(space,Vec3vfM(dir1)); + const Vec3vfM org2 = xfmPoint (space,Vec3vfM(org1)); + const Vec3vfM rcp_dir2 = rcp_safe(dir2); + + const vfloat t_lower_x = (vfloat::load(prim.bounds_vx_lower(N))-vfloat(org2.x))*vfloat(rcp_dir2.x); + const vfloat t_upper_x = (vfloat::load(prim.bounds_vx_upper(N))-vfloat(org2.x))*vfloat(rcp_dir2.x); + const vfloat t_lower_y = (vfloat::load(prim.bounds_vy_lower(N))-vfloat(org2.y))*vfloat(rcp_dir2.y); + const vfloat t_upper_y = (vfloat::load(prim.bounds_vy_upper(N))-vfloat(org2.y))*vfloat(rcp_dir2.y); + const vfloat t_lower_z = (vfloat::load(prim.bounds_vz_lower(N))-vfloat(org2.z))*vfloat(rcp_dir2.z); + const vfloat t_upper_z = (vfloat::load(prim.bounds_vz_upper(N))-vfloat(org2.z))*vfloat(rcp_dir2.z); + + const vfloat round_up (1.0f+3.0f*float(ulp)); + const vfloat round_down(1.0f-3.0f*float(ulp)); + const vfloat tNear = round_down*max(mini(t_lower_x,t_upper_x),mini(t_lower_y,t_upper_y),mini(t_lower_z,t_upper_z),vfloat(ray.tnear())); + const vfloat tFar = round_up *min(maxi(t_lower_x,t_upper_x),maxi(t_lower_y,t_upper_y),maxi(t_lower_z,t_upper_z),vfloat(ray.tfar)); + tNear_o = tNear; + return (vint(step) < vint(prim.N)) & (tNear <= tFar); + } + + template + static __forceinline void intersect_t(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff a0,a1,a2,a3; geom->gather(a0,a1,a2,a3,geom->curve(primID)); + + size_t mask1 = mask; + const size_t i1 = bscf(mask1); + if (mask) { + const unsigned int primID1 = prim.primID(N)[i1]; + geom->prefetchL1_vertices(geom->curve(primID1)); + if (mask1) { + const size_t i2 = bsf(mask1); + const unsigned int primID2 = prim.primID(N)[i2]; + geom->prefetchL2_vertices(geom->curve(primID2)); + } + } + + Intersector().intersect(pre,ray,context,geom,primID,a0,a1,a2,a3,Epilog(ray,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + } + + template + static __forceinline bool occluded_t(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff a0,a1,a2,a3; geom->gather(a0,a1,a2,a3,geom->curve(primID)); + + size_t mask1 = mask; + const size_t i1 = bscf(mask1); + if (mask) { + const unsigned int primID1 = prim.primID(N)[i1]; + geom->prefetchL1_vertices(geom->curve(primID1)); + if (mask1) { + const size_t i2 = bsf(mask1); + const unsigned int primID2 = prim.primID(N)[i2]; + geom->prefetchL2_vertices(geom->curve(primID2)); + } + } + + if (Intersector().intersect(pre,ray,context,geom,primID,a0,a1,a2,a3,Epilog(ray,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + return false; + } + + template + static __forceinline void intersect_n(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + + unsigned int vertexID = geom->curve(primID); + Vec3ff a0,a1,a2,a3; Vec3fa n0,n1,n2,n3; geom->gather(a0,a1,a2,a3,n0,n1,n2,n3,vertexID); + + size_t mask1 = mask; + const size_t i1 = bscf(mask1); + if (mask) { + const unsigned int primID1 = prim.primID(N)[i1]; + geom->prefetchL1_vertices(geom->curve(primID1)); + if (mask1) { + const size_t i2 = bsf(mask1); + const unsigned int primID2 = prim.primID(N)[i2]; + geom->prefetchL2_vertices(geom->curve(primID2)); + } + } + + Intersector().intersect(pre,ray,context,geom,primID,a0,a1,a2,a3,n0,n1,n2,n3,Epilog(ray,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + } + + template + static __forceinline bool occluded_n(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + + unsigned int vertexID = geom->curve(primID); + Vec3ff a0,a1,a2,a3; Vec3fa n0,n1,n2,n3; geom->gather(a0,a1,a2,a3,n0,n1,n2,n3,vertexID); + + size_t mask1 = mask; + const size_t i1 = bscf(mask1); + if (mask) { + const unsigned int primID1 = prim.primID(N)[i1]; + geom->prefetchL1_vertices(geom->curve(primID1)); + if (mask1) { + const size_t i2 = bsf(mask1); + const unsigned int primID2 = prim.primID(N)[i2]; + geom->prefetchL2_vertices(geom->curve(primID2)); + } + } + + if (Intersector().intersect(pre,ray,context,geom,primID,a0,a1,a2,a3,n0,n1,n2,n3,Epilog(ray,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + return false; + } + + template + static __forceinline void intersect_h(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff p0,t0,p1,t1; geom->gather_hermite(p0,t0,p1,t1,geom->curve(primID)); + Intersector().intersect(pre,ray,context,geom,primID,p0,t0,p1,t1,Epilog(ray,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + } + + template + static __forceinline bool occluded_h(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff p0,t0,p1,t1; geom->gather_hermite(p0,t0,p1,t1,geom->curve(primID)); + if (Intersector().intersect(pre,ray,context,geom,primID,p0,t0,p1,t1,Epilog(ray,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + return false; + } + + template + static __forceinline void intersect_hn(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff p0,t0,p1,t1; Vec3fa n0,dn0,n1,dn1; geom->gather_hermite(p0,t0,n0,dn0,p1,t1,n1,dn1,geom->curve(primID)); + Intersector().intersect(pre,ray,context,geom,primID,p0,t0,p1,t1,n0,dn0,n1,dn1,Epilog(ray,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + } + + template + static __forceinline bool occluded_hn(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff p0,t0,p1,t1; Vec3fa n0,dn0,n1,dn1; geom->gather_hermite(p0,t0,n0,dn0,p1,t1,n1,dn1,geom->curve(primID)); + if (Intersector().intersect(pre,ray,context,geom,primID,p0,t0,p1,t1,n0,dn0,n1,dn1,Epilog(ray,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + return false; + } + }; + + template + struct CurveNiIntersectorK + { + typedef CurveNi Primitive; + typedef Vec3vf Vec3vfM; + typedef LinearSpace3LinearSpace3vfM; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline vbool intersect(RayK& ray, const size_t k, const Primitive& prim, vfloat& tNear_o) + { + const size_t N = prim.N; + const vfloat4 offset_scale = vfloat4::loadu(prim.offset(N)); + const Vec3fa offset = Vec3fa(offset_scale); + const Vec3fa scale = Vec3fa(shuffle<3,3,3,3>(offset_scale)); + + const Vec3fa ray_org(ray.org.x[k],ray.org.y[k],ray.org.z[k]); + const Vec3fa ray_dir(ray.dir.x[k],ray.dir.y[k],ray.dir.z[k]); + const Vec3fa org1 = (ray_org-offset)*scale; + const Vec3fa dir1 = ray_dir*scale; + + const LinearSpace3vfM space(vfloat::load(prim.bounds_vx_x(N)), vfloat::load(prim.bounds_vx_y(N)), vfloat::load(prim.bounds_vx_z(N)), + vfloat::load(prim.bounds_vy_x(N)), vfloat::load(prim.bounds_vy_y(N)), vfloat::load(prim.bounds_vy_z(N)), + vfloat::load(prim.bounds_vz_x(N)), vfloat::load(prim.bounds_vz_y(N)), vfloat::load(prim.bounds_vz_z(N))); + + const Vec3vfM dir2 = xfmVector(space,Vec3vfM(dir1)); + const Vec3vfM org2 = xfmPoint (space,Vec3vfM(org1)); + const Vec3vfM rcp_dir2 = rcp_safe(dir2); + + const vfloat t_lower_x = (vfloat::load(prim.bounds_vx_lower(N))-vfloat(org2.x))*vfloat(rcp_dir2.x); + const vfloat t_upper_x = (vfloat::load(prim.bounds_vx_upper(N))-vfloat(org2.x))*vfloat(rcp_dir2.x); + const vfloat t_lower_y = (vfloat::load(prim.bounds_vy_lower(N))-vfloat(org2.y))*vfloat(rcp_dir2.y); + const vfloat t_upper_y = (vfloat::load(prim.bounds_vy_upper(N))-vfloat(org2.y))*vfloat(rcp_dir2.y); + const vfloat t_lower_z = (vfloat::load(prim.bounds_vz_lower(N))-vfloat(org2.z))*vfloat(rcp_dir2.z); + const vfloat t_upper_z = (vfloat::load(prim.bounds_vz_upper(N))-vfloat(org2.z))*vfloat(rcp_dir2.z); + + const vfloat round_up (1.0f+3.0f*float(ulp)); + const vfloat round_down(1.0f-3.0f*float(ulp)); + const vfloat tNear = round_down*max(mini(t_lower_x,t_upper_x),mini(t_lower_y,t_upper_y),mini(t_lower_z,t_upper_z),vfloat(ray.tnear()[k])); + const vfloat tFar = round_up *min(maxi(t_lower_x,t_upper_x),maxi(t_lower_y,t_upper_y),maxi(t_lower_z,t_upper_z),vfloat(ray.tfar[k])); + tNear_o = tNear; + return (vint(step) < vint(prim.N)) & (tNear <= tFar); + } + + template + static __forceinline void intersect_t(Precalculations& pre, RayHitK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff a0,a1,a2,a3; geom->gather(a0,a1,a2,a3,geom->curve(primID)); + + size_t mask1 = mask; + const size_t i1 = bscf(mask1); + if (mask) { + const unsigned int primID1 = prim.primID(N)[i1]; + geom->prefetchL1_vertices(geom->curve(primID1)); + if (mask1) { + const size_t i2 = bsf(mask1); + const unsigned int primID2 = prim.primID(N)[i2]; + geom->prefetchL2_vertices(geom->curve(primID2)); + } + } + + Intersector().intersect(pre,ray,k,context,geom,primID,a0,a1,a2,a3,Epilog(ray,k,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + } + + template + static __forceinline bool occluded_t(Precalculations& pre, RayK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff a0,a1,a2,a3; geom->gather(a0,a1,a2,a3,geom->curve(primID)); + + size_t mask1 = mask; + const size_t i1 = bscf(mask1); + if (mask) { + const unsigned int primID1 = prim.primID(N)[i1]; + geom->prefetchL1_vertices(geom->curve(primID1)); + if (mask1) { + const size_t i2 = bsf(mask1); + const unsigned int primID2 = prim.primID(N)[i2]; + geom->prefetchL2_vertices(geom->curve(primID2)); + } + } + + if (Intersector().intersect(pre,ray,k,context,geom,primID,a0,a1,a2,a3,Epilog(ray,k,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + return false; + } + + template + static __forceinline void intersect_n(Precalculations& pre, RayHitK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + + unsigned int vertexID = geom->curve(primID); + Vec3ff a0,a1,a2,a3; Vec3fa n0,n1,n2,n3; geom->gather(a0,a1,a2,a3,n0,n1,n2,n3,vertexID); + + size_t mask1 = mask; + const size_t i1 = bscf(mask1); + if (mask) { + const unsigned int primID1 = prim.primID(N)[i1]; + geom->prefetchL1_vertices(geom->curve(primID1)); + if (mask1) { + const size_t i2 = bsf(mask1); + const unsigned int primID2 = prim.primID(N)[i2]; + geom->prefetchL2_vertices(geom->curve(primID2)); + } + } + + Intersector().intersect(pre,ray,k,context,geom,primID,a0,a1,a2,a3,n0,n1,n2,n3,Epilog(ray,k,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + } + + template + static __forceinline bool occluded_n(Precalculations& pre, RayK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + + unsigned int vertexID = geom->curve(primID); + Vec3ff a0,a1,a2,a3; Vec3fa n0,n1,n2,n3; geom->gather(a0,a1,a2,a3,n0,n1,n2,n3,vertexID); + + size_t mask1 = mask; + const size_t i1 = bscf(mask1); + if (mask) { + const unsigned int primID1 = prim.primID(N)[i1]; + geom->prefetchL1_vertices(geom->curve(primID1)); + if (mask1) { + const size_t i2 = bsf(mask1); + const unsigned int primID2 = prim.primID(N)[i2]; + geom->prefetchL2_vertices(geom->curve(primID2)); + } + } + + if (Intersector().intersect(pre,ray,k,context,geom,primID,a0,a1,a2,a3,n0,n1,n2,n3,Epilog(ray,k,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + return false; + } + + template + static __forceinline void intersect_h(Precalculations& pre, RayHitK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff p0,t0,p1,t1; geom->gather_hermite(p0,t0,p1,t1,geom->curve(primID)); + Intersector().intersect(pre,ray,k,context,geom,primID,p0,t0,p1,t1,Epilog(ray,k,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + } + + template + static __forceinline bool occluded_h(Precalculations& pre, RayK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff p0,t0,p1,t1; geom->gather_hermite(p0,t0,p1,t1,geom->curve(primID)); + if (Intersector().intersect(pre,ray,k,context,geom,primID,p0,t0,p1,t1,Epilog(ray,k,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + return false; + } + + template + static __forceinline void intersect_hn(Precalculations& pre, RayHitK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff p0,t0,p1,t1; Vec3fa n0,dn0,n1,dn1; geom->gather_hermite(p0,t0,n0,dn0,p1,t1,n1,dn1,geom->curve(primID)); + Intersector().intersect(pre,ray,k,context,geom,primID,p0,t0,p1,t1,n0,dn0,n1,dn1,Epilog(ray,k,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + } + + template + static __forceinline bool occluded_hn(Precalculations& pre, RayK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff p0,t0,p1,t1; Vec3fa n0,dn0,n1,dn1; geom->gather_hermite(p0,t0,n0,dn0,p1,t1,n1,dn1,geom->curve(primID)); + if (Intersector().intersect(pre,ray,k,context,geom,primID,p0,t0,p1,t1,n0,dn0,n1,dn1,Epilog(ray,k,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + return false; + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/curveNi_mb.h b/thirdparty/embree/kernels/geometry/curveNi_mb.h new file mode 100644 index 000000000000..d2e1926220f2 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/curveNi_mb.h @@ -0,0 +1,278 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" +#include "curve_intersector_precalculations.h" + +namespace embree +{ + template + struct CurveNiMB + { + struct Type : public PrimitiveType { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + /* Returns maximum number of stored primitives */ + static __forceinline size_t max_size() { return M; } + + /* Returns required number of primitive blocks for N primitives */ + static __forceinline size_t blocks(size_t N) { return (N+M-1)/M; } + + static __forceinline size_t bytes(size_t N) + { + const size_t f = N/M, r = N%M; + static_assert(sizeof(CurveNiMB) == 6+37*M+24, "internal data layout issue"); + return f*sizeof(CurveNiMB) + (r!=0)*(6+37*r+24); + } + + public: + + /*! Default constructor. */ + __forceinline CurveNiMB () {} + + /*! fill curve from curve list */ + __forceinline LBBox3fa fillMB(const PrimRefMB* prims, size_t& begin, size_t _end, Scene* scene, const BBox1f time_range) + { + size_t end = min(begin+M,_end); + N = (unsigned char)(end-begin); + const unsigned int geomID0 = prims[begin].geomID(); + this->geomID(N) = geomID0; + ty = (unsigned char) scene->get(geomID0)->getType(); + + /* encode all primitives */ + LBBox3fa lbounds = empty; + for (size_t i=0; iget(geomID)->vlinearBounds(primID,time_range)); + } + BBox3fa bounds = lbounds.bounds(); + + /* calculate offset and scale */ + Vec3fa loffset = bounds.lower; + float lscale = reduce_min(256.0f/(bounds.size()*sqrt(3.0f))); + if (bounds.size() == Vec3fa(zero)) lscale = 0.0f; + *this->offset(N) = loffset; + *this->scale(N) = lscale; + this->time_offset(N) = time_range.lower; + this->time_scale(N) = 1.0f/time_range.size(); + + /* encode all primitives */ + for (size_t i=0; iget(geomID)->computeAlignedSpaceMB(primID,time_range); + + const LinearSpace3fa space3(trunc(126.0f*space2.vx),trunc(126.0f*space2.vy),trunc(126.0f*space2.vz)); + const LBBox3fa bounds = scene->get(geomID)->vlinearBounds(loffset,lscale,max(length(space3.vx),length(space3.vy),length(space3.vz)),space3.transposed(),primID,time_range); + + // NOTE: this weird (char) (short) cast works around VS2015 Win32 compiler bug + bounds_vx_x(N)[i] = (char) (short) space3.vx.x; + bounds_vx_y(N)[i] = (char) (short) space3.vx.y; + bounds_vx_z(N)[i] = (char) (short) space3.vx.z; + bounds_vx_lower0(N)[i] = (short) clamp(floor(bounds.bounds0.lower.x),-32767.0f,32767.0f); + bounds_vx_upper0(N)[i] = (short) clamp(ceil (bounds.bounds0.upper.x),-32767.0f,32767.0f); + bounds_vx_lower1(N)[i] = (short) clamp(floor(bounds.bounds1.lower.x),-32767.0f,32767.0f); + bounds_vx_upper1(N)[i] = (short) clamp(ceil (bounds.bounds1.upper.x),-32767.0f,32767.0f); + assert(-32767.0f <= floor(bounds.bounds0.lower.x) && floor(bounds.bounds0.lower.x) <= 32767.0f); + assert(-32767.0f <= ceil (bounds.bounds0.upper.x) && ceil (bounds.bounds0.upper.x) <= 32767.0f); + assert(-32767.0f <= floor(bounds.bounds1.lower.x) && floor(bounds.bounds1.lower.x) <= 32767.0f); + assert(-32767.0f <= ceil (bounds.bounds1.upper.x) && ceil (bounds.bounds1.upper.x) <= 32767.0f); + + bounds_vy_x(N)[i] = (char) (short) space3.vy.x; + bounds_vy_y(N)[i] = (char) (short) space3.vy.y; + bounds_vy_z(N)[i] = (char) (short) space3.vy.z; + bounds_vy_lower0(N)[i] = (short) clamp(floor(bounds.bounds0.lower.y),-32767.0f,32767.0f); + bounds_vy_upper0(N)[i] = (short) clamp(ceil (bounds.bounds0.upper.y),-32767.0f,32767.0f); + bounds_vy_lower1(N)[i] = (short) clamp(floor(bounds.bounds1.lower.y),-32767.0f,32767.0f); + bounds_vy_upper1(N)[i] = (short) clamp(ceil (bounds.bounds1.upper.y),-32767.0f,32767.0f); + assert(-32767.0f <= floor(bounds.bounds0.lower.y) && floor(bounds.bounds0.lower.y) <= 32767.0f); + assert(-32767.0f <= ceil (bounds.bounds0.upper.y) && ceil (bounds.bounds0.upper.y) <= 32767.0f); + assert(-32767.0f <= floor(bounds.bounds1.lower.y) && floor(bounds.bounds1.lower.y) <= 32767.0f); + assert(-32767.0f <= ceil (bounds.bounds1.upper.y) && ceil (bounds.bounds1.upper.y) <= 32767.0f); + + bounds_vz_x(N)[i] = (char) (short) space3.vz.x; + bounds_vz_y(N)[i] = (char) (short) space3.vz.y; + bounds_vz_z(N)[i] = (char) (short) space3.vz.z; + bounds_vz_lower0(N)[i] = (short) clamp(floor(bounds.bounds0.lower.z),-32767.0f,32767.0f); + bounds_vz_upper0(N)[i] = (short) clamp(ceil (bounds.bounds0.upper.z),-32767.0f,32767.0f); + bounds_vz_lower1(N)[i] = (short) clamp(floor(bounds.bounds1.lower.z),-32767.0f,32767.0f); + bounds_vz_upper1(N)[i] = (short) clamp(ceil (bounds.bounds1.upper.z),-32767.0f,32767.0f); + assert(-32767.0f <= floor(bounds.bounds0.lower.z) && floor(bounds.bounds0.lower.z) <= 32767.0f); + assert(-32767.0f <= ceil (bounds.bounds0.upper.z) && ceil (bounds.bounds0.upper.z) <= 32767.0f); + assert(-32767.0f <= floor(bounds.bounds1.lower.z) && floor(bounds.bounds1.lower.z) <= 32767.0f); + assert(-32767.0f <= ceil (bounds.bounds1.upper.z) && ceil (bounds.bounds1.upper.z) <= 32767.0f); + + this->primID(N)[i] = primID; + } + + return lbounds; + } + + template + __forceinline static typename BVH::NodeRecordMB4D createLeafMB(BVH* bvh, const SetMB& prims, const Allocator& alloc) + { + size_t start = prims.begin(); + size_t end = prims.end(); + size_t items = CurveNiMB::blocks(prims.size()); + size_t numbytes = CurveNiMB::bytes(prims.size()); + CurveNiMB* accel = (CurveNiMB*) alloc.malloc1(numbytes,BVH::byteAlignment); + const typename BVH::NodeRef node = bvh->encodeLeaf((char*)accel,items); + + LBBox3fa bounds = empty; + for (size_t i=0; idata(),start,end,bvh->scene,prims.time_range)); + + return typename BVH::NodeRecordMB4D(node,bounds,prims.time_range); + }; + + + public: + + // 27.6 - 46 bytes per primitive + unsigned char ty; + unsigned char N; + unsigned char data[4+37*M+24]; + + /* + struct Layout + { + unsigned int geomID; + unsigned int primID[N]; + + char bounds_vx_x[N]; + char bounds_vx_y[N]; + char bounds_vx_z[N]; + short bounds_vx_lower0[N]; + short bounds_vx_upper0[N]; + short bounds_vx_lower1[N]; + short bounds_vx_upper1[N]; + + char bounds_vy_x[N]; + char bounds_vy_y[N]; + char bounds_vy_z[N]; + short bounds_vy_lower0[N]; + short bounds_vy_upper0[N]; + short bounds_vy_lower1[N]; + short bounds_vy_upper1[N]; + + char bounds_vz_x[N]; + char bounds_vz_y[N]; + char bounds_vz_z[N]; + short bounds_vz_lower0[N]; + short bounds_vz_upper0[N]; + short bounds_vz_lower1[N]; + short bounds_vz_upper1[N]; + + Vec3f offset; + float scale; + + float time_offset; + float time_scale; + }; + */ + + __forceinline unsigned int& geomID(size_t N) { return *(unsigned int*)((char*)this+2); } + __forceinline const unsigned int& geomID(size_t N) const { return *(unsigned int*)((char*)this+2); } + + __forceinline unsigned int* primID(size_t N) { return (unsigned int*)((char*)this+6); } + __forceinline const unsigned int* primID(size_t N) const { return (unsigned int*)((char*)this+6); } + + __forceinline char* bounds_vx_x(size_t N) { return (char*)((char*)this+6+4*N); } + __forceinline const char* bounds_vx_x(size_t N) const { return (char*)((char*)this+6+4*N); } + + __forceinline char* bounds_vx_y(size_t N) { return (char*)((char*)this+6+5*N); } + __forceinline const char* bounds_vx_y(size_t N) const { return (char*)((char*)this+6+5*N); } + + __forceinline char* bounds_vx_z(size_t N) { return (char*)((char*)this+6+6*N); } + __forceinline const char* bounds_vx_z(size_t N) const { return (char*)((char*)this+6+6*N); } + + __forceinline short* bounds_vx_lower0(size_t N) { return (short*)((char*)this+6+7*N); } + __forceinline const short* bounds_vx_lower0(size_t N) const { return (short*)((char*)this+6+7*N); } + + __forceinline short* bounds_vx_upper0(size_t N) { return (short*)((char*)this+6+9*N); } + __forceinline const short* bounds_vx_upper0(size_t N) const { return (short*)((char*)this+6+9*N); } + + __forceinline short* bounds_vx_lower1(size_t N) { return (short*)((char*)this+6+11*N); } + __forceinline const short* bounds_vx_lower1(size_t N) const { return (short*)((char*)this+6+11*N); } + + __forceinline short* bounds_vx_upper1(size_t N) { return (short*)((char*)this+6+13*N); } + __forceinline const short* bounds_vx_upper1(size_t N) const { return (short*)((char*)this+6+13*N); } + + __forceinline char* bounds_vy_x(size_t N) { return (char*)((char*)this+6+15*N); } + __forceinline const char* bounds_vy_x(size_t N) const { return (char*)((char*)this+6+15*N); } + + __forceinline char* bounds_vy_y(size_t N) { return (char*)((char*)this+6+16*N); } + __forceinline const char* bounds_vy_y(size_t N) const { return (char*)((char*)this+6+16*N); } + + __forceinline char* bounds_vy_z(size_t N) { return (char*)((char*)this+6+17*N); } + __forceinline const char* bounds_vy_z(size_t N) const { return (char*)((char*)this+6+17*N); } + + __forceinline short* bounds_vy_lower0(size_t N) { return (short*)((char*)this+6+18*N); } + __forceinline const short* bounds_vy_lower0(size_t N) const { return (short*)((char*)this+6+18*N); } + + __forceinline short* bounds_vy_upper0(size_t N) { return (short*)((char*)this+6+20*N); } + __forceinline const short* bounds_vy_upper0(size_t N) const { return (short*)((char*)this+6+20*N); } + + __forceinline short* bounds_vy_lower1(size_t N) { return (short*)((char*)this+6+22*N); } + __forceinline const short* bounds_vy_lower1(size_t N) const { return (short*)((char*)this+6+22*N); } + + __forceinline short* bounds_vy_upper1(size_t N) { return (short*)((char*)this+6+24*N); } + __forceinline const short* bounds_vy_upper1(size_t N) const { return (short*)((char*)this+6+24*N); } + + __forceinline char* bounds_vz_x(size_t N) { return (char*)((char*)this+6+26*N); } + __forceinline const char* bounds_vz_x(size_t N) const { return (char*)((char*)this+6+26*N); } + + __forceinline char* bounds_vz_y(size_t N) { return (char*)((char*)this+6+27*N); } + __forceinline const char* bounds_vz_y(size_t N) const { return (char*)((char*)this+6+27*N); } + + __forceinline char* bounds_vz_z(size_t N) { return (char*)((char*)this+6+28*N); } + __forceinline const char* bounds_vz_z(size_t N) const { return (char*)((char*)this+6+28*N); } + + __forceinline short* bounds_vz_lower0(size_t N) { return (short*)((char*)this+6+29*N); } + __forceinline const short* bounds_vz_lower0(size_t N) const { return (short*)((char*)this+6+29*N); } + + __forceinline short* bounds_vz_upper0(size_t N) { return (short*)((char*)this+6+31*N); } + __forceinline const short* bounds_vz_upper0(size_t N) const { return (short*)((char*)this+6+31*N); } + + __forceinline short* bounds_vz_lower1(size_t N) { return (short*)((char*)this+6+33*N); } + __forceinline const short* bounds_vz_lower1(size_t N) const { return (short*)((char*)this+6+33*N); } + + __forceinline short* bounds_vz_upper1(size_t N) { return (short*)((char*)this+6+35*N); } + __forceinline const short* bounds_vz_upper1(size_t N) const { return (short*)((char*)this+6+35*N); } + + __forceinline Vec3f* offset(size_t N) { return (Vec3f*)((char*)this+6+37*N); } + __forceinline const Vec3f* offset(size_t N) const { return (Vec3f*)((char*)this+6+37*N); } + + __forceinline float* scale(size_t N) { return (float*)((char*)this+6+37*N+12); } + __forceinline const float* scale(size_t N) const { return (float*)((char*)this+6+37*N+12); } + + __forceinline float& time_offset(size_t N) { return *(float*)((char*)this+6+37*N+16); } + __forceinline const float& time_offset(size_t N) const { return *(float*)((char*)this+6+37*N+16); } + + __forceinline float& time_scale(size_t N) { return *(float*)((char*)this+6+37*N+20); } + __forceinline const float& time_scale(size_t N) const { return *(float*)((char*)this+6+37*N+20); } + + __forceinline char* end(size_t N) { return (char*)this+6+37*N+24; } + __forceinline const char* end(size_t N) const { return (char*)this+6+37*N+24; } + }; + + template + typename CurveNiMB::Type CurveNiMB::type; + + typedef CurveNiMB<4> Curve4iMB; + typedef CurveNiMB<8> Curve8iMB; +} diff --git a/thirdparty/embree/kernels/geometry/curveNi_mb_intersector.h b/thirdparty/embree/kernels/geometry/curveNi_mb_intersector.h new file mode 100644 index 000000000000..0cbc76466834 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/curveNi_mb_intersector.h @@ -0,0 +1,516 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "curveNi_mb.h" +#include "../subdiv/linear_bezier_patch.h" + +namespace embree +{ + namespace isa + { + template + struct CurveNiMBIntersector1 + { + typedef CurveNiMB Primitive; + typedef Vec3vf Vec3vfM; + typedef LinearSpace3LinearSpace3vfM; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline vbool intersect(Ray& ray, const Primitive& prim, vfloat& tNear_o) + { + const size_t N = prim.N; + const vfloat4 offset_scale = vfloat4::loadu(prim.offset(N)); + const Vec3fa offset = Vec3fa(offset_scale); + const Vec3fa scale = Vec3fa(shuffle<3,3,3,3>(offset_scale)); + const Vec3fa org1 = (ray.org-offset)*scale; + const Vec3fa dir1 = ray.dir*scale; + + const LinearSpace3vfM space(vfloat::load(prim.bounds_vx_x(N)), vfloat::load(prim.bounds_vx_y(N)), vfloat::load(prim.bounds_vx_z(N)), + vfloat::load(prim.bounds_vy_x(N)), vfloat::load(prim.bounds_vy_y(N)), vfloat::load(prim.bounds_vy_z(N)), + vfloat::load(prim.bounds_vz_x(N)), vfloat::load(prim.bounds_vz_y(N)), vfloat::load(prim.bounds_vz_z(N))); + + const Vec3vfM dir2 = xfmVector(space,Vec3vfM(dir1)); + const Vec3vfM org2 = xfmPoint (space,Vec3vfM(org1)); + const Vec3vfM rcp_dir2 = rcp_safe(dir2); + + const vfloat ltime = (ray.time()-prim.time_offset(N))*prim.time_scale(N); + const vfloat vx_lower0 = vfloat::load(prim.bounds_vx_lower0(N)); + const vfloat vx_lower1 = vfloat::load(prim.bounds_vx_lower1(N)); + const vfloat vx_lower = madd(ltime,vx_lower1-vx_lower0,vx_lower0); + const vfloat vx_upper0 = vfloat::load(prim.bounds_vx_upper0(N)); + const vfloat vx_upper1 = vfloat::load(prim.bounds_vx_upper1(N)); + const vfloat vx_upper = madd(ltime,vx_upper1-vx_upper0,vx_upper0); + + const vfloat vy_lower0 = vfloat::load(prim.bounds_vy_lower0(N)); + const vfloat vy_lower1 = vfloat::load(prim.bounds_vy_lower1(N)); + const vfloat vy_lower = madd(ltime,vy_lower1-vy_lower0,vy_lower0); + const vfloat vy_upper0 = vfloat::load(prim.bounds_vy_upper0(N)); + const vfloat vy_upper1 = vfloat::load(prim.bounds_vy_upper1(N)); + const vfloat vy_upper = madd(ltime,vy_upper1-vy_upper0,vy_upper0); + + const vfloat vz_lower0 = vfloat::load(prim.bounds_vz_lower0(N)); + const vfloat vz_lower1 = vfloat::load(prim.bounds_vz_lower1(N)); + const vfloat vz_lower = madd(ltime,vz_lower1-vz_lower0,vz_lower0); + const vfloat vz_upper0 = vfloat::load(prim.bounds_vz_upper0(N)); + const vfloat vz_upper1 = vfloat::load(prim.bounds_vz_upper1(N)); + const vfloat vz_upper = madd(ltime,vz_upper1-vz_upper0,vz_upper0); + + const vfloat t_lower_x = (vx_lower-vfloat(org2.x))*vfloat(rcp_dir2.x); + const vfloat t_upper_x = (vx_upper-vfloat(org2.x))*vfloat(rcp_dir2.x); + const vfloat t_lower_y = (vy_lower-vfloat(org2.y))*vfloat(rcp_dir2.y); + const vfloat t_upper_y = (vy_upper-vfloat(org2.y))*vfloat(rcp_dir2.y); + const vfloat t_lower_z = (vz_lower-vfloat(org2.z))*vfloat(rcp_dir2.z); + const vfloat t_upper_z = (vz_upper-vfloat(org2.z))*vfloat(rcp_dir2.z); + + const vfloat round_up (1.0f+3.0f*float(ulp)); + const vfloat round_down(1.0f-3.0f*float(ulp)); + const vfloat tNear = round_down*max(mini(t_lower_x,t_upper_x),mini(t_lower_y,t_upper_y),mini(t_lower_z,t_upper_z),vfloat(ray.tnear())); + const vfloat tFar = round_up *min(maxi(t_lower_x,t_upper_x),maxi(t_lower_y,t_upper_y),maxi(t_lower_z,t_upper_z),vfloat(ray.tfar)); + tNear_o = tNear; + return (vint(step) < vint(prim.N)) & (tNear <= tFar); + } + + template + static __forceinline void intersect_t(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff a0,a1,a2,a3; geom->gather(a0,a1,a2,a3,geom->curve(primID),ray.time()); + + Intersector().intersect(pre,ray,context,geom,primID,a0,a1,a2,a3,Epilog(ray,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + } + + template + static __forceinline bool occluded_t(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff a0,a1,a2,a3; geom->gather(a0,a1,a2,a3,geom->curve(primID),ray.time()); + + if (Intersector().intersect(pre,ray,context,geom,primID,a0,a1,a2,a3,Epilog(ray,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + return false; + } + + template + static __forceinline void intersect_n(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + const TensorLinearCubicBezierSurface3fa curve = geom->getNormalOrientedCurve(context, ray.org, primID,ray.time()); + Intersector().intersect(pre,ray,context,geom,primID,curve,Epilog(ray,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + } + + template + static __forceinline bool occluded_n(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + const TensorLinearCubicBezierSurface3fa curve = geom->getNormalOrientedCurve(context, ray.org, primID,ray.time()); + + if (Intersector().intersect(pre,ray,context,geom,primID,curve,Epilog(ray,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + return false; + } + + template + static __forceinline void intersect_h(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff p0,t0,p1,t1; geom->gather_hermite(p0,t0,p1,t1,geom->curve(primID),ray.time()); + Intersector().intersect(pre,ray,context,geom,primID,p0,t0,p1,t1,Epilog(ray,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + } + + template + static __forceinline bool occluded_h(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff p0,t0,p1,t1; geom->gather_hermite(p0,t0,p1,t1,geom->curve(primID),ray.time()); + if (Intersector().intersect(pre,ray,context,geom,primID,p0,t0,p1,t1,Epilog(ray,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + return false; + } + + template + static __forceinline void intersect_hn(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + const TensorLinearCubicBezierSurface3fa curve = geom->getNormalOrientedHermiteCurve(context, ray.org, primID,ray.time()); + Intersector().intersect(pre,ray,context,geom,primID,curve,Epilog(ray,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + } + + template + static __forceinline bool occluded_hn(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + const TensorLinearCubicBezierSurface3fa curve = geom->getNormalOrientedHermiteCurve(context, ray.org, primID,ray.time()); + if (Intersector().intersect(pre,ray,context,geom,primID,curve,Epilog(ray,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + return false; + } + }; + + template + struct CurveNiMBIntersectorK + { + typedef CurveNiMB Primitive; + typedef Vec3vf Vec3vfM; + typedef LinearSpace3LinearSpace3vfM; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline vbool intersect(RayK& ray, const size_t k, const Primitive& prim, vfloat& tNear_o) + { + const size_t N = prim.N; + const vfloat4 offset_scale = vfloat4::loadu(prim.offset(N)); + const Vec3fa offset = Vec3fa(offset_scale); + const Vec3fa scale = Vec3fa(shuffle<3,3,3,3>(offset_scale)); + + const Vec3fa ray_org(ray.org.x[k],ray.org.y[k],ray.org.z[k]); + const Vec3fa ray_dir(ray.dir.x[k],ray.dir.y[k],ray.dir.z[k]); + const Vec3fa org1 = (ray_org-offset)*scale; + const Vec3fa dir1 = ray_dir*scale; + + const LinearSpace3vfM space(vfloat::load(prim.bounds_vx_x(N)), vfloat::load(prim.bounds_vx_y(N)), vfloat::load(prim.bounds_vx_z(N)), + vfloat::load(prim.bounds_vy_x(N)), vfloat::load(prim.bounds_vy_y(N)), vfloat::load(prim.bounds_vy_z(N)), + vfloat::load(prim.bounds_vz_x(N)), vfloat::load(prim.bounds_vz_y(N)), vfloat::load(prim.bounds_vz_z(N))); + + const Vec3vfM dir2 = xfmVector(space,Vec3vfM(dir1)); + const Vec3vfM org2 = xfmPoint (space,Vec3vfM(org1)); + const Vec3vfM rcp_dir2 = rcp_safe(dir2); + + const vfloat ltime = (ray.time()[k]-prim.time_offset(N))*prim.time_scale(N); + const vfloat vx_lower0 = vfloat::load(prim.bounds_vx_lower0(N)); + const vfloat vx_lower1 = vfloat::load(prim.bounds_vx_lower1(N)); + const vfloat vx_lower = madd(ltime,vx_lower1-vx_lower0,vx_lower0); + const vfloat vx_upper0 = vfloat::load(prim.bounds_vx_upper0(N)); + const vfloat vx_upper1 = vfloat::load(prim.bounds_vx_upper1(N)); + const vfloat vx_upper = madd(ltime,vx_upper1-vx_upper0,vx_upper0); + + const vfloat vy_lower0 = vfloat::load(prim.bounds_vy_lower0(N)); + const vfloat vy_lower1 = vfloat::load(prim.bounds_vy_lower1(N)); + const vfloat vy_lower = madd(ltime,vy_lower1-vy_lower0,vy_lower0); + const vfloat vy_upper0 = vfloat::load(prim.bounds_vy_upper0(N)); + const vfloat vy_upper1 = vfloat::load(prim.bounds_vy_upper1(N)); + const vfloat vy_upper = madd(ltime,vy_upper1-vy_upper0,vy_upper0); + + const vfloat vz_lower0 = vfloat::load(prim.bounds_vz_lower0(N)); + const vfloat vz_lower1 = vfloat::load(prim.bounds_vz_lower1(N)); + const vfloat vz_lower = madd(ltime,vz_lower1-vz_lower0,vz_lower0); + const vfloat vz_upper0 = vfloat::load(prim.bounds_vz_upper0(N)); + const vfloat vz_upper1 = vfloat::load(prim.bounds_vz_upper1(N)); + const vfloat vz_upper = madd(ltime,vz_upper1-vz_upper0,vz_upper0); + + const vfloat t_lower_x = (vx_lower-vfloat(org2.x))*vfloat(rcp_dir2.x); + const vfloat t_upper_x = (vx_upper-vfloat(org2.x))*vfloat(rcp_dir2.x); + const vfloat t_lower_y = (vy_lower-vfloat(org2.y))*vfloat(rcp_dir2.y); + const vfloat t_upper_y = (vy_upper-vfloat(org2.y))*vfloat(rcp_dir2.y); + const vfloat t_lower_z = (vz_lower-vfloat(org2.z))*vfloat(rcp_dir2.z); + const vfloat t_upper_z = (vz_upper-vfloat(org2.z))*vfloat(rcp_dir2.z); + + const vfloat round_up (1.0f+3.0f*float(ulp)); + const vfloat round_down(1.0f-3.0f*float(ulp)); + const vfloat tNear = round_down*max(mini(t_lower_x,t_upper_x),mini(t_lower_y,t_upper_y),mini(t_lower_z,t_upper_z),vfloat(ray.tnear()[k])); + const vfloat tFar = round_up *min(maxi(t_lower_x,t_upper_x),maxi(t_lower_y,t_upper_y),maxi(t_lower_z,t_upper_z),vfloat(ray.tfar[k])); + tNear_o = tNear; + return (vint(step) < vint(prim.N)) & (tNear <= tFar); + } + + template + static __forceinline void intersect_t(Precalculations& pre, RayHitK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff a0,a1,a2,a3; geom->gather(a0,a1,a2,a3,geom->curve(primID),ray.time()[k]); + + Intersector().intersect(pre,ray,k,context,geom,primID,a0,a1,a2,a3,Epilog(ray,k,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + } + + template + static __forceinline bool occluded_t(Precalculations& pre, RayK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff a0,a1,a2,a3; geom->gather(a0,a1,a2,a3,geom->curve(primID),ray.time()[k]); + + if (Intersector().intersect(pre,ray,k,context,geom,primID,a0,a1,a2,a3,Epilog(ray,k,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + return false; + } + + template + static __forceinline void intersect_n(Precalculations& pre, RayHitK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + const Vec3fa ray_org(ray.org.x[k], ray.org.y[k], ray.org.z[k]); + const TensorLinearCubicBezierSurface3fa curve = geom->getNormalOrientedCurve(context, ray_org, primID,ray.time()[k]); + Intersector().intersect(pre,ray,k,context,geom,primID,curve,Epilog(ray,k,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + } + + template + static __forceinline bool occluded_n(Precalculations& pre, RayK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + const Vec3fa ray_org(ray.org.x[k], ray.org.y[k], ray.org.z[k]); + const TensorLinearCubicBezierSurface3fa curve = geom->getNormalOrientedCurve(context, ray_org, primID,ray.time()[k]); + + if (Intersector().intersect(pre,ray,k,context,geom,primID,curve,Epilog(ray,k,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + return false; + } + + template + static __forceinline void intersect_h(Precalculations& pre, RayHitK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff p0,t0,p1,t1; geom->gather_hermite(p0,t0,p1,t1,geom->curve(primID),ray.time()[k]); + Intersector().intersect(pre,ray,k,context,geom,primID,p0,t0,p1,t1,Epilog(ray,k,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + } + + template + static __forceinline bool occluded_h(Precalculations& pre, RayK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + Vec3ff p0,t0,p1,t1; geom->gather_hermite(p0,t0,p1,t1,geom->curve(primID),ray.time()[k]); + if (Intersector().intersect(pre,ray,k,context,geom,primID,p0,t0,p1,t1,Epilog(ray,k,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + return false; + } + + template + static __forceinline void intersect_hn(Precalculations& pre, RayHitK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + const Vec3fa ray_org(ray.org.x[k], ray.org.y[k], ray.org.z[k]); + const TensorLinearCubicBezierSurface3fa curve = geom->getNormalOrientedHermiteCurve(context, ray_org, primID,ray.time()[k]); + Intersector().intersect(pre,ray,k,context,geom,primID,curve,Epilog(ray,k,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + } + + template + static __forceinline bool occluded_hn(Precalculations& pre, RayK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = context->scene->get(geomID); + const Vec3fa ray_org(ray.org.x[k], ray.org.y[k], ray.org.z[k]); + const TensorLinearCubicBezierSurface3fa curve = geom->getNormalOrientedHermiteCurve(context, ray_org, primID,ray.time()[k]); + if (Intersector().intersect(pre,ray,k,context,geom,primID,curve,Epilog(ray,k,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + return false; + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/curveNv.h b/thirdparty/embree/kernels/geometry/curveNv.h new file mode 100644 index 000000000000..6eb5e30b39aa --- /dev/null +++ b/thirdparty/embree/kernels/geometry/curveNv.h @@ -0,0 +1,101 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "curveNi.h" + +namespace embree +{ + template + struct CurveNv : public CurveNi + { + using CurveNi::N; + + struct Type : public PrimitiveType { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + /* Returns maximum number of stored primitives */ + static __forceinline size_t max_size() { return M; } + + /* Returns required number of primitive blocks for N primitives */ + static __forceinline size_t blocks(size_t N) { return (N+M-1)/M; } + + static __forceinline size_t bytes(size_t N) + { + const size_t f = N/M, r = N%M; + static_assert(sizeof(CurveNv) == 22+25*M+4*16*M, "internal data layout issue"); + return f*sizeof(CurveNv) + (r!=0)*(22 + 25*r + 4*16*r); + } + + public: + + /*! Default constructor. */ + __forceinline CurveNv () {} + + /*! fill curve from curve list */ + __forceinline void fill(const PrimRef* prims, size_t& begin, size_t _end, Scene* scene) + { + size_t end = min(begin+M,_end); + size_t N = end-begin; + + /* encode all primitives */ + for (size_t i=0; iget(geomID); + const unsigned vtxID = mesh->curve(primID); + Vec3fa::storeu(&this->vertices(i,N)[0],mesh->vertex(vtxID+0)); + Vec3fa::storeu(&this->vertices(i,N)[1],mesh->vertex(vtxID+1)); + Vec3fa::storeu(&this->vertices(i,N)[2],mesh->vertex(vtxID+2)); + Vec3fa::storeu(&this->vertices(i,N)[3],mesh->vertex(vtxID+3)); + } + } + + template + __forceinline static typename BVH::NodeRef createLeaf (BVH* bvh, const PrimRef* prims, const range& set, const Allocator& alloc) + { + if (set.size() == 0) + return BVH::emptyNode; + + /* fall back to CurveNi for oriented curves */ + unsigned int geomID = prims[set.begin()].geomID(); + if (bvh->scene->get(geomID)->getCurveType() == Geometry::GTY_SUBTYPE_ORIENTED_CURVE) { + return CurveNi::createLeaf(bvh,prims,set,alloc); + } + if (bvh->scene->get(geomID)->getCurveBasis() == Geometry::GTY_BASIS_HERMITE) { + return CurveNi::createLeaf(bvh,prims,set,alloc); + } + + size_t start = set.begin(); + size_t items = CurveNv::blocks(set.size()); + size_t numbytes = CurveNv::bytes(set.size()); + CurveNv* accel = (CurveNv*) alloc.malloc1(numbytes,BVH::byteAlignment); + for (size_t i=0; i::fill(prims,start,set.end(),bvh->scene); + accel[i].CurveNi::fill(prims,start,set.end(),bvh->scene); + } + return bvh->encodeLeaf((char*)accel,items); + }; + + public: + unsigned char data[4*16*M]; + __forceinline Vec3fa* vertices(size_t i, size_t N) { return (Vec3fa*)CurveNi::end(N)+4*i; } + __forceinline const Vec3fa* vertices(size_t i, size_t N) const { return (Vec3fa*)CurveNi::end(N)+4*i; } + }; + + template + typename CurveNv::Type CurveNv::type; + + typedef CurveNv<4> Curve4v; + typedef CurveNv<8> Curve8v; +} diff --git a/thirdparty/embree/kernels/geometry/curveNv_intersector.h b/thirdparty/embree/kernels/geometry/curveNv_intersector.h new file mode 100644 index 000000000000..e20da2882efd --- /dev/null +++ b/thirdparty/embree/kernels/geometry/curveNv_intersector.h @@ -0,0 +1,181 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "curveNv.h" +#include "curveNi_intersector.h" + +namespace embree +{ + namespace isa + { + template + struct CurveNvIntersector1 : public CurveNiIntersector1 + { + typedef CurveNv Primitive; + typedef CurvePrecalculations1 Precalculations; + + template + static __forceinline void intersect_t(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = CurveNiIntersector1::intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = (CurveGeometry*) context->scene->get(geomID); + const Vec3ff a0 = Vec3ff::loadu(&prim.vertices(i,N)[0]); + const Vec3ff a1 = Vec3ff::loadu(&prim.vertices(i,N)[1]); + const Vec3ff a2 = Vec3ff::loadu(&prim.vertices(i,N)[2]); + const Vec3ff a3 = Vec3ff::loadu(&prim.vertices(i,N)[3]); + + size_t mask1 = mask; + const size_t i1 = bscf(mask1); + if (mask) { + prefetchL1(&prim.vertices(i1,N)[0]); + prefetchL1(&prim.vertices(i1,N)[4]); + if (mask1) { + const size_t i2 = bsf(mask1); + prefetchL2(&prim.vertices(i2,N)[0]); + prefetchL2(&prim.vertices(i2,N)[4]); + } + } + + Intersector().intersect(pre,ray,context,geom,primID,a0,a1,a2,a3,Epilog(ray,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + } + + template + static __forceinline bool occluded_t(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = CurveNiIntersector1::intersect(ray,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = (CurveGeometry*) context->scene->get(geomID); + const Vec3ff a0 = Vec3ff::loadu(&prim.vertices(i,N)[0]); + const Vec3ff a1 = Vec3ff::loadu(&prim.vertices(i,N)[1]); + const Vec3ff a2 = Vec3ff::loadu(&prim.vertices(i,N)[2]); + const Vec3ff a3 = Vec3ff::loadu(&prim.vertices(i,N)[3]); + + size_t mask1 = mask; + const size_t i1 = bscf(mask1); + if (mask) { + prefetchL1(&prim.vertices(i1,N)[0]); + prefetchL1(&prim.vertices(i1,N)[4]); + if (mask1) { + const size_t i2 = bsf(mask1); + prefetchL2(&prim.vertices(i2,N)[0]); + prefetchL2(&prim.vertices(i2,N)[4]); + } + } + + if (Intersector().intersect(pre,ray,context,geom,primID,a0,a1,a2,a3,Epilog(ray,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar)); + } + return false; + } + }; + + template + struct CurveNvIntersectorK : public CurveNiIntersectorK + { + typedef CurveNv Primitive; + typedef CurvePrecalculationsK Precalculations; + + template + static __forceinline void intersect_t(Precalculations& pre, RayHitK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = CurveNiIntersectorK::intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(normal.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = (CurveGeometry*) context->scene->get(geomID); + const Vec3ff a0 = Vec3ff::loadu(&prim.vertices(i,N)[0]); + const Vec3ff a1 = Vec3ff::loadu(&prim.vertices(i,N)[1]); + const Vec3ff a2 = Vec3ff::loadu(&prim.vertices(i,N)[2]); + const Vec3ff a3 = Vec3ff::loadu(&prim.vertices(i,N)[3]); + + size_t mask1 = mask; + const size_t i1 = bscf(mask1); + if (mask) { + prefetchL1(&prim.vertices(i1,N)[0]); + prefetchL1(&prim.vertices(i1,N)[4]); + if (mask1) { + const size_t i2 = bsf(mask1); + prefetchL2(&prim.vertices(i2,N)[0]); + prefetchL2(&prim.vertices(i2,N)[4]); + } + } + + Intersector().intersect(pre,ray,k,context,geom,primID,a0,a1,a2,a3,Epilog(ray,k,context,geomID,primID)); + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + } + + template + static __forceinline bool occluded_t(Precalculations& pre, RayK& ray, const size_t k, IntersectContext* context, const Primitive& prim) + { + vfloat tNear; + vbool valid = CurveNiIntersectorK::intersect(ray,k,prim,tNear); + + const size_t N = prim.N; + size_t mask = movemask(valid); + while (mask) + { + const size_t i = bscf(mask); + STAT3(shadow.trav_prims,1,1,1); + const unsigned int geomID = prim.geomID(N); + const unsigned int primID = prim.primID(N)[i]; + const CurveGeometry* geom = (CurveGeometry*) context->scene->get(geomID); + const Vec3ff a0 = Vec3ff::loadu(&prim.vertices(i,N)[0]); + const Vec3ff a1 = Vec3ff::loadu(&prim.vertices(i,N)[1]); + const Vec3ff a2 = Vec3ff::loadu(&prim.vertices(i,N)[2]); + const Vec3ff a3 = Vec3ff::loadu(&prim.vertices(i,N)[3]); + + size_t mask1 = mask; + const size_t i1 = bscf(mask1); + if (mask) { + prefetchL1(&prim.vertices(i1,N)[0]); + prefetchL1(&prim.vertices(i1,N)[4]); + if (mask1) { + const size_t i2 = bsf(mask1); + prefetchL2(&prim.vertices(i2,N)[0]); + prefetchL2(&prim.vertices(i2,N)[4]); + } + } + + if (Intersector().intersect(pre,ray,k,context,geom,primID,a0,a1,a2,a3,Epilog(ray,k,context,geomID,primID))) + return true; + + mask &= movemask(tNear <= vfloat(ray.tfar[k])); + } + return false; + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/curve_intersector.h b/thirdparty/embree/kernels/geometry/curve_intersector.h new file mode 100644 index 000000000000..204958f7cc36 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/curve_intersector.h @@ -0,0 +1,98 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" +#include "../subdiv/bezier_curve.h" +#include "../common/primref.h" +#include "bezier_hair_intersector.h" +#include "bezier_ribbon_intersector.h" +#include "bezier_curve_intersector.h" +#include "oriented_curve_intersector.h" +#include "../bvh/node_intersector1.h" + +// FIXME: this file seems replicate of curve_intersector_virtual.h + +namespace embree +{ + namespace isa + { + struct VirtualCurveIntersector1 + { + typedef unsigned char Primitive; + typedef CurvePrecalculations1 Precalculations; + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + assert(num == 1); + RTCGeometryType ty = (RTCGeometryType)(*prim); + assert(This->leafIntersector); + VirtualCurvePrimitive::Intersectors& leafIntersector = ((VirtualCurvePrimitive*) This->leafIntersector)->vtbl[ty]; + leafIntersector.intersect<1>(&pre,&ray,context,prim); + } + + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + assert(num == 1); + RTCGeometryType ty = (RTCGeometryType)(*prim); + assert(This->leafIntersector); + VirtualCurvePrimitive::Intersectors& leafIntersector = ((VirtualCurvePrimitive*) This->leafIntersector)->vtbl[ty]; + return leafIntersector.occluded<1>(&pre,&ray,context,prim); + } + }; + + template + struct VirtualCurveIntersectorK + { + typedef unsigned char Primitive; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline void intersect(const vbool& valid_i, const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive* prim, size_t num, size_t& lazy_node) + { + assert(num == 1); + RTCGeometryType ty = (RTCGeometryType)(*prim); + assert(This->leafIntersector); + VirtualCurvePrimitive::Intersectors& leafIntersector = ((VirtualCurvePrimitive*) This->leafIntersector)->vtbl[ty]; + size_t mask = movemask(valid_i); + while (mask) leafIntersector.intersect(&pre,&ray,bscf(mask),context,prim); + } + + static __forceinline vbool occluded(const vbool& valid_i, const Accel::Intersectors* This, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive* prim, size_t num, size_t& lazy_node) + { + assert(num == 1); + RTCGeometryType ty = (RTCGeometryType)(*prim); + assert(This->leafIntersector); + VirtualCurvePrimitive::Intersectors& leafIntersector = ((VirtualCurvePrimitive*) This->leafIntersector)->vtbl[ty]; + vbool valid_o = false; + size_t mask = movemask(valid_i); + while (mask) { + size_t k = bscf(mask); + if (leafIntersector.occluded(&pre,&ray,k,context,prim)) + set(valid_o, k); + } + return valid_o; + } + + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t num, size_t& lazy_node) + { + assert(num == 1); + RTCGeometryType ty = (RTCGeometryType)(*prim); + assert(This->leafIntersector); + VirtualCurvePrimitive::Intersectors& leafIntersector = ((VirtualCurvePrimitive*) This->leafIntersector)->vtbl[ty]; + leafIntersector.intersect(&pre,&ray,k,context,prim); + } + + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t num, size_t& lazy_node) + { + assert(num == 1); + RTCGeometryType ty = (RTCGeometryType)(*prim); + assert(This->leafIntersector); + VirtualCurvePrimitive::Intersectors& leafIntersector = ((VirtualCurvePrimitive*) This->leafIntersector)->vtbl[ty]; + return leafIntersector.occluded(&pre,&ray,k,context,prim); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/curve_intersector_distance.h b/thirdparty/embree/kernels/geometry/curve_intersector_distance.h new file mode 100644 index 000000000000..343cc8ff28f0 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/curve_intersector_distance.h @@ -0,0 +1,129 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" +#include "curve_intersector_precalculations.h" + +namespace embree +{ + namespace isa + { + template + struct DistanceCurveHit + { + __forceinline DistanceCurveHit() {} + + __forceinline DistanceCurveHit(const vbool& valid, const vfloat& U, const vfloat& V, const vfloat& T, const int i, const int N, + const NativeCurve3fa& curve3D) + : U(U), V(V), T(T), i(i), N(N), curve3D(curve3D), valid(valid) {} + + __forceinline void finalize() + { + vu = (vfloat(step)+U+vfloat(float(i)))*(1.0f/float(N)); + vv = V; + vt = T; + } + + __forceinline Vec2f uv (const size_t i) const { return Vec2f(vu[i],vv[i]); } + __forceinline float t (const size_t i) const { return vt[i]; } + __forceinline Vec3fa Ng(const size_t i) const { + return curve3D.eval_du(vu[i]); + } + + public: + vfloat U; + vfloat V; + vfloat T; + int i, N; + NativeCurve3fa curve3D; + + public: + vbool valid; + vfloat vu; + vfloat vv; + vfloat vt; + }; + + template + struct DistanceCurve1Intersector1 + { + template + __forceinline bool intersect(const CurvePrecalculations1& pre,Ray& ray, + IntersectContext* context, + const CurveGeometry* geom, const unsigned int primID, + const Vec3fa& v0, const Vec3fa& v1, const Vec3fa& v2, const Vec3fa& v3, + const Epilog& epilog) + { + const int N = geom->tessellationRate; + + /* transform control points into ray space */ + const NativeCurve3fa curve3Di(v0,v1,v2,v3); + const NativeCurve3fa curve3D = enlargeRadiusToMinWidth(context,geom,ray.org,curve3Di); + const NativeCurve3fa curve2D = curve3D.xfm_pr(pre.ray_space,ray.org); + + /* evaluate the bezier curve */ + vboolx valid = vfloatx(step) < vfloatx(float(N)); + const Vec4vfx p0 = curve2D.template eval0(0,N); + const Vec4vfx p1 = curve2D.template eval1(0,N); + + /* approximative intersection with cone */ + const Vec4vfx v = p1-p0; + const Vec4vfx w = -p0; + const vfloatx d0 = madd(w.x,v.x,w.y*v.y); + const vfloatx d1 = madd(v.x,v.x,v.y*v.y); + const vfloatx u = clamp(d0*rcp(d1),vfloatx(zero),vfloatx(one)); + const Vec4vfx p = madd(u,v,p0); + const vfloatx t = p.z*pre.depth_scale; + const vfloatx d2 = madd(p.x,p.x,p.y*p.y); + const vfloatx r = p.w; + const vfloatx r2 = r*r; + valid &= (d2 <= r2) & (vfloatx(ray.tnear()) <= t) & (t <= vfloatx(ray.tfar)); + if (EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR != 0.0f) + valid &= t > float(EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR)*r*pre.depth_scale; // ignore self intersections + + /* update hit information */ + bool ishit = false; + if (unlikely(any(valid))) { + DistanceCurveHit hit(valid,u,0.0f,t,0,N,curve3D); + ishit = ishit | epilog(valid,hit); + } + + if (unlikely(VSIZEX < N)) + { + /* process SIMD-size many segments per iteration */ + for (int i=VSIZEX; i(i,N); + const Vec4vfx p1 = curve2D.template eval1(i,N); + + /* approximative intersection with cone */ + const Vec4vfx v = p1-p0; + const Vec4vfx w = -p0; + const vfloatx d0 = madd(w.x,v.x,w.y*v.y); + const vfloatx d1 = madd(v.x,v.x,v.y*v.y); + const vfloatx u = clamp(d0*rcp(d1),vfloatx(zero),vfloatx(one)); + const Vec4vfx p = madd(u,v,p0); + const vfloatx t = p.z*pre.depth_scale; + const vfloatx d2 = madd(p.x,p.x,p.y*p.y); + const vfloatx r = p.w; + const vfloatx r2 = r*r; + valid &= (d2 <= r2) & (vfloatx(ray.tnear()) <= t) & (t <= vfloatx(ray.tfar)); + if (EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR != 0.0f) + valid &= t > float(EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR)*r*pre.depth_scale; // ignore self intersections + + /* update hit information */ + if (unlikely(any(valid))) { + DistanceCurveHit hit(valid,u,0.0f,t,i,N,curve3D); + ishit = ishit | epilog(valid,hit); + } + } + } + return ishit; + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/curve_intersector_oriented.h b/thirdparty/embree/kernels/geometry/curve_intersector_oriented.h new file mode 100644 index 000000000000..47531027fc3b --- /dev/null +++ b/thirdparty/embree/kernels/geometry/curve_intersector_oriented.h @@ -0,0 +1,417 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" +#include "curve_intersector_precalculations.h" +#include "curve_intersector_sweep.h" +#include "../subdiv/linear_bezier_patch.h" + +#define DBG(x) + +namespace embree +{ + namespace isa + { + template + struct TensorLinearCubicBezierSurfaceIntersector + { + const LinearSpace3fa& ray_space; + Ray& ray; + TensorLinearCubicBezierSurface3fa curve3d; + TensorLinearCubicBezierSurface2fa curve2d; + float eps; + const Epilog& epilog; + bool isHit; + + __forceinline TensorLinearCubicBezierSurfaceIntersector (const LinearSpace3fa& ray_space, Ray& ray, const TensorLinearCubicBezierSurface3fa& curve3d, const Epilog& epilog) + : ray_space(ray_space), ray(ray), curve3d(curve3d), epilog(epilog), isHit(false) + { + const TensorLinearCubicBezierSurface3fa curve3dray = curve3d.xfm(ray_space,ray.org); + curve2d = TensorLinearCubicBezierSurface2fa(CubicBezierCurve2fa(curve3dray.L),CubicBezierCurve2fa(curve3dray.R)); + const BBox2fa b2 = curve2d.bounds(); + eps = 8.0f*float(ulp)*reduce_max(max(abs(b2.lower),abs(b2.upper))); + } + + __forceinline Interval1f solve_linear(const float u0, const float u1, const float& p0, const float& p1) + { + if (p1 == p0) { + if (p0 == 0.0f) return Interval1f(u0,u1); + else return Interval1f(empty); + } + const float t = -p0/(p1-p0); + const float tt = lerp(u0,u1,t); + return Interval1f(tt); + } + + __forceinline void solve_linear(const float u0, const float u1, const Interval1f& p0, const Interval1f& p1, Interval1f& u) + { + if (sign(p0.lower) != sign(p0.upper)) u.extend(u0); + if (sign(p0.lower) != sign(p1.lower)) u.extend(solve_linear(u0,u1,p0.lower,p1.lower)); + if (sign(p0.upper) != sign(p1.upper)) u.extend(solve_linear(u0,u1,p0.upper,p1.upper)); + if (sign(p1.lower) != sign(p1.upper)) u.extend(u1); + } + + __forceinline Interval1f bezier_clipping(const CubicBezierCurve& curve) + { + Interval1f u = empty; + solve_linear(0.0f/3.0f,1.0f/3.0f,curve.v0,curve.v1,u); + solve_linear(0.0f/3.0f,2.0f/3.0f,curve.v0,curve.v2,u); + solve_linear(0.0f/3.0f,3.0f/3.0f,curve.v0,curve.v3,u); + solve_linear(1.0f/3.0f,2.0f/3.0f,curve.v1,curve.v2,u); + solve_linear(1.0f/3.0f,3.0f/3.0f,curve.v1,curve.v3,u); + solve_linear(2.0f/3.0f,3.0f/3.0f,curve.v2,curve.v3,u); + return intersect(u,Interval1f(0.0f,1.0f)); + } + + __forceinline Interval1f bezier_clipping(const LinearBezierCurve& curve) + { + Interval1f v = empty; + solve_linear(0.0f,1.0f,curve.v0,curve.v1,v); + return intersect(v,Interval1f(0.0f,1.0f)); + } + + __forceinline void solve_bezier_clipping(BBox1f cu, BBox1f cv, const TensorLinearCubicBezierSurface2fa& curve2) + { + BBox2fa bounds = curve2.bounds(); + if (bounds.upper.x < 0.0f) return; + if (bounds.upper.y < 0.0f) return; + if (bounds.lower.x > 0.0f) return; + if (bounds.lower.y > 0.0f) return; + + if (max(cu.size(),cv.size()) < 1E-4f) + { + const float u = cu.center(); + const float v = cv.center(); + TensorLinearCubicBezierSurface1f curve_z = curve3d.xfm(ray_space.row2(),ray.org); + const float t = curve_z.eval(u,v); + if (ray.tnear() <= t && t <= ray.tfar) { + const Vec3fa Ng = cross(curve3d.eval_du(u,v),curve3d.eval_dv(u,v)); + BezierCurveHit hit(t,u,v,Ng); + isHit |= epilog(hit); + } + return; + } + + const Vec2fa dv = curve2.axis_v(); + const TensorLinearCubicBezierSurface1f curve1v = curve2.xfm(dv); + LinearBezierCurve curve0v = curve1v.reduce_u(); + if (!curve0v.hasRoot()) return; + + const Interval1f v = bezier_clipping(curve0v); + if (isEmpty(v)) return; + TensorLinearCubicBezierSurface2fa curve2a = curve2.clip_v(v); + cv = BBox1f(lerp(cv.lower,cv.upper,v.lower),lerp(cv.lower,cv.upper,v.upper)); + + const Vec2fa du = curve2.axis_u(); + const TensorLinearCubicBezierSurface1f curve1u = curve2a.xfm(du); + CubicBezierCurve curve0u = curve1u.reduce_v(); + int roots = curve0u.maxRoots(); + if (roots == 0) return; + + if (roots == 1) + { + const Interval1f u = bezier_clipping(curve0u); + if (isEmpty(u)) return; + TensorLinearCubicBezierSurface2fa curve2b = curve2a.clip_u(u); + cu = BBox1f(lerp(cu.lower,cu.upper,u.lower),lerp(cu.lower,cu.upper,u.upper)); + solve_bezier_clipping(cu,cv,curve2b); + return; + } + + TensorLinearCubicBezierSurface2fa curve2l, curve2r; + curve2a.split_u(curve2l,curve2r); + solve_bezier_clipping(BBox1f(cu.lower,cu.center()),cv,curve2l); + solve_bezier_clipping(BBox1f(cu.center(),cu.upper),cv,curve2r); + } + + __forceinline bool solve_bezier_clipping() + { + solve_bezier_clipping(BBox1f(0.0f,1.0f),BBox1f(0.0f,1.0f),curve2d); + return isHit; + } + + __forceinline void solve_newton_raphson(BBox1f cu, BBox1f cv) + { + Vec2fa uv(cu.center(),cv.center()); + const Vec2fa dfdu = curve2d.eval_du(uv.x,uv.y); + const Vec2fa dfdv = curve2d.eval_dv(uv.x,uv.y); + const LinearSpace2fa rcp_J = rcp(LinearSpace2fa(dfdu,dfdv)); + solve_newton_raphson_loop(cu,cv,uv,dfdu,dfdv,rcp_J); + } + + __forceinline void solve_newton_raphson_loop(BBox1f cu, BBox1f cv, const Vec2fa& uv_in, const Vec2fa& dfdu, const Vec2fa& dfdv, const LinearSpace2fa& rcp_J) + { + Vec2fa uv = uv_in; + + for (size_t i=0; i<200; i++) + { + const Vec2fa f = curve2d.eval(uv.x,uv.y); + const Vec2fa duv = rcp_J*f; + uv -= duv; + + if (max(abs(f.x),abs(f.y)) < eps) + { + const float u = uv.x; + const float v = uv.y; + if (!(u >= 0.0f && u <= 1.0f)) return; // rejects NaNs + if (!(v >= 0.0f && v <= 1.0f)) return; // rejects NaNs + const TensorLinearCubicBezierSurface1f curve_z = curve3d.xfm(ray_space.row2(),ray.org); + const float t = curve_z.eval(u,v); + if (!(ray.tnear() <= t && t <= ray.tfar)) return; // rejects NaNs + const Vec3fa Ng = cross(curve3d.eval_du(u,v),curve3d.eval_dv(u,v)); + BezierCurveHit hit(t,u,v,Ng); + isHit |= epilog(hit); + return; + } + } + } + + __forceinline bool clip_v(BBox1f& cu, BBox1f& cv) + { + const Vec2fa dv = curve2d.eval_dv(cu.lower,cv.lower); + const TensorLinearCubicBezierSurface1f curve1v = curve2d.xfm(dv).clip(cu,cv); + LinearBezierCurve curve0v = curve1v.reduce_u(); + if (!curve0v.hasRoot()) return false; + Interval1f v = bezier_clipping(curve0v); + if (isEmpty(v)) return false; + v = intersect(v + Interval1f(-0.1f,+0.1f),Interval1f(0.0f,1.0f)); + cv = BBox1f(lerp(cv.lower,cv.upper,v.lower),lerp(cv.lower,cv.upper,v.upper)); + return true; + } + + __forceinline bool solve_krawczyk(bool very_small, BBox1f& cu, BBox1f& cv) + { + /* perform bezier clipping in v-direction to get tight v-bounds */ + TensorLinearCubicBezierSurface2fa curve2 = curve2d.clip(cu,cv); + const Vec2fa dv = curve2.axis_v(); + const TensorLinearCubicBezierSurface1f curve1v = curve2.xfm(dv); + LinearBezierCurve curve0v = curve1v.reduce_u(); + if (unlikely(!curve0v.hasRoot())) return true; + Interval1f v = bezier_clipping(curve0v); + if (unlikely(isEmpty(v))) return true; + v = intersect(v + Interval1f(-0.1f,+0.1f),Interval1f(0.0f,1.0f)); + curve2 = curve2.clip_v(v); + cv = BBox1f(lerp(cv.lower,cv.upper,v.lower),lerp(cv.lower,cv.upper,v.upper)); + + /* perform one newton raphson iteration */ + Vec2fa c(cu.center(),cv.center()); + Vec2fa f,dfdu,dfdv; curve2d.eval(c.x,c.y,f,dfdu,dfdv); + const LinearSpace2fa rcp_J = rcp(LinearSpace2fa(dfdu,dfdv)); + const Vec2fa c1 = c - rcp_J*f; + + /* calculate bounds of derivatives */ + const BBox2fa bounds_du = (1.0f/cu.size())*curve2.derivative_u().bounds(); + const BBox2fa bounds_dv = (1.0f/cv.size())*curve2.derivative_v().bounds(); + + /* calculate krawczyk test */ + LinearSpace2> I(Interval1f(1.0f), Interval1f(0.0f), + Interval1f(0.0f), Interval1f(1.0f)); + + LinearSpace2> G(Interval1f(bounds_du.lower.x,bounds_du.upper.x), Interval1f(bounds_dv.lower.x,bounds_dv.upper.x), + Interval1f(bounds_du.lower.y,bounds_du.upper.y), Interval1f(bounds_dv.lower.y,bounds_dv.upper.y)); + + const LinearSpace2 rcp_J2(rcp_J); + const LinearSpace2> rcp_Ji(rcp_J2); + + const Vec2 x(cu,cv); + const Vec2 K = Vec2(Vec2f(c1)) + (I - rcp_Ji*G)*(x-Vec2(Vec2f(c))); + + /* test if there is no solution */ + const Vec2 KK = intersect(K,x); + if (unlikely(isEmpty(KK.x) || isEmpty(KK.y))) return true; + + /* exit if convergence cannot get proven, but terminate if we are very small */ + if (unlikely(!subset(K,x) && !very_small)) return false; + + /* solve using newton raphson iteration of convergence is guarenteed */ + solve_newton_raphson_loop(cu,cv,c1,dfdu,dfdv,rcp_J); + return true; + } + + __forceinline void solve_newton_raphson_no_recursion(BBox1f cu, BBox1f cv) + { + if (!clip_v(cu,cv)) return; + return solve_newton_raphson(cu,cv); + } + + __forceinline void solve_newton_raphson_recursion(BBox1f cu, BBox1f cv) + { + unsigned int sptr = 0; + const unsigned int stack_size = 4; + unsigned int mask_stack[stack_size]; + BBox1f cu_stack[stack_size]; + BBox1f cv_stack[stack_size]; + goto entry; + + /* terminate if stack is empty */ + while (sptr) + { + /* pop from stack */ + { + sptr--; + size_t mask = mask_stack[sptr]; + cu = cu_stack[sptr]; + cv = cv_stack[sptr]; + const size_t i = bscf(mask); + mask_stack[sptr] = mask; + if (mask) sptr++; // there are still items on the stack + + /* process next element recurse into each hit curve segment */ + const float u0 = float(i+0)*(1.0f/(VSIZEX-1)); + const float u1 = float(i+1)*(1.0f/(VSIZEX-1)); + const BBox1f cui(lerp(cu.lower,cu.upper,u0),lerp(cu.lower,cu.upper,u1)); + cu = cui; + } + +#if 0 + solve_newton_raphson_no_recursion(cu,cv); + continue; + +#else + /* we assume convergence for small u ranges and verify using krawczyk */ + if (cu.size() < 1.0f/6.0f) { + const bool very_small = cu.size() < 0.001f || sptr >= stack_size; + if (solve_krawczyk(very_small,cu,cv)) { + continue; + } + } +#endif + + entry: + + /* split the curve into VSIZEX-1 segments in u-direction */ + vboolx valid = true; + TensorLinearCubicBezierSurface subcurves = curve2d.clip_v(cv).vsplit_u(valid,cu); + + /* slabs test in u-direction */ + Vec2vfx ndv = cross(subcurves.axis_v()); + BBox boundsv = subcurves.vxfm(ndv).bounds(); + valid &= boundsv.lower <= eps; + valid &= boundsv.upper >= -eps; + if (none(valid)) continue; + + /* slabs test in v-direction */ + Vec2vfx ndu = cross(subcurves.axis_u()); + BBox boundsu = subcurves.vxfm(ndu).bounds(); + valid &= boundsu.lower <= eps; + valid &= boundsu.upper >= -eps; + if (none(valid)) continue; + + /* push valid segments to stack */ + assert(sptr < stack_size); + mask_stack [sptr] = movemask(valid); + cu_stack [sptr] = cu; + cv_stack [sptr] = cv; + sptr++; + } + } + + __forceinline bool solve_newton_raphson_main() + { + BBox1f vu(0.0f,1.0f); + BBox1f vv(0.0f,1.0f); + solve_newton_raphson_recursion(vu,vv); + return isHit; + } + }; + + + template class SourceCurve> + struct OrientedCurve1Intersector1 + { + //template using Curve = SourceCurve; + typedef SourceCurve SourceCurve3ff; + typedef SourceCurve SourceCurve3fa; + + __forceinline OrientedCurve1Intersector1() {} + + __forceinline OrientedCurve1Intersector1(const Ray& ray, const void* ptr) {} + + template + __noinline bool intersect(const CurvePrecalculations1& pre, Ray& ray, + IntersectContext* context, + const CurveGeometry* geom, const unsigned int primID, + const Vec3ff& v0i, const Vec3ff& v1i, const Vec3ff& v2i, const Vec3ff& v3i, + const Vec3fa& n0i, const Vec3fa& n1i, const Vec3fa& n2i, const Vec3fa& n3i, + const Epilog& epilog) const + { + STAT3(normal.trav_prims,1,1,1); + + SourceCurve3ff ccurve(v0i,v1i,v2i,v3i); + SourceCurve3fa ncurve(n0i,n1i,n2i,n3i); + ccurve = enlargeRadiusToMinWidth(context,geom,ray.org,ccurve); + TensorLinearCubicBezierSurface3fa curve = TensorLinearCubicBezierSurface3fa::fromCenterAndNormalCurve(ccurve,ncurve); + //return TensorLinearCubicBezierSurfaceIntersector(pre.ray_space,ray,curve,epilog).solve_bezier_clipping(); + return TensorLinearCubicBezierSurfaceIntersector(pre.ray_space,ray,curve,epilog).solve_newton_raphson_main(); + } + + template + __noinline bool intersect(const CurvePrecalculations1& pre, Ray& ray, + IntersectContext* context, + const CurveGeometry* geom, const unsigned int primID, + const TensorLinearCubicBezierSurface3fa& curve, const Epilog& epilog) const + { + STAT3(normal.trav_prims,1,1,1); + //return TensorLinearCubicBezierSurfaceIntersector(pre.ray_space,ray,curve,epilog).solve_bezier_clipping(); + return TensorLinearCubicBezierSurfaceIntersector(pre.ray_space,ray,curve,epilog).solve_newton_raphson_main(); + } + }; + + template class SourceCurve, int K> + struct OrientedCurve1IntersectorK + { + //template using Curve = SourceCurve; + typedef SourceCurve SourceCurve3ff; + typedef SourceCurve SourceCurve3fa; + + struct Ray1 + { + __forceinline Ray1(RayK& ray, size_t k) + : org(ray.org.x[k],ray.org.y[k],ray.org.z[k]), dir(ray.dir.x[k],ray.dir.y[k],ray.dir.z[k]), _tnear(ray.tnear()[k]), tfar(ray.tfar[k]) {} + + Vec3fa org; + Vec3fa dir; + float _tnear; + float& tfar; + + __forceinline float& tnear() { return _tnear; } + //__forceinline float& tfar() { return _tfar; } + __forceinline const float& tnear() const { return _tnear; } + //__forceinline const float& tfar() const { return _tfar; } + }; + + template + __forceinline bool intersect(const CurvePrecalculationsK& pre, RayK& vray, size_t k, + IntersectContext* context, + const CurveGeometry* geom, const unsigned int primID, + const Vec3ff& v0i, const Vec3ff& v1i, const Vec3ff& v2i, const Vec3ff& v3i, + const Vec3fa& n0i, const Vec3fa& n1i, const Vec3fa& n2i, const Vec3fa& n3i, + const Epilog& epilog) + { + STAT3(normal.trav_prims,1,1,1); + Ray1 ray(vray,k); + SourceCurve3ff ccurve(v0i,v1i,v2i,v3i); + SourceCurve3fa ncurve(n0i,n1i,n2i,n3i); + ccurve = enlargeRadiusToMinWidth(context,geom,ray.org,ccurve); + TensorLinearCubicBezierSurface3fa curve = TensorLinearCubicBezierSurface3fa::fromCenterAndNormalCurve(ccurve,ncurve); + //return TensorLinearCubicBezierSurfaceIntersector(pre.ray_space[k],ray,curve,epilog).solve_bezier_clipping(); + return TensorLinearCubicBezierSurfaceIntersector(pre.ray_space[k],ray,curve,epilog).solve_newton_raphson_main(); + } + + template + __forceinline bool intersect(const CurvePrecalculationsK& pre, RayK& vray, size_t k, + IntersectContext* context, + const CurveGeometry* geom, const unsigned int primID, + const TensorLinearCubicBezierSurface3fa& curve, + const Epilog& epilog) + { + STAT3(normal.trav_prims,1,1,1); + Ray1 ray(vray,k); + //return TensorLinearCubicBezierSurfaceIntersector(pre.ray_space[k],ray,curve,epilog).solve_bezier_clipping(); + return TensorLinearCubicBezierSurfaceIntersector(pre.ray_space[k],ray,curve,epilog).solve_newton_raphson_main(); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/curve_intersector_precalculations.h b/thirdparty/embree/kernels/geometry/curve_intersector_precalculations.h new file mode 100644 index 000000000000..6e9fc9192575 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/curve_intersector_precalculations.h @@ -0,0 +1,49 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" +#include "../common/geometry.h" + +namespace embree +{ + namespace isa + { + struct CurvePrecalculations1 + { + float depth_scale; + LinearSpace3fa ray_space; + + __forceinline CurvePrecalculations1() {} + + __forceinline CurvePrecalculations1(const Ray& ray, const void* ptr) + { + depth_scale = rsqrt(dot(ray.dir,ray.dir)); + LinearSpace3fa space = frame(depth_scale*ray.dir); + space.vz *= depth_scale; + ray_space = space.transposed(); + } + }; + + template + struct CurvePrecalculationsK + { + vfloat depth_scale; + LinearSpace3fa ray_space[K]; + + __forceinline CurvePrecalculationsK(const vbool& valid, const RayK& ray) + { + size_t mask = movemask(valid); + depth_scale = rsqrt(dot(ray.dir,ray.dir)); + while (mask) { + size_t k = bscf(mask); + Vec3fa ray_dir_k = Vec3fa(ray.dir.x[k],ray.dir.y[k],ray.dir.z[k]); + LinearSpace3fa ray_space_k = frame(depth_scale[k]*ray_dir_k); + ray_space_k.vz *= depth_scale[k]; + ray_space[k] = ray_space_k.transposed(); + } + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/curve_intersector_ribbon.h b/thirdparty/embree/kernels/geometry/curve_intersector_ribbon.h new file mode 100644 index 000000000000..a99cf99d56c5 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/curve_intersector_ribbon.h @@ -0,0 +1,214 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" +#include "quad_intersector.h" +#include "curve_intersector_precalculations.h" + +#define Bezier1Intersector1 RibbonCurve1Intersector1 +#define Bezier1IntersectorK RibbonCurve1IntersectorK + +namespace embree +{ + namespace isa + { + template + struct RibbonHit + { + __forceinline RibbonHit() {} + + __forceinline RibbonHit(const vbool& valid, const vfloat& U, const vfloat& V, const vfloat& T, const int i, const int N, + const NativeCurve3ff& curve3D) + : U(U), V(V), T(T), i(i), N(N), curve3D(curve3D), valid(valid) {} + + __forceinline void finalize() + { + vu = (vfloat(step)+U+vfloat(float(i)))*(1.0f/float(N)); + vv = V; + vt = T; + } + + __forceinline Vec2f uv (const size_t i) const { return Vec2f(vu[i],vv[i]); } + __forceinline float t (const size_t i) const { return vt[i]; } + __forceinline Vec3fa Ng(const size_t i) const { + return curve3D.eval_du(vu[i]); + } + + public: + vfloat U; + vfloat V; + vfloat T; + int i, N; + NativeCurve3ff curve3D; + + public: + vbool valid; + vfloat vu; + vfloat vv; + vfloat vt; + }; + + /* calculate squared distance of point p0 to line p1->p2 */ + __forceinline std::pair sqr_point_line_distance(const Vec2vfx& p0, const Vec2vfx& p1, const Vec2vfx& p2) + { + const vfloatx num = det(p2-p1,p1-p0); + const vfloatx den2 = dot(p2-p1,p2-p1); + return std::make_pair(num*num,den2); + } + + /* performs culling against a cylinder */ + __forceinline vboolx cylinder_culling_test(const Vec2vfx& p0, const Vec2vfx& p1, const Vec2vfx& p2, const vfloatx& r) + { + const std::pair d = sqr_point_line_distance(p0,p1,p2); + return d.first <= r*r*d.second; + } + + template + __forceinline bool intersect_ribbon(const Vec3fa& ray_org, const Vec3fa& ray_dir, const float ray_tnear, const float& ray_tfar, + const LinearSpace3fa& ray_space, const float& depth_scale, + const NativeCurve3ff& curve3D, const int N, + const Epilog& epilog) + { + /* transform control points into ray space */ + const NativeCurve3ff curve2D = curve3D.xfm_pr(ray_space,ray_org); + float eps = 4.0f*float(ulp)*reduce_max(max(abs(curve2D.v0),abs(curve2D.v1),abs(curve2D.v2),abs(curve2D.v3))); + + /* evaluate the bezier curve */ + bool ishit = false; + vboolx valid = vfloatx(step) < vfloatx(float(N)); + const Vec4vfx p0 = curve2D.template eval0(0,N); + const Vec4vfx p1 = curve2D.template eval1(0,N); + valid &= cylinder_culling_test(zero,Vec2vfx(p0.x,p0.y),Vec2vfx(p1.x,p1.y),max(p0.w,p1.w)); + + if (any(valid)) + { + Vec3vfx dp0dt = curve2D.template derivative0(0,N); + Vec3vfx dp1dt = curve2D.template derivative1(0,N); + dp0dt = select(reduce_max(abs(dp0dt)) < vfloatx(eps),Vec3vfx(p1-p0),dp0dt); + dp1dt = select(reduce_max(abs(dp1dt)) < vfloatx(eps),Vec3vfx(p1-p0),dp1dt); + const Vec3vfx n0(dp0dt.y,-dp0dt.x,0.0f); + const Vec3vfx n1(dp1dt.y,-dp1dt.x,0.0f); + const Vec3vfx nn0 = normalize(n0); + const Vec3vfx nn1 = normalize(n1); + const Vec3vfx lp0 = madd(p0.w,nn0,Vec3vfx(p0)); + const Vec3vfx lp1 = madd(p1.w,nn1,Vec3vfx(p1)); + const Vec3vfx up0 = nmadd(p0.w,nn0,Vec3vfx(p0)); + const Vec3vfx up1 = nmadd(p1.w,nn1,Vec3vfx(p1)); + + vfloatx vu,vv,vt; + vboolx valid0 = intersect_quad_backface_culling(valid,zero,Vec3fa(0,0,1),ray_tnear,ray_tfar,lp0,lp1,up1,up0,vu,vv,vt); + + if (any(valid0)) + { + /* ignore self intersections */ + if (EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR != 0.0f) { + vfloatx r = lerp(p0.w, p1.w, vu); + valid0 &= vt > float(EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR)*r*depth_scale; + } + + if (any(valid0)) + { + vv = madd(2.0f,vv,vfloatx(-1.0f)); + RibbonHit bhit(valid0,vu,vv,vt,0,N,curve3D); + ishit |= epilog(bhit.valid,bhit); + } + } + } + + if (unlikely(VSIZEX < N)) + { + /* process SIMD-size many segments per iteration */ + for (int i=VSIZEX; i(i,N); + const Vec4vfx p1 = curve2D.template eval1(i,N); + valid &= cylinder_culling_test(zero,Vec2vfx(p0.x,p0.y),Vec2vfx(p1.x,p1.y),max(p0.w,p1.w)); + if (none(valid)) continue; + + Vec3vfx dp0dt = curve2D.template derivative0(i,N); + Vec3vfx dp1dt = curve2D.template derivative1(i,N); + dp0dt = select(reduce_max(abs(dp0dt)) < vfloatx(eps),Vec3vfx(p1-p0),dp0dt); + dp1dt = select(reduce_max(abs(dp1dt)) < vfloatx(eps),Vec3vfx(p1-p0),dp1dt); + const Vec3vfx n0(dp0dt.y,-dp0dt.x,0.0f); + const Vec3vfx n1(dp1dt.y,-dp1dt.x,0.0f); + const Vec3vfx nn0 = normalize(n0); + const Vec3vfx nn1 = normalize(n1); + const Vec3vfx lp0 = madd(p0.w,nn0,Vec3vfx(p0)); + const Vec3vfx lp1 = madd(p1.w,nn1,Vec3vfx(p1)); + const Vec3vfx up0 = nmadd(p0.w,nn0,Vec3vfx(p0)); + const Vec3vfx up1 = nmadd(p1.w,nn1,Vec3vfx(p1)); + + vfloatx vu,vv,vt; + vboolx valid0 = intersect_quad_backface_culling(valid,zero,Vec3fa(0,0,1),ray_tnear,ray_tfar,lp0,lp1,up1,up0,vu,vv,vt); + + if (any(valid0)) + { + /* ignore self intersections */ + if (EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR != 0.0f) { + vfloatx r = lerp(p0.w, p1.w, vu); + valid0 &= vt > float(EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR)*r*depth_scale; + } + + if (any(valid0)) + { + vv = madd(2.0f,vv,vfloatx(-1.0f)); + RibbonHit bhit(valid0,vu,vv,vt,i,N,curve3D); + ishit |= epilog(bhit.valid,bhit); + } + } + } + } + return ishit; + } + + template class NativeCurve> + struct RibbonCurve1Intersector1 + { + typedef NativeCurve NativeCurve3ff; + + template + __forceinline bool intersect(const CurvePrecalculations1& pre, Ray& ray, + IntersectContext* context, + const CurveGeometry* geom, const unsigned int primID, + const Vec3ff& v0, const Vec3ff& v1, const Vec3ff& v2, const Vec3ff& v3, + const Epilog& epilog) + { + const int N = geom->tessellationRate; + NativeCurve3ff curve(v0,v1,v2,v3); + curve = enlargeRadiusToMinWidth(context,geom,ray.org,curve); + return intersect_ribbon(ray.org,ray.dir,ray.tnear(),ray.tfar, + pre.ray_space,pre.depth_scale, + curve,N, + epilog); + } + }; + + template class NativeCurve, int K> + struct RibbonCurve1IntersectorK + { + typedef NativeCurve NativeCurve3ff; + + template + __forceinline bool intersect(const CurvePrecalculationsK& pre, RayK& ray, size_t k, + IntersectContext* context, + const CurveGeometry* geom, const unsigned int primID, + const Vec3ff& v0, const Vec3ff& v1, const Vec3ff& v2, const Vec3ff& v3, + const Epilog& epilog) + { + const int N = geom->tessellationRate; + const Vec3fa ray_org(ray.org.x[k],ray.org.y[k],ray.org.z[k]); + const Vec3fa ray_dir(ray.dir.x[k],ray.dir.y[k],ray.dir.z[k]); + NativeCurve3ff curve(v0,v1,v2,v3); + curve = enlargeRadiusToMinWidth(context,geom,ray_org,curve); + return intersect_ribbon(ray_org,ray_dir,ray.tnear()[k],ray.tfar[k], + pre.ray_space[k],pre.depth_scale[k], + curve,N, + epilog); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/curve_intersector_sweep.h b/thirdparty/embree/kernels/geometry/curve_intersector_sweep.h new file mode 100644 index 000000000000..883cedc3d2a4 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/curve_intersector_sweep.h @@ -0,0 +1,362 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" +#include "cylinder.h" +#include "plane.h" +#include "line_intersector.h" +#include "curve_intersector_precalculations.h" + +namespace embree +{ + namespace isa + { + static const size_t numJacobianIterations = 5; +#if defined(__AVX__) + static const size_t numBezierSubdivisions = 2; +#else + static const size_t numBezierSubdivisions = 3; +#endif + + struct BezierCurveHit + { + __forceinline BezierCurveHit() {} + + __forceinline BezierCurveHit(const float t, const float u, const Vec3fa& Ng) + : t(t), u(u), v(0.0f), Ng(Ng) {} + + __forceinline BezierCurveHit(const float t, const float u, const float v, const Vec3fa& Ng) + : t(t), u(u), v(v), Ng(Ng) {} + + __forceinline void finalize() {} + + public: + float t; + float u; + float v; + Vec3fa Ng; + }; + + template + __forceinline bool intersect_bezier_iterative_debug(const Ray& ray, const float dt, const NativeCurve3ff& curve, size_t i, + const vfloatx& u, const BBox& tp, const BBox& h0, const BBox& h1, + const Vec3vfx& Ng, const Vec4vfx& dP0du, const Vec4vfx& dP3du, + const Epilog& epilog) + { + if (tp.lower[i]+dt > ray.tfar) return false; + Vec3fa Ng_o = Vec3fa(Ng.x[i],Ng.y[i],Ng.z[i]); + if (h0.lower[i] == tp.lower[i]) Ng_o = -Vec3fa(dP0du.x[i],dP0du.y[i],dP0du.z[i]); + if (h1.lower[i] == tp.lower[i]) Ng_o = +Vec3fa(dP3du.x[i],dP3du.y[i],dP3du.z[i]); + BezierCurveHit hit(tp.lower[i]+dt,u[i],Ng_o); + return epilog(hit); + } + + template + __forceinline bool intersect_bezier_iterative_jacobian(const Ray& ray, const float dt, const NativeCurve3ff& curve, float u, float t, const Epilog& epilog) + { + const Vec3fa org = zero; + const Vec3fa dir = ray.dir; + const float length_ray_dir = length(dir); + + /* error of curve evaluations is propertional to largest coordinate */ + const BBox3ff box = curve.bounds(); + const float P_err = 16.0f*float(ulp)*reduce_max(max(abs(box.lower),abs(box.upper))); + + for (size_t i=0; i= 0.0f && u <= 1.0f)) return false; // rejects NaNs + const Vec3fa R = normalize(Q-P); + const Vec3fa U = madd(Vec3fa(dPdu.w),R,dPdu); + const Vec3fa V = cross(dPdu,R); + BezierCurveHit hit(t,u,cross(V,U)); + return epilog(hit); + } + } + return false; + } + + template + bool intersect_bezier_recursive_jacobian(const Ray& ray, const float dt, const NativeCurve3ff& curve, + float u0, float u1, unsigned int depth, const Epilog& epilog) + { +#if defined(__AVX__) + typedef vbool8 vboolx; // maximally 8-wide to work around KNL issues + typedef vint8 vintx; + typedef vfloat8 vfloatx; +#else + typedef vbool4 vboolx; + typedef vint4 vintx; + typedef vfloat4 vfloatx; +#endif + typedef Vec3 Vec3vfx; + typedef Vec4 Vec4vfx; + + unsigned int maxDepth = numBezierSubdivisions; + bool found = false; + const Vec3fa org = zero; + const Vec3fa dir = ray.dir; + + unsigned int sptr = 0; + const unsigned int stack_size = numBezierSubdivisions+1; // +1 because of unstable workaround below + struct StackEntry { + vboolx valid; + vfloatx tlower; + float u0; + float u1; + unsigned int depth; + }; + StackEntry stack[stack_size]; + goto entry; + + /* terminate if stack is empty */ + while (sptr) + { + /* pop from stack */ + { + sptr--; + vboolx valid = stack[sptr].valid; + const vfloatx tlower = stack[sptr].tlower; + valid &= tlower+dt <= ray.tfar; + if (none(valid)) continue; + u0 = stack[sptr].u0; + u1 = stack[sptr].u1; + depth = stack[sptr].depth; + const size_t i = select_min(valid,tlower); clear(valid,i); + stack[sptr].valid = valid; + if (any(valid)) sptr++; // there are still items on the stack + + /* process next segment */ + const vfloatx vu0 = lerp(u0,u1,vfloatx(step)*(1.0f/(vfloatx::size-1))); + u0 = vu0[i+0]; + u1 = vu0[i+1]; + } + entry: + + /* subdivide curve */ + const float dscale = (u1-u0)*(1.0f/(3.0f*(vfloatx::size-1))); + const vfloatx vu0 = lerp(u0,u1,vfloatx(step)*(1.0f/(vfloatx::size-1))); + Vec4vfx P0, dP0du; curve.veval(vu0,P0,dP0du); dP0du = dP0du * Vec4vfx(dscale); + const Vec4vfx P3 = shift_right_1(P0); + const Vec4vfx dP3du = shift_right_1(dP0du); + const Vec4vfx P1 = P0 + dP0du; + const Vec4vfx P2 = P3 - dP3du; + + /* calculate bounding cylinders */ + const vfloatx rr1 = sqr_point_to_line_distance(Vec3vfx(dP0du),Vec3vfx(P3-P0)); + const vfloatx rr2 = sqr_point_to_line_distance(Vec3vfx(dP3du),Vec3vfx(P3-P0)); + const vfloatx maxr12 = sqrt(max(rr1,rr2)); + const vfloatx one_plus_ulp = 1.0f+2.0f*float(ulp); + const vfloatx one_minus_ulp = 1.0f-2.0f*float(ulp); + vfloatx r_outer = max(P0.w,P1.w,P2.w,P3.w)+maxr12; + vfloatx r_inner = min(P0.w,P1.w,P2.w,P3.w)-maxr12; + r_outer = one_plus_ulp*r_outer; + r_inner = max(0.0f,one_minus_ulp*r_inner); + const CylinderN cylinder_outer(Vec3vfx(P0),Vec3vfx(P3),r_outer); + const CylinderN cylinder_inner(Vec3vfx(P0),Vec3vfx(P3),r_inner); + vboolx valid = true; clear(valid,vfloatx::size-1); + + /* intersect with outer cylinder */ + BBox tc_outer; vfloatx u_outer0; Vec3vfx Ng_outer0; vfloatx u_outer1; Vec3vfx Ng_outer1; + valid &= cylinder_outer.intersect(org,dir,tc_outer,u_outer0,Ng_outer0,u_outer1,Ng_outer1); + if (none(valid)) continue; + + /* intersect with cap-planes */ + BBox tp(ray.tnear()-dt,ray.tfar-dt); + tp = embree::intersect(tp,tc_outer); + BBox h0 = HalfPlaneN(Vec3vfx(P0),+Vec3vfx(dP0du)).intersect(org,dir); + tp = embree::intersect(tp,h0); + BBox h1 = HalfPlaneN(Vec3vfx(P3),-Vec3vfx(dP3du)).intersect(org,dir); + tp = embree::intersect(tp,h1); + valid &= tp.lower <= tp.upper; + if (none(valid)) continue; + + /* clamp and correct u parameter */ + u_outer0 = clamp(u_outer0,vfloatx(0.0f),vfloatx(1.0f)); + u_outer1 = clamp(u_outer1,vfloatx(0.0f),vfloatx(1.0f)); + u_outer0 = lerp(u0,u1,(vfloatx(step)+u_outer0)*(1.0f/float(vfloatx::size))); + u_outer1 = lerp(u0,u1,(vfloatx(step)+u_outer1)*(1.0f/float(vfloatx::size))); + + /* intersect with inner cylinder */ + BBox tc_inner; + vfloatx u_inner0 = zero; Vec3vfx Ng_inner0 = zero; vfloatx u_inner1 = zero; Vec3vfx Ng_inner1 = zero; + const vboolx valid_inner = cylinder_inner.intersect(org,dir,tc_inner,u_inner0,Ng_inner0,u_inner1,Ng_inner1); + + /* at the unstable area we subdivide deeper */ + const vboolx unstable0 = (!valid_inner) | (abs(dot(Vec3vfx(Vec3fa(ray.dir)),Ng_inner0)) < 0.3f); + const vboolx unstable1 = (!valid_inner) | (abs(dot(Vec3vfx(Vec3fa(ray.dir)),Ng_inner1)) < 0.3f); + + /* subtract the inner interval from the current hit interval */ + BBox tp0, tp1; + subtract(tp,tc_inner,tp0,tp1); + vboolx valid0 = valid & (tp0.lower <= tp0.upper); + vboolx valid1 = valid & (tp1.lower <= tp1.upper); + if (none(valid0 | valid1)) continue; + + /* iterate over all first hits front to back */ + const vintx termDepth0 = select(unstable0,vintx(maxDepth+1),vintx(maxDepth)); + vboolx recursion_valid0 = valid0 & (depth < termDepth0); + valid0 &= depth >= termDepth0; + + while (any(valid0)) + { + const size_t i = select_min(valid0,tp0.lower); clear(valid0,i); + found = found | intersect_bezier_iterative_jacobian(ray,dt,curve,u_outer0[i],tp0.lower[i],epilog); + //found = found | intersect_bezier_iterative_debug (ray,dt,curve,i,u_outer0,tp0,h0,h1,Ng_outer0,dP0du,dP3du,epilog); + valid0 &= tp0.lower+dt <= ray.tfar; + } + valid1 &= tp1.lower+dt <= ray.tfar; + + /* iterate over all second hits front to back */ + const vintx termDepth1 = select(unstable1,vintx(maxDepth+1),vintx(maxDepth)); + vboolx recursion_valid1 = valid1 & (depth < termDepth1); + valid1 &= depth >= termDepth1; + while (any(valid1)) + { + const size_t i = select_min(valid1,tp1.lower); clear(valid1,i); + found = found | intersect_bezier_iterative_jacobian(ray,dt,curve,u_outer1[i],tp1.upper[i],epilog); + //found = found | intersect_bezier_iterative_debug (ray,dt,curve,i,u_outer1,tp1,h0,h1,Ng_outer1,dP0du,dP3du,epilog); + valid1 &= tp1.lower+dt <= ray.tfar; + } + + /* push valid segments to stack */ + recursion_valid0 &= tp0.lower+dt <= ray.tfar; + recursion_valid1 &= tp1.lower+dt <= ray.tfar; + const vboolx recursion_valid = recursion_valid0 | recursion_valid1; + if (any(recursion_valid)) + { + assert(sptr < stack_size); + stack[sptr].valid = recursion_valid; + stack[sptr].tlower = select(recursion_valid0,tp0.lower,tp1.lower); + stack[sptr].u0 = u0; + stack[sptr].u1 = u1; + stack[sptr].depth = depth+1; + sptr++; + } + } + return found; + } + + template class NativeCurve> + struct SweepCurve1Intersector1 + { + typedef NativeCurve NativeCurve3ff; + + template + __noinline bool intersect(const CurvePrecalculations1& pre, Ray& ray, + IntersectContext* context, + const CurveGeometry* geom, const unsigned int primID, + const Vec3ff& v0, const Vec3ff& v1, const Vec3ff& v2, const Vec3ff& v3, + const Epilog& epilog) + { + STAT3(normal.trav_prims,1,1,1); + + /* move ray closer to make intersection stable */ + NativeCurve3ff curve0(v0,v1,v2,v3); + curve0 = enlargeRadiusToMinWidth(context,geom,ray.org,curve0); + const float dt = dot(curve0.center()-ray.org,ray.dir)*rcp(dot(ray.dir,ray.dir)); + const Vec3ff ref(madd(Vec3fa(dt),ray.dir,ray.org),0.0f); + const NativeCurve3ff curve1 = curve0-ref; + return intersect_bezier_recursive_jacobian(ray,dt,curve1,0.0f,1.0f,1,epilog); + } + }; + + template class NativeCurve, int K> + struct SweepCurve1IntersectorK + { + typedef NativeCurve NativeCurve3ff; + + struct Ray1 + { + __forceinline Ray1(RayK& ray, size_t k) + : org(ray.org.x[k],ray.org.y[k],ray.org.z[k]), dir(ray.dir.x[k],ray.dir.y[k],ray.dir.z[k]), _tnear(ray.tnear()[k]), tfar(ray.tfar[k]) {} + + Vec3fa org; + Vec3fa dir; + float _tnear; + float& tfar; + + __forceinline float& tnear() { return _tnear; } + //__forceinline float& tfar() { return _tfar; } + __forceinline const float& tnear() const { return _tnear; } + //__forceinline const float& tfar() const { return _tfar; } + + }; + + template + __forceinline bool intersect(const CurvePrecalculationsK& pre, RayK& vray, size_t k, + IntersectContext* context, + const CurveGeometry* geom, const unsigned int primID, + const Vec3ff& v0, const Vec3ff& v1, const Vec3ff& v2, const Vec3ff& v3, + const Epilog& epilog) + { + STAT3(normal.trav_prims,1,1,1); + Ray1 ray(vray,k); + + /* move ray closer to make intersection stable */ + NativeCurve3ff curve0(v0,v1,v2,v3); + curve0 = enlargeRadiusToMinWidth(context,geom,ray.org,curve0); + const float dt = dot(curve0.center()-ray.org,ray.dir)*rcp(dot(ray.dir,ray.dir)); + const Vec3ff ref(madd(Vec3fa(dt),ray.dir,ray.org),0.0f); + const NativeCurve3ff curve1 = curve0-ref; + return intersect_bezier_recursive_jacobian(ray,dt,curve1,0.0f,1.0f,1,epilog); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/curve_intersector_virtual.h b/thirdparty/embree/kernels/geometry/curve_intersector_virtual.h new file mode 100644 index 000000000000..e1f42381306e --- /dev/null +++ b/thirdparty/embree/kernels/geometry/curve_intersector_virtual.h @@ -0,0 +1,671 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" +#include "../subdiv/bezier_curve.h" +#include "../common/primref.h" +#include "curve_intersector_precalculations.h" +#include "../bvh/node_intersector1.h" +#include "../bvh/node_intersector_packet.h" + +#include "intersector_epilog.h" + +#include "../subdiv/bezier_curve.h" +#include "../subdiv/bspline_curve.h" +#include "../subdiv/hermite_curve.h" +#include "../subdiv/catmullrom_curve.h" + +#include "spherei_intersector.h" +#include "disci_intersector.h" + +#include "linei_intersector.h" +#include "roundlinei_intersector.h" +#include "conelinei_intersector.h" + +#include "curveNi_intersector.h" +#include "curveNv_intersector.h" +#include "curveNi_mb_intersector.h" + +#include "curve_intersector_distance.h" +#include "curve_intersector_ribbon.h" +#include "curve_intersector_oriented.h" +#include "curve_intersector_sweep.h" + +namespace embree +{ + struct VirtualCurveIntersector + { + typedef void (*Intersect1Ty)(void* pre, void* ray, IntersectContext* context, const void* primitive); + typedef bool (*Occluded1Ty )(void* pre, void* ray, IntersectContext* context, const void* primitive); + + typedef void (*Intersect4Ty)(void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive); + typedef bool (*Occluded4Ty) (void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive); + + typedef void (*Intersect8Ty)(void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive); + typedef bool (*Occluded8Ty) (void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive); + + typedef void (*Intersect16Ty)(void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive); + typedef bool (*Occluded16Ty) (void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive); + + public: + struct Intersectors + { + Intersectors() {} // WARNING: Do not zero initialize this, as we otherwise get problems with thread unsafe local static variable initialization (e.g. on VS2013) in curve_intersector_virtual.cpp. + + template void intersect(void* pre, void* ray, IntersectContext* context, const void* primitive); + template bool occluded (void* pre, void* ray, IntersectContext* context, const void* primitive); + + template void intersect(void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive); + template bool occluded (void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive); + + public: + Intersect1Ty intersect1; + Occluded1Ty occluded1; + Intersect4Ty intersect4; + Occluded4Ty occluded4; + Intersect8Ty intersect8; + Occluded8Ty occluded8; + Intersect16Ty intersect16; + Occluded16Ty occluded16; + }; + + Intersectors vtbl[Geometry::GTY_END]; + }; + + template<> __forceinline void VirtualCurveIntersector::Intersectors::intersect<1> (void* pre, void* ray, IntersectContext* context, const void* primitive) { assert(intersect1); intersect1(pre,ray,context,primitive); } + template<> __forceinline bool VirtualCurveIntersector::Intersectors::occluded<1> (void* pre, void* ray, IntersectContext* context, const void* primitive) { assert(occluded1); return occluded1(pre,ray,context,primitive); } + + template<> __forceinline void VirtualCurveIntersector::Intersectors::intersect<4>(void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive) { assert(intersect4); intersect4(pre,ray,k,context,primitive); } + template<> __forceinline bool VirtualCurveIntersector::Intersectors::occluded<4> (void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive) { assert(occluded4); return occluded4(pre,ray,k,context,primitive); } + +#if defined(__AVX__) + template<> __forceinline void VirtualCurveIntersector::Intersectors::intersect<8>(void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive) { assert(intersect8); intersect8(pre,ray,k,context,primitive); } + template<> __forceinline bool VirtualCurveIntersector::Intersectors::occluded<8> (void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive) { assert(occluded8); return occluded8(pre,ray,k,context,primitive); } +#endif + +#if defined(__AVX512F__) + template<> __forceinline void VirtualCurveIntersector::Intersectors::intersect<16>(void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive) { assert(intersect16); intersect16(pre,ray,k,context,primitive); } + template<> __forceinline bool VirtualCurveIntersector::Intersectors::occluded<16> (void* pre, void* ray, size_t k, IntersectContext* context, const void* primitive) { assert(occluded16); return occluded16(pre,ray,k,context,primitive); } +#endif + + namespace isa + { + struct VirtualCurveIntersector1 + { + typedef unsigned char Primitive; + typedef CurvePrecalculations1 Precalculations; + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + assert(num == 1); + RTCGeometryType ty = (RTCGeometryType)(*prim); + assert(This->leafIntersector); + VirtualCurveIntersector::Intersectors& leafIntersector = ((VirtualCurveIntersector*) This->leafIntersector)->vtbl[ty]; + leafIntersector.intersect<1>(&pre,&ray,context,prim); + } + + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + assert(num == 1); + RTCGeometryType ty = (RTCGeometryType)(*prim); + assert(This->leafIntersector); + VirtualCurveIntersector::Intersectors& leafIntersector = ((VirtualCurveIntersector*) This->leafIntersector)->vtbl[ty]; + return leafIntersector.occluded<1>(&pre,&ray,context,prim); + } + }; + + template + struct VirtualCurveIntersectorK + { + typedef unsigned char Primitive; + typedef CurvePrecalculationsK Precalculations; + + template + static __forceinline void intersect(const vbool& valid_i, const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRayK &tray, size_t& lazy_node) + { + assert(num == 1); + RTCGeometryType ty = (RTCGeometryType)(*prim); + assert(This->leafIntersector); + VirtualCurveIntersector::Intersectors& leafIntersector = ((VirtualCurveIntersector*) This->leafIntersector)->vtbl[ty]; + size_t mask = movemask(valid_i); + while (mask) leafIntersector.intersect(&pre,&ray,bscf(mask),context,prim); + } + + template + static __forceinline vbool occluded(const vbool& valid_i, const Accel::Intersectors* This, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRayK &tray, size_t& lazy_node) + { + assert(num == 1); + RTCGeometryType ty = (RTCGeometryType)(*prim); + assert(This->leafIntersector); + VirtualCurveIntersector::Intersectors& leafIntersector = ((VirtualCurveIntersector*) This->leafIntersector)->vtbl[ty]; + vbool valid_o = false; + size_t mask = movemask(valid_i); + while (mask) { + size_t k = bscf(mask); + if (leafIntersector.occluded(&pre,&ray,k,context,prim)) + set(valid_o, k); + } + return valid_o; + } + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + assert(num == 1); + RTCGeometryType ty = (RTCGeometryType)(*prim); + assert(This->leafIntersector); + VirtualCurveIntersector::Intersectors& leafIntersector = ((VirtualCurveIntersector*) This->leafIntersector)->vtbl[ty]; + leafIntersector.intersect(&pre,&ray,k,context,prim); + } + + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + assert(num == 1); + RTCGeometryType ty = (RTCGeometryType)(*prim); + assert(This->leafIntersector); + VirtualCurveIntersector::Intersectors& leafIntersector = ((VirtualCurveIntersector*) This->leafIntersector)->vtbl[ty]; + return leafIntersector.occluded(&pre,&ray,k,context,prim); + } + }; + + template + static VirtualCurveIntersector::Intersectors LinearRoundConeNiIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &RoundLinearCurveMiIntersector1::intersect; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &RoundLinearCurveMiIntersector1::occluded; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &RoundLinearCurveMiIntersectorK::intersect; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &RoundLinearCurveMiIntersectorK::occluded; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&RoundLinearCurveMiIntersectorK::intersect; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &RoundLinearCurveMiIntersectorK::occluded; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&RoundLinearCurveMiIntersectorK::intersect; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &RoundLinearCurveMiIntersectorK::occluded; +#endif + return intersectors; + } + + template + static VirtualCurveIntersector::Intersectors LinearConeNiIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &ConeCurveMiIntersector1::intersect; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &ConeCurveMiIntersector1::occluded; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &ConeCurveMiIntersectorK::intersect; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &ConeCurveMiIntersectorK::occluded; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&ConeCurveMiIntersectorK::intersect; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &ConeCurveMiIntersectorK::occluded; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&ConeCurveMiIntersectorK::intersect; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &ConeCurveMiIntersectorK::occluded; +#endif + return intersectors; + } + + template + static VirtualCurveIntersector::Intersectors LinearRoundConeNiMBIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &RoundLinearCurveMiMBIntersector1::intersect; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &RoundLinearCurveMiMBIntersector1::occluded; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &RoundLinearCurveMiMBIntersectorK::intersect; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &RoundLinearCurveMiMBIntersectorK::occluded; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&RoundLinearCurveMiMBIntersectorK::intersect; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &RoundLinearCurveMiMBIntersectorK::occluded; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&RoundLinearCurveMiMBIntersectorK::intersect; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &RoundLinearCurveMiMBIntersectorK::occluded; +#endif + return intersectors; + } + + template + static VirtualCurveIntersector::Intersectors LinearConeNiMBIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &ConeCurveMiMBIntersector1::intersect; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &ConeCurveMiMBIntersector1::occluded; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &ConeCurveMiMBIntersectorK::intersect; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &ConeCurveMiMBIntersectorK::occluded; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&ConeCurveMiMBIntersectorK::intersect; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &ConeCurveMiMBIntersectorK::occluded; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&ConeCurveMiMBIntersectorK::intersect; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &ConeCurveMiMBIntersectorK::occluded; +#endif + return intersectors; + } + + + template + static VirtualCurveIntersector::Intersectors LinearRibbonNiIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &FlatLinearCurveMiIntersector1::intersect; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &FlatLinearCurveMiIntersector1::occluded; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &FlatLinearCurveMiIntersectorK::intersect; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &FlatLinearCurveMiIntersectorK::occluded; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&FlatLinearCurveMiIntersectorK::intersect; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &FlatLinearCurveMiIntersectorK::occluded; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&FlatLinearCurveMiIntersectorK::intersect; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &FlatLinearCurveMiIntersectorK::occluded; +#endif + return intersectors; + } + + template + static VirtualCurveIntersector::Intersectors LinearRibbonNiMBIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &FlatLinearCurveMiMBIntersector1::intersect; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &FlatLinearCurveMiMBIntersector1::occluded; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &FlatLinearCurveMiMBIntersectorK::intersect; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &FlatLinearCurveMiMBIntersectorK::occluded; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&FlatLinearCurveMiMBIntersectorK::intersect; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &FlatLinearCurveMiMBIntersectorK::occluded; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&FlatLinearCurveMiMBIntersectorK::intersect; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &FlatLinearCurveMiMBIntersectorK::occluded; +#endif + return intersectors; + } + + template + static VirtualCurveIntersector::Intersectors SphereNiIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &SphereMiIntersector1::intersect; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &SphereMiIntersector1::occluded; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &SphereMiIntersectorK::intersect; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &SphereMiIntersectorK::occluded; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&SphereMiIntersectorK::intersect; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &SphereMiIntersectorK::occluded; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&SphereMiIntersectorK::intersect; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &SphereMiIntersectorK::occluded; +#endif + return intersectors; + } + + template + static VirtualCurveIntersector::Intersectors SphereNiMBIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &SphereMiMBIntersector1::intersect; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &SphereMiMBIntersector1::occluded; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &SphereMiMBIntersectorK::intersect; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &SphereMiMBIntersectorK::occluded; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&SphereMiMBIntersectorK::intersect; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &SphereMiMBIntersectorK::occluded; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&SphereMiMBIntersectorK::intersect; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &SphereMiMBIntersectorK::occluded; +#endif + return intersectors; + } + + template + static VirtualCurveIntersector::Intersectors DiscNiIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &DiscMiIntersector1::intersect; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &DiscMiIntersector1::occluded; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &DiscMiIntersectorK::intersect; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &DiscMiIntersectorK::occluded; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&DiscMiIntersectorK::intersect; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &DiscMiIntersectorK::occluded; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&DiscMiIntersectorK::intersect; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &DiscMiIntersectorK::occluded; +#endif + return intersectors; + } + + template + static VirtualCurveIntersector::Intersectors DiscNiMBIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &DiscMiMBIntersector1::intersect; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &DiscMiMBIntersector1::occluded; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &DiscMiMBIntersectorK::intersect; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &DiscMiMBIntersectorK::occluded; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&DiscMiMBIntersectorK::intersect; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &DiscMiMBIntersectorK::occluded; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&DiscMiMBIntersectorK::intersect; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &DiscMiMBIntersectorK::occluded; +#endif + return intersectors; + } + + template + static VirtualCurveIntersector::Intersectors OrientedDiscNiIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &OrientedDiscMiIntersector1::intersect; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &OrientedDiscMiIntersector1::occluded; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &OrientedDiscMiIntersectorK::intersect; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &OrientedDiscMiIntersectorK::occluded; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&OrientedDiscMiIntersectorK::intersect; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &OrientedDiscMiIntersectorK::occluded; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&OrientedDiscMiIntersectorK::intersect; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &OrientedDiscMiIntersectorK::occluded; +#endif + return intersectors; + } + + template + static VirtualCurveIntersector::Intersectors OrientedDiscNiMBIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &OrientedDiscMiMBIntersector1::intersect; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &OrientedDiscMiMBIntersector1::occluded; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &OrientedDiscMiMBIntersectorK::intersect; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &OrientedDiscMiMBIntersectorK::occluded; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&OrientedDiscMiMBIntersectorK::intersect; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &OrientedDiscMiMBIntersectorK::occluded; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&OrientedDiscMiMBIntersectorK::intersect; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &OrientedDiscMiMBIntersectorK::occluded; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors RibbonNiIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNiIntersector1::template intersect_t, Intersect1EpilogMU >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNiIntersector1::template occluded_t , Occluded1EpilogMU >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &CurveNiIntersectorK::template intersect_t, Intersect1KEpilogMU >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNiIntersectorK::template occluded_t , Occluded1KEpilogMU >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNiIntersectorK::template intersect_t, Intersect1KEpilogMU >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNiIntersectorK::template occluded_t , Occluded1KEpilogMU >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNiIntersectorK::template intersect_t, Intersect1KEpilogMU >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNiIntersectorK::template occluded_t , Occluded1KEpilogMU >; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors RibbonNvIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNvIntersector1::template intersect_t, Intersect1EpilogMU >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNvIntersector1::template occluded_t , Occluded1EpilogMU >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &CurveNvIntersectorK::template intersect_t, Intersect1KEpilogMU >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNvIntersectorK::template occluded_t , Occluded1KEpilogMU >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNvIntersectorK::template intersect_t, Intersect1KEpilogMU >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNvIntersectorK::template occluded_t , Occluded1KEpilogMU >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNvIntersectorK::template intersect_t, Intersect1KEpilogMU >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNvIntersectorK::template occluded_t , Occluded1KEpilogMU >; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors RibbonNiMBIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNiMBIntersector1::template intersect_t, Intersect1EpilogMU >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNiMBIntersector1::template occluded_t , Occluded1EpilogMU >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty) &CurveNiMBIntersectorK::template intersect_t, Intersect1KEpilogMU >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNiMBIntersectorK::template occluded_t , Occluded1KEpilogMU >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNiMBIntersectorK::template intersect_t, Intersect1KEpilogMU >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNiMBIntersectorK::template occluded_t , Occluded1KEpilogMU >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNiMBIntersectorK::template intersect_t, Intersect1KEpilogMU >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNiMBIntersectorK::template occluded_t , Occluded1KEpilogMU >; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors CurveNiIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNiIntersector1::template intersect_t, Intersect1Epilog1 >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNiIntersector1::template occluded_t , Occluded1Epilog1 >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty)&CurveNiIntersectorK::template intersect_t, Intersect1KEpilog1<4,true> >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNiIntersectorK::template occluded_t , Occluded1KEpilog1<4,true> >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNiIntersectorK::template intersect_t, Intersect1KEpilog1<8,true> >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNiIntersectorK::template occluded_t , Occluded1KEpilog1<8,true> >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNiIntersectorK::template intersect_t, Intersect1KEpilog1<16,true> >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNiIntersectorK::template occluded_t , Occluded1KEpilog1<16,true> >; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors CurveNvIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNvIntersector1::template intersect_t, Intersect1Epilog1 >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNvIntersector1::template occluded_t , Occluded1Epilog1 >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty)&CurveNvIntersectorK::template intersect_t, Intersect1KEpilog1<4,true> >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNvIntersectorK::template occluded_t , Occluded1KEpilog1<4,true> >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNvIntersectorK::template intersect_t, Intersect1KEpilog1<8,true> >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNvIntersectorK::template occluded_t , Occluded1KEpilog1<8,true> >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNvIntersectorK::template intersect_t, Intersect1KEpilog1<16,true> >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNvIntersectorK::template occluded_t , Occluded1KEpilog1<16,true> >; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors CurveNiMBIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNiMBIntersector1::template intersect_t, Intersect1Epilog1 >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNiMBIntersector1::template occluded_t , Occluded1Epilog1 >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty)&CurveNiMBIntersectorK::template intersect_t, Intersect1KEpilog1<4,true> >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNiMBIntersectorK::template occluded_t , Occluded1KEpilog1<4,true> >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNiMBIntersectorK::template intersect_t, Intersect1KEpilog1<8,true> >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNiMBIntersectorK::template occluded_t , Occluded1KEpilog1<8,true> >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNiMBIntersectorK::template intersect_t, Intersect1KEpilog1<16,true> >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNiMBIntersectorK::template occluded_t , Occluded1KEpilog1<16,true> >; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors OrientedCurveNiIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNiIntersector1::template intersect_n, Intersect1Epilog1 >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNiIntersector1::template occluded_n , Occluded1Epilog1 >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty)&CurveNiIntersectorK::template intersect_n, Intersect1KEpilog1<4,true> >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNiIntersectorK::template occluded_n , Occluded1KEpilog1<4,true> >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNiIntersectorK::template intersect_n, Intersect1KEpilog1<8,true> >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNiIntersectorK::template occluded_n , Occluded1KEpilog1<8,true> >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNiIntersectorK::template intersect_n, Intersect1KEpilog1<16,true> >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNiIntersectorK::template occluded_n , Occluded1KEpilog1<16,true> >; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors OrientedCurveNiMBIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNiMBIntersector1::template intersect_n, Intersect1Epilog1 >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNiMBIntersector1::template occluded_n , Occluded1Epilog1 >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty)&CurveNiMBIntersectorK::template intersect_n, Intersect1KEpilog1<4,true> >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNiMBIntersectorK::template occluded_n , Occluded1KEpilog1<4,true> >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNiMBIntersectorK::template intersect_n, Intersect1KEpilog1<8,true> >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNiMBIntersectorK::template occluded_n , Occluded1KEpilog1<8,true> >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNiMBIntersectorK::template intersect_n, Intersect1KEpilog1<16,true> >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNiMBIntersectorK::template occluded_n , Occluded1KEpilog1<16,true> >; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors HermiteRibbonNiIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNiIntersector1::template intersect_h, Intersect1EpilogMU >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNiIntersector1::template occluded_h , Occluded1EpilogMU >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty)&CurveNiIntersectorK::template intersect_h, Intersect1KEpilogMU >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNiIntersectorK::template occluded_h , Occluded1KEpilogMU >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNiIntersectorK::template intersect_h, Intersect1KEpilogMU >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNiIntersectorK::template occluded_h , Occluded1KEpilogMU >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNiIntersectorK::template intersect_h, Intersect1KEpilogMU >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNiIntersectorK::template occluded_h , Occluded1KEpilogMU >; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors HermiteRibbonNiMBIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNiMBIntersector1::template intersect_h, Intersect1EpilogMU >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNiMBIntersector1::template occluded_h , Occluded1EpilogMU >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty)&CurveNiMBIntersectorK::template intersect_h, Intersect1KEpilogMU >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNiMBIntersectorK::template occluded_h , Occluded1KEpilogMU >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNiMBIntersectorK::template intersect_h, Intersect1KEpilogMU >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNiMBIntersectorK::template occluded_h , Occluded1KEpilogMU >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNiMBIntersectorK::template intersect_h, Intersect1KEpilogMU >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNiMBIntersectorK::template occluded_h , Occluded1KEpilogMU >; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors HermiteCurveNiIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNiIntersector1::template intersect_h, Intersect1Epilog1 >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNiIntersector1::template occluded_h , Occluded1Epilog1 >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty)&CurveNiIntersectorK::template intersect_h, Intersect1KEpilog1<4,true> >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNiIntersectorK::template occluded_h , Occluded1KEpilog1<4,true> >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNiIntersectorK::template intersect_h, Intersect1KEpilog1<8,true> >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNiIntersectorK::template occluded_h , Occluded1KEpilog1<8,true> >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNiIntersectorK::template intersect_h, Intersect1KEpilog1<16,true> >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNiIntersectorK::template occluded_h , Occluded1KEpilog1<16,true> >; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors HermiteCurveNiMBIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNiMBIntersector1::template intersect_h, Intersect1Epilog1 >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNiMBIntersector1::template occluded_h , Occluded1Epilog1 >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty)&CurveNiMBIntersectorK::template intersect_h, Intersect1KEpilog1<4,true> >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNiMBIntersectorK::template occluded_h , Occluded1KEpilog1<4,true> >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNiMBIntersectorK::template intersect_h, Intersect1KEpilog1<8,true> >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNiMBIntersectorK::template occluded_h , Occluded1KEpilog1<8,true> >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNiMBIntersectorK::template intersect_h, Intersect1KEpilog1<16,true> >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNiMBIntersectorK::template occluded_h , Occluded1KEpilog1<16,true> >; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors HermiteOrientedCurveNiIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNiIntersector1::template intersect_hn, Intersect1Epilog1 >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNiIntersector1::template occluded_hn , Occluded1Epilog1 >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty)&CurveNiIntersectorK::template intersect_hn, Intersect1KEpilog1<4,true> >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNiIntersectorK::template occluded_hn , Occluded1KEpilog1<4,true> >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNiIntersectorK::template intersect_hn, Intersect1KEpilog1<8,true> >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNiIntersectorK::template occluded_hn , Occluded1KEpilog1<8,true> >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNiIntersectorK::template intersect_hn, Intersect1KEpilog1<16,true> >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNiIntersectorK::template occluded_hn , Occluded1KEpilog1<16,true> >; +#endif + return intersectors; + } + + template class Curve, int N> + static VirtualCurveIntersector::Intersectors HermiteOrientedCurveNiMBIntersectors() + { + VirtualCurveIntersector::Intersectors intersectors; + intersectors.intersect1 = (VirtualCurveIntersector::Intersect1Ty) &CurveNiMBIntersector1::template intersect_hn, Intersect1Epilog1 >; + intersectors.occluded1 = (VirtualCurveIntersector::Occluded1Ty) &CurveNiMBIntersector1::template occluded_hn , Occluded1Epilog1 >; + intersectors.intersect4 = (VirtualCurveIntersector::Intersect4Ty)&CurveNiMBIntersectorK::template intersect_hn, Intersect1KEpilog1<4,true> >; + intersectors.occluded4 = (VirtualCurveIntersector::Occluded4Ty) &CurveNiMBIntersectorK::template occluded_hn , Occluded1KEpilog1<4,true> >; +#if defined(__AVX__) + intersectors.intersect8 = (VirtualCurveIntersector::Intersect8Ty)&CurveNiMBIntersectorK::template intersect_hn, Intersect1KEpilog1<8,true> >; + intersectors.occluded8 = (VirtualCurveIntersector::Occluded8Ty) &CurveNiMBIntersectorK::template occluded_hn , Occluded1KEpilog1<8,true> >; +#endif +#if defined(__AVX512F__) + intersectors.intersect16 = (VirtualCurveIntersector::Intersect16Ty)&CurveNiMBIntersectorK::template intersect_hn, Intersect1KEpilog1<16,true> >; + intersectors.occluded16 = (VirtualCurveIntersector::Occluded16Ty) &CurveNiMBIntersectorK::template occluded_hn , Occluded1KEpilog1<16,true> >; +#endif + return intersectors; + } + } +} diff --git a/thirdparty/embree/kernels/geometry/cylinder.h b/thirdparty/embree/kernels/geometry/cylinder.h new file mode 100644 index 000000000000..39a582864c52 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/cylinder.h @@ -0,0 +1,223 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" + +namespace embree +{ + namespace isa + { + struct Cylinder + { + const Vec3fa p0; //!< start location + const Vec3fa p1; //!< end position + const float rr; //!< squared radius of cylinder + + __forceinline Cylinder(const Vec3fa& p0, const Vec3fa& p1, const float r) + : p0(p0), p1(p1), rr(sqr(r)) {} + + __forceinline Cylinder(const Vec3fa& p0, const Vec3fa& p1, const float rr, bool) + : p0(p0), p1(p1), rr(rr) {} + + __forceinline bool intersect(const Vec3fa& org, + const Vec3fa& dir, + BBox1f& t_o, + float& u0_o, Vec3fa& Ng0_o, + float& u1_o, Vec3fa& Ng1_o) const + { + /* calculate quadratic equation to solve */ + const float rl = rcp_length(p1-p0); + const Vec3fa P0 = p0, dP = (p1-p0)*rl; + const Vec3fa O = org-P0, dO = dir; + + const float dOdO = dot(dO,dO); + const float OdO = dot(dO,O); + const float OO = dot(O,O); + const float dOz = dot(dP,dO); + const float Oz = dot(dP,O); + + const float A = dOdO - sqr(dOz); + const float B = 2.0f * (OdO - dOz*Oz); + const float C = OO - sqr(Oz) - rr; + + /* we miss the cylinder if determinant is smaller than zero */ + const float D = B*B - 4.0f*A*C; + if (D < 0.0f) { + t_o = BBox1f(pos_inf,neg_inf); + return false; + } + + /* special case for rays that are parallel to the cylinder */ + const float eps = 16.0f*float(ulp)*max(abs(dOdO),abs(sqr(dOz))); + if (abs(A) < eps) + { + if (C <= 0.0f) { + t_o = BBox1f(neg_inf,pos_inf); + return true; + } else { + t_o = BBox1f(pos_inf,neg_inf); + return false; + } + } + + /* standard case for rays that are not parallel to the cylinder */ + const float Q = sqrt(D); + const float rcp_2A = rcp(2.0f*A); + const float t0 = (-B-Q)*rcp_2A; + const float t1 = (-B+Q)*rcp_2A; + + /* calculates u and Ng for near hit */ + { + u0_o = madd(t0,dOz,Oz)*rl; + const Vec3fa Pr = t0*dir; + const Vec3fa Pl = madd(u0_o,p1-p0,p0); + Ng0_o = Pr-Pl; + } + + /* calculates u and Ng for far hit */ + { + u1_o = madd(t1,dOz,Oz)*rl; + const Vec3fa Pr = t1*dir; + const Vec3fa Pl = madd(u1_o,p1-p0,p0); + Ng1_o = Pr-Pl; + } + + t_o.lower = t0; + t_o.upper = t1; + return true; + } + + __forceinline bool intersect(const Vec3fa& org_i, const Vec3fa& dir, BBox1f& t_o) const + { + float u0_o; Vec3fa Ng0_o; + float u1_o; Vec3fa Ng1_o; + return intersect(org_i,dir,t_o,u0_o,Ng0_o,u1_o,Ng1_o); + } + + static bool verify(const size_t id, const Cylinder& cylinder, const RayHit& ray, bool shouldhit, const float t0, const float t1) + { + float eps = 0.001f; + BBox1f t; bool hit; + hit = cylinder.intersect(ray.org,ray.dir,t); + + bool failed = hit != shouldhit; + if (shouldhit) failed |= std::isinf(t0) ? t0 != t.lower : abs(t0-t.lower) > eps; + if (shouldhit) failed |= std::isinf(t1) ? t1 != t.upper : abs(t1-t.upper) > eps; + if (!failed) return true; + embree_cout << "Cylinder test " << id << " failed: cylinder = " << cylinder << ", ray = " << ray << ", hit = " << hit << ", t = " << t << embree_endl; + return false; + } + + /* verify cylinder class */ + static bool verify() + { + bool passed = true; + const Cylinder cylinder(Vec3fa(0.0f,0.0f,0.0f),Vec3fa(1.0f,0.0f,0.0f),1.0f); + passed &= verify(0,cylinder,RayHit(Vec3fa(-2.0f,1.0f,0.0f),Vec3fa( 0.0f,-1.0f,+0.0f),0.0f,float(inf)),true,0.0f,2.0f); + passed &= verify(1,cylinder,RayHit(Vec3fa(+2.0f,1.0f,0.0f),Vec3fa( 0.0f,-1.0f,+0.0f),0.0f,float(inf)),true,0.0f,2.0f); + passed &= verify(2,cylinder,RayHit(Vec3fa(+2.0f,1.0f,2.0f),Vec3fa( 0.0f,-1.0f,+0.0f),0.0f,float(inf)),false,0.0f,0.0f); + passed &= verify(3,cylinder,RayHit(Vec3fa(+0.0f,0.0f,0.0f),Vec3fa( 1.0f, 0.0f,+0.0f),0.0f,float(inf)),true,neg_inf,pos_inf); + passed &= verify(4,cylinder,RayHit(Vec3fa(+0.0f,0.0f,0.0f),Vec3fa(-1.0f, 0.0f,+0.0f),0.0f,float(inf)),true,neg_inf,pos_inf); + passed &= verify(5,cylinder,RayHit(Vec3fa(+0.0f,2.0f,0.0f),Vec3fa( 1.0f, 0.0f,+0.0f),0.0f,float(inf)),false,pos_inf,neg_inf); + passed &= verify(6,cylinder,RayHit(Vec3fa(+0.0f,2.0f,0.0f),Vec3fa(-1.0f, 0.0f,+0.0f),0.0f,float(inf)),false,pos_inf,neg_inf); + return passed; + } + + /*! output operator */ + friend __forceinline embree_ostream operator<<(embree_ostream cout, const Cylinder& c) { + return cout << "Cylinder { p0 = " << c.p0 << ", p1 = " << c.p1 << ", r = " << sqrtf(c.rr) << "}"; + } + }; + + template + struct CylinderN + { + const Vec3vf p0; //!< start location + const Vec3vf p1; //!< end position + const vfloat rr; //!< squared radius of cylinder + + __forceinline CylinderN(const Vec3vf& p0, const Vec3vf& p1, const vfloat& r) + : p0(p0), p1(p1), rr(sqr(r)) {} + + __forceinline CylinderN(const Vec3vf& p0, const Vec3vf& p1, const vfloat& rr, bool) + : p0(p0), p1(p1), rr(rr) {} + + + __forceinline vbool intersect(const Vec3fa& org, const Vec3fa& dir, + BBox>& t_o, + vfloat& u0_o, Vec3vf& Ng0_o, + vfloat& u1_o, Vec3vf& Ng1_o) const + { + /* calculate quadratic equation to solve */ + const vfloat rl = rcp_length(p1-p0); + const Vec3vf P0 = p0, dP = (p1-p0)*rl; + const Vec3vf O = Vec3vf(org)-P0, dO = dir; + + const vfloat dOdO = dot(dO,dO); + const vfloat OdO = dot(dO,O); + const vfloat OO = dot(O,O); + const vfloat dOz = dot(dP,dO); + const vfloat Oz = dot(dP,O); + + const vfloat A = dOdO - sqr(dOz); + const vfloat B = 2.0f * (OdO - dOz*Oz); + const vfloat C = OO - sqr(Oz) - rr; + + /* we miss the cylinder if determinant is smaller than zero */ + const vfloat D = B*B - 4.0f*A*C; + vbool valid = D >= 0.0f; + if (none(valid)) { + t_o = BBox>(empty); + return valid; + } + + /* standard case for rays that are not parallel to the cylinder */ + const vfloat Q = sqrt(D); + const vfloat rcp_2A = rcp(2.0f*A); + const vfloat t0 = (-B-Q)*rcp_2A; + const vfloat t1 = (-B+Q)*rcp_2A; + + /* calculates u and Ng for near hit */ + { + u0_o = madd(t0,dOz,Oz)*rl; + const Vec3vf Pr = t0*Vec3vf(dir); + const Vec3vf Pl = madd(u0_o,p1-p0,p0); + Ng0_o = Pr-Pl; + } + + /* calculates u and Ng for far hit */ + { + u1_o = madd(t1,dOz,Oz)*rl; + const Vec3vf Pr = t1*Vec3vf(dir); + const Vec3vf Pl = madd(u1_o,p1-p0,p0); + Ng1_o = Pr-Pl; + } + + t_o.lower = select(valid, t0, vfloat(pos_inf)); + t_o.upper = select(valid, t1, vfloat(neg_inf)); + + /* special case for rays that are parallel to the cylinder */ + const vfloat eps = 16.0f*float(ulp)*max(abs(dOdO),abs(sqr(dOz))); + vbool validt = valid & (abs(A) < eps); + if (unlikely(any(validt))) + { + vbool inside = C <= 0.0f; + t_o.lower = select(validt,select(inside,vfloat(neg_inf),vfloat(pos_inf)),t_o.lower); + t_o.upper = select(validt,select(inside,vfloat(pos_inf),vfloat(neg_inf)),t_o.upper); + valid &= !validt | inside; + } + return valid; + } + + __forceinline vbool intersect(const Vec3fa& org_i, const Vec3fa& dir, BBox>& t_o) const + { + vfloat u0_o; Vec3vf Ng0_o; + vfloat u1_o; Vec3vf Ng1_o; + return intersect(org_i,dir,t_o,u0_o,Ng0_o,u1_o,Ng1_o); + } + }; + } +} + diff --git a/thirdparty/embree/kernels/geometry/disc_intersector.h b/thirdparty/embree/kernels/geometry/disc_intersector.h new file mode 100644 index 000000000000..e8305780e58e --- /dev/null +++ b/thirdparty/embree/kernels/geometry/disc_intersector.h @@ -0,0 +1,216 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" +#include "../common/scene_points.h" +#include "curve_intersector_precalculations.h" + +namespace embree +{ + namespace isa + { + template + struct DiscIntersectorHitM + { + __forceinline DiscIntersectorHitM() {} + + __forceinline DiscIntersectorHitM(const vfloat& u, const vfloat& v, const vfloat& t, const Vec3vf& Ng) + : vu(u), vv(v), vt(t), vNg(Ng) + { + } + + __forceinline void finalize() {} + + __forceinline Vec2f uv(const size_t i) const + { + return Vec2f(vu[i], vv[i]); + } + __forceinline float t(const size_t i) const + { + return vt[i]; + } + __forceinline Vec3fa Ng(const size_t i) const + { + return Vec3fa(vNg.x[i], vNg.y[i], vNg.z[i]); + } + + public: + vfloat vu; + vfloat vv; + vfloat vt; + Vec3vf vNg; + }; + + template + struct DiscIntersector1 + { + typedef CurvePrecalculations1 Precalculations; + + template + static __forceinline bool intersect( + const vbool& valid_i, + Ray& ray, + IntersectContext* context, + const Points* geom, + const Precalculations& pre, + const Vec4vf& v0i, + const Epilog& epilog) + { + vbool valid = valid_i; + + const Vec3vf ray_org(ray.org.x, ray.org.y, ray.org.z); + const Vec3vf ray_dir(ray.dir.x, ray.dir.y, ray.dir.z); + const vfloat rd2 = rcp(dot(ray_dir, ray_dir)); + + const Vec4vf v0 = enlargeRadiusToMinWidth(context,geom,ray_org,v0i); + const Vec3vf center = v0.xyz(); + const vfloat radius = v0.w; + + const Vec3vf c0 = center - ray_org; + const vfloat projC0 = dot(c0, ray_dir) * rd2; + + valid &= (vfloat(ray.tnear()) <= projC0) & (projC0 <= vfloat(ray.tfar)); + if (EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR != 0.0f) + valid &= projC0 > float(EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR) * radius * pre.depth_scale; // ignore self intersections + if (unlikely(none(valid))) + return false; + + const Vec3vf perp = c0 - projC0 * ray_dir; + const vfloat l2 = dot(perp, perp); + const vfloat r2 = radius * radius; + valid &= (l2 <= r2); + if (unlikely(none(valid))) + return false; + + DiscIntersectorHitM hit(zero, zero, projC0, -ray_dir); + return epilog(valid, hit); + } + + template + static __forceinline bool intersect(const vbool& valid_i, + Ray& ray, + IntersectContext* context, + const Points* geom, + const Precalculations& pre, + const Vec4vf& v0i, + const Vec3vf& normal, + const Epilog& epilog) + { + vbool valid = valid_i; + const Vec3vf ray_org(ray.org.x, ray.org.y, ray.org.z); + + const Vec4vf v0 = enlargeRadiusToMinWidth(context,geom,ray_org,v0i); + const Vec3vf center = v0.xyz(); + const vfloat radius = v0.w; + + vfloat divisor = dot(Vec3vf((Vec3fa)ray.dir), normal); + const vbool parallel = divisor == vfloat(0.f); + valid &= !parallel; + divisor = select(parallel, 1.f, divisor); // prevent divide by zero + + vfloat t = dot(center - Vec3vf((Vec3fa)ray.org), Vec3vf(normal)) / divisor; + + valid &= (vfloat(ray.tnear()) <= t) & (t <= vfloat(ray.tfar)); + if (unlikely(none(valid))) + return false; + + Vec3vf intersection = Vec3vf((Vec3fa)ray.org) + Vec3vf((Vec3fa)ray.dir) * t; + vfloat dist2 = dot(intersection - center, intersection - center); + valid &= dist2 < radius * radius; + if (unlikely(none(valid))) + return false; + + DiscIntersectorHitM hit(zero, zero, t, normal); + return epilog(valid, hit); + } + }; + + template + struct DiscIntersectorK + { + typedef CurvePrecalculationsK Precalculations; + + template + static __forceinline bool intersect(const vbool& valid_i, + RayK& ray, + size_t k, + IntersectContext* context, + const Points* geom, + const Precalculations& pre, + const Vec4vf& v0i, + const Epilog& epilog) + { + vbool valid = valid_i; + + const Vec3vf ray_org(ray.org.x[k], ray.org.y[k], ray.org.z[k]); + const Vec3vf ray_dir(ray.dir.x[k], ray.dir.y[k], ray.dir.z[k]); + const vfloat rd2 = rcp(dot(ray_dir, ray_dir)); + + const Vec4vf v0 = enlargeRadiusToMinWidth(context,geom,ray_org,v0i); + const Vec3vf center = v0.xyz(); + const vfloat radius = v0.w; + + const Vec3vf c0 = center - ray_org; + const vfloat projC0 = dot(c0, ray_dir) * rd2; + + valid &= (vfloat(ray.tnear()[k]) <= projC0) & (projC0 <= vfloat(ray.tfar[k])); + if (EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR != 0.0f) + valid &= projC0 > float(EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR) * radius * pre.depth_scale[k]; // ignore self intersections + if (unlikely(none(valid))) + return false; + + const Vec3vf perp = c0 - projC0 * ray_dir; + const vfloat l2 = dot(perp, perp); + const vfloat r2 = radius * radius; + valid &= (l2 <= r2); + if (unlikely(none(valid))) + return false; + + DiscIntersectorHitM hit(zero, zero, projC0, -ray_dir); + return epilog(valid, hit); + } + + template + static __forceinline bool intersect(const vbool& valid_i, + RayK& ray, + size_t k, + IntersectContext* context, + const Points* geom, + const Precalculations& pre, + const Vec4vf& v0i, + const Vec3vf& normal, + const Epilog& epilog) + { + vbool valid = valid_i; + const Vec3vf ray_org(ray.org.x[k], ray.org.y[k], ray.org.z[k]); + const Vec3vf ray_dir(ray.dir.x[k], ray.dir.y[k], ray.dir.z[k]); + + const Vec4vf v0 = enlargeRadiusToMinWidth(context,geom,ray_org,v0i); + const Vec3vf center = v0.xyz(); + const vfloat radius = v0.w; + + vfloat divisor = dot(Vec3vf(ray_dir), normal); + const vbool parallel = divisor == vfloat(0.f); + valid &= !parallel; + divisor = select(parallel, 1.f, divisor); // prevent divide by zero + + vfloat t = dot(center - Vec3vf(ray_org), Vec3vf(normal)) / divisor; + + valid &= (vfloat(ray.tnear()[k]) <= t) & (t <= vfloat(ray.tfar[k])); + if (unlikely(none(valid))) + return false; + + Vec3vf intersection = Vec3vf(ray_org) + Vec3vf(ray_dir) * t; + vfloat dist2 = dot(intersection - center, intersection - center); + valid &= dist2 < radius * radius; + if (unlikely(none(valid))) + return false; + + DiscIntersectorHitM hit(zero, zero, t, normal); + return epilog(valid, hit); + } + }; + } // namespace isa +} // namespace embree diff --git a/thirdparty/embree/kernels/geometry/disci_intersector.h b/thirdparty/embree/kernels/geometry/disci_intersector.h new file mode 100644 index 000000000000..e1dc3aa98ec9 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/disci_intersector.h @@ -0,0 +1,277 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "disc_intersector.h" +#include "intersector_epilog.h" +#include "pointi.h" + +namespace embree +{ + namespace isa + { + template + struct DiscMiIntersector1 + { + typedef PointMi Primitive; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, + RayHit& ray, + IntersectContext* context, + const Primitive& Disc) + { + STAT3(normal.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Disc.gather(v0, geom); + const vbool valid = Disc.template valid(); + DiscIntersector1::intersect( + valid, ray, context, geom, pre, v0, Intersect1EpilogM(ray, context, Disc.geomID(), Disc.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, + Ray& ray, + IntersectContext* context, + const Primitive& Disc) + { + STAT3(shadow.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Disc.gather(v0, geom); + const vbool valid = Disc.template valid(); + return DiscIntersector1::intersect( + valid, ray, context, geom, pre, v0, Occluded1EpilogM(ray, context, Disc.geomID(), Disc.primID())); + } + }; + + template + struct DiscMiMBIntersector1 + { + typedef PointMi Primitive; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, + RayHit& ray, + IntersectContext* context, + const Primitive& Disc) + { + STAT3(normal.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Disc.gather(v0, geom, ray.time()); + const vbool valid = Disc.template valid(); + DiscIntersector1::intersect( + valid, ray, context, geom, pre, v0, Intersect1EpilogM(ray, context, Disc.geomID(), Disc.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, + Ray& ray, + IntersectContext* context, + const Primitive& Disc) + { + STAT3(shadow.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Disc.gather(v0, geom, ray.time()); + const vbool valid = Disc.template valid(); + return DiscIntersector1::intersect( + valid, ray, context, geom, pre, v0, Occluded1EpilogM(ray, context, Disc.geomID(), Disc.primID())); + } + }; + + template + struct DiscMiIntersectorK + { + typedef PointMi Primitive; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline void intersect( + const Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& Disc) + { + STAT3(normal.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Disc.gather(v0, geom); + const vbool valid = Disc.template valid(); + DiscIntersectorK::intersect( + valid, ray, k, context, geom, pre, v0, + Intersect1KEpilogM(ray, k, context, Disc.geomID(), Disc.primID())); + } + + static __forceinline bool occluded( + const Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& Disc) + { + STAT3(shadow.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Disc.gather(v0, geom); + const vbool valid = Disc.template valid(); + return DiscIntersectorK::intersect( + valid, ray, k, context, geom, pre, v0, + Occluded1KEpilogM(ray, k, context, Disc.geomID(), Disc.primID())); + } + }; + + template + struct DiscMiMBIntersectorK + { + typedef PointMi Primitive; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline void intersect( + const Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& Disc) + { + STAT3(normal.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Disc.gather(v0, geom, ray.time()[k]); + const vbool valid = Disc.template valid(); + DiscIntersectorK::intersect( + valid, ray, k, context, geom, pre, v0, + Intersect1KEpilogM(ray, k, context, Disc.geomID(), Disc.primID())); + } + + static __forceinline bool occluded( + const Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& Disc) + { + STAT3(shadow.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Disc.gather(v0, geom, ray.time()[k]); + const vbool valid = Disc.template valid(); + return DiscIntersectorK::intersect( + valid, ray, k, context, geom, pre, v0, Occluded1KEpilogM(ray, k, context, Disc.geomID(), Disc.primID())); + } + }; + + template + struct OrientedDiscMiIntersector1 + { + typedef PointMi Primitive; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, + RayHit& ray, + IntersectContext* context, + const Primitive& Disc) + { + STAT3(normal.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Vec3vf n0; + Disc.gather(v0, n0, geom); + const vbool valid = Disc.template valid(); + DiscIntersector1::intersect( + valid, ray, context, geom, pre, v0, n0, Intersect1EpilogM(ray, context, Disc.geomID(), Disc.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, + Ray& ray, + IntersectContext* context, + const Primitive& Disc) + { + STAT3(shadow.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Vec3vf n0; + Disc.gather(v0, n0, geom); + const vbool valid = Disc.template valid(); + return DiscIntersector1::intersect( + valid, ray, context, geom, pre, v0, n0, Occluded1EpilogM(ray, context, Disc.geomID(), Disc.primID())); + } + }; + + template + struct OrientedDiscMiMBIntersector1 + { + typedef PointMi Primitive; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, + RayHit& ray, + IntersectContext* context, + const Primitive& Disc) + { + STAT3(normal.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Vec3vf n0; + Disc.gather(v0, n0, geom, ray.time()); + const vbool valid = Disc.template valid(); + DiscIntersector1::intersect( + valid, ray, context, geom, pre, v0, n0, Intersect1EpilogM(ray, context, Disc.geomID(), Disc.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, + Ray& ray, + IntersectContext* context, + const Primitive& Disc) + { + STAT3(shadow.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Vec3vf n0; + Disc.gather(v0, n0, geom, ray.time()); + const vbool valid = Disc.template valid(); + return DiscIntersector1::intersect( + valid, ray, context, geom, pre, v0, n0, Occluded1EpilogM(ray, context, Disc.geomID(), Disc.primID())); + } + }; + + template + struct OrientedDiscMiIntersectorK + { + typedef PointMi Primitive; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline void intersect( + const Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& Disc) + { + STAT3(normal.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Vec3vf n0; + Disc.gather(v0, n0, geom); + const vbool valid = Disc.template valid(); + DiscIntersectorK::intersect( + valid, ray, k, context, geom, pre, v0, n0, + Intersect1KEpilogM(ray, k, context, Disc.geomID(), Disc.primID())); + } + + static __forceinline bool occluded( + const Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& Disc) + { + STAT3(shadow.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Vec3vf n0; + Disc.gather(v0, n0, geom); + const vbool valid = Disc.template valid(); + return DiscIntersectorK::intersect( + valid, ray, k, context, geom, pre, v0, n0, + Occluded1KEpilogM(ray, k, context, Disc.geomID(), Disc.primID())); + } + }; + + template + struct OrientedDiscMiMBIntersectorK + { + typedef PointMi Primitive; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline void intersect( + const Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& Disc) + { + STAT3(normal.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Vec3vf n0; + Disc.gather(v0, n0, geom, ray.time()[k]); + const vbool valid = Disc.template valid(); + DiscIntersectorK::intersect( + valid, ray, k, context, geom, pre, v0, n0, + Intersect1KEpilogM(ray, k, context, Disc.geomID(), Disc.primID())); + } + + static __forceinline bool occluded( + const Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& Disc) + { + STAT3(shadow.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(Disc.geomID()); + Vec4vf v0; Vec3vf n0; + Disc.gather(v0, n0, geom, ray.time()[k]); + const vbool valid = Disc.template valid(); + return DiscIntersectorK::intersect( + valid, ray, k, context, geom, pre, v0, n0, + Occluded1KEpilogM(ray, k, context, Disc.geomID(), Disc.primID())); + } + }; + } // namespace isa +} // namespace embree diff --git a/thirdparty/embree/kernels/geometry/filter.h b/thirdparty/embree/kernels/geometry/filter.h new file mode 100644 index 000000000000..4cdf7a395a49 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/filter.h @@ -0,0 +1,204 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/geometry.h" +#include "../common/ray.h" +#include "../common/hit.h" +#include "../common/context.h" + +namespace embree +{ + namespace isa + { + __forceinline bool runIntersectionFilter1Helper(RTCFilterFunctionNArguments* args, const Geometry* const geometry, IntersectContext* context) + { + if (geometry->intersectionFilterN) + { + assert(context->scene->hasGeometryFilterFunction()); + geometry->intersectionFilterN(args); + + if (args->valid[0] == 0) + return false; + } + + if (context->user->filter) { + assert(context->scene->hasContextFilterFunction()); + context->user->filter(args); + + if (args->valid[0] == 0) + return false; + } + + copyHitToRay(*(RayHit*)args->ray,*(Hit*)args->hit); + return true; + } + + __forceinline bool runIntersectionFilter1(const Geometry* const geometry, RayHit& ray, IntersectContext* context, Hit& hit) + { + RTCFilterFunctionNArguments args; + int mask = -1; + args.valid = &mask; + args.geometryUserPtr = geometry->userPtr; + args.context = context->user; + args.ray = (RTCRayN*)&ray; + args.hit = (RTCHitN*)&hit; + args.N = 1; + return runIntersectionFilter1Helper(&args,geometry,context); + } + + __forceinline void reportIntersection1(IntersectFunctionNArguments* args, const RTCFilterFunctionNArguments* filter_args) + { +#if defined(EMBREE_FILTER_FUNCTION) + IntersectContext* MAYBE_UNUSED context = args->internal_context; + const Geometry* const geometry = args->geometry; + if (geometry->intersectionFilterN) { + assert(context->scene->hasGeometryFilterFunction()); + geometry->intersectionFilterN(filter_args); + } + + //if (args->valid[0] == 0) + // return; + + if (context->user->filter) { + assert(context->scene->hasContextFilterFunction()); + context->user->filter(filter_args); + } +#endif + } + + __forceinline bool runOcclusionFilter1Helper(RTCFilterFunctionNArguments* args, const Geometry* const geometry, IntersectContext* context) + { + if (geometry->occlusionFilterN) + { + assert(context->scene->hasGeometryFilterFunction()); + geometry->occlusionFilterN(args); + + if (args->valid[0] == 0) + return false; + } + + if (context->user->filter) { + assert(context->scene->hasContextFilterFunction()); + context->user->filter(args); + + if (args->valid[0] == 0) + return false; + } + return true; + } + + __forceinline bool runOcclusionFilter1(const Geometry* const geometry, Ray& ray, IntersectContext* context, Hit& hit) + { + RTCFilterFunctionNArguments args; + int mask = -1; + args.valid = &mask; + args.geometryUserPtr = geometry->userPtr; + args.context = context->user; + args.ray = (RTCRayN*)&ray; + args.hit = (RTCHitN*)&hit; + args.N = 1; + return runOcclusionFilter1Helper(&args,geometry,context); + } + + __forceinline void reportOcclusion1(OccludedFunctionNArguments* args, const RTCFilterFunctionNArguments* filter_args) + { +#if defined(EMBREE_FILTER_FUNCTION) + IntersectContext* MAYBE_UNUSED context = args->internal_context; + const Geometry* const geometry = args->geometry; + if (geometry->occlusionFilterN) { + assert(context->scene->hasGeometryFilterFunction()); + geometry->occlusionFilterN(filter_args); + } + + //if (args->valid[0] == 0) + // return false; + + if (context->user->filter) { + assert(context->scene->hasContextFilterFunction()); + context->user->filter(filter_args); + } +#endif + } + + template + __forceinline vbool runIntersectionFilterHelper(RTCFilterFunctionNArguments* args, const Geometry* const geometry, IntersectContext* context) + { + vint* mask = (vint*) args->valid; + if (geometry->intersectionFilterN) + { + assert(context->scene->hasGeometryFilterFunction()); + geometry->intersectionFilterN(args); + } + + vbool valid_o = *mask != vint(zero); + if (none(valid_o)) return valid_o; + + if (context->user->filter) { + assert(context->scene->hasContextFilterFunction()); + context->user->filter(args); + } + + valid_o = *mask != vint(zero); + if (none(valid_o)) return valid_o; + + copyHitToRay(valid_o,*(RayHitK*)args->ray,*(HitK*)args->hit); + return valid_o; + } + + template + __forceinline vbool runIntersectionFilter(const vbool& valid, const Geometry* const geometry, RayHitK& ray, IntersectContext* context, HitK& hit) + { + RTCFilterFunctionNArguments args; + vint mask = valid.mask32(); + args.valid = (int*)&mask; + args.geometryUserPtr = geometry->userPtr; + args.context = context->user; + args.ray = (RTCRayN*)&ray; + args.hit = (RTCHitN*)&hit; + args.N = K; + return runIntersectionFilterHelper(&args,geometry,context); + } + + template + __forceinline vbool runOcclusionFilterHelper(RTCFilterFunctionNArguments* args, const Geometry* const geometry, IntersectContext* context) + { + vint* mask = (vint*) args->valid; + if (geometry->occlusionFilterN) + { + assert(context->scene->hasGeometryFilterFunction()); + geometry->occlusionFilterN(args); + } + + vbool valid_o = *mask != vint(zero); + + if (none(valid_o)) return valid_o; + + if (context->user->filter) { + assert(context->scene->hasContextFilterFunction()); + context->user->filter(args); + } + + valid_o = *mask != vint(zero); + + RayK* ray = (RayK*) args->ray; + ray->tfar = select(valid_o, vfloat(neg_inf), ray->tfar); + return valid_o; + } + + template + __forceinline vbool runOcclusionFilter(const vbool& valid, const Geometry* const geometry, RayK& ray, IntersectContext* context, HitK& hit) + { + RTCFilterFunctionNArguments args; + vint mask = valid.mask32(); + args.valid = (int*)&mask; + args.geometryUserPtr = geometry->userPtr; + args.context = context->user; + args.ray = (RTCRayN*)&ray; + args.hit = (RTCHitN*)&hit; + args.N = K; + return runOcclusionFilterHelper(&args,geometry,context); + } + } +} diff --git a/thirdparty/embree/kernels/geometry/grid_intersector.h b/thirdparty/embree/kernels/geometry/grid_intersector.h new file mode 100644 index 000000000000..46a0af08279c --- /dev/null +++ b/thirdparty/embree/kernels/geometry/grid_intersector.h @@ -0,0 +1,99 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "grid_soa.h" +#include "grid_soa_intersector1.h" +#include "grid_soa_intersector_packet.h" +#include "../common/ray.h" + +namespace embree +{ + namespace isa + { + template + class SubdivPatch1Precalculations : public T + { + public: + __forceinline SubdivPatch1Precalculations (const Ray& ray, const void* ptr) + : T(ray,ptr) {} + }; + + template + class SubdivPatch1PrecalculationsK : public T + { + public: + __forceinline SubdivPatch1PrecalculationsK (const vbool& valid, RayK& ray) + : T(valid,ray) {} + }; + + class Grid1Intersector1 + { + public: + typedef GridSOA Primitive; + typedef Grid1Precalculations Precalculations; + + /*! Intersect a ray with the primitive. */ + static __forceinline void intersect(Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive* prim, size_t ty, size_t& lazy_node) + { + GridSOAIntersector1::intersect(pre,ray,context,prim,lazy_node); + } + static __forceinline void intersect(Precalculations& pre, RayHit& ray, IntersectContext* context, size_t ty0, const Primitive* prim, size_t ty, size_t& lazy_node) { + intersect(pre,ray,context,prim,ty,lazy_node); + } + + /*! Test if the ray is occluded by the primitive */ + static __forceinline bool occluded(Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive* prim, size_t ty, size_t& lazy_node) + { + GridSOAIntersector1::occluded(pre,ray,context,prim,lazy_node); + } + static __forceinline bool occluded(Precalculations& pre, Ray& ray, IntersectContext* context, size_t ty0, const Primitive* prim, size_t ty, size_t& lazy_node) { + return occluded(pre,ray,context,prim,ty,lazy_node); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive* prim, size_t ty, size_t& lazy_node) { + assert(false && "not implemented"); + return false; + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, size_t ty0, const Primitive* prim, size_t ty, size_t& lazy_node) { + assert(false && "not implemented"); + return false; + } + }; + + template + struct GridIntersectorK + { + typedef GridSOA Primitive; + typedef SubdivPatch1PrecalculationsK::Precalculations> Precalculations; + + + static __forceinline void intersect(const vbool& valid, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive* prim, size_t ty, size_t& lazy_node) + { + GridSOAIntersectorK::intersect(valid,pre,ray,context,prim,lazy_node); + } + + static __forceinline vbool occluded(const vbool& valid, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive* prim, size_t ty, size_t& lazy_node) + { + GridSOAIntersectorK::occluded(valid,pre,ray,context,prim,lazy_node); + } + + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t ty, size_t& lazy_node) + { + GridSOAIntersectorK::intersect(pre,ray,k,context,prim,lazy_node); + } + + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t ty, size_t& lazy_node) + { + GridSOAIntersectorK::occluded(pre,ray,k,context,prim,lazy_node); + } + }; + + typedef Grid1IntersectorK<4> SubdivPatch1Intersector4; + typedef Grid1IntersectorK<8> SubdivPatch1Intersector8; + typedef Grid1IntersectorK<16> SubdivPatch1Intersector16; + + } +} diff --git a/thirdparty/embree/kernels/geometry/grid_soa.h b/thirdparty/embree/kernels/geometry/grid_soa.h new file mode 100644 index 000000000000..02edbbed5e1d --- /dev/null +++ b/thirdparty/embree/kernels/geometry/grid_soa.h @@ -0,0 +1,275 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" +#include "../common/scene_subdiv_mesh.h" +#include "../bvh/bvh.h" +#include "../subdiv/tessellation.h" +#include "../subdiv/tessellation_cache.h" +#include "subdivpatch1.h" + +namespace embree +{ + namespace isa + { + class GridSOA + { + public: + + /*! GridSOA constructor */ + GridSOA(const SubdivPatch1Base* patches, const unsigned time_steps, + const unsigned x0, const unsigned x1, const unsigned y0, const unsigned y1, const unsigned swidth, const unsigned sheight, + const SubdivMesh* const geom, const size_t totalBvhBytes, const size_t gridBytes, BBox3fa* bounds_o = nullptr); + + /*! Subgrid creation */ + template + static GridSOA* create(const SubdivPatch1Base* patches, const unsigned time_steps, + unsigned x0, unsigned x1, unsigned y0, unsigned y1, + const Scene* scene, Allocator& alloc, BBox3fa* bounds_o = nullptr) + { + const unsigned width = x1-x0+1; + const unsigned height = y1-y0+1; + const GridRange range(0,width-1,0,height-1); + size_t bvhBytes = 0; + if (time_steps == 1) + bvhBytes = getBVHBytes(range,sizeof(BVH4::AABBNode),0); + else { + bvhBytes = (time_steps-1)*getBVHBytes(range,sizeof(BVH4::AABBNodeMB),0); + bvhBytes += getTemporalBVHBytes(make_range(0,int(time_steps-1)),sizeof(BVH4::AABBNodeMB4D)); + } + const size_t gridBytes = 4*size_t(width)*size_t(height)*sizeof(float); + size_t rootBytes = time_steps*sizeof(BVH4::NodeRef); +#if !defined(__X86_64__) + rootBytes += 4; // We read 2 elements behind the grid. As we store at least 8 root bytes after the grid we are fine in 64 bit mode. But in 32 bit mode we have to do additional padding. +#endif + void* data = alloc(offsetof(GridSOA,data)+bvhBytes+time_steps*gridBytes+rootBytes); + assert(data); + return new (data) GridSOA(patches,time_steps,x0,x1,y0,y1,patches->grid_u_res,patches->grid_v_res,scene->get(patches->geomID()),bvhBytes,gridBytes,bounds_o); + } + + /*! Grid creation */ + template + static GridSOA* create(const SubdivPatch1Base* const patches, const unsigned time_steps, + const Scene* scene, const Allocator& alloc, BBox3fa* bounds_o = nullptr) + { + return create(patches,time_steps,0,patches->grid_u_res-1,0,patches->grid_v_res-1,scene,alloc,bounds_o); + } + + /*! returns reference to root */ + __forceinline BVH4::NodeRef& root(size_t t = 0) { return (BVH4::NodeRef&)data[rootOffset + t*sizeof(BVH4::NodeRef)]; } + __forceinline const BVH4::NodeRef& root(size_t t = 0) const { return (BVH4::NodeRef&)data[rootOffset + t*sizeof(BVH4::NodeRef)]; } + + /*! returns pointer to BVH array */ + __forceinline char* bvhData() { return &data[0]; } + __forceinline const char* bvhData() const { return &data[0]; } + + /*! returns pointer to Grid array */ + __forceinline float* gridData(size_t t = 0) { return (float*) &data[gridOffset + t*gridBytes]; } + __forceinline const float* gridData(size_t t = 0) const { return (float*) &data[gridOffset + t*gridBytes]; } + + __forceinline void* encodeLeaf(size_t u, size_t v) { + return (void*) (16*(v * width + u + 1)); // +1 to not create empty leaf + } + __forceinline float* decodeLeaf(size_t t, const void* ptr) { + return gridData(t) + (((size_t) (ptr) >> 4) - 1); + } + + /*! returns the size of the BVH over the grid in bytes */ + static size_t getBVHBytes(const GridRange& range, const size_t nodeBytes, const size_t leafBytes); + + /*! returns the size of the temporal BVH over the time range BVHs */ + static size_t getTemporalBVHBytes(const range time_range, const size_t nodeBytes); + + /*! calculates bounding box of grid range */ + __forceinline BBox3fa calculateBounds(size_t time, const GridRange& range) const + { + const float* const grid_array = gridData(time); + const float* const grid_x_array = grid_array + 0 * dim_offset; + const float* const grid_y_array = grid_array + 1 * dim_offset; + const float* const grid_z_array = grid_array + 2 * dim_offset; + + /* compute the bounds just for the range! */ + BBox3fa bounds( empty ); + for (unsigned v = range.v_start; v<=range.v_end; v++) + { + for (unsigned u = range.u_start; u<=range.u_end; u++) + { + const float x = grid_x_array[ v * width + u]; + const float y = grid_y_array[ v * width + u]; + const float z = grid_z_array[ v * width + u]; + bounds.extend( Vec3fa(x,y,z) ); + } + } + assert(is_finite(bounds)); + return bounds; + } + + /*! Evaluates grid over patch and builds BVH4 tree over the grid. */ + std::pair buildBVH(BBox3fa* bounds_o); + + /*! Create BVH4 tree over grid. */ + std::pair buildBVH(const GridRange& range, size_t& allocator); + + /*! Evaluates grid over patch and builds MSMBlur BVH4 tree over the grid. */ + std::pair buildMSMBlurBVH(const range time_range, BBox3fa* bounds_o); + + /*! Create MBlur BVH4 tree over grid. */ + std::pair buildMBlurBVH(size_t time, const GridRange& range, size_t& allocator); + + /*! Create MSMBlur BVH4 tree over grid. */ + std::pair buildMSMBlurBVH(const range time_range, size_t& allocator, BBox3fa* bounds_o); + + template + struct MapUV + { + typedef typename Loader::vfloat vfloat; + const float* const grid_uv; + size_t line_offset; + size_t lines; + + __forceinline MapUV(const float* const grid_uv, size_t line_offset, const size_t lines) + : grid_uv(grid_uv), line_offset(line_offset), lines(lines) {} + + __forceinline void operator() (vfloat& u, vfloat& v) const { + const Vec3 tri_v012_uv = Loader::gather(grid_uv,line_offset,lines); + const Vec2 uv0 = GridSOA::decodeUV(tri_v012_uv[0]); + const Vec2 uv1 = GridSOA::decodeUV(tri_v012_uv[1]); + const Vec2 uv2 = GridSOA::decodeUV(tri_v012_uv[2]); + const Vec2 uv = u * uv1 + v * uv2 + (1.0f-u-v) * uv0; + u = uv[0];v = uv[1]; + } + }; + + struct Gather2x3 + { + enum { M = 4 }; + typedef vbool4 vbool; + typedef vint4 vint; + typedef vfloat4 vfloat; + + static __forceinline const Vec3vf4 gather(const float* const grid, const size_t line_offset, const size_t lines) + { + vfloat4 r0 = vfloat4::loadu(grid + 0*line_offset); + vfloat4 r1 = vfloat4::loadu(grid + 1*line_offset); // this accesses 2 elements too much in case of 2x2 grid, but this is ok as we ensure enough padding after the grid + if (unlikely(line_offset == 2)) + { + r0 = shuffle<0,1,1,1>(r0); + r1 = shuffle<0,1,1,1>(r1); + } + return Vec3vf4(unpacklo(r0,r1), // r00, r10, r01, r11 + shuffle<1,1,2,2>(r0), // r01, r01, r02, r02 + shuffle<0,1,1,2>(r1)); // r10, r11, r11, r12 + } + + static __forceinline void gather(const float* const grid_x, + const float* const grid_y, + const float* const grid_z, + const size_t line_offset, + const size_t lines, + Vec3vf4& v0_o, + Vec3vf4& v1_o, + Vec3vf4& v2_o) + { + const Vec3vf4 tri_v012_x = gather(grid_x,line_offset,lines); + const Vec3vf4 tri_v012_y = gather(grid_y,line_offset,lines); + const Vec3vf4 tri_v012_z = gather(grid_z,line_offset,lines); + v0_o = Vec3vf4(tri_v012_x[0],tri_v012_y[0],tri_v012_z[0]); + v1_o = Vec3vf4(tri_v012_x[1],tri_v012_y[1],tri_v012_z[1]); + v2_o = Vec3vf4(tri_v012_x[2],tri_v012_y[2],tri_v012_z[2]); + } + }; + +#if defined (__AVX__) + struct Gather3x3 + { + enum { M = 8 }; + typedef vbool8 vbool; + typedef vint8 vint; + typedef vfloat8 vfloat; + + static __forceinline const Vec3vf8 gather(const float* const grid, const size_t line_offset, const size_t lines) + { + vfloat4 ra = vfloat4::loadu(grid + 0*line_offset); + vfloat4 rb = vfloat4::loadu(grid + 1*line_offset); // this accesses 2 elements too much in case of 2x2 grid, but this is ok as we ensure enough padding after the grid + vfloat4 rc; + if (likely(lines > 2)) + rc = vfloat4::loadu(grid + 2*line_offset); + else + rc = rb; + + if (unlikely(line_offset == 2)) + { + ra = shuffle<0,1,1,1>(ra); + rb = shuffle<0,1,1,1>(rb); + rc = shuffle<0,1,1,1>(rc); + } + + const vfloat8 r0 = vfloat8(ra,rb); + const vfloat8 r1 = vfloat8(rb,rc); + return Vec3vf8(unpacklo(r0,r1), // r00, r10, r01, r11, r10, r20, r11, r21 + shuffle<1,1,2,2>(r0), // r01, r01, r02, r02, r11, r11, r12, r12 + shuffle<0,1,1,2>(r1)); // r10, r11, r11, r12, r20, r21, r21, r22 + } + + static __forceinline void gather(const float* const grid_x, + const float* const grid_y, + const float* const grid_z, + const size_t line_offset, + const size_t lines, + Vec3vf8& v0_o, + Vec3vf8& v1_o, + Vec3vf8& v2_o) + { + const Vec3vf8 tri_v012_x = gather(grid_x,line_offset,lines); + const Vec3vf8 tri_v012_y = gather(grid_y,line_offset,lines); + const Vec3vf8 tri_v012_z = gather(grid_z,line_offset,lines); + v0_o = Vec3vf8(tri_v012_x[0],tri_v012_y[0],tri_v012_z[0]); + v1_o = Vec3vf8(tri_v012_x[1],tri_v012_y[1],tri_v012_z[1]); + v2_o = Vec3vf8(tri_v012_x[2],tri_v012_y[2],tri_v012_z[2]); + } + }; +#endif + + template + static __forceinline Vec2 decodeUV(const vfloat& uv) + { + typedef typename vfloat::Int vint; + const vint iu = asInt(uv) & 0xffff; + const vint iv = srl(asInt(uv),16); + const vfloat u = (vfloat)iu * vfloat(8.0f/0x10000); + const vfloat v = (vfloat)iv * vfloat(8.0f/0x10000); + return Vec2(u,v); + } + + __forceinline unsigned int geomID() const { + return _geomID; + } + + __forceinline unsigned int primID() const { + return _primID; + } + + public: + BVH4::NodeRef troot; +#if !defined(__X86_64__) + unsigned align1; +#endif + unsigned time_steps; + unsigned width; + + unsigned height; + unsigned dim_offset; + unsigned _geomID; + unsigned _primID; + + unsigned align2; + unsigned gridOffset; + unsigned gridBytes; + unsigned rootOffset; + + char data[1]; //!< after the struct we first store the BVH, then the grid, and finally the roots + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/grid_soa_intersector1.h b/thirdparty/embree/kernels/geometry/grid_soa_intersector1.h new file mode 100644 index 000000000000..2ed922a5aeb3 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/grid_soa_intersector1.h @@ -0,0 +1,207 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "grid_soa.h" +#include "../common/ray.h" +#include "triangle_intersector_pluecker.h" + +namespace embree +{ + namespace isa + { + class GridSOAIntersector1 + { + public: + typedef void Primitive; + + class Precalculations + { + public: + __forceinline Precalculations (const Ray& ray, const void* ptr) + : grid(nullptr) {} + + public: + GridSOA* grid; + int itime; + float ftime; + }; + + template + static __forceinline void intersect(RayHit& ray, + IntersectContext* context, + const float* const grid_x, + const size_t line_offset, + const size_t lines, + Precalculations& pre) + { + typedef typename Loader::vfloat vfloat; + const size_t dim_offset = pre.grid->dim_offset; + const float* const grid_y = grid_x + 1 * dim_offset; + const float* const grid_z = grid_x + 2 * dim_offset; + const float* const grid_uv = grid_x + 3 * dim_offset; + Vec3 v0, v1, v2; + Loader::gather(grid_x,grid_y,grid_z,line_offset,lines,v0,v1,v2); + GridSOA::MapUV mapUV(grid_uv,line_offset,lines); + PlueckerIntersector1 intersector(ray,nullptr); + intersector.intersect(ray,v0,v1,v2,mapUV,Intersect1EpilogMU(ray,context,pre.grid->geomID(),pre.grid->primID())); + }; + + template + static __forceinline bool occluded(Ray& ray, + IntersectContext* context, + const float* const grid_x, + const size_t line_offset, + const size_t lines, + Precalculations& pre) + { + typedef typename Loader::vfloat vfloat; + const size_t dim_offset = pre.grid->dim_offset; + const float* const grid_y = grid_x + 1 * dim_offset; + const float* const grid_z = grid_x + 2 * dim_offset; + const float* const grid_uv = grid_x + 3 * dim_offset; + + Vec3 v0, v1, v2; + Loader::gather(grid_x,grid_y,grid_z,line_offset,lines,v0,v1,v2); + + GridSOA::MapUV mapUV(grid_uv,line_offset,lines); + PlueckerIntersector1 intersector(ray,nullptr); + return intersector.intersect(ray,v0,v1,v2,mapUV,Occluded1EpilogMU(ray,context,pre.grid->geomID(),pre.grid->primID())); + } + + /*! Intersect a ray with the primitive. */ + static __forceinline void intersect(Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + const size_t line_offset = pre.grid->width; + const size_t lines = pre.grid->height; + const float* const grid_x = pre.grid->decodeLeaf(0,prim); + +#if defined(__AVX__) + intersect( ray, context, grid_x, line_offset, lines, pre); +#else + intersect(ray, context, grid_x , line_offset, lines, pre); + if (likely(lines > 2)) + intersect(ray, context, grid_x+line_offset, line_offset, lines, pre); +#endif + } + + /*! Test if the ray is occluded by the primitive */ + static __forceinline bool occluded(Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + const size_t line_offset = pre.grid->width; + const size_t lines = pre.grid->height; + const float* const grid_x = pre.grid->decodeLeaf(0,prim); + +#if defined(__AVX__) + return occluded( ray, context, grid_x, line_offset, lines, pre); +#else + if (occluded(ray, context, grid_x , line_offset, lines, pre)) return true; + if (likely(lines > 2)) + if (occluded(ray, context, grid_x+line_offset, line_offset, lines, pre)) return true; +#endif + return false; + } + }; + + class GridSOAMBIntersector1 + { + public: + typedef void Primitive; + typedef GridSOAIntersector1::Precalculations Precalculations; + + template + static __forceinline void intersect(RayHit& ray, const float ftime, + IntersectContext* context, + const float* const grid_x, + const size_t line_offset, + const size_t lines, + Precalculations& pre) + { + typedef typename Loader::vfloat vfloat; + const size_t dim_offset = pre.grid->dim_offset; + const size_t grid_offset = pre.grid->gridBytes >> 2; + const float* const grid_y = grid_x + 1 * dim_offset; + const float* const grid_z = grid_x + 2 * dim_offset; + const float* const grid_uv = grid_x + 3 * dim_offset; + + Vec3 a0, a1, a2; + Loader::gather(grid_x,grid_y,grid_z,line_offset,lines,a0,a1,a2); + + Vec3 b0, b1, b2; + Loader::gather(grid_x+grid_offset,grid_y+grid_offset,grid_z+grid_offset,line_offset,lines,b0,b1,b2); + + Vec3 v0 = lerp(a0,b0,vfloat(ftime)); + Vec3 v1 = lerp(a1,b1,vfloat(ftime)); + Vec3 v2 = lerp(a2,b2,vfloat(ftime)); + + GridSOA::MapUV mapUV(grid_uv,line_offset,lines); + PlueckerIntersector1 intersector(ray,nullptr); + intersector.intersect(ray,v0,v1,v2,mapUV,Intersect1EpilogMU(ray,context,pre.grid->geomID(),pre.grid->primID())); + }; + + template + static __forceinline bool occluded(Ray& ray, const float ftime, + IntersectContext* context, + const float* const grid_x, + const size_t line_offset, + const size_t lines, + Precalculations& pre) + { + typedef typename Loader::vfloat vfloat; + const size_t dim_offset = pre.grid->dim_offset; + const size_t grid_offset = pre.grid->gridBytes >> 2; + const float* const grid_y = grid_x + 1 * dim_offset; + const float* const grid_z = grid_x + 2 * dim_offset; + const float* const grid_uv = grid_x + 3 * dim_offset; + + Vec3 a0, a1, a2; + Loader::gather(grid_x,grid_y,grid_z,line_offset,lines,a0,a1,a2); + + Vec3 b0, b1, b2; + Loader::gather(grid_x+grid_offset,grid_y+grid_offset,grid_z+grid_offset,line_offset,lines,b0,b1,b2); + + Vec3 v0 = lerp(a0,b0,vfloat(ftime)); + Vec3 v1 = lerp(a1,b1,vfloat(ftime)); + Vec3 v2 = lerp(a2,b2,vfloat(ftime)); + + GridSOA::MapUV mapUV(grid_uv,line_offset,lines); + PlueckerIntersector1 intersector(ray,nullptr); + return intersector.intersect(ray,v0,v1,v2,mapUV,Occluded1EpilogMU(ray,context,pre.grid->geomID(),pre.grid->primID())); + } + + /*! Intersect a ray with the primitive. */ + static __forceinline void intersect(Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + const size_t line_offset = pre.grid->width; + const size_t lines = pre.grid->height; + const float* const grid_x = pre.grid->decodeLeaf(pre.itime,prim); + +#if defined(__AVX__) + intersect( ray, pre.ftime, context, grid_x, line_offset, lines, pre); +#else + intersect(ray, pre.ftime, context, grid_x, line_offset, lines, pre); + if (likely(lines > 2)) + intersect(ray, pre.ftime, context, grid_x+line_offset, line_offset, lines, pre); +#endif + } + + /*! Test if the ray is occluded by the primitive */ + static __forceinline bool occluded(Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + const size_t line_offset = pre.grid->width; + const size_t lines = pre.grid->height; + const float* const grid_x = pre.grid->decodeLeaf(pre.itime,prim); + +#if defined(__AVX__) + return occluded( ray, pre.ftime, context, grid_x, line_offset, lines, pre); +#else + if (occluded(ray, pre.ftime, context, grid_x , line_offset, lines, pre)) return true; + if (likely(lines > 2)) + if (occluded(ray, pre.ftime, context, grid_x+line_offset, line_offset, lines, pre)) return true; +#endif + return false; + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/grid_soa_intersector_packet.h b/thirdparty/embree/kernels/geometry/grid_soa_intersector_packet.h new file mode 100644 index 000000000000..41d66e1e285b --- /dev/null +++ b/thirdparty/embree/kernels/geometry/grid_soa_intersector_packet.h @@ -0,0 +1,445 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "grid_soa.h" +#include "../common/ray.h" +#include "triangle_intersector_pluecker.h" + +namespace embree +{ + namespace isa + { + template + struct MapUV0 + { + const float* const grid_uv; + size_t ofs00, ofs01, ofs10, ofs11; + + __forceinline MapUV0(const float* const grid_uv, size_t ofs00, size_t ofs01, size_t ofs10, size_t ofs11) + : grid_uv(grid_uv), ofs00(ofs00), ofs01(ofs01), ofs10(ofs10), ofs11(ofs11) {} + + __forceinline void operator() (vfloat& u, vfloat& v) const { + const vfloat uv00(grid_uv[ofs00]); + const vfloat uv01(grid_uv[ofs01]); + const vfloat uv10(grid_uv[ofs10]); + const vfloat uv11(grid_uv[ofs11]); + const Vec2vf uv0 = GridSOA::decodeUV(uv00); + const Vec2vf uv1 = GridSOA::decodeUV(uv01); + const Vec2vf uv2 = GridSOA::decodeUV(uv10); + const Vec2vf uv = madd(u,uv1,madd(v,uv2,(1.0f-u-v)*uv0)); + u = uv[0]; v = uv[1]; + } + }; + + template + struct MapUV1 + { + const float* const grid_uv; + size_t ofs00, ofs01, ofs10, ofs11; + + __forceinline MapUV1(const float* const grid_uv, size_t ofs00, size_t ofs01, size_t ofs10, size_t ofs11) + : grid_uv(grid_uv), ofs00(ofs00), ofs01(ofs01), ofs10(ofs10), ofs11(ofs11) {} + + __forceinline void operator() (vfloat& u, vfloat& v) const { + const vfloat uv00(grid_uv[ofs00]); + const vfloat uv01(grid_uv[ofs01]); + const vfloat uv10(grid_uv[ofs10]); + const vfloat uv11(grid_uv[ofs11]); + const Vec2vf uv0 = GridSOA::decodeUV(uv10); + const Vec2vf uv1 = GridSOA::decodeUV(uv01); + const Vec2vf uv2 = GridSOA::decodeUV(uv11); + const Vec2vf uv = madd(u,uv1,madd(v,uv2,(1.0f-u-v)*uv0)); + u = uv[0]; v = uv[1]; + } + }; + + template + class GridSOAIntersectorK + { + public: + typedef void Primitive; + + class Precalculations + { +#if defined(__AVX__) + static const int M = 8; +#else + static const int M = 4; +#endif + + public: + __forceinline Precalculations (const vbool& valid, const RayK& ray) + : grid(nullptr), intersector(valid,ray) {} + + public: + GridSOA* grid; + PlueckerIntersectorK intersector; // FIXME: use quad intersector + }; + + /*! Intersect a ray with the primitive. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + const size_t dim_offset = pre.grid->dim_offset; + const size_t line_offset = pre.grid->width; + const float* const grid_x = pre.grid->decodeLeaf(0,prim); + const float* const grid_y = grid_x + 1 * dim_offset; + const float* const grid_z = grid_x + 2 * dim_offset; + const float* const grid_uv = grid_x + 3 * dim_offset; + + const size_t max_x = pre.grid->width == 2 ? 1 : 2; + const size_t max_y = pre.grid->height == 2 ? 1 : 2; + for (size_t y=0; y p00(grid_x[ofs00],grid_y[ofs00],grid_z[ofs00]); + const Vec3vf p01(grid_x[ofs01],grid_y[ofs01],grid_z[ofs01]); + const Vec3vf p10(grid_x[ofs10],grid_y[ofs10],grid_z[ofs10]); + const Vec3vf p11(grid_x[ofs11],grid_y[ofs11],grid_z[ofs11]); + + pre.intersector.intersectK(valid_i,ray,p00,p01,p10,MapUV0(grid_uv,ofs00,ofs01,ofs10,ofs11),IntersectKEpilogMU<1,K,true>(ray,context,pre.grid->geomID(),pre.grid->primID())); + pre.intersector.intersectK(valid_i,ray,p10,p01,p11,MapUV1(grid_uv,ofs00,ofs01,ofs10,ofs11),IntersectKEpilogMU<1,K,true>(ray,context,pre.grid->geomID(),pre.grid->primID())); + } + } + } + + /*! Test if the ray is occluded by the primitive */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + const size_t dim_offset = pre.grid->dim_offset; + const size_t line_offset = pre.grid->width; + const float* const grid_x = pre.grid->decodeLeaf(0,prim); + const float* const grid_y = grid_x + 1 * dim_offset; + const float* const grid_z = grid_x + 2 * dim_offset; + const float* const grid_uv = grid_x + 3 * dim_offset; + + vbool valid = valid_i; + const size_t max_x = pre.grid->width == 2 ? 1 : 2; + const size_t max_y = pre.grid->height == 2 ? 1 : 2; + for (size_t y=0; y p00(grid_x[ofs00],grid_y[ofs00],grid_z[ofs00]); + const Vec3vf p01(grid_x[ofs01],grid_y[ofs01],grid_z[ofs01]); + const Vec3vf p10(grid_x[ofs10],grid_y[ofs10],grid_z[ofs10]); + const Vec3vf p11(grid_x[ofs11],grid_y[ofs11],grid_z[ofs11]); + + pre.intersector.intersectK(valid,ray,p00,p01,p10,MapUV0(grid_uv,ofs00,ofs01,ofs10,ofs11),OccludedKEpilogMU<1,K,true>(valid,ray,context,pre.grid->geomID(),pre.grid->primID())); + if (none(valid)) break; + pre.intersector.intersectK(valid,ray,p10,p01,p11,MapUV1(grid_uv,ofs00,ofs01,ofs10,ofs11),OccludedKEpilogMU<1,K,true>(valid,ray,context,pre.grid->geomID(),pre.grid->primID())); + if (none(valid)) break; + } + } + return !valid; + } + + template + static __forceinline void intersect(RayHitK& ray, size_t k, + IntersectContext* context, + const float* const grid_x, + const size_t line_offset, + const size_t lines, + Precalculations& pre) + { + typedef typename Loader::vfloat vfloat; + const size_t dim_offset = pre.grid->dim_offset; + const float* const grid_y = grid_x + 1 * dim_offset; + const float* const grid_z = grid_x + 2 * dim_offset; + const float* const grid_uv = grid_x + 3 * dim_offset; + Vec3 v0, v1, v2; Loader::gather(grid_x,grid_y,grid_z,line_offset,lines,v0,v1,v2); + pre.intersector.intersect(ray,k,v0,v1,v2,GridSOA::MapUV(grid_uv,line_offset,lines),Intersect1KEpilogMU(ray,k,context,pre.grid->geomID(),pre.grid->primID())); + }; + + template + static __forceinline bool occluded(RayK& ray, size_t k, + IntersectContext* context, + const float* const grid_x, + const size_t line_offset, + const size_t lines, + Precalculations& pre) + { + typedef typename Loader::vfloat vfloat; + const size_t dim_offset = pre.grid->dim_offset; + const float* const grid_y = grid_x + 1 * dim_offset; + const float* const grid_z = grid_x + 2 * dim_offset; + const float* const grid_uv = grid_x + 3 * dim_offset; + Vec3 v0, v1, v2; Loader::gather(grid_x,grid_y,grid_z,line_offset,lines,v0,v1,v2); + return pre.intersector.intersect(ray,k,v0,v1,v2,GridSOA::MapUV(grid_uv,line_offset,lines),Occluded1KEpilogMU(ray,k,context,pre.grid->geomID(),pre.grid->primID())); + } + + /*! Intersect a ray with the primitive. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + const size_t line_offset = pre.grid->width; + const size_t lines = pre.grid->height; + const float* const grid_x = pre.grid->decodeLeaf(0,prim); +#if defined(__AVX__) + intersect( ray, k, context, grid_x, line_offset, lines, pre); +#else + intersect(ray, k, context, grid_x , line_offset, lines, pre); + if (likely(lines > 2)) + intersect(ray, k, context, grid_x+line_offset, line_offset, lines, pre); +#endif + } + + /*! Test if the ray is occluded by the primitive */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + const size_t line_offset = pre.grid->width; + const size_t lines = pre.grid->height; + const float* const grid_x = pre.grid->decodeLeaf(0,prim); + +#if defined(__AVX__) + return occluded( ray, k, context, grid_x, line_offset, lines, pre); +#else + if (occluded(ray, k, context, grid_x , line_offset, lines, pre)) return true; + if (likely(lines > 2)) + if (occluded(ray, k, context, grid_x+line_offset, line_offset, lines, pre)) return true; +#endif + return false; + } + }; + + template + class GridSOAMBIntersectorK + { + public: + typedef void Primitive; + typedef typename GridSOAIntersectorK::Precalculations Precalculations; + + /*! Intersect a ray with the primitive. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + vfloat vftime; + vint vitime = getTimeSegment(ray.time(), vfloat((float)(pre.grid->time_steps-1)), vftime); + + vbool valid1 = valid_i; + while (any(valid1)) { + const size_t j = bsf(movemask(valid1)); + const int itime = vitime[j]; + const vbool valid2 = valid1 & (itime == vitime); + valid1 = valid1 & !valid2; + intersect(valid2,pre,ray,vftime,itime,context,prim,lazy_node); + } + } + + /*! Intersect a ray with the primitive. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, const vfloat& ftime, int itime, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + const size_t grid_offset = pre.grid->gridBytes >> 2; + const size_t dim_offset = pre.grid->dim_offset; + const size_t line_offset = pre.grid->width; + const float* const grid_x = pre.grid->decodeLeaf(itime,prim); + const float* const grid_y = grid_x + 1 * dim_offset; + const float* const grid_z = grid_x + 2 * dim_offset; + const float* const grid_uv = grid_x + 3 * dim_offset; + + const size_t max_x = pre.grid->width == 2 ? 1 : 2; + const size_t max_y = pre.grid->height == 2 ? 1 : 2; + for (size_t y=0; y a00(grid_x[ofs00],grid_y[ofs00],grid_z[ofs00]); + const Vec3vf a01(grid_x[ofs01],grid_y[ofs01],grid_z[ofs01]); + const Vec3vf a10(grid_x[ofs10],grid_y[ofs10],grid_z[ofs10]); + const Vec3vf a11(grid_x[ofs11],grid_y[ofs11],grid_z[ofs11]); + ofs00 += grid_offset; + ofs01 += grid_offset; + ofs10 += grid_offset; + ofs11 += grid_offset; + const Vec3vf b00(grid_x[ofs00],grid_y[ofs00],grid_z[ofs00]); + const Vec3vf b01(grid_x[ofs01],grid_y[ofs01],grid_z[ofs01]); + const Vec3vf b10(grid_x[ofs10],grid_y[ofs10],grid_z[ofs10]); + const Vec3vf b11(grid_x[ofs11],grid_y[ofs11],grid_z[ofs11]); + const Vec3vf p00 = lerp(a00,b00,ftime); + const Vec3vf p01 = lerp(a01,b01,ftime); + const Vec3vf p10 = lerp(a10,b10,ftime); + const Vec3vf p11 = lerp(a11,b11,ftime); + + pre.intersector.intersectK(valid_i,ray,p00,p01,p10,MapUV0(grid_uv,ofs00,ofs01,ofs10,ofs11),IntersectKEpilogMU<1,K,true>(ray,context,pre.grid->geomID(),pre.grid->primID())); + pre.intersector.intersectK(valid_i,ray,p10,p01,p11,MapUV1(grid_uv,ofs00,ofs01,ofs10,ofs11),IntersectKEpilogMU<1,K,true>(ray,context,pre.grid->geomID(),pre.grid->primID())); + } + } + } + + /*! Test if the ray is occluded by the primitive */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + vfloat vftime; + vint vitime = getTimeSegment(ray.time(), vfloat((float)(pre.grid->time_steps-1)), vftime); + + vbool valid_o = valid_i; + vbool valid1 = valid_i; + while (any(valid1)) { + const int j = int(bsf(movemask(valid1))); + const int itime = vitime[j]; + const vbool valid2 = valid1 & (itime == vitime); + valid1 = valid1 & !valid2; + valid_o &= !valid2 | occluded(valid2,pre,ray,vftime,itime,context,prim,lazy_node); + } + return !valid_o; + } + + /*! Test if the ray is occluded by the primitive */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, const vfloat& ftime, int itime, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + const size_t grid_offset = pre.grid->gridBytes >> 2; + const size_t dim_offset = pre.grid->dim_offset; + const size_t line_offset = pre.grid->width; + const float* const grid_x = pre.grid->decodeLeaf(itime,prim); + const float* const grid_y = grid_x + 1 * dim_offset; + const float* const grid_z = grid_x + 2 * dim_offset; + const float* const grid_uv = grid_x + 3 * dim_offset; + + vbool valid = valid_i; + const size_t max_x = pre.grid->width == 2 ? 1 : 2; + const size_t max_y = pre.grid->height == 2 ? 1 : 2; + for (size_t y=0; y a00(grid_x[ofs00],grid_y[ofs00],grid_z[ofs00]); + const Vec3vf a01(grid_x[ofs01],grid_y[ofs01],grid_z[ofs01]); + const Vec3vf a10(grid_x[ofs10],grid_y[ofs10],grid_z[ofs10]); + const Vec3vf a11(grid_x[ofs11],grid_y[ofs11],grid_z[ofs11]); + ofs00 += grid_offset; + ofs01 += grid_offset; + ofs10 += grid_offset; + ofs11 += grid_offset; + const Vec3vf b00(grid_x[ofs00],grid_y[ofs00],grid_z[ofs00]); + const Vec3vf b01(grid_x[ofs01],grid_y[ofs01],grid_z[ofs01]); + const Vec3vf b10(grid_x[ofs10],grid_y[ofs10],grid_z[ofs10]); + const Vec3vf b11(grid_x[ofs11],grid_y[ofs11],grid_z[ofs11]); + const Vec3vf p00 = lerp(a00,b00,ftime); + const Vec3vf p01 = lerp(a01,b01,ftime); + const Vec3vf p10 = lerp(a10,b10,ftime); + const Vec3vf p11 = lerp(a11,b11,ftime); + + pre.intersector.intersectK(valid,ray,p00,p01,p10,MapUV0(grid_uv,ofs00,ofs01,ofs10,ofs11),OccludedKEpilogMU<1,K,true>(valid,ray,context,pre.grid->geomID(),pre.grid->primID())); + if (none(valid)) break; + pre.intersector.intersectK(valid,ray,p10,p01,p11,MapUV1(grid_uv,ofs00,ofs01,ofs10,ofs11),OccludedKEpilogMU<1,K,true>(valid,ray,context,pre.grid->geomID(),pre.grid->primID())); + if (none(valid)) break; + } + } + return valid; + } + + template + static __forceinline void intersect(RayHitK& ray, size_t k, + const float ftime, + IntersectContext* context, + const float* const grid_x, + const size_t line_offset, + const size_t lines, + Precalculations& pre) + { + typedef typename Loader::vfloat vfloat; + const size_t grid_offset = pre.grid->gridBytes >> 2; + const size_t dim_offset = pre.grid->dim_offset; + const float* const grid_y = grid_x + 1 * dim_offset; + const float* const grid_z = grid_x + 2 * dim_offset; + const float* const grid_uv = grid_x + 3 * dim_offset; + + Vec3 a0, a1, a2; + Loader::gather(grid_x,grid_y,grid_z,line_offset,lines,a0,a1,a2); + + Vec3 b0, b1, b2; + Loader::gather(grid_x+grid_offset,grid_y+grid_offset,grid_z+grid_offset,line_offset,lines,b0,b1,b2); + + Vec3 v0 = lerp(a0,b0,vfloat(ftime)); + Vec3 v1 = lerp(a1,b1,vfloat(ftime)); + Vec3 v2 = lerp(a2,b2,vfloat(ftime)); + + pre.intersector.intersect(ray,k,v0,v1,v2,GridSOA::MapUV(grid_uv,line_offset,lines),Intersect1KEpilogMU(ray,k,context,pre.grid->geomID(),pre.grid->primID())); + }; + + template + static __forceinline bool occluded(RayK& ray, size_t k, + const float ftime, + IntersectContext* context, + const float* const grid_x, + const size_t line_offset, + const size_t lines, + Precalculations& pre) + { + typedef typename Loader::vfloat vfloat; + const size_t grid_offset = pre.grid->gridBytes >> 2; + const size_t dim_offset = pre.grid->dim_offset; + const float* const grid_y = grid_x + 1 * dim_offset; + const float* const grid_z = grid_x + 2 * dim_offset; + const float* const grid_uv = grid_x + 3 * dim_offset; + + Vec3 a0, a1, a2; + Loader::gather(grid_x,grid_y,grid_z,line_offset,lines,a0,a1,a2); + + Vec3 b0, b1, b2; + Loader::gather(grid_x+grid_offset,grid_y+grid_offset,grid_z+grid_offset,line_offset,lines,b0,b1,b2); + + Vec3 v0 = lerp(a0,b0,vfloat(ftime)); + Vec3 v1 = lerp(a1,b1,vfloat(ftime)); + Vec3 v2 = lerp(a2,b2,vfloat(ftime)); + + return pre.intersector.intersect(ray,k,v0,v1,v2,GridSOA::MapUV(grid_uv,line_offset,lines),Occluded1KEpilogMU(ray,k,context,pre.grid->geomID(),pre.grid->primID())); + } + + /*! Intersect a ray with the primitive. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + float ftime; + int itime = getTimeSegment(ray.time()[k], float(pre.grid->time_steps-1), ftime); + + const size_t line_offset = pre.grid->width; + const size_t lines = pre.grid->height; + const float* const grid_x = pre.grid->decodeLeaf(itime,prim); + +#if defined(__AVX__) + intersect( ray, k, ftime, context, grid_x, line_offset, lines, pre); +#else + intersect(ray, k, ftime, context, grid_x, line_offset, lines, pre); + if (likely(lines > 2)) + intersect(ray, k, ftime, context, grid_x+line_offset, line_offset, lines, pre); +#endif + } + + /*! Test if the ray is occluded by the primitive */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + float ftime; + int itime = getTimeSegment(ray.time()[k], float(pre.grid->time_steps-1), ftime); + + const size_t line_offset = pre.grid->width; + const size_t lines = pre.grid->height; + const float* const grid_x = pre.grid->decodeLeaf(itime,prim); + +#if defined(__AVX__) + return occluded( ray, k, ftime, context, grid_x, line_offset, lines, pre); +#else + if (occluded(ray, k, ftime, context, grid_x, line_offset, lines, pre)) return true; + if (likely(lines > 2)) + if (occluded(ray, k, ftime, context, grid_x+line_offset, line_offset, lines, pre)) return true; +#endif + return false; + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/instance.h b/thirdparty/embree/kernels/geometry/instance.h new file mode 100644 index 000000000000..66893d581f0b --- /dev/null +++ b/thirdparty/embree/kernels/geometry/instance.h @@ -0,0 +1,78 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" +#include "../common/scene_instance.h" + +namespace embree +{ + struct InstancePrimitive + { + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + /* primitive supports multiple time segments */ + static const bool singleTimeSegment = false; + + /* Returns maximum number of stored primitives */ + static __forceinline size_t max_size() { return 1; } + + /* Returns required number of primitive blocks for N primitives */ + static __forceinline size_t blocks(size_t N) { return N; } + + public: + + InstancePrimitive (const Instance* instance, unsigned int instID) + : instance(instance) + , instID_(instID) + {} + + __forceinline void fill(const PrimRef* prims, size_t& i, size_t end, Scene* scene) + { + assert(end-i == 1); + const PrimRef& prim = prims[i]; i++; + const unsigned int geomID = prim.geomID(); + const Instance* instance = scene->get(geomID); + new (this) InstancePrimitive(instance, geomID); + } + + __forceinline LBBox3fa fillMB(const PrimRef* prims, size_t& i, size_t end, Scene* scene, size_t itime) + { + assert(end-i == 1); + const PrimRef& prim = prims[i]; i++; + const unsigned int geomID = prim.geomID(); + const Instance* instance = scene->get(geomID); + new (this) InstancePrimitive(instance,geomID); + return instance->linearBounds(0,itime); + } + + __forceinline LBBox3fa fillMB(const PrimRefMB* prims, size_t& i, size_t end, Scene* scene, const BBox1f time_range) + { + assert(end-i == 1); + const PrimRefMB& prim = prims[i]; i++; + const unsigned int geomID = prim.geomID(); + const Instance* instance = scene->get(geomID); + new (this) InstancePrimitive(instance,geomID); + return instance->linearBounds(0,time_range); + } + + /* Updates the primitive */ + __forceinline BBox3fa update(Instance* instance) { + return instance->bounds(0); + } + + public: + const Instance* instance; + const unsigned int instID_ = std::numeric_limits::max (); + }; +} diff --git a/thirdparty/embree/kernels/geometry/instance_intersector.h b/thirdparty/embree/kernels/geometry/instance_intersector.h new file mode 100644 index 000000000000..91731a39c5a8 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/instance_intersector.h @@ -0,0 +1,84 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "instance.h" +#include "../common/ray.h" +#include "../common/point_query.h" + +namespace embree +{ + namespace isa + { + struct InstanceIntersector1 + { + typedef InstancePrimitive Primitive; + + struct Precalculations { + __forceinline Precalculations (const Ray& ray, const void *ptr) {} + }; + + static void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& prim); + static bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& prim); + static bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& prim); + }; + + struct InstanceIntersector1MB + { + typedef InstancePrimitive Primitive; + + struct Precalculations { + __forceinline Precalculations (const Ray& ray, const void *ptr) {} + }; + + static void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& prim); + static bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& prim); + static bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& prim); + }; + + template + struct InstanceIntersectorK + { + typedef InstancePrimitive Primitive; + + struct Precalculations { + __forceinline Precalculations (const vbool& valid, const RayK& ray) {} + }; + + static void intersect(const vbool& valid_i, const Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive& prim); + static vbool occluded(const vbool& valid_i, const Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive& prim); + + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& prim) { + intersect(vbool(1<& ray, size_t k, IntersectContext* context, const Primitive& prim) { + occluded(vbool(1< + struct InstanceIntersectorKMB + { + typedef InstancePrimitive Primitive; + + struct Precalculations { + __forceinline Precalculations (const vbool& valid, const RayK& ray) {} + }; + + static void intersect(const vbool& valid_i, const Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive& prim); + static vbool occluded(const vbool& valid_i, const Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive& prim); + + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& prim) { + intersect(vbool(1<& ray, size_t k, IntersectContext* context, const Primitive& prim) { + occluded(vbool(1< + struct UVIdentity { + __forceinline void operator() (vfloat& u, vfloat& v) const {} + }; + + + template + struct Intersect1Epilog1 + { + RayHit& ray; + IntersectContext* context; + const unsigned int geomID; + const unsigned int primID; + + __forceinline Intersect1Epilog1(RayHit& ray, + IntersectContext* context, + const unsigned int geomID, + const unsigned int primID) + : ray(ray), context(context), geomID(geomID), primID(primID) {} + + template + __forceinline bool operator() (Hit& hit) const + { + /* ray mask test */ + Scene* scene MAYBE_UNUSED = context->scene; + Geometry* geometry MAYBE_UNUSED = scene->get(geomID); +#if defined(EMBREE_RAY_MASK) + if ((geometry->mask & ray.mask) == 0) return false; +#endif + hit.finalize(); + + /* intersection filter test */ +#if defined(EMBREE_FILTER_FUNCTION) + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasIntersectionFilter())) { + HitK<1> h(context->user,geomID,primID,hit.u,hit.v,hit.Ng); + const float old_t = ray.tfar; + ray.tfar = hit.t; + bool found = runIntersectionFilter1(geometry,ray,context,h); + if (!found) ray.tfar = old_t; + return found; + } + } +#endif + + /* update hit information */ + ray.tfar = hit.t; + ray.Ng = hit.Ng; + ray.u = hit.u; + ray.v = hit.v; + ray.primID = primID; + ray.geomID = geomID; + instance_id_stack::copy(context->user->instID, ray.instID); + return true; + } + }; + + template + struct Occluded1Epilog1 + { + Ray& ray; + IntersectContext* context; + const unsigned int geomID; + const unsigned int primID; + + __forceinline Occluded1Epilog1(Ray& ray, + IntersectContext* context, + const unsigned int geomID, + const unsigned int primID) + : ray(ray), context(context), geomID(geomID), primID(primID) {} + + template + __forceinline bool operator() (Hit& hit) const + { + /* ray mask test */ + Scene* scene MAYBE_UNUSED = context->scene; + Geometry* geometry MAYBE_UNUSED = scene->get(geomID); + + +#if defined(EMBREE_RAY_MASK) + if ((geometry->mask & ray.mask) == 0) return false; +#endif + hit.finalize(); + + /* intersection filter test */ +#if defined(EMBREE_FILTER_FUNCTION) + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasOcclusionFilter())) { + HitK<1> h(context->user,geomID,primID,hit.u,hit.v,hit.Ng); + const float old_t = ray.tfar; + ray.tfar = hit.t; + const bool found = runOcclusionFilter1(geometry,ray,context,h); + if (!found) ray.tfar = old_t; + return found; + } + } +#endif + return true; + } + }; + + template + struct Intersect1KEpilog1 + { + RayHitK& ray; + size_t k; + IntersectContext* context; + const unsigned int geomID; + const unsigned int primID; + + __forceinline Intersect1KEpilog1(RayHitK& ray, size_t k, + IntersectContext* context, + const unsigned int geomID, + const unsigned int primID) + : ray(ray), k(k), context(context), geomID(geomID), primID(primID) {} + + template + __forceinline bool operator() (Hit& hit) const + { + /* ray mask test */ + Scene* scene MAYBE_UNUSED = context->scene; + Geometry* geometry MAYBE_UNUSED = scene->get(geomID); +#if defined(EMBREE_RAY_MASK) + if ((geometry->mask & ray.mask[k]) == 0) + return false; +#endif + hit.finalize(); + + /* intersection filter test */ +#if defined(EMBREE_FILTER_FUNCTION) + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasIntersectionFilter())) { + HitK h(context->user,geomID,primID,hit.u,hit.v,hit.Ng); + const float old_t = ray.tfar[k]; + ray.tfar[k] = hit.t; + const bool found = any(runIntersectionFilter(vbool(1<*, const size_t&>(context->user->instID, ray.instID, k); + return true; + } + }; + + template + struct Occluded1KEpilog1 + { + RayK& ray; + size_t k; + IntersectContext* context; + const unsigned int geomID; + const unsigned int primID; + + __forceinline Occluded1KEpilog1(RayK& ray, size_t k, + IntersectContext* context, + const unsigned int geomID, + const unsigned int primID) + : ray(ray), k(k), context(context), geomID(geomID), primID(primID) {} + + template + __forceinline bool operator() (Hit& hit) const + { + /* ray mask test */ + Scene* scene MAYBE_UNUSED = context->scene; + Geometry* geometry MAYBE_UNUSED = scene->get(geomID); +#if defined(EMBREE_RAY_MASK) + if ((geometry->mask & ray.mask[k]) == 0) + return false; +#endif + + /* intersection filter test */ +#if defined(EMBREE_FILTER_FUNCTION) + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasOcclusionFilter())) { + hit.finalize(); + HitK h(context->user,geomID,primID,hit.u,hit.v,hit.Ng); + const float old_t = ray.tfar[k]; + ray.tfar[k] = hit.t; + const bool found = any(runOcclusionFilter(vbool(1< + struct Intersect1EpilogM + { + RayHit& ray; + IntersectContext* context; + const vuint& geomIDs; + const vuint& primIDs; + + __forceinline Intersect1EpilogM(RayHit& ray, + IntersectContext* context, + const vuint& geomIDs, + const vuint& primIDs) + : ray(ray), context(context), geomIDs(geomIDs), primIDs(primIDs) {} + + template + __forceinline bool operator() (const vbool& valid_i, Hit& hit) const + { + Scene* scene MAYBE_UNUSED = context->scene; + vbool valid = valid_i; + if (Mx > M) valid &= (1<get(geomID); + +#if defined(EMBREE_RAY_MASK) + /* goto next hit if mask test fails */ + if ((geometry->mask & ray.mask) == 0) { + clear(valid,i); + continue; + } +#endif + +#if defined(EMBREE_FILTER_FUNCTION) + /* call intersection filter function */ + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasIntersectionFilter())) { + const Vec2f uv = hit.uv(i); + HitK<1> h(context->user,geomID,primIDs[i],uv.x,uv.y,hit.Ng(i)); + const float old_t = ray.tfar; + ray.tfar = hit.t(i); + const bool found = runIntersectionFilter1(geometry,ray,context,h); + if (!found) ray.tfar = old_t; + foundhit |= found; + clear(valid,i); + valid &= hit.vt <= ray.tfar; // intersection filters may modify tfar value + continue; + } + } +#endif + break; + } +#endif + + /* update hit information */ + const Vec2f uv = hit.uv(i); + ray.tfar = hit.vt[i]; + ray.Ng.x = hit.vNg.x[i]; + ray.Ng.y = hit.vNg.y[i]; + ray.Ng.z = hit.vNg.z[i]; + ray.u = uv.x; + ray.v = uv.y; + ray.primID = primIDs[i]; + ray.geomID = geomID; + instance_id_stack::copy(context->user->instID, ray.instID); + return true; + + } + }; + +#if 0 && defined(__AVX512F__) // do not enable, this reduced frequency for BVH4 + template + struct Intersect1EpilogM + { + static const size_t Mx = 16; + RayHit& ray; + IntersectContext* context; + const vuint& geomIDs; + const vuint& primIDs; + + __forceinline Intersect1EpilogM(RayHit& ray, + IntersectContext* context, + const vuint& geomIDs, + const vuint& primIDs) + : ray(ray), context(context), geomIDs(geomIDs), primIDs(primIDs) {} + + template + __forceinline bool operator() (const vbool& valid_i, Hit& hit) const + { + Scene* MAYBE_UNUSED scene = context->scene; + vbool valid = valid_i; + if (Mx > M) valid &= (1<get(geomID); + +#if defined(EMBREE_RAY_MASK) + /* goto next hit if mask test fails */ + if ((geometry->mask & ray.mask) == 0) { + clear(valid,i); + continue; + } +#endif + +#if defined(EMBREE_FILTER_FUNCTION) + /* call intersection filter function */ + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasIntersectionFilter())) { + const Vec2f uv = hit.uv(i); + HitK<1> h(context->user,geomID,primIDs[i],uv.x,uv.y,hit.Ng(i)); + const float old_t = ray.tfar; + ray.tfar = hit.t(i); + const bool found = runIntersectionFilter1(geometry,ray,context,h); + if (!found) ray.tfar = old_t; + foundhit |= found; + clear(valid,i); + valid &= hit.vt <= ray.tfar; // intersection filters may modify tfar value + continue; + } + } +#endif + break; + } +#endif + + vbool finalMask(((unsigned int)1 << i)); + ray.update(finalMask,hit.vt,hit.vu,hit.vv,hit.vNg.x,hit.vNg.y,hit.vNg.z,geomID,primIDs); + instance_id_stack::foreach([&](unsigned level) + { + ray.instID[level] = context->user->instID[level]; + return (context->user->instID[level] != RTC_INVALID_GEOMETRY_ID); + }); + return true; + + } + }; +#endif + + template + struct Occluded1EpilogM + { + Ray& ray; + IntersectContext* context; + const vuint& geomIDs; + const vuint& primIDs; + + __forceinline Occluded1EpilogM(Ray& ray, + IntersectContext* context, + const vuint& geomIDs, + const vuint& primIDs) + : ray(ray), context(context), geomIDs(geomIDs), primIDs(primIDs) {} + + template + __forceinline bool operator() (const vbool& valid_i, Hit& hit) const + { + Scene* scene MAYBE_UNUSED = context->scene; + /* intersection filter test */ +#if defined(EMBREE_FILTER_FUNCTION) || defined(EMBREE_RAY_MASK) + if (unlikely(filter)) + hit.finalize(); /* called only once */ + + vbool valid = valid_i; + if (Mx > M) valid &= (1<get(geomID); + +#if defined(EMBREE_RAY_MASK) + /* goto next hit if mask test fails */ + if ((geometry->mask & ray.mask) == 0) { + m=btc(m,i); + continue; + } +#endif + +#if defined(EMBREE_FILTER_FUNCTION) + /* if we have no filter then the test passed */ + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasOcclusionFilter())) + { + const Vec2f uv = hit.uv(i); + HitK<1> h(context->user,geomID,primIDs[i],uv.x,uv.y,hit.Ng(i)); + const float old_t = ray.tfar; + ray.tfar = hit.t(i); + if (runOcclusionFilter1(geometry,ray,context,h)) return true; + ray.tfar = old_t; + m=btc(m,i); + continue; + } + } +#endif + break; + } +#endif + + return true; + } + }; + + template + struct Intersect1EpilogMU + { + RayHit& ray; + IntersectContext* context; + const unsigned int geomID; + const unsigned int primID; + + __forceinline Intersect1EpilogMU(RayHit& ray, + IntersectContext* context, + const unsigned int geomID, + const unsigned int primID) + : ray(ray), context(context), geomID(geomID), primID(primID) {} + + template + __forceinline bool operator() (const vbool& valid_i, Hit& hit) const + { + /* ray mask test */ + Scene* scene MAYBE_UNUSED = context->scene; + Geometry* geometry MAYBE_UNUSED = scene->get(geomID); +#if defined(EMBREE_RAY_MASK) + if ((geometry->mask & ray.mask) == 0) return false; +#endif + + vbool valid = valid_i; + hit.finalize(); + + size_t i = select_min(valid,hit.vt); + + /* intersection filter test */ +#if defined(EMBREE_FILTER_FUNCTION) + if (unlikely(context->hasContextFilter() || geometry->hasIntersectionFilter())) + { + bool foundhit = false; + while (true) + { + /* call intersection filter function */ + Vec2f uv = hit.uv(i); + const float old_t = ray.tfar; + ray.tfar = hit.t(i); + HitK<1> h(context->user,geomID,primID,uv.x,uv.y,hit.Ng(i)); + const bool found = runIntersectionFilter1(geometry,ray,context,h); + if (!found) ray.tfar = old_t; + foundhit |= found; + clear(valid,i); + valid &= hit.vt <= ray.tfar; // intersection filters may modify tfar value + if (unlikely(none(valid))) break; + i = select_min(valid,hit.vt); + } + return foundhit; + } +#endif + + /* update hit information */ + const Vec2f uv = hit.uv(i); + const Vec3fa Ng = hit.Ng(i); + ray.tfar = hit.t(i); + ray.Ng.x = Ng.x; + ray.Ng.y = Ng.y; + ray.Ng.z = Ng.z; + ray.u = uv.x; + ray.v = uv.y; + ray.primID = primID; + ray.geomID = geomID; + instance_id_stack::copy(context->user->instID, ray.instID); + return true; + } + }; + + template + struct Occluded1EpilogMU + { + Ray& ray; + IntersectContext* context; + const unsigned int geomID; + const unsigned int primID; + + __forceinline Occluded1EpilogMU(Ray& ray, + IntersectContext* context, + const unsigned int geomID, + const unsigned int primID) + : ray(ray), context(context), geomID(geomID), primID(primID) {} + + template + __forceinline bool operator() (const vbool& valid, Hit& hit) const + { + /* ray mask test */ + Scene* scene MAYBE_UNUSED = context->scene; + Geometry* geometry MAYBE_UNUSED = scene->get(geomID); +#if defined(EMBREE_RAY_MASK) + if ((geometry->mask & ray.mask) == 0) return false; +#endif + + /* intersection filter test */ +#if defined(EMBREE_FILTER_FUNCTION) + if (unlikely(context->hasContextFilter() || geometry->hasOcclusionFilter())) + { + hit.finalize(); + for (size_t m=movemask(valid), i=bsf(m); m!=0; m=btc(m,i), i=bsf(m)) + { + const Vec2f uv = hit.uv(i); + const float old_t = ray.tfar; + ray.tfar = hit.t(i); + HitK<1> h(context->user,geomID,primID,uv.x,uv.y,hit.Ng(i)); + if (runOcclusionFilter1(geometry,ray,context,h)) return true; + ray.tfar = old_t; + } + return false; + } +#endif + return true; + } + }; + + template + struct IntersectKEpilogM + { + RayHitK& ray; + IntersectContext* context; + const vuint& geomIDs; + const vuint& primIDs; + const size_t i; + + __forceinline IntersectKEpilogM(RayHitK& ray, + IntersectContext* context, + const vuint& geomIDs, + const vuint& primIDs, + size_t i) + : ray(ray), context(context), geomIDs(geomIDs), primIDs(primIDs), i(i) {} + + template + __forceinline vbool operator() (const vbool& valid_i, const Hit& hit) const + { + Scene* scene MAYBE_UNUSED = context->scene; + + vfloat u, v, t; + Vec3vf Ng; + vbool valid = valid_i; + + std::tie(u,v,t,Ng) = hit(); + + const unsigned int geomID = geomIDs[i]; + const unsigned int primID = primIDs[i]; + Geometry* geometry MAYBE_UNUSED = scene->get(geomID); + + /* ray masking test */ +#if defined(EMBREE_RAY_MASK) + valid &= (geometry->mask & ray.mask) != 0; + if (unlikely(none(valid))) return false; +#endif + + /* occlusion filter test */ +#if defined(EMBREE_FILTER_FUNCTION) + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasIntersectionFilter())) { + HitK h(context->user,geomID,primID,u,v,Ng); + const vfloat old_t = ray.tfar; + ray.tfar = select(valid,t,ray.tfar); + const vbool m_accept = runIntersectionFilter(valid,geometry,ray,context,h); + ray.tfar = select(m_accept,ray.tfar,old_t); + return m_accept; + } + } +#endif + + /* update hit information */ + vfloat::store(valid,&ray.tfar,t); + vfloat::store(valid,&ray.Ng.x,Ng.x); + vfloat::store(valid,&ray.Ng.y,Ng.y); + vfloat::store(valid,&ray.Ng.z,Ng.z); + vfloat::store(valid,&ray.u,u); + vfloat::store(valid,&ray.v,v); + vuint::store(valid,&ray.primID,primID); + vuint::store(valid,&ray.geomID,geomID); + instance_id_stack::copy*, const vbool&>(context->user->instID, ray.instID, valid); + return valid; + } + }; + + template + struct OccludedKEpilogM + { + vbool& valid0; + RayK& ray; + IntersectContext* context; + const vuint& geomIDs; + const vuint& primIDs; + const size_t i; + + __forceinline OccludedKEpilogM(vbool& valid0, + RayK& ray, + IntersectContext* context, + const vuint& geomIDs, + const vuint& primIDs, + size_t i) + : valid0(valid0), ray(ray), context(context), geomIDs(geomIDs), primIDs(primIDs), i(i) {} + + template + __forceinline vbool operator() (const vbool& valid_i, const Hit& hit) const + { + vbool valid = valid_i; + + /* ray masking test */ + Scene* scene MAYBE_UNUSED = context->scene; + const unsigned int geomID = geomIDs[i]; + const unsigned int primID = primIDs[i]; + Geometry* geometry MAYBE_UNUSED = scene->get(geomID); +#if defined(EMBREE_RAY_MASK) + valid &= (geometry->mask & ray.mask) != 0; + if (unlikely(none(valid))) return valid; +#endif + + /* intersection filter test */ +#if defined(EMBREE_FILTER_FUNCTION) + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasOcclusionFilter())) + { + vfloat u, v, t; + Vec3vf Ng; + std::tie(u,v,t,Ng) = hit(); + HitK h(context->user,geomID,primID,u,v,Ng); + const vfloat old_t = ray.tfar; + ray.tfar = select(valid,t,ray.tfar); + valid = runOcclusionFilter(valid,geometry,ray,context,h); + ray.tfar = select(valid,ray.tfar,old_t); + } + } +#endif + + /* update occlusion */ + valid0 = valid0 & !valid; + return valid; + } + }; + + template + struct IntersectKEpilogMU + { + RayHitK& ray; + IntersectContext* context; + const unsigned int geomID; + const unsigned int primID; + + __forceinline IntersectKEpilogMU(RayHitK& ray, + IntersectContext* context, + const unsigned int geomID, + const unsigned int primID) + : ray(ray), context(context), geomID(geomID), primID(primID) {} + + template + __forceinline vbool operator() (const vbool& valid_org, const Hit& hit) const + { + vbool valid = valid_org; + vfloat u, v, t; + Vec3vf Ng; + std::tie(u,v,t,Ng) = hit(); + + Scene* scene MAYBE_UNUSED = context->scene; + Geometry* geometry MAYBE_UNUSED = scene->get(geomID); + + /* ray masking test */ +#if defined(EMBREE_RAY_MASK) + valid &= (geometry->mask & ray.mask) != 0; + if (unlikely(none(valid))) return false; +#endif + + /* intersection filter test */ +#if defined(EMBREE_FILTER_FUNCTION) + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasIntersectionFilter())) { + HitK h(context->user,geomID,primID,u,v,Ng); + const vfloat old_t = ray.tfar; + ray.tfar = select(valid,t,ray.tfar); + const vbool m_accept = runIntersectionFilter(valid,geometry,ray,context,h); + ray.tfar = select(m_accept,ray.tfar,old_t); + return m_accept; + } + } +#endif + + /* update hit information */ + vfloat::store(valid,&ray.tfar,t); + vfloat::store(valid,&ray.Ng.x,Ng.x); + vfloat::store(valid,&ray.Ng.y,Ng.y); + vfloat::store(valid,&ray.Ng.z,Ng.z); + vfloat::store(valid,&ray.u,u); + vfloat::store(valid,&ray.v,v); + vuint::store(valid,&ray.primID,primID); + vuint::store(valid,&ray.geomID,geomID); + instance_id_stack::copy*, const vbool&>(context->user->instID, ray.instID, valid); + + return valid; + } + }; + + template + struct OccludedKEpilogMU + { + vbool& valid0; + RayK& ray; + IntersectContext* context; + const unsigned int geomID; + const unsigned int primID; + + __forceinline OccludedKEpilogMU(vbool& valid0, + RayK& ray, + IntersectContext* context, + const unsigned int geomID, + const unsigned int primID) + : valid0(valid0), ray(ray), context(context), geomID(geomID), primID(primID) {} + + template + __forceinline vbool operator() (const vbool& valid_i, const Hit& hit) const + { + vbool valid = valid_i; + Scene* scene MAYBE_UNUSED = context->scene; + Geometry* geometry MAYBE_UNUSED = scene->get(geomID); + +#if defined(EMBREE_RAY_MASK) + valid &= (geometry->mask & ray.mask) != 0; + if (unlikely(none(valid))) return false; +#endif + + /* occlusion filter test */ +#if defined(EMBREE_FILTER_FUNCTION) + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasOcclusionFilter())) + { + vfloat u, v, t; + Vec3vf Ng; + std::tie(u,v,t,Ng) = hit(); + HitK h(context->user,geomID,primID,u,v,Ng); + const vfloat old_t = ray.tfar; + ray.tfar = select(valid,t,ray.tfar); + valid = runOcclusionFilter(valid,geometry,ray,context,h); + ray.tfar = select(valid,ray.tfar,old_t); + } + } +#endif + + /* update occlusion */ + valid0 = valid0 & !valid; + return valid; + } + }; + + template + struct Intersect1KEpilogM + { + RayHitK& ray; + size_t k; + IntersectContext* context; + const vuint& geomIDs; + const vuint& primIDs; + + __forceinline Intersect1KEpilogM(RayHitK& ray, size_t k, + IntersectContext* context, + const vuint& geomIDs, + const vuint& primIDs) + : ray(ray), k(k), context(context), geomIDs(geomIDs), primIDs(primIDs) {} + + template + __forceinline bool operator() (const vbool& valid_i, Hit& hit) const + { + Scene* scene MAYBE_UNUSED = context->scene; + vbool valid = valid_i; + hit.finalize(); + if (Mx > M) valid &= (1<get(geomID); + +#if defined(EMBREE_RAY_MASK) + /* goto next hit if mask test fails */ + if ((geometry->mask & ray.mask[k]) == 0) { + clear(valid,i); + continue; + } +#endif + +#if defined(EMBREE_FILTER_FUNCTION) + /* call intersection filter function */ + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasIntersectionFilter())) { + assert(i h(context->user,geomID,primIDs[i],uv.x,uv.y,hit.Ng(i)); + const float old_t = ray.tfar[k]; + ray.tfar[k] = hit.t(i); + const bool found = any(runIntersectionFilter(vbool(1<(hit.vNg.x),vfloat(hit.vNg.y),vfloat(hit.vNg.z),geomID,vuint(primIDs)); +#else + const Vec2f uv = hit.uv(i); + ray.tfar[k] = hit.t(i); + ray.Ng.x[k] = hit.vNg.x[i]; + ray.Ng.y[k] = hit.vNg.y[i]; + ray.Ng.z[k] = hit.vNg.z[i]; + ray.u[k] = uv.x; + ray.v[k] = uv.y; + ray.primID[k] = primIDs[i]; + ray.geomID[k] = geomID; + instance_id_stack::copy*, const size_t&>(context->user->instID, ray.instID, k); +#endif + return true; + } + }; + + template + struct Occluded1KEpilogM + { + RayK& ray; + size_t k; + IntersectContext* context; + const vuint& geomIDs; + const vuint& primIDs; + + __forceinline Occluded1KEpilogM(RayK& ray, size_t k, + IntersectContext* context, + const vuint& geomIDs, + const vuint& primIDs) + : ray(ray), k(k), context(context), geomIDs(geomIDs), primIDs(primIDs) {} + + template + __forceinline bool operator() (const vbool& valid_i, Hit& hit) const + { + Scene* scene MAYBE_UNUSED = context->scene; + + /* intersection filter test */ +#if defined(EMBREE_FILTER_FUNCTION) || defined(EMBREE_RAY_MASK) + if (unlikely(filter)) + hit.finalize(); /* called only once */ + + vbool valid = valid_i; + if (Mx > M) valid &= (1<get(geomID); + +#if defined(EMBREE_RAY_MASK) + /* goto next hit if mask test fails */ + if ((geometry->mask & ray.mask[k]) == 0) { + m=btc(m,i); + continue; + } +#endif + +#if defined(EMBREE_FILTER_FUNCTION) + /* execute occlusion filer */ + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasOcclusionFilter())) + { + const Vec2f uv = hit.uv(i); + const float old_t = ray.tfar[k]; + ray.tfar[k] = hit.t(i); + HitK h(context->user,geomID,primIDs[i],uv.x,uv.y,hit.Ng(i)); + if (any(runOcclusionFilter(vbool(1< + struct Intersect1KEpilogMU + { + RayHitK& ray; + size_t k; + IntersectContext* context; + const unsigned int geomID; + const unsigned int primID; + + __forceinline Intersect1KEpilogMU(RayHitK& ray, size_t k, + IntersectContext* context, + const unsigned int geomID, + const unsigned int primID) + : ray(ray), k(k), context(context), geomID(geomID), primID(primID) {} + + template + __forceinline bool operator() (const vbool& valid_i, Hit& hit) const + { + Scene* scene MAYBE_UNUSED = context->scene; + Geometry* geometry MAYBE_UNUSED = scene->get(geomID); +#if defined(EMBREE_RAY_MASK) + /* ray mask test */ + if ((geometry->mask & ray.mask[k]) == 0) + return false; +#endif + + /* finalize hit calculation */ + vbool valid = valid_i; + hit.finalize(); + size_t i = select_min(valid,hit.vt); + + /* intersection filter test */ +#if defined(EMBREE_FILTER_FUNCTION) + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasIntersectionFilter())) + { + bool foundhit = false; + while (true) + { + const Vec2f uv = hit.uv(i); + const float old_t = ray.tfar[k]; + ray.tfar[k] = hit.t(i); + HitK h(context->user,geomID,primID,uv.x,uv.y,hit.Ng(i)); + const bool found = any(runIntersectionFilter(vbool(1<(Ng.x),vfloat(Ng.y),vfloat(Ng.z),geomID,vuint(primID)); +#else + const Vec2f uv = hit.uv(i); + const Vec3fa Ng = hit.Ng(i); + ray.tfar[k] = hit.t(i); + ray.Ng.x[k] = Ng.x; + ray.Ng.y[k] = Ng.y; + ray.Ng.z[k] = Ng.z; + ray.u[k] = uv.x; + ray.v[k] = uv.y; + ray.primID[k] = primID; + ray.geomID[k] = geomID; + instance_id_stack::copy*, const size_t&>(context->user->instID, ray.instID, k); +#endif + return true; + } + }; + + template + struct Occluded1KEpilogMU + { + RayK& ray; + size_t k; + IntersectContext* context; + const unsigned int geomID; + const unsigned int primID; + + __forceinline Occluded1KEpilogMU(RayK& ray, size_t k, + IntersectContext* context, + const unsigned int geomID, + const unsigned int primID) + : ray(ray), k(k), context(context), geomID(geomID), primID(primID) {} + + template + __forceinline bool operator() (const vbool& valid_i, Hit& hit) const + { + Scene* scene MAYBE_UNUSED = context->scene; + Geometry* geometry MAYBE_UNUSED = scene->get(geomID); +#if defined(EMBREE_RAY_MASK) + /* ray mask test */ + if ((geometry->mask & ray.mask[k]) == 0) + return false; +#endif + + /* intersection filter test */ +#if defined(EMBREE_FILTER_FUNCTION) + if (filter) { + if (unlikely(context->hasContextFilter() || geometry->hasOcclusionFilter())) + { + hit.finalize(); + for (size_t m=movemask(valid_i), i=bsf(m); m!=0; m=btc(m,i), i=bsf(m)) + { + const Vec2f uv = hit.uv(i); + const float old_t = ray.tfar[k]; + ray.tfar[k] = hit.t(i); + HitK h(context->user,geomID,primID,uv.x,uv.y,hit.Ng(i)); + if (any(runOcclusionFilter(vbool(1< + struct ArrayIntersector1 + { + typedef typename Intersector::Primitive Primitive; + typedef typename Intersector::Precalculations Precalculations; + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + for (size_t i=0; i + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + for (size_t i=0; i + static __forceinline bool pointQuery(const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context, const Primitive* prim, size_t num, const TravPointQuery &tquery, size_t& lazy_node) + { + bool changed = false; + for (size_t i=0; i + static __forceinline void intersectK(const vbool& valid, /* PrecalculationsK& pre, */ RayHitK& ray, IntersectContext* context, const Primitive* prim, size_t num, size_t& lazy_node) + { + } + + template + static __forceinline vbool occludedK(const vbool& valid, /* PrecalculationsK& pre, */ RayK& ray, IntersectContext* context, const Primitive* prim, size_t num, size_t& lazy_node) + { + return valid; + } + }; + + template + struct ArrayIntersectorK_1 + { + typedef typename Intersector::Primitive Primitive; + typedef typename Intersector::Precalculations Precalculations; + + template + static __forceinline void intersect(const vbool& valid, const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRayK &tray, size_t& lazy_node) + { + for (size_t i=0; i + static __forceinline vbool occluded(const vbool& valid, const Accel::Intersectors* This, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRayK &tray, size_t& lazy_node) + { + vbool valid0 = valid; + for (size_t i=0; i + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + for (size_t i=0; i + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + for (size_t i=0; i + struct ArrayIntersectorKStream + { + typedef typename IntersectorK::Primitive PrimitiveK; + typedef typename IntersectorK::Precalculations PrecalculationsK; + + static __forceinline void intersectK(const vbool& valid, const Accel::Intersectors* This, /* PrecalculationsK& pre, */ RayHitK& ray, IntersectContext* context, const PrimitiveK* prim, size_t num, size_t& lazy_node) + { + PrecalculationsK pre(valid,ray); // FIXME: might cause trouble + + for (size_t i=0; i occludedK(const vbool& valid, const Accel::Intersectors* This, /* PrecalculationsK& pre, */ RayK& ray, IntersectContext* context, const PrimitiveK* prim, size_t num, size_t& lazy_node) + { + PrecalculationsK pre(valid,ray); // FIXME: might cause trouble + vbool valid0 = valid; + for (size_t i=0; i& ray, size_t k, IntersectContext* context, const PrimitiveK* prim, size_t num, size_t& lazy_node) + { + PrecalculationsK pre(ray.tnear() <= ray.tfar,ray); // FIXME: might cause trouble + for (size_t i=0; i& ray, size_t k, IntersectContext* context, const PrimitiveK* prim, size_t num, size_t& lazy_node) + { + PrecalculationsK pre(ray.tnear() <= ray.tfar,ray); // FIXME: might cause trouble + for (size_t i=0; i** __restrict__ inputPackets, IntersectContext* context, const PrimitiveK* prim, size_t num, size_t& lazy_node) + { + size_t m_occluded = 0; + for (size_t i=0; i &ray = *inputPackets[rayID / K]; + const size_t k = rayID % K; + PrecalculationsK pre(ray.tnear() <= ray.tfar,ray); // FIXME: might cause trouble + if (IntersectorK::occluded(pre,ray,k,context,prim[i])) + { + m_occluded |= (size_t)1 << rayID; + ray.tfar[k] = neg_inf; + } + } + } + return m_occluded; + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/line_intersector.h b/thirdparty/embree/kernels/geometry/line_intersector.h new file mode 100644 index 000000000000..eef5b0b1fd1c --- /dev/null +++ b/thirdparty/embree/kernels/geometry/line_intersector.h @@ -0,0 +1,141 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" +#include "curve_intersector_precalculations.h" + +namespace embree +{ + namespace isa + { + template + struct LineIntersectorHitM + { + __forceinline LineIntersectorHitM() {} + + __forceinline LineIntersectorHitM(const vfloat& u, const vfloat& v, const vfloat& t, const Vec3vf& Ng) + : vu(u), vv(v), vt(t), vNg(Ng) {} + + __forceinline void finalize() {} + + __forceinline Vec2f uv (const size_t i) const { return Vec2f(vu[i],vv[i]); } + __forceinline float t (const size_t i) const { return vt[i]; } + __forceinline Vec3fa Ng(const size_t i) const { return Vec3fa(vNg.x[i],vNg.y[i],vNg.z[i]); } + + public: + vfloat vu; + vfloat vv; + vfloat vt; + Vec3vf vNg; + }; + + template + struct FlatLinearCurveIntersector1 + { + typedef CurvePrecalculations1 Precalculations; + + template + static __forceinline bool intersect(const vbool& valid_i, + Ray& ray, + IntersectContext* context, + const LineSegments* geom, + const Precalculations& pre, + const Vec4vf& v0i, const Vec4vf& v1i, + const Epilog& epilog) + { + /* transform end points into ray space */ + vbool valid = valid_i; + vfloat depth_scale = pre.depth_scale; + LinearSpace3> ray_space = pre.ray_space; + + const Vec3vf ray_org ((Vec3fa)ray.org); + const Vec4vf v0 = enlargeRadiusToMinWidth(context,geom,ray_org,v0i); + const Vec4vf v1 = enlargeRadiusToMinWidth(context,geom,ray_org,v1i); + + Vec4vf p0(xfmVector(ray_space,v0.xyz()-ray_org), v0.w); + Vec4vf p1(xfmVector(ray_space,v1.xyz()-ray_org), v1.w); + + /* approximative intersection with cone */ + const Vec4vf v = p1-p0; + const Vec4vf w = -p0; + const vfloat d0 = madd(w.x,v.x,w.y*v.y); + const vfloat d1 = madd(v.x,v.x,v.y*v.y); + const vfloat u = clamp(d0*rcp(d1),vfloat(zero),vfloat(one)); + const Vec4vf p = madd(u,v,p0); + const vfloat t = p.z; + const vfloat d2 = madd(p.x,p.x,p.y*p.y); + const vfloat r = p.w; + const vfloat r2 = r*r; + valid &= (d2 <= r2) & (vfloat(ray.tnear()) <= t) & (t <= vfloat(ray.tfar)); + if (EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR != 0.0f) + valid &= t > float(EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR)*r*depth_scale; // ignore self intersections + if (unlikely(none(valid))) return false; + + /* ignore denormalized segments */ + const Vec3vf T = v1.xyz()-v0.xyz(); + valid &= (T.x != vfloat(zero)) | (T.y != vfloat(zero)) | (T.z != vfloat(zero)); + if (unlikely(none(valid))) return false; + + /* update hit information */ + LineIntersectorHitM hit(u,zero,t,T); + return epilog(valid,hit); + } + }; + + template + struct FlatLinearCurveIntersectorK + { + typedef CurvePrecalculationsK Precalculations; + + template + static __forceinline bool intersect(const vbool& valid_i, + RayK& ray, size_t k, + IntersectContext* context, + const LineSegments* geom, + const Precalculations& pre, + const Vec4vf& v0i, const Vec4vf& v1i, + const Epilog& epilog) + { + /* transform end points into ray space */ + vbool valid = valid_i; + vfloat depth_scale = pre.depth_scale[k]; + LinearSpace3> ray_space = pre.ray_space[k]; + const Vec3vf ray_org(ray.org.x[k],ray.org.y[k],ray.org.z[k]); + const Vec3vf ray_dir(ray.dir.x[k],ray.dir.y[k],ray.dir.z[k]); + + const Vec4vf v0 = enlargeRadiusToMinWidth(context,geom,ray_org,v0i); + const Vec4vf v1 = enlargeRadiusToMinWidth(context,geom,ray_org,v1i); + + Vec4vf p0(xfmVector(ray_space,v0.xyz()-ray_org), v0.w); + Vec4vf p1(xfmVector(ray_space,v1.xyz()-ray_org), v1.w); + + /* approximative intersection with cone */ + const Vec4vf v = p1-p0; + const Vec4vf w = -p0; + const vfloat d0 = madd(w.x,v.x,w.y*v.y); + const vfloat d1 = madd(v.x,v.x,v.y*v.y); + const vfloat u = clamp(d0*rcp(d1),vfloat(zero),vfloat(one)); + const Vec4vf p = madd(u,v,p0); + const vfloat t = p.z; + const vfloat d2 = madd(p.x,p.x,p.y*p.y); + const vfloat r = p.w; + const vfloat r2 = r*r; + valid &= (d2 <= r2) & (vfloat(ray.tnear()[k]) <= t) & (t <= vfloat(ray.tfar[k])); + if (EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR != 0.0f) + valid &= t > float(EMBREE_CURVE_SELF_INTERSECTION_AVOIDANCE_FACTOR)*r*depth_scale; // ignore self intersections + if (unlikely(none(valid))) return false; + + /* ignore denormalized segments */ + const Vec3vf T = v1.xyz()-v0.xyz(); + valid &= (T.x != vfloat(zero)) | (T.y != vfloat(zero)) | (T.z != vfloat(zero)); + if (unlikely(none(valid))) return false; + + /* update hit information */ + LineIntersectorHitM hit(u,zero,t,T); + return epilog(valid,hit); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/linei.h b/thirdparty/embree/kernels/geometry/linei.h new file mode 100644 index 000000000000..a72029ca53d7 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/linei.h @@ -0,0 +1,709 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" + +namespace embree +{ + template + struct LineMi + { + /* Virtual interface to query information about the line segment type */ + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + /* primitive supports multiple time segments */ + static const bool singleTimeSegment = false; + + /* Returns maximum number of stored line segments */ + static __forceinline size_t max_size() { return M; } + + /* Returns required number of primitive blocks for N line segments */ + static __forceinline size_t blocks(size_t N) { return (N+max_size()-1)/max_size(); } + + /* Returns required number of bytes for N line segments */ + static __forceinline size_t bytes(size_t N) { return blocks(N)*sizeof(LineMi); } + + public: + + /* Default constructor */ + __forceinline LineMi() { } + + /* Construction from vertices and IDs */ + __forceinline LineMi(const vuint& v0, unsigned short leftExists, unsigned short rightExists, const vuint& geomIDs, const vuint& primIDs, Geometry::GType gtype) + : gtype((unsigned char)gtype), m((unsigned char)popcnt(vuint(primIDs) != vuint(-1))), sharedGeomID(geomIDs[0]), leftExists (leftExists), rightExists(rightExists), v0(v0), primIDs(primIDs) + { + assert(all(vuint(geomID()) == geomIDs)); + } + + /* Returns a mask that tells which line segments are valid */ + __forceinline vbool valid() const { return primIDs != vuint(-1); } + + /* Returns a mask that tells which line segments are valid */ + template + __forceinline vbool valid() const { return vuint(primIDs) != vuint(-1); } + + /* Returns if the specified line segment is valid */ + __forceinline bool valid(const size_t i) const { assert(i + //static __forceinline T unmask(T &index) { return index & 0x3fffffff; } + + __forceinline unsigned int geomID(unsigned int i = 0) const { return sharedGeomID; } + //__forceinline vuint geomID() { return unmask(geomIDs); } + //__forceinline const vuint geomID() const { return unmask(geomIDs); } + //__forceinline unsigned int geomID(const size_t i) const { assert(i& primID() { return primIDs; } + __forceinline const vuint& primID() const { return primIDs; } + __forceinline unsigned int primID(const size_t i) const { assert(i& p0, + Vec4vf& p1, + const LineSegments* geom) const; + + __forceinline void gatheri(Vec4vf& p0, + Vec4vf& p1, + const LineSegments* geom, + const int itime) const; + + __forceinline void gather(Vec4vf& p0, + Vec4vf& p1, + const LineSegments* geom, + float time) const; + + /* gather the line segments with lateral info */ + __forceinline void gather(Vec4vf& p0, + Vec4vf& p1, + Vec4vf& pL, + Vec4vf& pR, + const LineSegments* geom) const; + + __forceinline void gatheri(Vec4vf& p0, + Vec4vf& p1, + Vec4vf& pL, + Vec4vf& pR, + const LineSegments* geom, + const int itime) const; + + __forceinline void gather(Vec4vf& p0, + Vec4vf& p1, + Vec4vf& pL, + Vec4vf& pR, + const LineSegments* geom, + float time) const; + + __forceinline void gather(Vec4vf& p0, + Vec4vf& p1, + vbool& cL, + vbool& cR, + const LineSegments* geom) const; + + __forceinline void gatheri(Vec4vf& p0, + Vec4vf& p1, + vbool& cL, + vbool& cR, + const LineSegments* geom, + const int itime) const; + + __forceinline void gather(Vec4vf& p0, + Vec4vf& p1, + vbool& cL, + vbool& cR, + const LineSegments* geom, + float time) const; + + /* Calculate the bounds of the line segments */ + __forceinline const BBox3fa bounds(const Scene* scene, size_t itime = 0) const + { + BBox3fa bounds = empty; + for (size_t i=0; iget(geomID(i)); + const Vec3ff& p0 = geom->vertex(v0[i]+0,itime); + const Vec3ff& p1 = geom->vertex(v0[i]+1,itime); + BBox3fa b = merge(BBox3fa(p0),BBox3fa(p1)); + b = enlarge(b,Vec3fa(max(p0.w,p1.w))); + bounds.extend(b); + } + return bounds; + } + + /* Calculate the linear bounds of the primitive */ + __forceinline LBBox3fa linearBounds(const Scene* scene, size_t itime) { + return LBBox3fa(bounds(scene,itime+0), bounds(scene,itime+1)); + } + + __forceinline LBBox3fa linearBounds(const Scene *const scene, size_t itime, size_t numTimeSteps) { + LBBox3fa allBounds = empty; + for (size_t i=0; iget(geomID(i)); + allBounds.extend(geom->linearBounds(primID(i), itime, numTimeSteps)); + } + return allBounds; + } + + __forceinline LBBox3fa linearBounds(const Scene *const scene, const BBox1f time_range) + { + LBBox3fa allBounds = empty; + for (size_t i=0; iget(geomID((unsigned int)i)); + allBounds.extend(geom->linearBounds(primID(i), time_range)); + } + return allBounds; + } + + /* Fill line segment from line segment list */ + template + __forceinline void fill(const PrimRefT* prims, size_t& begin, size_t end, Scene* scene) + { + Geometry::GType gty = scene->get(prims[begin].geomID())->getType(); + vuint geomID, primID; + vuint v0; + unsigned short leftExists = 0; + unsigned short rightExists = 0; + const PrimRefT* prim = &prims[begin]; + + for (size_t i=0; iget(prim->geomID()); + if (begingeomID(); + primID[i] = prim->primID(); + v0[i] = geom->segment(prim->primID()); + leftExists |= geom->segmentLeftExists(primID[i]) << i; + rightExists |= geom->segmentRightExists(primID[i]) << i; + begin++; + } else { + assert(i); + if (i>0) { + geomID[i] = geomID[i-1]; + primID[i] = -1; + v0[i] = v0[i-1]; + } + } + if (begin + __forceinline static typename BVH::NodeRef createLeaf (BVH* bvh, const PrimRef* prims, const range& set, const Allocator& alloc) + { + size_t start = set.begin(); + size_t items = LineMi::blocks(set.size()); + size_t numbytes = LineMi::bytes(set.size()); + LineMi* accel = (LineMi*) alloc.malloc1(numbytes,M*sizeof(float)); + for (size_t i=0; iscene); + } + return bvh->encodeLeaf((char*)accel,items); + }; + + __forceinline LBBox3fa fillMB(const PrimRef* prims, size_t& begin, size_t end, Scene* scene, size_t itime) + { + fill(prims,begin,end,scene); + return linearBounds(scene,itime); + } + + __forceinline LBBox3fa fillMB(const PrimRefMB* prims, size_t& begin, size_t end, Scene* scene, const BBox1f time_range) + { + fill(prims,begin,end,scene); + return linearBounds(scene,time_range); + } + + template + __forceinline static typename BVH::NodeRecordMB4D createLeafMB(BVH* bvh, const SetMB& prims, const Allocator& alloc) + { + size_t start = prims.begin(); + size_t end = prims.end(); + size_t items = LineMi::blocks(prims.size()); + size_t numbytes = LineMi::bytes(prims.size()); + LineMi* accel = (LineMi*) alloc.malloc1(numbytes,M*sizeof(float)); + const typename BVH::NodeRef node = bvh->encodeLeaf((char*)accel,items); + + LBBox3fa bounds = empty; + for (size_t i=0; idata(),start,end,bvh->scene,prims.time_range)); + + return typename BVH::NodeRecordMB4D(node,bounds,prims.time_range); + }; + + /* Updates the primitive */ + __forceinline BBox3fa update(LineSegments* geom) + { + BBox3fa bounds = empty; + for (size_t i=0; ivertex(v0[i]+0); + const Vec3ff& p1 = geom->vertex(v0[i]+1); + BBox3fa b = merge(BBox3fa(p0),BBox3fa(p1)); + b = enlarge(b,Vec3fa(max(p0.w,p1.w))); + bounds.extend(b); + } + return bounds; + } + + /*! output operator */ + friend __forceinline embree_ostream operator<<(embree_ostream cout, const LineMi& line) { + return cout << "Line" << M << "i {" << line.v0 << ", " << line.geomID() << ", " << line.primID() << "}"; + } + + public: + unsigned char gtype; + unsigned char m; + unsigned int sharedGeomID; + unsigned short leftExists, rightExists; + vuint v0; // index of start vertex + private: + vuint primIDs; // primitive ID + }; + + template<> + __forceinline void LineMi<4>::gather(Vec4vf4& p0, + Vec4vf4& p1, + const LineSegments* geom) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(v0[0])); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(v0[1])); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(v0[2])); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(v0[3])); + transpose(a0,a1,a2,a3,p0.x,p0.y,p0.z,p0.w); + + const vfloat4 b0 = vfloat4::loadu(geom->vertexPtr(v0[0]+1)); + const vfloat4 b1 = vfloat4::loadu(geom->vertexPtr(v0[1]+1)); + const vfloat4 b2 = vfloat4::loadu(geom->vertexPtr(v0[2]+1)); + const vfloat4 b3 = vfloat4::loadu(geom->vertexPtr(v0[3]+1)); + transpose(b0,b1,b2,b3,p1.x,p1.y,p1.z,p1.w); + } + + template<> + __forceinline void LineMi<4>::gatheri(Vec4vf4& p0, + Vec4vf4& p1, + const LineSegments* geom, + const int itime) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(v0[0],itime)); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(v0[1],itime)); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(v0[2],itime)); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(v0[3],itime)); + transpose(a0,a1,a2,a3,p0.x,p0.y,p0.z,p0.w); + + const vfloat4 b0 = vfloat4::loadu(geom->vertexPtr(v0[0]+1,itime)); + const vfloat4 b1 = vfloat4::loadu(geom->vertexPtr(v0[1]+1,itime)); + const vfloat4 b2 = vfloat4::loadu(geom->vertexPtr(v0[2]+1,itime)); + const vfloat4 b3 = vfloat4::loadu(geom->vertexPtr(v0[3]+1,itime)); + transpose(b0,b1,b2,b3,p1.x,p1.y,p1.z,p1.w); + } + + template<> + __forceinline void LineMi<4>::gather(Vec4vf4& p0, + Vec4vf4& p1, + const LineSegments* geom, + float time) const + { + float ftime; + const int itime = geom->timeSegment(time, ftime); + + Vec4vf4 a0,a1; + gatheri(a0,a1,geom,itime); + Vec4vf4 b0,b1; + gatheri(b0,b1,geom,itime+1); + p0 = lerp(a0,b0,vfloat4(ftime)); + p1 = lerp(a1,b1,vfloat4(ftime)); + } + + template<> + __forceinline void LineMi<4>::gather(Vec4vf4& p0, + Vec4vf4& p1, + vbool4& cL, + vbool4& cR, + const LineSegments* geom) const + { + gather(p0,p1,geom); + cL = !vbool4(leftExists); + cR = !vbool4(rightExists); + } + + template<> + __forceinline void LineMi<4>::gatheri(Vec4vf4& p0, + Vec4vf4& p1, + vbool4& cL, + vbool4& cR, + const LineSegments* geom, + const int itime) const + { + gatheri(p0,p1,geom,itime); + cL = !vbool4(leftExists); + cR = !vbool4(rightExists); + } + + template<> + __forceinline void LineMi<4>::gather(Vec4vf4& p0, + Vec4vf4& p1, + vbool4& cL, + vbool4& cR, + const LineSegments* geom, + float time) const + { + float ftime; + const int itime = geom->timeSegment(time, ftime); + + Vec4vf4 a0,a1; + gatheri(a0,a1,geom,itime); + Vec4vf4 b0,b1; + gatheri(b0,b1,geom,itime+1); + p0 = lerp(a0,b0,vfloat4(ftime)); + p1 = lerp(a1,b1,vfloat4(ftime)); + cL = !vbool4(leftExists); + cR = !vbool4(rightExists); + } + + template<> + __forceinline void LineMi<4>::gather(Vec4vf4& p0, + Vec4vf4& p1, + Vec4vf4& pL, + Vec4vf4& pR, + const LineSegments* geom) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(v0[0])); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(v0[1])); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(v0[2])); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(v0[3])); + transpose(a0,a1,a2,a3,p0.x,p0.y,p0.z,p0.w); + + const vfloat4 b0 = vfloat4::loadu(geom->vertexPtr(v0[0]+1)); + const vfloat4 b1 = vfloat4::loadu(geom->vertexPtr(v0[1]+1)); + const vfloat4 b2 = vfloat4::loadu(geom->vertexPtr(v0[2]+1)); + const vfloat4 b3 = vfloat4::loadu(geom->vertexPtr(v0[3]+1)); + transpose(b0,b1,b2,b3,p1.x,p1.y,p1.z,p1.w); + + const vfloat4 l0 = (leftExists & (1<<0)) ? vfloat4::loadu(geom->vertexPtr(v0[0]-1)) : vfloat4(inf); + const vfloat4 l1 = (leftExists & (1<<1)) ? vfloat4::loadu(geom->vertexPtr(v0[1]-1)) : vfloat4(inf); + const vfloat4 l2 = (leftExists & (1<<2)) ? vfloat4::loadu(geom->vertexPtr(v0[2]-1)) : vfloat4(inf); + const vfloat4 l3 = (leftExists & (1<<3)) ? vfloat4::loadu(geom->vertexPtr(v0[3]-1)) : vfloat4(inf); + transpose(l0,l1,l2,l3,pL.x,pL.y,pL.z,pL.w); + + const vfloat4 r0 = (rightExists & (1<<0)) ? vfloat4::loadu(geom->vertexPtr(v0[0]+2)) : vfloat4(inf); + const vfloat4 r1 = (rightExists & (1<<1)) ? vfloat4::loadu(geom->vertexPtr(v0[1]+2)) : vfloat4(inf); + const vfloat4 r2 = (rightExists & (1<<2)) ? vfloat4::loadu(geom->vertexPtr(v0[2]+2)) : vfloat4(inf); + const vfloat4 r3 = (rightExists & (1<<3)) ? vfloat4::loadu(geom->vertexPtr(v0[3]+2)) : vfloat4(inf); + transpose(r0,r1,r2,r3,pR.x,pR.y,pR.z,pR.w); + } + + template<> + __forceinline void LineMi<4>::gatheri(Vec4vf4& p0, + Vec4vf4& p1, + Vec4vf4& pL, + Vec4vf4& pR, + const LineSegments* geom, + const int itime) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(v0[0],itime)); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(v0[1],itime)); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(v0[2],itime)); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(v0[3],itime)); + transpose(a0,a1,a2,a3,p0.x,p0.y,p0.z,p0.w); + + const vfloat4 b0 = vfloat4::loadu(geom->vertexPtr(v0[0]+1,itime)); + const vfloat4 b1 = vfloat4::loadu(geom->vertexPtr(v0[1]+1,itime)); + const vfloat4 b2 = vfloat4::loadu(geom->vertexPtr(v0[2]+1,itime)); + const vfloat4 b3 = vfloat4::loadu(geom->vertexPtr(v0[3]+1,itime)); + transpose(b0,b1,b2,b3,p1.x,p1.y,p1.z,p1.w); + + const vfloat4 l0 = (leftExists & (1<<0)) ? vfloat4::loadu(geom->vertexPtr(v0[0]-1,itime)) : vfloat4(inf); + const vfloat4 l1 = (leftExists & (1<<1)) ? vfloat4::loadu(geom->vertexPtr(v0[1]-1,itime)) : vfloat4(inf); + const vfloat4 l2 = (leftExists & (1<<2)) ? vfloat4::loadu(geom->vertexPtr(v0[2]-1,itime)) : vfloat4(inf); + const vfloat4 l3 = (leftExists & (1<<3)) ? vfloat4::loadu(geom->vertexPtr(v0[3]-1,itime)) : vfloat4(inf); + transpose(l0,l1,l2,l3,pL.x,pL.y,pL.z,pL.w); + + const vfloat4 r0 = (rightExists & (1<<0)) ? vfloat4::loadu(geom->vertexPtr(v0[0]+2,itime)) : vfloat4(inf); + const vfloat4 r1 = (rightExists & (1<<1)) ? vfloat4::loadu(geom->vertexPtr(v0[1]+2,itime)) : vfloat4(inf); + const vfloat4 r2 = (rightExists & (1<<2)) ? vfloat4::loadu(geom->vertexPtr(v0[2]+2,itime)) : vfloat4(inf); + const vfloat4 r3 = (rightExists & (1<<3)) ? vfloat4::loadu(geom->vertexPtr(v0[3]+2,itime)) : vfloat4(inf); + transpose(r0,r1,r2,r3,pR.x,pR.y,pR.z,pR.w); + } + + template<> + __forceinline void LineMi<4>::gather(Vec4vf4& p0, + Vec4vf4& p1, + Vec4vf4& pL, + Vec4vf4& pR, + const LineSegments* geom, + float time) const + { + float ftime; + const int itime = geom->timeSegment(time, ftime); + + Vec4vf4 a0,a1,aL,aR; + gatheri(a0,a1,aL,aR,geom,itime); + Vec4vf4 b0,b1,bL,bR; + gatheri(b0,b1,bL,bR,geom,itime+1); + p0 = lerp(a0,b0,vfloat4(ftime)); + p1 = lerp(a1,b1,vfloat4(ftime)); + pL = lerp(aL,bL,vfloat4(ftime)); + pR = lerp(aR,bR,vfloat4(ftime)); + } + +#if defined(__AVX__) + + template<> + __forceinline void LineMi<8>::gather(Vec4vf8& p0, + Vec4vf8& p1, + const LineSegments* geom) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(v0[0])); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(v0[1])); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(v0[2])); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(v0[3])); + const vfloat4 a4 = vfloat4::loadu(geom->vertexPtr(v0[4])); + const vfloat4 a5 = vfloat4::loadu(geom->vertexPtr(v0[5])); + const vfloat4 a6 = vfloat4::loadu(geom->vertexPtr(v0[6])); + const vfloat4 a7 = vfloat4::loadu(geom->vertexPtr(v0[7])); + transpose(a0,a1,a2,a3,a4,a5,a6,a7,p0.x,p0.y,p0.z,p0.w); + + const vfloat4 b0 = vfloat4::loadu(geom->vertexPtr(v0[0]+1)); + const vfloat4 b1 = vfloat4::loadu(geom->vertexPtr(v0[1]+1)); + const vfloat4 b2 = vfloat4::loadu(geom->vertexPtr(v0[2]+1)); + const vfloat4 b3 = vfloat4::loadu(geom->vertexPtr(v0[3]+1)); + const vfloat4 b4 = vfloat4::loadu(geom->vertexPtr(v0[4]+1)); + const vfloat4 b5 = vfloat4::loadu(geom->vertexPtr(v0[5]+1)); + const vfloat4 b6 = vfloat4::loadu(geom->vertexPtr(v0[6]+1)); + const vfloat4 b7 = vfloat4::loadu(geom->vertexPtr(v0[7]+1)); + transpose(b0,b1,b2,b3,b4,b5,b6,b7,p1.x,p1.y,p1.z,p1.w); + } + + template<> + __forceinline void LineMi<8>::gatheri(Vec4vf8& p0, + Vec4vf8& p1, + const LineSegments* geom, + const int itime) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(v0[0],itime)); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(v0[1],itime)); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(v0[2],itime)); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(v0[3],itime)); + const vfloat4 a4 = vfloat4::loadu(geom->vertexPtr(v0[4],itime)); + const vfloat4 a5 = vfloat4::loadu(geom->vertexPtr(v0[5],itime)); + const vfloat4 a6 = vfloat4::loadu(geom->vertexPtr(v0[6],itime)); + const vfloat4 a7 = vfloat4::loadu(geom->vertexPtr(v0[7],itime)); + transpose(a0,a1,a2,a3,a4,a5,a6,a7,p0.x,p0.y,p0.z,p0.w); + + const vfloat4 b0 = vfloat4::loadu(geom->vertexPtr(v0[0]+1,itime)); + const vfloat4 b1 = vfloat4::loadu(geom->vertexPtr(v0[1]+1,itime)); + const vfloat4 b2 = vfloat4::loadu(geom->vertexPtr(v0[2]+1,itime)); + const vfloat4 b3 = vfloat4::loadu(geom->vertexPtr(v0[3]+1,itime)); + const vfloat4 b4 = vfloat4::loadu(geom->vertexPtr(v0[4]+1,itime)); + const vfloat4 b5 = vfloat4::loadu(geom->vertexPtr(v0[5]+1,itime)); + const vfloat4 b6 = vfloat4::loadu(geom->vertexPtr(v0[6]+1,itime)); + const vfloat4 b7 = vfloat4::loadu(geom->vertexPtr(v0[7]+1,itime)); + transpose(b0,b1,b2,b3,b4,b5,b6,b7,p1.x,p1.y,p1.z,p1.w); + } + + template<> + __forceinline void LineMi<8>::gather(Vec4vf8& p0, + Vec4vf8& p1, + const LineSegments* geom, + float time) const + { + float ftime; + const int itime = geom->timeSegment(time, ftime); + + Vec4vf8 a0,a1; + gatheri(a0,a1,geom,itime); + Vec4vf8 b0,b1; + gatheri(b0,b1,geom,itime+1); + p0 = lerp(a0,b0,vfloat8(ftime)); + p1 = lerp(a1,b1,vfloat8(ftime)); + } + + template<> + __forceinline void LineMi<8>::gather(Vec4vf8& p0, + Vec4vf8& p1, + Vec4vf8& pL, + Vec4vf8& pR, + const LineSegments* geom) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(v0[0])); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(v0[1])); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(v0[2])); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(v0[3])); + const vfloat4 a4 = vfloat4::loadu(geom->vertexPtr(v0[4])); + const vfloat4 a5 = vfloat4::loadu(geom->vertexPtr(v0[5])); + const vfloat4 a6 = vfloat4::loadu(geom->vertexPtr(v0[6])); + const vfloat4 a7 = vfloat4::loadu(geom->vertexPtr(v0[7])); + transpose(a0,a1,a2,a3,a4,a5,a6,a7,p0.x,p0.y,p0.z,p0.w); + + const vfloat4 b0 = vfloat4::loadu(geom->vertexPtr(v0[0]+1)); + const vfloat4 b1 = vfloat4::loadu(geom->vertexPtr(v0[1]+1)); + const vfloat4 b2 = vfloat4::loadu(geom->vertexPtr(v0[2]+1)); + const vfloat4 b3 = vfloat4::loadu(geom->vertexPtr(v0[3]+1)); + const vfloat4 b4 = vfloat4::loadu(geom->vertexPtr(v0[4]+1)); + const vfloat4 b5 = vfloat4::loadu(geom->vertexPtr(v0[5]+1)); + const vfloat4 b6 = vfloat4::loadu(geom->vertexPtr(v0[6]+1)); + const vfloat4 b7 = vfloat4::loadu(geom->vertexPtr(v0[7]+1)); + transpose(b0,b1,b2,b3,b4,b5,b6,b7,p1.x,p1.y,p1.z,p1.w); + + const vfloat4 l0 = (leftExists & (1<<0)) ? vfloat4::loadu(geom->vertexPtr(v0[0]-1)) : vfloat4(inf); + const vfloat4 l1 = (leftExists & (1<<1)) ? vfloat4::loadu(geom->vertexPtr(v0[1]-1)) : vfloat4(inf); + const vfloat4 l2 = (leftExists & (1<<2)) ? vfloat4::loadu(geom->vertexPtr(v0[2]-1)) : vfloat4(inf); + const vfloat4 l3 = (leftExists & (1<<3)) ? vfloat4::loadu(geom->vertexPtr(v0[3]-1)) : vfloat4(inf); + const vfloat4 l4 = (leftExists & (1<<4)) ? vfloat4::loadu(geom->vertexPtr(v0[4]-1)) : vfloat4(inf); + const vfloat4 l5 = (leftExists & (1<<5)) ? vfloat4::loadu(geom->vertexPtr(v0[5]-1)) : vfloat4(inf); + const vfloat4 l6 = (leftExists & (1<<6)) ? vfloat4::loadu(geom->vertexPtr(v0[6]-1)) : vfloat4(inf); + const vfloat4 l7 = (leftExists & (1<<7)) ? vfloat4::loadu(geom->vertexPtr(v0[7]-1)) : vfloat4(inf); + transpose(l0,l1,l2,l3,l4,l5,l6,l7,pL.x,pL.y,pL.z,pL.w); + + const vfloat4 r0 = (rightExists & (1<<0)) ? vfloat4::loadu(geom->vertexPtr(v0[0]+2)) : vfloat4(inf); + const vfloat4 r1 = (rightExists & (1<<1)) ? vfloat4::loadu(geom->vertexPtr(v0[1]+2)) : vfloat4(inf); + const vfloat4 r2 = (rightExists & (1<<2)) ? vfloat4::loadu(geom->vertexPtr(v0[2]+2)) : vfloat4(inf); + const vfloat4 r3 = (rightExists & (1<<3)) ? vfloat4::loadu(geom->vertexPtr(v0[3]+2)) : vfloat4(inf); + const vfloat4 r4 = (rightExists & (1<<4)) ? vfloat4::loadu(geom->vertexPtr(v0[4]+2)) : vfloat4(inf); + const vfloat4 r5 = (rightExists & (1<<5)) ? vfloat4::loadu(geom->vertexPtr(v0[5]+2)) : vfloat4(inf); + const vfloat4 r6 = (rightExists & (1<<6)) ? vfloat4::loadu(geom->vertexPtr(v0[6]+2)) : vfloat4(inf); + const vfloat4 r7 = (rightExists & (1<<7)) ? vfloat4::loadu(geom->vertexPtr(v0[7]+2)) : vfloat4(inf); + transpose(r0,r1,r2,r3,r4,r5,r6,r7,pR.x,pR.y,pR.z,pR.w); + } + + template<> + __forceinline void LineMi<8>::gatheri(Vec4vf8& p0, + Vec4vf8& p1, + Vec4vf8& pL, + Vec4vf8& pR, + const LineSegments* geom, + const int itime) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(v0[0],itime)); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(v0[1],itime)); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(v0[2],itime)); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(v0[3],itime)); + const vfloat4 a4 = vfloat4::loadu(geom->vertexPtr(v0[4],itime)); + const vfloat4 a5 = vfloat4::loadu(geom->vertexPtr(v0[5],itime)); + const vfloat4 a6 = vfloat4::loadu(geom->vertexPtr(v0[6],itime)); + const vfloat4 a7 = vfloat4::loadu(geom->vertexPtr(v0[7],itime)); + transpose(a0,a1,a2,a3,a4,a5,a6,a7,p0.x,p0.y,p0.z,p0.w); + + const vfloat4 b0 = vfloat4::loadu(geom->vertexPtr(v0[0]+1,itime)); + const vfloat4 b1 = vfloat4::loadu(geom->vertexPtr(v0[1]+1,itime)); + const vfloat4 b2 = vfloat4::loadu(geom->vertexPtr(v0[2]+1,itime)); + const vfloat4 b3 = vfloat4::loadu(geom->vertexPtr(v0[3]+1,itime)); + const vfloat4 b4 = vfloat4::loadu(geom->vertexPtr(v0[4]+1,itime)); + const vfloat4 b5 = vfloat4::loadu(geom->vertexPtr(v0[5]+1,itime)); + const vfloat4 b6 = vfloat4::loadu(geom->vertexPtr(v0[6]+1,itime)); + const vfloat4 b7 = vfloat4::loadu(geom->vertexPtr(v0[7]+1,itime)); + transpose(b0,b1,b2,b3,b4,b5,b6,b7,p1.x,p1.y,p1.z,p1.w); + + const vfloat4 l0 = (leftExists & (1<<0)) ? vfloat4::loadu(geom->vertexPtr(v0[0]-1,itime)) : vfloat4(inf); + const vfloat4 l1 = (leftExists & (1<<1)) ? vfloat4::loadu(geom->vertexPtr(v0[1]-1,itime)) : vfloat4(inf); + const vfloat4 l2 = (leftExists & (1<<2)) ? vfloat4::loadu(geom->vertexPtr(v0[2]-1,itime)) : vfloat4(inf); + const vfloat4 l3 = (leftExists & (1<<3)) ? vfloat4::loadu(geom->vertexPtr(v0[3]-1,itime)) : vfloat4(inf); + const vfloat4 l4 = (leftExists & (1<<4)) ? vfloat4::loadu(geom->vertexPtr(v0[4]-1,itime)) : vfloat4(inf); + const vfloat4 l5 = (leftExists & (1<<5)) ? vfloat4::loadu(geom->vertexPtr(v0[5]-1,itime)) : vfloat4(inf); + const vfloat4 l6 = (leftExists & (1<<6)) ? vfloat4::loadu(geom->vertexPtr(v0[6]-1,itime)) : vfloat4(inf); + const vfloat4 l7 = (leftExists & (1<<7)) ? vfloat4::loadu(geom->vertexPtr(v0[7]-1,itime)) : vfloat4(inf); + transpose(l0,l1,l2,l3,l4,l5,l6,l7,pL.x,pL.y,pL.z,pL.w); + + const vfloat4 r0 = (rightExists & (1<<0)) ? vfloat4::loadu(geom->vertexPtr(v0[0]+2,itime)) : vfloat4(inf); + const vfloat4 r1 = (rightExists & (1<<1)) ? vfloat4::loadu(geom->vertexPtr(v0[1]+2,itime)) : vfloat4(inf); + const vfloat4 r2 = (rightExists & (1<<2)) ? vfloat4::loadu(geom->vertexPtr(v0[2]+2,itime)) : vfloat4(inf); + const vfloat4 r3 = (rightExists & (1<<3)) ? vfloat4::loadu(geom->vertexPtr(v0[3]+2,itime)) : vfloat4(inf); + const vfloat4 r4 = (rightExists & (1<<4)) ? vfloat4::loadu(geom->vertexPtr(v0[4]+2,itime)) : vfloat4(inf); + const vfloat4 r5 = (rightExists & (1<<5)) ? vfloat4::loadu(geom->vertexPtr(v0[5]+2,itime)) : vfloat4(inf); + const vfloat4 r6 = (rightExists & (1<<6)) ? vfloat4::loadu(geom->vertexPtr(v0[6]+2,itime)) : vfloat4(inf); + const vfloat4 r7 = (rightExists & (1<<7)) ? vfloat4::loadu(geom->vertexPtr(v0[7]+2,itime)) : vfloat4(inf); + transpose(r0,r1,r2,r3,r4,r5,r6,r7,pR.x,pR.y,pR.z,pR.w); + } + + template<> + __forceinline void LineMi<8>::gather(Vec4vf8& p0, + Vec4vf8& p1, + Vec4vf8& pL, + Vec4vf8& pR, + const LineSegments* geom, + float time) const + { + float ftime; + const int itime = geom->timeSegment(time, ftime); + + Vec4vf8 a0,a1,aL,aR; + gatheri(a0,a1,aL,aR,geom,itime); + Vec4vf8 b0,b1,bL,bR; + gatheri(b0,b1,bL,bR,geom,itime+1); + p0 = lerp(a0,b0,vfloat8(ftime)); + p1 = lerp(a1,b1,vfloat8(ftime)); + pL = lerp(aL,bL,vfloat8(ftime)); + pR = lerp(aR,bR,vfloat8(ftime)); + } + + template<> + __forceinline void LineMi<8>::gather(Vec4vf8& p0, + Vec4vf8& p1, + vbool8& cL, + vbool8& cR, + const LineSegments* geom) const + { + gather(p0,p1,geom); + cL = !vbool8(leftExists); + cR = !vbool8(rightExists); + } + + template<> + __forceinline void LineMi<8>::gatheri(Vec4vf8& p0, + Vec4vf8& p1, + vbool8& cL, + vbool8& cR, + const LineSegments* geom, + const int itime) const + { + gatheri(p0,p1,geom,itime); + cL = !vbool8(leftExists); + cR = !vbool8(rightExists); + } + + template<> + __forceinline void LineMi<8>::gather(Vec4vf8& p0, + Vec4vf8& p1, + vbool8& cL, + vbool8& cR, + const LineSegments* geom, + float time) const + { + float ftime; + const int itime = geom->timeSegment(time, ftime); + + Vec4vf8 a0,a1; + gatheri(a0,a1,geom,itime); + Vec4vf8 b0,b1; + gatheri(b0,b1,geom,itime+1); + p0 = lerp(a0,b0,vfloat8(ftime)); + p1 = lerp(a1,b1,vfloat8(ftime)); + cL = !vbool8(leftExists); + cR = !vbool8(rightExists); + } + +#endif + + template + typename LineMi::Type LineMi::type; + + typedef LineMi<4> Line4i; + typedef LineMi<8> Line8i; +} diff --git a/thirdparty/embree/kernels/geometry/linei_intersector.h b/thirdparty/embree/kernels/geometry/linei_intersector.h new file mode 100644 index 000000000000..a431796a88d5 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/linei_intersector.h @@ -0,0 +1,124 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "linei.h" +#include "line_intersector.h" +#include "intersector_epilog.h" + +namespace embree +{ + namespace isa + { + template + struct FlatLinearCurveMiIntersector1 + { + typedef LineMi Primitive; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& line) + { + STAT3(normal.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; line.gather(v0,v1,geom); + const vbool valid = line.template valid(); + FlatLinearCurveIntersector1::intersect(valid,ray,context,geom,pre,v0,v1,Intersect1EpilogM(ray,context,line.geomID(),line.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& line) + { + STAT3(shadow.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; line.gather(v0,v1,geom); + const vbool valid = line.template valid(); + return FlatLinearCurveIntersector1::intersect(valid,ray,context,geom,pre,v0,v1,Occluded1EpilogM(ray,context,line.geomID(),line.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& line) + { + return PrimitivePointQuery1::pointQuery(query, context, line); + } + }; + + template + struct FlatLinearCurveMiMBIntersector1 + { + typedef LineMi Primitive; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& line) + { + STAT3(normal.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; line.gather(v0,v1,geom,ray.time()); + const vbool valid = line.template valid(); + FlatLinearCurveIntersector1::intersect(valid,ray,context,geom,pre,v0,v1,Intersect1EpilogM(ray,context,line.geomID(),line.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& line) + { + STAT3(shadow.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; line.gather(v0,v1,geom,ray.time()); + const vbool valid = line.template valid(); + return FlatLinearCurveIntersector1::intersect(valid,ray,context,geom,pre,v0,v1,Occluded1EpilogM(ray,context,line.geomID(),line.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& line) + { + return PrimitivePointQuery1::pointQuery(query, context, line); + } + }; + + template + struct FlatLinearCurveMiIntersectorK + { + typedef LineMi Primitive; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& line) + { + STAT3(normal.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; line.gather(v0,v1,geom); + const vbool valid = line.template valid(); + FlatLinearCurveIntersectorK::intersect(valid,ray,k,context,geom,pre,v0,v1,Intersect1KEpilogM(ray,k,context,line.geomID(),line.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& line) + { + STAT3(shadow.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; line.gather(v0,v1,geom); + const vbool valid = line.template valid(); + return FlatLinearCurveIntersectorK::intersect(valid,ray,k,context,geom,pre,v0,v1,Occluded1KEpilogM(ray,k,context,line.geomID(),line.primID())); + } + }; + + template + struct FlatLinearCurveMiMBIntersectorK + { + typedef LineMi Primitive; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& line) + { + STAT3(normal.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; line.gather(v0,v1,geom,ray.time()[k]); + const vbool valid = line.template valid(); + FlatLinearCurveIntersectorK::intersect(valid,ray,k,context,geom,pre,v0,v1,Intersect1KEpilogM(ray,k,context,line.geomID(),line.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& line) + { + STAT3(shadow.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1; line.gather(v0,v1,geom,ray.time()[k]); + const vbool valid = line.template valid(); + return FlatLinearCurveIntersectorK::intersect(valid,ray,k,context,geom,pre,v0,v1,Occluded1KEpilogM(ray,k,context,line.geomID(),line.primID())); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/object.h b/thirdparty/embree/kernels/geometry/object.h new file mode 100644 index 000000000000..f26391de5239 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/object.h @@ -0,0 +1,84 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" + +namespace embree +{ + struct Object + { + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + /* primitive supports multiple time segments */ + static const bool singleTimeSegment = false; + + /* Returns maximum number of stored primitives */ + static __forceinline size_t max_size() { return 1; } + + /* Returns required number of primitive blocks for N primitives */ + static __forceinline size_t blocks(size_t N) { return N; } + + public: + + /*! constructs a virtual object */ + Object (unsigned geomID, unsigned primID) + : _geomID(geomID), _primID(primID) {} + + __forceinline unsigned geomID() const { + return _geomID; + } + + __forceinline unsigned primID() const { + return _primID; + } + + /*! fill triangle from triangle list */ + __forceinline void fill(const PrimRef* prims, size_t& i, size_t end, Scene* scene) + { + const PrimRef& prim = prims[i]; i++; + new (this) Object(prim.geomID(), prim.primID()); + } + + /*! fill triangle from triangle list */ + __forceinline LBBox3fa fillMB(const PrimRef* prims, size_t& i, size_t end, Scene* scene, size_t itime) + { + const PrimRef& prim = prims[i]; i++; + const unsigned geomID = prim.geomID(); + const unsigned primID = prim.primID(); + new (this) Object(geomID, primID); + AccelSet* accel = (AccelSet*) scene->get(geomID); + return accel->linearBounds(primID,itime); + } + + /*! fill triangle from triangle list */ + __forceinline LBBox3fa fillMB(const PrimRefMB* prims, size_t& i, size_t end, Scene* scene, const BBox1f time_range) + { + const PrimRefMB& prim = prims[i]; i++; + const unsigned geomID = prim.geomID(); + const unsigned primID = prim.primID(); + new (this) Object(geomID, primID); + AccelSet* accel = (AccelSet*) scene->get(geomID); + return accel->linearBounds(primID,time_range); + } + + /* Updates the primitive */ + __forceinline BBox3fa update(AccelSet* mesh) { + return mesh->bounds(primID()); + } + + private: + unsigned int _geomID; //!< geometry ID + unsigned int _primID; //!< primitive ID + }; +} diff --git a/thirdparty/embree/kernels/geometry/object_intersector.h b/thirdparty/embree/kernels/geometry/object_intersector.h new file mode 100644 index 000000000000..97882e0e599e --- /dev/null +++ b/thirdparty/embree/kernels/geometry/object_intersector.h @@ -0,0 +1,127 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "object.h" +#include "../common/ray.h" + +namespace embree +{ + namespace isa + { + template + struct ObjectIntersector1 + { + typedef Object Primitive; + + static const bool validIntersectorK = false; + + struct Precalculations { + __forceinline Precalculations() {} + __forceinline Precalculations (const Ray& ray, const void *ptr) {} + }; + + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& prim) + { + AccelSet* accel = (AccelSet*) context->scene->get(prim.geomID()); + + /* perform ray mask test */ +#if defined(EMBREE_RAY_MASK) + if ((ray.mask & accel->mask) == 0) + return; +#endif + + accel->intersect(ray,prim.geomID(),prim.primID(),context,reportIntersection1); + } + + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& prim) + { + AccelSet* accel = (AccelSet*) context->scene->get(prim.geomID()); + /* perform ray mask test */ +#if defined(EMBREE_RAY_MASK) + if ((ray.mask & accel->mask) == 0) + return false; +#endif + + accel->occluded(ray,prim.geomID(),prim.primID(),context,&reportOcclusion1); + return ray.tfar < 0.0f; + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& prim) + { + AccelSet* accel = (AccelSet*)context->scene->get(prim.geomID()); + context->geomID = prim.geomID(); + context->primID = prim.primID(); + return accel->pointQuery(query, context); + } + + template + static __forceinline void intersectK(const vbool& valid, /* PrecalculationsK& pre, */ RayHitK& ray, IntersectContext* context, const Primitive* prim, size_t num, size_t& lazy_node) + { + assert(false); + } + + template + static __forceinline vbool occludedK(const vbool& valid, /* PrecalculationsK& pre, */ RayK& ray, IntersectContext* context, const Primitive* prim, size_t num, size_t& lazy_node) + { + assert(false); + return valid; + } + }; + + template + struct ObjectIntersectorK + { + typedef Object Primitive; + + struct Precalculations { + __forceinline Precalculations (const vbool& valid, const RayK& ray) {} + }; + + static __forceinline void intersect(const vbool& valid_i, const Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive& prim) + { + vbool valid = valid_i; + AccelSet* accel = (AccelSet*) context->scene->get(prim.geomID()); + + /* perform ray mask test */ +#if defined(EMBREE_RAY_MASK) + valid &= (ray.mask & accel->mask) != 0; + if (none(valid)) return; +#endif + accel->intersect(valid,ray,prim.geomID(),prim.primID(),context,&reportIntersection1); + } + + static __forceinline vbool occluded(const vbool& valid_i, const Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive& prim) + { + vbool valid = valid_i; + AccelSet* accel = (AccelSet*) context->scene->get(prim.geomID()); + + /* perform ray mask test */ +#if defined(EMBREE_RAY_MASK) + valid &= (ray.mask & accel->mask) != 0; + if (none(valid)) return false; +#endif + accel->occluded(valid,ray,prim.geomID(),prim.primID(),context,&reportOcclusion1); + return ray.tfar < 0.0f; + } + + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& prim) { + intersect(vbool(1<& ray, size_t k, IntersectContext* context, const Primitive& prim) { + occluded(vbool(1< ObjectIntersector4; + typedef ObjectIntersectorK<8,false> ObjectIntersector8; + typedef ObjectIntersectorK<16,false> ObjectIntersector16; + + typedef ObjectIntersectorK<4,true> ObjectIntersector4MB; + typedef ObjectIntersectorK<8,true> ObjectIntersector8MB; + typedef ObjectIntersectorK<16,true> ObjectIntersector16MB; + } +} diff --git a/thirdparty/embree/kernels/geometry/plane.h b/thirdparty/embree/kernels/geometry/plane.h new file mode 100644 index 000000000000..ebe45db55850 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/plane.h @@ -0,0 +1,57 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" + +namespace embree +{ + namespace isa + { + struct HalfPlane + { + const Vec3fa P; //!< plane origin + const Vec3fa N; //!< plane normal + + __forceinline HalfPlane(const Vec3fa& P, const Vec3fa& N) + : P(P), N(N) {} + + __forceinline BBox1f intersect(const Vec3fa& ray_org, const Vec3fa& ray_dir) const + { + Vec3fa O = Vec3fa(ray_org) - P; + Vec3fa D = Vec3fa(ray_dir); + float ON = dot(O,N); + float DN = dot(D,N); + bool eps = abs(DN) < min_rcp_input; + float t = -ON*rcp(DN); + float lower = select(eps || DN < 0.0f, float(neg_inf), t); + float upper = select(eps || DN > 0.0f, float(pos_inf), t); + return BBox1f(lower,upper); + } + }; + + template + struct HalfPlaneN + { + const Vec3vf P; //!< plane origin + const Vec3vf N; //!< plane normal + + __forceinline HalfPlaneN(const Vec3vf& P, const Vec3vf& N) + : P(P), N(N) {} + + __forceinline BBox> intersect(const Vec3fa& ray_org, const Vec3fa& ray_dir) const + { + Vec3vf O = Vec3vf((Vec3fa)ray_org) - P; + Vec3vf D = Vec3vf((Vec3fa)ray_dir); + vfloat ON = dot(O,N); + vfloat DN = dot(D,N); + vbool eps = abs(DN) < min_rcp_input; + vfloat t = -ON*rcp(DN); + vfloat lower = select(eps | DN < 0.0f, vfloat(neg_inf), t); + vfloat upper = select(eps | DN > 0.0f, vfloat(pos_inf), t); + return BBox>(lower,upper); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/pointi.h b/thirdparty/embree/kernels/geometry/pointi.h new file mode 100644 index 000000000000..4ba298e86b68 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/pointi.h @@ -0,0 +1,417 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" + +namespace embree +{ + template + struct PointMi + { + /* Virtual interface to query information about the line segment type */ + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + /* primitive supports multiple time segments */ + static const bool singleTimeSegment = false; + + /* Returns maximum number of stored line segments */ + static __forceinline size_t max_size() + { + return M; + } + + /* Returns required number of primitive blocks for N line segments */ + static __forceinline size_t blocks(size_t N) + { + return (N + max_size() - 1) / max_size(); + } + + /* Returns required number of bytes for N line segments */ + static __forceinline size_t bytes(size_t N) + { + return blocks(N) * sizeof(PointMi); + } + + public: + /* Default constructor */ + __forceinline PointMi() {} + + /* Construction from vertices and IDs */ + __forceinline PointMi(const vuint& geomIDs, const vuint& primIDs, Geometry::GType gtype, uint32_t numPrimitives) + : gtype((unsigned char)gtype), + numPrimitives(numPrimitives), + sharedGeomID(geomIDs[0]), + primIDs(primIDs) + { + assert(all(vuint(geomID()) == geomIDs)); + } + + /* Returns a mask that tells which line segments are valid */ + __forceinline vbool valid() const { + return vint(step) < vint(numPrimitives); + } + + /* Returns a mask that tells which line segments are valid */ + template __forceinline vbool valid() const { + return vint(step) < vint(numPrimitives); + } + + /* Returns if the specified line segment is valid */ + __forceinline bool valid(const size_t i) const + { + assert(i < M); + return i < numPrimitives; + } + + /* Returns the number of stored line segments */ + __forceinline size_t size() const { + return numPrimitives; + } + + __forceinline unsigned int geomID(unsigned int i = 0) const { + return sharedGeomID; + } + + __forceinline vuint& primID() { + return primIDs; + } + __forceinline const vuint& primID() const { + return primIDs; + } + __forceinline unsigned int primID(const size_t i) const { + assert(i < M); + return primIDs[i]; + } + + /* gather the line segments */ + __forceinline void gather(Vec4vf& p0, const Points* geom) const; + __forceinline void gather(Vec4vf& p0, Vec3vf& n0, const Points* geom) const; + + __forceinline void gatheri(Vec4vf& p0, const Points* geom, const int itime) const; + __forceinline void gatheri(Vec4vf& p0, Vec3vf& n0, const Points* geom, const int itime) const; + + __forceinline void gather(Vec4vf& p0, const Points* geom, float time) const; + __forceinline void gather(Vec4vf& p0, Vec3vf& n0, const Points* geom, float time) const; + + /* Calculate the bounds of the line segments */ + __forceinline const BBox3fa bounds(const Scene* scene, size_t itime = 0) const + { + BBox3fa bounds = empty; + for (size_t i = 0; i < M && valid(i); i++) { + const Points* geom = scene->get(geomID(i)); + bounds.extend(geom->bounds(primID(i),itime)); + } + return bounds; + } + + /* Calculate the linear bounds of the primitive */ + __forceinline LBBox3fa linearBounds(const Scene* scene, size_t itime) { + return LBBox3fa(bounds(scene, itime + 0), bounds(scene, itime + 1)); + } + + __forceinline LBBox3fa linearBounds(const Scene* const scene, size_t itime, size_t numTimeSteps) + { + LBBox3fa allBounds = empty; + for (size_t i = 0; i < M && valid(i); i++) { + const Points* geom = scene->get(geomID(i)); + allBounds.extend(geom->linearBounds(primID(i), itime, numTimeSteps)); + } + return allBounds; + } + + __forceinline LBBox3fa linearBounds(const Scene* const scene, const BBox1f time_range) + { + LBBox3fa allBounds = empty; + for (size_t i = 0; i < M && valid(i); i++) { + const Points* geom = scene->get(geomID((unsigned int)i)); + allBounds.extend(geom->linearBounds(primID(i), time_range)); + } + return allBounds; + } + + /* Fill line segment from line segment list */ + template + __forceinline void fill(const PrimRefT* prims, size_t& begin, size_t end, Scene* scene) + { + Geometry::GType gty = scene->get(prims[begin].geomID())->getType(); + vuint geomID, primID; + vuint v0; + const PrimRefT* prim = &prims[begin]; + + int numPrimitives = 0; + for (size_t i = 0; i < M; i++) { + if (begin < end) { + geomID[i] = prim->geomID(); + primID[i] = prim->primID(); + begin++; + numPrimitives++; + } else { + assert(i); + if (i > 0) { + geomID[i] = geomID[i - 1]; + primID[i] = primID[i - 1]; + } + } + if (begin < end) + prim = &prims[begin]; // FIXME: remove this line + } + new (this) PointMi(geomID, primID, gty, numPrimitives); // FIXME: use non temporal store + } + + template + __forceinline static typename BVH::NodeRef createLeaf(BVH* bvh, + const PrimRef* prims, + const range& set, + const Allocator& alloc) + { + size_t start = set.begin(); + size_t items = PointMi::blocks(set.size()); + size_t numbytes = PointMi::bytes(set.size()); + PointMi* accel = (PointMi*)alloc.malloc1(numbytes, M * sizeof(float)); + for (size_t i = 0; i < items; i++) { + accel[i].fill(prims, start, set.end(), bvh->scene); + } + return bvh->encodeLeaf((char*)accel, items); + }; + + __forceinline LBBox3fa fillMB(const PrimRef* prims, size_t& begin, size_t end, Scene* scene, size_t itime) + { + fill(prims, begin, end, scene); + return linearBounds(scene, itime); + } + + __forceinline LBBox3fa fillMB( + const PrimRefMB* prims, size_t& begin, size_t end, Scene* scene, const BBox1f time_range) + { + fill(prims, begin, end, scene); + return linearBounds(scene, time_range); + } + + template + __forceinline static typename BVH::NodeRecordMB4D createLeafMB(BVH* bvh, const SetMB& prims, const Allocator& alloc) + { + size_t start = prims.object_range.begin(); + size_t end = prims.object_range.end(); + size_t items = PointMi::blocks(prims.object_range.size()); + size_t numbytes = PointMi::bytes(prims.object_range.size()); + PointMi* accel = (PointMi*)alloc.malloc1(numbytes, M * sizeof(float)); + const typename BVH::NodeRef node = bvh->encodeLeaf((char*)accel, items); + + LBBox3fa bounds = empty; + for (size_t i = 0; i < items; i++) + bounds.extend(accel[i].fillMB(prims.prims->data(), start, end, bvh->scene, prims.time_range)); + + return typename BVH::NodeRecordMB4D(node, bounds, prims.time_range); + }; + + /*! output operator */ + friend __forceinline embree_ostream operator<<(embree_ostream cout, const PointMi& line) + { + return cout << "Line" << M << "i {" << line.v0 << ", " << line.geomID() << ", " << line.primID() << "}"; + } + + public: + unsigned char gtype; + unsigned char numPrimitives; + unsigned int sharedGeomID; + + private: + vuint primIDs; // primitive ID + }; + + template<> + __forceinline void PointMi<4>::gather(Vec4vf4& p0, const Points* geom) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(primID(0))); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(primID(1))); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(primID(2))); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(primID(3))); + transpose(a0, a1, a2, a3, p0.x, p0.y, p0.z, p0.w); + } + + template<> + __forceinline void PointMi<4>::gather(Vec4vf4& p0, Vec3vf4& n0, const Points* geom) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(primID(0))); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(primID(1))); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(primID(2))); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(primID(3))); + transpose(a0, a1, a2, a3, p0.x, p0.y, p0.z, p0.w); + const vfloat4 b0 = vfloat4(geom->normal(primID(0))); + const vfloat4 b1 = vfloat4(geom->normal(primID(1))); + const vfloat4 b2 = vfloat4(geom->normal(primID(2))); + const vfloat4 b3 = vfloat4(geom->normal(primID(3))); + transpose(b0, b1, b2, b3, n0.x, n0.y, n0.z); + } + + template<> + __forceinline void PointMi<4>::gatheri(Vec4vf4& p0, const Points* geom, const int itime) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(primID(0), itime)); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(primID(1), itime)); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(primID(2), itime)); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(primID(3), itime)); + transpose(a0, a1, a2, a3, p0.x, p0.y, p0.z, p0.w); + } + + template<> + __forceinline void PointMi<4>::gatheri(Vec4vf4& p0, Vec3vf4& n0, const Points* geom, const int itime) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(primID(0), itime)); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(primID(1), itime)); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(primID(2), itime)); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(primID(3), itime)); + transpose(a0, a1, a2, a3, p0.x, p0.y, p0.z, p0.w); + const vfloat4 b0 = vfloat4(geom->normal(primID(0), itime)); + const vfloat4 b1 = vfloat4(geom->normal(primID(1), itime)); + const vfloat4 b2 = vfloat4(geom->normal(primID(2), itime)); + const vfloat4 b3 = vfloat4(geom->normal(primID(3), itime)); + transpose(b0, b1, b2, b3, n0.x, n0.y, n0.z); + } + + template<> + __forceinline void PointMi<4>::gather(Vec4vf4& p0, const Points* geom, float time) const + { + float ftime; + const int itime = geom->timeSegment(time, ftime); + + Vec4vf4 a0; gatheri(a0, geom, itime); + Vec4vf4 b0; gatheri(b0, geom, itime + 1); + p0 = lerp(a0, b0, vfloat4(ftime)); + } + + template<> + __forceinline void PointMi<4>::gather(Vec4vf4& p0, Vec3vf4& n0, const Points* geom, float time) const + { + float ftime; + const int itime = geom->timeSegment(time, ftime); + + Vec4vf4 a0, b0; + Vec3vf4 norm0, norm1; + gatheri(a0, norm0, geom, itime); + gatheri(b0, norm1, geom, itime + 1); + p0 = lerp(a0, b0, vfloat4(ftime)); + n0 = lerp(norm0, norm1, vfloat4(ftime)); + } + +#if defined(__AVX__) + + template<> + __forceinline void PointMi<8>::gather(Vec4vf8& p0, const Points* geom) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(primID(0))); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(primID(1))); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(primID(2))); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(primID(3))); + const vfloat4 a4 = vfloat4::loadu(geom->vertexPtr(primID(4))); + const vfloat4 a5 = vfloat4::loadu(geom->vertexPtr(primID(5))); + const vfloat4 a6 = vfloat4::loadu(geom->vertexPtr(primID(6))); + const vfloat4 a7 = vfloat4::loadu(geom->vertexPtr(primID(7))); + transpose(a0, a1, a2, a3, a4, a5, a6, a7, p0.x, p0.y, p0.z, p0.w); + } + + template<> + __forceinline void PointMi<8>::gather(Vec4vf8& p0, Vec3vf8& n0, const Points* geom) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(primID(0))); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(primID(1))); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(primID(2))); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(primID(3))); + const vfloat4 a4 = vfloat4::loadu(geom->vertexPtr(primID(4))); + const vfloat4 a5 = vfloat4::loadu(geom->vertexPtr(primID(5))); + const vfloat4 a6 = vfloat4::loadu(geom->vertexPtr(primID(6))); + const vfloat4 a7 = vfloat4::loadu(geom->vertexPtr(primID(7))); + transpose(a0, a1, a2, a3, a4, a5, a6, a7, p0.x, p0.y, p0.z, p0.w); + const vfloat4 b0 = vfloat4(geom->normal(primID(0))); + const vfloat4 b1 = vfloat4(geom->normal(primID(1))); + const vfloat4 b2 = vfloat4(geom->normal(primID(2))); + const vfloat4 b3 = vfloat4(geom->normal(primID(3))); + const vfloat4 b4 = vfloat4(geom->normal(primID(4))); + const vfloat4 b5 = vfloat4(geom->normal(primID(5))); + const vfloat4 b6 = vfloat4(geom->normal(primID(6))); + const vfloat4 b7 = vfloat4(geom->normal(primID(7))); + transpose(b0, b1, b2, b3, b4, b5, b6, b7, n0.x, n0.y, n0.z); + } + + template<> + __forceinline void PointMi<8>::gatheri(Vec4vf8& p0, const Points* geom, const int itime) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(primID(0), itime)); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(primID(1), itime)); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(primID(2), itime)); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(primID(3), itime)); + const vfloat4 a4 = vfloat4::loadu(geom->vertexPtr(primID(4), itime)); + const vfloat4 a5 = vfloat4::loadu(geom->vertexPtr(primID(5), itime)); + const vfloat4 a6 = vfloat4::loadu(geom->vertexPtr(primID(6), itime)); + const vfloat4 a7 = vfloat4::loadu(geom->vertexPtr(primID(7), itime)); + transpose(a0, a1, a2, a3, a4, a5, a6, a7, p0.x, p0.y, p0.z, p0.w); + } + + template<> + __forceinline void PointMi<8>::gatheri(Vec4vf8& p0, Vec3vf8& n0, const Points* geom, const int itime) const + { + const vfloat4 a0 = vfloat4::loadu(geom->vertexPtr(primID(0), itime)); + const vfloat4 a1 = vfloat4::loadu(geom->vertexPtr(primID(1), itime)); + const vfloat4 a2 = vfloat4::loadu(geom->vertexPtr(primID(2), itime)); + const vfloat4 a3 = vfloat4::loadu(geom->vertexPtr(primID(3), itime)); + const vfloat4 a4 = vfloat4::loadu(geom->vertexPtr(primID(4), itime)); + const vfloat4 a5 = vfloat4::loadu(geom->vertexPtr(primID(5), itime)); + const vfloat4 a6 = vfloat4::loadu(geom->vertexPtr(primID(6), itime)); + const vfloat4 a7 = vfloat4::loadu(geom->vertexPtr(primID(7), itime)); + transpose(a0, a1, a2, a3, a4, a5, a6, a7, p0.x, p0.y, p0.z, p0.w); + const vfloat4 b0 = vfloat4(geom->normal(primID(0), itime)); + const vfloat4 b1 = vfloat4(geom->normal(primID(1), itime)); + const vfloat4 b2 = vfloat4(geom->normal(primID(2), itime)); + const vfloat4 b3 = vfloat4(geom->normal(primID(3), itime)); + const vfloat4 b4 = vfloat4(geom->normal(primID(4), itime)); + const vfloat4 b5 = vfloat4(geom->normal(primID(5), itime)); + const vfloat4 b6 = vfloat4(geom->normal(primID(6), itime)); + const vfloat4 b7 = vfloat4(geom->normal(primID(7), itime)); + transpose(b0, b1, b2, b3, b4, b5, b6, b7, n0.x, n0.y, n0.z); + } + + template<> + __forceinline void PointMi<8>::gather(Vec4vf8& p0, const Points* geom, float time) const + { + float ftime; + const int itime = geom->timeSegment(time, ftime); + + Vec4vf8 a0; + gatheri(a0, geom, itime); + Vec4vf8 b0; + gatheri(b0, geom, itime + 1); + p0 = lerp(a0, b0, vfloat8(ftime)); + } + + template<> + __forceinline void PointMi<8>::gather(Vec4vf8& p0, Vec3vf8& n0, const Points* geom, float time) const + { + float ftime; + const int itime = geom->timeSegment(time, ftime); + + Vec4vf8 a0, b0; + Vec3vf8 norm0, norm1; + gatheri(a0, norm0, geom, itime); + gatheri(b0, norm1, geom, itime + 1); + p0 = lerp(a0, b0, vfloat8(ftime)); + n0 = lerp(norm0, norm1, vfloat8(ftime)); + } +#endif + + template + typename PointMi::Type PointMi::type; + + typedef PointMi<4> Point4i; + typedef PointMi<8> Point8i; + +} // namespace embree diff --git a/thirdparty/embree/kernels/geometry/primitive.h b/thirdparty/embree/kernels/geometry/primitive.h new file mode 100644 index 000000000000..41e5b2b3042c --- /dev/null +++ b/thirdparty/embree/kernels/geometry/primitive.h @@ -0,0 +1,49 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/default.h" +#include "../common/scene.h" +#include "../../common/simd/simd.h" +#include "../common/primref.h" +#include "../common/primref_mb.h" + +namespace embree +{ + struct PrimitiveType + { + /*! returns name of this primitive type */ + virtual const char* name() const = 0; + + /*! Returns the number of stored active primitives in a block. */ + virtual size_t sizeActive(const char* This) const = 0; + + /*! Returns the number of stored active and inactive primitives in a block. */ + virtual size_t sizeTotal(const char* This) const = 0; + + /*! Returns the number of bytes of block. */ + virtual size_t getBytes(const char* This) const = 0; + }; + + template + struct PrimitivePointQuery1 + { + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& prim) + { + bool changed = false; + for (size_t i = 0; i < Primitive::max_size(); i++) + { + if (!prim.valid(i)) break; + STAT3(point_query.trav_prims,1,1,1); + AccelSet* accel = (AccelSet*)context->scene->get(prim.geomID(i)); + context->geomID = prim.geomID(i); + context->primID = prim.primID(i); + changed |= accel->pointQuery(query, context); + } + return changed; + } + + static __forceinline void pointQueryNoop(PointQuery* query, PointQueryContext* context, const Primitive& prim) { } + }; +} diff --git a/thirdparty/embree/kernels/geometry/primitive4.cpp b/thirdparty/embree/kernels/geometry/primitive4.cpp new file mode 100644 index 000000000000..f93574c9c8d0 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/primitive4.cpp @@ -0,0 +1,379 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "primitive.h" +#include "curveNv.h" +#include "curveNi.h" +#include "curveNi_mb.h" +#include "linei.h" +#include "triangle.h" +#include "trianglev.h" +#include "trianglev_mb.h" +#include "trianglei.h" +#include "quadv.h" +#include "quadi.h" +#include "subdivpatch1.h" +#include "object.h" +#include "instance.h" +#include "subgrid.h" + +namespace embree +{ + /********************** Curve4v **************************/ + + template<> + const char* Curve4v::Type::name () const { + return "curve4v"; + } + + template<> + size_t Curve4v::Type::sizeActive(const char* This) const + { + if ((*This & Geometry::GType::GTY_BASIS_MASK) == Geometry::GType::GTY_BASIS_LINEAR) + return ((Line4i*)This)->size(); + else + return ((Curve4v*)This)->N; + } + + template<> + size_t Curve4v::Type::sizeTotal(const char* This) const + { + if ((*This & Geometry::GType::GTY_BASIS_MASK) == Geometry::GType::GTY_BASIS_LINEAR) + return 4; + else + return ((Curve4v*)This)->N; + } + + template<> + size_t Curve4v::Type::getBytes(const char* This) const + { + if ((*This & Geometry::GType::GTY_BASIS_MASK) == Geometry::GType::GTY_BASIS_LINEAR) + return Line4i::bytes(sizeActive(This)); + else + return Curve4v::bytes(sizeActive(This)); + } + + /********************** Curve4i **************************/ + + template<> + const char* Curve4i::Type::name () const { + return "curve4i"; + } + + template<> + size_t Curve4i::Type::sizeActive(const char* This) const + { + if ((*This & Geometry::GType::GTY_BASIS_MASK) == Geometry::GType::GTY_BASIS_LINEAR) + return ((Line4i*)This)->size(); + else + return ((Curve4i*)This)->N; + } + + template<> + size_t Curve4i::Type::sizeTotal(const char* This) const + { + if ((*This & Geometry::GType::GTY_BASIS_MASK) == Geometry::GType::GTY_BASIS_LINEAR) + return 4; + else + return ((Curve4i*)This)->N; + } + + template<> + size_t Curve4i::Type::getBytes(const char* This) const + { + if ((*This & Geometry::GType::GTY_BASIS_MASK) == Geometry::GType::GTY_BASIS_LINEAR) + return Line4i::bytes(sizeActive(This)); + else + return Curve4i::bytes(sizeActive(This)); + } + + /********************** Curve4iMB **************************/ + + template<> + const char* Curve4iMB::Type::name () const { + return "curve4imb"; + } + + template<> + size_t Curve4iMB::Type::sizeActive(const char* This) const + { + if ((*This & Geometry::GType::GTY_BASIS_MASK) == Geometry::GType::GTY_BASIS_LINEAR) + return ((Line4i*)This)->size(); + else + return ((Curve4iMB*)This)->N; + } + + template<> + size_t Curve4iMB::Type::sizeTotal(const char* This) const + { + if ((*This & Geometry::GType::GTY_BASIS_MASK) == Geometry::GType::GTY_BASIS_LINEAR) + return 4; + else + return ((Curve4iMB*)This)->N; + } + + template<> + size_t Curve4iMB::Type::getBytes(const char* This) const + { + if ((*This & Geometry::GType::GTY_BASIS_MASK) == Geometry::GType::GTY_BASIS_LINEAR) + return Line4i::bytes(sizeActive(This)); + else + return Curve4iMB::bytes(sizeActive(This)); + } + + /********************** Line4i **************************/ + + template<> + const char* Line4i::Type::name () const { + return "line4i"; + } + + template<> + size_t Line4i::Type::sizeActive(const char* This) const { + return ((Line4i*)This)->size(); + } + + template<> + size_t Line4i::Type::sizeTotal(const char* This) const { + return 4; + } + + template<> + size_t Line4i::Type::getBytes(const char* This) const { + return sizeof(Line4i); + } + + /********************** Triangle4 **************************/ + + template<> + const char* Triangle4::Type::name () const { + return "triangle4"; + } + + template<> + size_t Triangle4::Type::sizeActive(const char* This) const { + return ((Triangle4*)This)->size(); + } + + template<> + size_t Triangle4::Type::sizeTotal(const char* This) const { + return 4; + } + + template<> + size_t Triangle4::Type::getBytes(const char* This) const { + return sizeof(Triangle4); + } + + /********************** Triangle4v **************************/ + + template<> + const char* Triangle4v::Type::name () const { + return "triangle4v"; + } + + template<> + size_t Triangle4v::Type::sizeActive(const char* This) const { + return ((Triangle4v*)This)->size(); + } + + template<> + size_t Triangle4v::Type::sizeTotal(const char* This) const { + return 4; + } + + template<> + size_t Triangle4v::Type::getBytes(const char* This) const { + return sizeof(Triangle4v); + } + + /********************** Triangle4i **************************/ + + template<> + const char* Triangle4i::Type::name () const { + return "triangle4i"; + } + + template<> + size_t Triangle4i::Type::sizeActive(const char* This) const { + return ((Triangle4i*)This)->size(); + } + + template<> + size_t Triangle4i::Type::sizeTotal(const char* This) const { + return 4; + } + + template<> + size_t Triangle4i::Type::getBytes(const char* This) const { + return sizeof(Triangle4i); + } + + /********************** Triangle4vMB **************************/ + + template<> + const char* Triangle4vMB::Type::name () const { + return "triangle4vmb"; + } + + template<> + size_t Triangle4vMB::Type::sizeActive(const char* This) const { + return ((Triangle4vMB*)This)->size(); + } + + template<> + size_t Triangle4vMB::Type::sizeTotal(const char* This) const { + return 4; + } + + template<> + size_t Triangle4vMB::Type::getBytes(const char* This) const { + return sizeof(Triangle4vMB); + } + + /********************** Quad4v **************************/ + + template<> + const char* Quad4v::Type::name () const { + return "quad4v"; + } + + template<> + size_t Quad4v::Type::sizeActive(const char* This) const { + return ((Quad4v*)This)->size(); + } + + template<> + size_t Quad4v::Type::sizeTotal(const char* This) const { + return 4; + } + + template<> + size_t Quad4v::Type::getBytes(const char* This) const { + return sizeof(Quad4v); + } + + /********************** Quad4i **************************/ + + template<> + const char* Quad4i::Type::name () const { + return "quad4i"; + } + + template<> + size_t Quad4i::Type::sizeActive(const char* This) const { + return ((Quad4i*)This)->size(); + } + + template<> + size_t Quad4i::Type::sizeTotal(const char* This) const { + return 4; + } + + template<> + size_t Quad4i::Type::getBytes(const char* This) const { + return sizeof(Quad4i); + } + + /********************** SubdivPatch1 **************************/ + + const char* SubdivPatch1::Type::name () const { + return "subdivpatch1"; + } + + size_t SubdivPatch1::Type::sizeActive(const char* This) const { + return 1; + } + + size_t SubdivPatch1::Type::sizeTotal(const char* This) const { + return 1; + } + + size_t SubdivPatch1::Type::getBytes(const char* This) const { + return sizeof(SubdivPatch1); + } + + SubdivPatch1::Type SubdivPatch1::type; + + /********************** Virtual Object **************************/ + + const char* Object::Type::name () const { + return "object"; + } + + size_t Object::Type::sizeActive(const char* This) const { + return 1; + } + + size_t Object::Type::sizeTotal(const char* This) const { + return 1; + } + + size_t Object::Type::getBytes(const char* This) const { + return sizeof(Object); + } + + Object::Type Object::type; + + /********************** Instance **************************/ + + const char* InstancePrimitive::Type::name () const { + return "instance"; + } + + size_t InstancePrimitive::Type::sizeActive(const char* This) const { + return 1; + } + + size_t InstancePrimitive::Type::sizeTotal(const char* This) const { + return 1; + } + + size_t InstancePrimitive::Type::getBytes(const char* This) const { + return sizeof(InstancePrimitive); + } + + InstancePrimitive::Type InstancePrimitive::type; + + /********************** SubGrid **************************/ + + const char* SubGrid::Type::name () const { + return "subgrid"; + } + + size_t SubGrid::Type::sizeActive(const char* This) const { + return 1; + } + + size_t SubGrid::Type::sizeTotal(const char* This) const { + return 1; + } + + size_t SubGrid::Type::getBytes(const char* This) const { + return sizeof(SubGrid); + } + + SubGrid::Type SubGrid::type; + + /********************** SubGridQBVH4 **************************/ + + template<> + const char* SubGridQBVH4::Type::name () const { + return "SubGridQBVH4"; + } + + template<> + size_t SubGridQBVH4::Type::sizeActive(const char* This) const { + return 1; + } + + template<> + size_t SubGridQBVH4::Type::sizeTotal(const char* This) const { + return 1; + } + + template<> + size_t SubGridQBVH4::Type::getBytes(const char* This) const { + return sizeof(SubGridQBVH4); + } +} diff --git a/thirdparty/embree/kernels/geometry/quad_intersector.h b/thirdparty/embree/kernels/geometry/quad_intersector.h new file mode 100644 index 000000000000..57ff4e60e5f5 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/quad_intersector.h @@ -0,0 +1,76 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + namespace isa + { + /*! Intersects a ray with a quad with backface culling + * enabled. The quad v0,v1,v2,v3 is split into two triangles + * v0,v1,v3 and v2,v3,v1. The edge v1,v2 decides which of the two + * triangles gets intersected. */ + template + __forceinline vbool intersect_quad_backface_culling(const vbool& valid0, + const Vec3fa& ray_org, + const Vec3fa& ray_dir, + const float ray_tnear, + const float ray_tfar, + const Vec3vf& quad_v0, + const Vec3vf& quad_v1, + const Vec3vf& quad_v2, + const Vec3vf& quad_v3, + vfloat& u_o, + vfloat& v_o, + vfloat& t_o) + { + /* calculate vertices relative to ray origin */ + vbool valid = valid0; + const Vec3vf O = Vec3vf(ray_org); + const Vec3vf D = Vec3vf(ray_dir); + const Vec3vf va = quad_v0-O; + const Vec3vf vb = quad_v1-O; + const Vec3vf vc = quad_v2-O; + const Vec3vf vd = quad_v3-O; + + const Vec3vf edb = vb-vd; + const vfloat WW = dot(cross(vd,edb),D); + const Vec3vf v0 = select(WW <= 0.0f,va,vc); + const Vec3vf v1 = select(WW <= 0.0f,vb,vd); + const Vec3vf v2 = select(WW <= 0.0f,vd,vb); + + /* calculate edges */ + const Vec3vf e0 = v2-v0; + const Vec3vf e1 = v0-v1; + + /* perform edge tests */ + const vfloat U = dot(cross(v0,e0),D); + const vfloat V = dot(cross(v1,e1),D); + valid &= max(U,V) <= 0.0f; + if (unlikely(none(valid))) return false; + + /* calculate geometry normal and denominator */ + const Vec3vf Ng = cross(e1,e0); + const vfloat den = dot(Ng,D); + const vfloat rcpDen = rcp(den); + + /* perform depth test */ + const vfloat t = rcpDen*dot(v0,Ng); + valid &= vfloat(ray_tnear) <= t & t <= vfloat(ray_tfar); + if (unlikely(none(valid))) return false; + + /* avoid division by 0 */ + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; + + /* update hit information */ + t_o = t; + u_o = U * rcpDen; + v_o = V * rcpDen; + u_o = select(WW <= 0.0f,u_o,1.0f-u_o); + v_o = select(WW <= 0.0f,v_o,1.0f-v_o); + return valid; + } + } +} diff --git a/thirdparty/embree/kernels/geometry/quad_intersector_moeller.h b/thirdparty/embree/kernels/geometry/quad_intersector_moeller.h new file mode 100644 index 000000000000..74e8c7720c4b --- /dev/null +++ b/thirdparty/embree/kernels/geometry/quad_intersector_moeller.h @@ -0,0 +1,566 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "quadv.h" +#include "triangle_intersector_moeller.h" + +namespace embree +{ + namespace isa + { + template + struct QuadHitM + { + __forceinline QuadHitM() {} + + __forceinline QuadHitM(const vbool& valid, + const vfloat& U, + const vfloat& V, + const vfloat& T, + const vfloat& absDen, + const Vec3vf& Ng, + const vbool& flags) + : U(U), V(V), T(T), absDen(absDen), tri_Ng(Ng), valid(valid), flags(flags) {} + + __forceinline void finalize() + { + const vfloat rcpAbsDen = rcp(absDen); + vt = T * rcpAbsDen; + const vfloat u = min(U * rcpAbsDen,1.0f); + const vfloat v = min(V * rcpAbsDen,1.0f); + const vfloat u1 = vfloat(1.0f) - u; + const vfloat v1 = vfloat(1.0f) - v; +#if !defined(__AVX__) || defined(EMBREE_BACKFACE_CULLING) + vu = select(flags,u1,u); + vv = select(flags,v1,v); + vNg = Vec3vf(tri_Ng.x,tri_Ng.y,tri_Ng.z); +#else + const vfloat flip = select(flags,vfloat(-1.0f),vfloat(1.0f)); + vv = select(flags,u1,v); + vu = select(flags,v1,u); + vNg = Vec3vf(flip*tri_Ng.x,flip*tri_Ng.y,flip*tri_Ng.z); +#endif + } + + __forceinline Vec2f uv(const size_t i) + { + const float u = vu[i]; + const float v = vv[i]; + return Vec2f(u,v); + } + + __forceinline float t(const size_t i) { return vt[i]; } + __forceinline Vec3fa Ng(const size_t i) { return Vec3fa(vNg.x[i],vNg.y[i],vNg.z[i]); } + + private: + vfloat U; + vfloat V; + vfloat T; + vfloat absDen; + Vec3vf tri_Ng; + + public: + vbool valid; + vfloat vu; + vfloat vv; + vfloat vt; + Vec3vf vNg; + + public: + const vbool flags; + }; + + template + struct QuadHitK + { + __forceinline QuadHitK(const vfloat& U, + const vfloat& V, + const vfloat& T, + const vfloat& absDen, + const Vec3vf& Ng, + const vbool& flags) + : U(U), V(V), T(T), absDen(absDen), flags(flags), tri_Ng(Ng) {} + + __forceinline std::tuple,vfloat,vfloat,Vec3vf> operator() () const + { + const vfloat rcpAbsDen = rcp(absDen); + const vfloat t = T * rcpAbsDen; + const vfloat u0 = min(U * rcpAbsDen,1.0f); + const vfloat v0 = min(V * rcpAbsDen,1.0f); + const vfloat u1 = vfloat(1.0f) - u0; + const vfloat v1 = vfloat(1.0f) - v0; + const vfloat u = select(flags,u1,u0); + const vfloat v = select(flags,v1,v0); + const Vec3vf Ng(tri_Ng.x,tri_Ng.y,tri_Ng.z); + return std::make_tuple(u,v,t,Ng); + } + + private: + const vfloat U; + const vfloat V; + const vfloat T; + const vfloat absDen; + const vbool flags; + const Vec3vf tri_Ng; + }; + + /* ----------------------------- */ + /* -- single ray intersectors -- */ + /* ----------------------------- */ + + + template + struct QuadMIntersector1MoellerTrumbore; + + /*! Intersects M quads with 1 ray */ + template + struct QuadMIntersector1MoellerTrumbore + { + __forceinline QuadMIntersector1MoellerTrumbore() {} + + __forceinline QuadMIntersector1MoellerTrumbore(const Ray& ray, const void* ptr) {} + + __forceinline void intersect(RayHit& ray, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, + const vuint& geomID, const vuint& primID) const + { + MoellerTrumboreHitM hit; + MoellerTrumboreIntersector1 intersector(ray,nullptr); + Intersect1EpilogM epilog(ray,context,geomID,primID); + + /* intersect first triangle */ + if (intersector.intersect(ray,v0,v1,v3,hit)) + epilog(hit.valid,hit); + + /* intersect second triangle */ + if (intersector.intersect(ray,v2,v3,v1,hit)) + { + hit.U = hit.absDen - hit.U; + hit.V = hit.absDen - hit.V; + epilog(hit.valid,hit); + } + } + + __forceinline bool occluded(Ray& ray, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, + const vuint& geomID, const vuint& primID) const + { + MoellerTrumboreHitM hit; + MoellerTrumboreIntersector1 intersector(ray,nullptr); + Occluded1EpilogM epilog(ray,context,geomID,primID); + + /* intersect first triangle */ + if (intersector.intersect(ray,v0,v1,v3,hit)) + { + if (epilog(hit.valid,hit)) + return true; + } + + /* intersect second triangle */ + if (intersector.intersect(ray,v2,v3,v1,hit)) + { + hit.U = hit.absDen - hit.U; + hit.V = hit.absDen - hit.V; + if (epilog(hit.valid,hit)) + return true; + } + return false; + } + }; + +#if defined(__AVX512ER__) // KNL + + /*! Intersects 4 quads with 1 ray using AVX512 */ + template + struct QuadMIntersector1MoellerTrumbore<4,filter> + { + __forceinline QuadMIntersector1MoellerTrumbore() {} + + __forceinline QuadMIntersector1MoellerTrumbore(const Ray& ray, const void* ptr) {} + + template + __forceinline bool intersect(Ray& ray, const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, const Epilog& epilog) const + { + const Vec3vf16 vtx0(select(0x0f0f,vfloat16(v0.x),vfloat16(v2.x)), + select(0x0f0f,vfloat16(v0.y),vfloat16(v2.y)), + select(0x0f0f,vfloat16(v0.z),vfloat16(v2.z))); +#if !defined(EMBREE_BACKFACE_CULLING) + const Vec3vf16 vtx1(vfloat16(v1.x),vfloat16(v1.y),vfloat16(v1.z)); + const Vec3vf16 vtx2(vfloat16(v3.x),vfloat16(v3.y),vfloat16(v3.z)); +#else + const Vec3vf16 vtx1(select(0x0f0f,vfloat16(v1.x),vfloat16(v3.x)), + select(0x0f0f,vfloat16(v1.y),vfloat16(v3.y)), + select(0x0f0f,vfloat16(v1.z),vfloat16(v3.z))); + const Vec3vf16 vtx2(select(0x0f0f,vfloat16(v3.x),vfloat16(v1.x)), + select(0x0f0f,vfloat16(v3.y),vfloat16(v1.y)), + select(0x0f0f,vfloat16(v3.z),vfloat16(v1.z))); +#endif + const vbool16 flags(0xf0f0); + + MoellerTrumboreHitM<16> hit; + MoellerTrumboreIntersector1<16> intersector(ray,nullptr); + if (unlikely(intersector.intersect(ray,vtx0,vtx1,vtx2,hit))) + { + vfloat16 U = hit.U, V = hit.V, absDen = hit.absDen; +#if !defined(EMBREE_BACKFACE_CULLING) + hit.U = select(flags,absDen-V,U); + hit.V = select(flags,absDen-U,V); + hit.vNg *= select(flags,vfloat16(-1.0f),vfloat16(1.0f)); // FIXME: use XOR +#else + hit.U = select(flags,absDen-U,U); + hit.V = select(flags,absDen-V,V); +#endif + if (likely(epilog(hit.valid,hit))) + return true; + } + return false; + } + + __forceinline bool intersect(RayHit& ray, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect(ray,v0,v1,v2,v3,Intersect1EpilogM<8,16,filter>(ray,context,vuint8(geomID),vuint8(primID))); + } + + __forceinline bool occluded(Ray& ray, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect(ray,v0,v1,v2,v3,Occluded1EpilogM<8,16,filter>(ray,context,vuint8(geomID),vuint8(primID))); + } + }; + +#elif defined(__AVX__) + + /*! Intersects 4 quads with 1 ray using AVX */ + template + struct QuadMIntersector1MoellerTrumbore<4,filter> + { + __forceinline QuadMIntersector1MoellerTrumbore() {} + + __forceinline QuadMIntersector1MoellerTrumbore(const Ray& ray, const void* ptr) {} + + template + __forceinline bool intersect(Ray& ray, const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, const Epilog& epilog) const + { + const Vec3vf8 vtx0(vfloat8(v0.x,v2.x),vfloat8(v0.y,v2.y),vfloat8(v0.z,v2.z)); +#if !defined(EMBREE_BACKFACE_CULLING) + const Vec3vf8 vtx1(vfloat8(v1.x),vfloat8(v1.y),vfloat8(v1.z)); + const Vec3vf8 vtx2(vfloat8(v3.x),vfloat8(v3.y),vfloat8(v3.z)); +#else + const Vec3vf8 vtx1(vfloat8(v1.x,v3.x),vfloat8(v1.y,v3.y),vfloat8(v1.z,v3.z)); + const Vec3vf8 vtx2(vfloat8(v3.x,v1.x),vfloat8(v3.y,v1.y),vfloat8(v3.z,v1.z)); +#endif + MoellerTrumboreHitM<8> hit; + MoellerTrumboreIntersector1<8> intersector(ray,nullptr); + const vbool8 flags(0,0,0,0,1,1,1,1); + if (unlikely(intersector.intersect(ray,vtx0,vtx1,vtx2,hit))) + { + vfloat8 U = hit.U, V = hit.V, absDen = hit.absDen; + +#if !defined(EMBREE_BACKFACE_CULLING) + hit.U = select(flags,absDen-V,U); + hit.V = select(flags,absDen-U,V); + hit.vNg *= select(flags,vfloat8(-1.0f),vfloat8(1.0f)); // FIXME: use XOR +#else + hit.U = select(flags,absDen-U,U); + hit.V = select(flags,absDen-V,V); +#endif + if (unlikely(epilog(hit.valid,hit))) + return true; + } + return false; + } + + __forceinline bool intersect(RayHit& ray, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect(ray,v0,v1,v2,v3,Intersect1EpilogM<8,8,filter>(ray,context,vuint8(geomID),vuint8(primID))); + } + + __forceinline bool occluded(Ray& ray, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect(ray,v0,v1,v2,v3,Occluded1EpilogM<8,8,filter>(ray,context,vuint8(geomID),vuint8(primID))); + } + }; + +#endif + + /* ----------------------------- */ + /* -- ray packet intersectors -- */ + /* ----------------------------- */ + + + struct MoellerTrumboreIntersector1KTriangleM + { + /*! Intersect k'th ray from ray packet of size K with M triangles. */ + template + static __forceinline bool intersect(RayK& ray, + size_t k, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + const Vec3vf& tri_Ng, + const vbool& flags, + const Epilog& epilog) + { + /* calculate denominator */ + const Vec3vf O = broadcast>(ray.org,k); + const Vec3vf D = broadcast>(ray.dir,k); + const Vec3vf C = Vec3vf(tri_v0) - O; + const Vec3vf R = cross(C,D); + const vfloat den = dot(Vec3vf(tri_Ng),D); + const vfloat absDen = abs(den); + const vfloat sgnDen = signmsk(den); + + /* perform edge tests */ + const vfloat U = dot(R,Vec3vf(tri_e2)) ^ sgnDen; + const vfloat V = dot(R,Vec3vf(tri_e1)) ^ sgnDen; + + /* perform backface culling */ +#if defined(EMBREE_BACKFACE_CULLING) + vbool valid = (den < vfloat(zero)) & (U >= 0.0f) & (V >= 0.0f) & (U+V<=absDen); +#else + vbool valid = (den != vfloat(zero)) & (U >= 0.0f) & (V >= 0.0f) & (U+V<=absDen); +#endif + if (likely(none(valid))) return false; + + /* perform depth test */ + const vfloat T = dot(Vec3vf(tri_Ng),C) ^ sgnDen; + valid &= (absDen*vfloat(ray.tnear()[k]) < T) & (T <= absDen*vfloat(ray.tfar[k])); + if (likely(none(valid))) return false; + + /* calculate hit information */ + QuadHitM hit(valid,U,V,T,absDen,tri_Ng,flags); + return epilog(valid,hit); + } + + template + static __forceinline bool intersect1(RayK& ray, + size_t k, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const vbool& flags, + const Epilog& epilog) + { + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v2-v0; + const Vec3vf Ng = cross(e2,e1); + return intersect(ray,k,v0,e1,e2,Ng,flags,epilog); + } + }; + + template + struct QuadMIntersectorKMoellerTrumboreBase + { + __forceinline QuadMIntersectorKMoellerTrumboreBase(const vbool& valid, const RayK& ray) {} + + /*! Intersects K rays with one of M triangles. */ + template + __forceinline vbool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + const Vec3vf& tri_Ng, + const vbool& flags, + const Epilog& epilog) const + { + /* calculate denominator */ + vbool valid = valid0; + const Vec3vf C = tri_v0 - ray.org; + const Vec3vf R = cross(C,ray.dir); + const vfloat den = dot(tri_Ng,ray.dir); + const vfloat absDen = abs(den); + const vfloat sgnDen = signmsk(den); + + /* test against edge p2 p0 */ + const vfloat U = dot(R,tri_e2) ^ sgnDen; + valid &= U >= 0.0f; + if (likely(none(valid))) return false; + + /* test against edge p0 p1 */ + const vfloat V = dot(R,tri_e1) ^ sgnDen; + valid &= V >= 0.0f; + if (likely(none(valid))) return false; + + /* test against edge p1 p2 */ + const vfloat W = absDen-U-V; + valid &= W >= 0.0f; + if (likely(none(valid))) return false; + + /* perform depth test */ + const vfloat T = dot(tri_Ng,C) ^ sgnDen; + valid &= (absDen*ray.tnear() < T) & (T <= absDen*ray.tfar); + if (unlikely(none(valid))) return false; + + /* perform backface culling */ +#if defined(EMBREE_BACKFACE_CULLING) + valid &= den < vfloat(zero); + if (unlikely(none(valid))) return false; +#else + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; +#endif + + /* calculate hit information */ + QuadHitK hit(U,V,T,absDen,tri_Ng,flags); + return epilog(valid,hit); + } + + /*! Intersects K rays with one of M quads. */ + template + __forceinline vbool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + const vbool& flags, + const Epilog& epilog) const + { + const Vec3vf e1 = tri_v0-tri_v1; + const Vec3vf e2 = tri_v2-tri_v0; + const Vec3vf Ng = cross(e2,e1); + return intersectK(valid0,ray,tri_v0,e1,e2,Ng,flags,epilog); + } + + /*! Intersects K rays with one of M quads. */ + template + __forceinline bool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const Vec3vf& v3, + const Epilog& epilog) const + { + intersectK(valid0,ray,v0,v1,v3,vbool(false),epilog); + if (none(valid0)) return true; + intersectK(valid0,ray,v2,v3,v1,vbool(true ),epilog); + return none(valid0); + } + }; + + template + struct QuadMIntersectorKMoellerTrumbore : public QuadMIntersectorKMoellerTrumboreBase + { + __forceinline QuadMIntersectorKMoellerTrumbore(const vbool& valid, const RayK& ray) + : QuadMIntersectorKMoellerTrumboreBase(valid,ray) {} + + __forceinline void intersect1(RayHitK& ray, size_t k, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, + const vuint& geomID, const vuint& primID) const + { + Intersect1KEpilogM epilog(ray,k,context,geomID,primID); + MoellerTrumboreIntersector1KTriangleM::intersect1(ray,k,v0,v1,v3,vbool(false),epilog); + MoellerTrumboreIntersector1KTriangleM::intersect1(ray,k,v2,v3,v1,vbool(true ),epilog); + } + + __forceinline bool occluded1(RayK& ray, size_t k, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, + const vuint& geomID, const vuint& primID) const + { + Occluded1KEpilogM epilog(ray,k,context,geomID,primID); + if (MoellerTrumboreIntersector1KTriangleM::intersect1(ray,k,v0,v1,v3,vbool(false),epilog)) return true; + if (MoellerTrumboreIntersector1KTriangleM::intersect1(ray,k,v2,v3,v1,vbool(true ),epilog)) return true; + return false; + } + }; + + +#if defined(__AVX512ER__) // KNL + + /*! Intersects 4 quads with 1 ray using AVX512 */ + template + struct QuadMIntersectorKMoellerTrumbore<4,K,filter> : public QuadMIntersectorKMoellerTrumboreBase<4,K,filter> + { + __forceinline QuadMIntersectorKMoellerTrumbore(const vbool& valid, const RayK& ray) + : QuadMIntersectorKMoellerTrumboreBase<4,K,filter>(valid,ray) {} + + template + __forceinline bool intersect1(RayK& ray, size_t k, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, const Epilog& epilog) const + { + const Vec3vf16 vtx0(select(0x0f0f,vfloat16(v0.x),vfloat16(v2.x)), + select(0x0f0f,vfloat16(v0.y),vfloat16(v2.y)), + select(0x0f0f,vfloat16(v0.z),vfloat16(v2.z))); +#if !defined(EMBREE_BACKFACE_CULLING) + const Vec3vf16 vtx1(vfloat16(v1.x),vfloat16(v1.y),vfloat16(v1.z)); + const Vec3vf16 vtx2(vfloat16(v3.x),vfloat16(v3.y),vfloat16(v3.z)); +#else + const Vec3vf16 vtx1(select(0x0f0f,vfloat16(v1.x),vfloat16(v3.x)), + select(0x0f0f,vfloat16(v1.y),vfloat16(v3.y)), + select(0x0f0f,vfloat16(v1.z),vfloat16(v3.z))); + const Vec3vf16 vtx2(select(0x0f0f,vfloat16(v3.x),vfloat16(v1.x)), + select(0x0f0f,vfloat16(v3.y),vfloat16(v1.y)), + select(0x0f0f,vfloat16(v3.z),vfloat16(v1.z))); +#endif + const vbool16 flags(0xf0f0); + return MoellerTrumboreIntersector1KTriangleM::intersect1(ray,k,vtx0,vtx1,vtx2,flags,epilog); + } + + __forceinline bool intersect1(RayHitK& ray, size_t k, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect1(ray,k,v0,v1,v2,v3,Intersect1KEpilogM<8,16,K,filter>(ray,k,context,vuint8(geomID),vuint8(primID))); + } + + __forceinline bool occluded1(RayK& ray, size_t k, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect1(ray,k,v0,v1,v2,v3,Occluded1KEpilogM<8,16,K,filter>(ray,k,context,vuint8(geomID),vuint8(primID))); + } + }; + +#elif defined(__AVX__) + + /*! Intersects 4 quads with 1 ray using AVX */ + template + struct QuadMIntersectorKMoellerTrumbore<4,K,filter> : public QuadMIntersectorKMoellerTrumboreBase<4,K,filter> + { + __forceinline QuadMIntersectorKMoellerTrumbore(const vbool& valid, const RayK& ray) + : QuadMIntersectorKMoellerTrumboreBase<4,K,filter>(valid,ray) {} + + template + __forceinline bool intersect1(RayK& ray, size_t k, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, const Epilog& epilog) const + { + const Vec3vf8 vtx0(vfloat8(v0.x,v2.x),vfloat8(v0.y,v2.y),vfloat8(v0.z,v2.z)); +#if !defined(EMBREE_BACKFACE_CULLING) + const Vec3vf8 vtx1(vfloat8(v1.x),vfloat8(v1.y),vfloat8(v1.z)); + const Vec3vf8 vtx2(vfloat8(v3.x),vfloat8(v3.y),vfloat8(v3.z)); +#else + const Vec3vf8 vtx1(vfloat8(v1.x,v3.x),vfloat8(v1.y,v3.y),vfloat8(v1.z,v3.z)); + const Vec3vf8 vtx2(vfloat8(v3.x,v1.x),vfloat8(v3.y,v1.y),vfloat8(v3.z,v1.z)); +#endif + const vbool8 flags(0,0,0,0,1,1,1,1); + return MoellerTrumboreIntersector1KTriangleM::intersect1(ray,k,vtx0,vtx1,vtx2,flags,epilog); + } + + __forceinline bool intersect1(RayHitK& ray, size_t k, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect1(ray,k,v0,v1,v2,v3,Intersect1KEpilogM<8,8,K,filter>(ray,k,context,vuint8(geomID),vuint8(primID))); + } + + __forceinline bool occluded1(RayK& ray, size_t k, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect1(ray,k,v0,v1,v2,v3,Occluded1KEpilogM<8,8,K,filter>(ray,k,context,vuint8(geomID),vuint8(primID))); + } + }; + +#endif + } +} diff --git a/thirdparty/embree/kernels/geometry/quad_intersector_pluecker.h b/thirdparty/embree/kernels/geometry/quad_intersector_pluecker.h new file mode 100644 index 000000000000..7ca3aed0a040 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/quad_intersector_pluecker.h @@ -0,0 +1,529 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "quad_intersector_moeller.h" + +/*! Modified Pluecker ray/triangle intersector. The test first shifts + * the ray origin into the origin of the coordinate system and then + * uses Pluecker coordinates for the intersection. Due to the shift, + * the Pluecker coordinate calculation simplifies and the tests get + * numerically stable. The edge equations are watertight along the + * edge for neighboring triangles. */ + +namespace embree +{ + namespace isa + { + template + struct QuadHitPlueckerM + { + __forceinline QuadHitPlueckerM() {} + + __forceinline QuadHitPlueckerM(const vbool& valid, + const vfloat& U, + const vfloat& V, + const vfloat& UVW, + const vfloat& t, + const Vec3vf& Ng, + const vbool& flags) + : U(U), V(V), UVW(UVW), tri_Ng(Ng), valid(valid), vt(t), flags(flags) {} + + __forceinline void finalize() + { + const vbool invalid = abs(UVW) < min_rcp_input; + const vfloat rcpUVW = select(invalid,vfloat(0.0f),rcp(UVW)); + const vfloat u = min(U * rcpUVW,1.0f); + const vfloat v = min(V * rcpUVW,1.0f); + const vfloat u1 = vfloat(1.0f) - u; + const vfloat v1 = vfloat(1.0f) - v; +#if !defined(__AVX__) || defined(EMBREE_BACKFACE_CULLING) + vu = select(flags,u1,u); + vv = select(flags,v1,v); + vNg = Vec3vf(tri_Ng.x,tri_Ng.y,tri_Ng.z); +#else + const vfloat flip = select(flags,vfloat(-1.0f),vfloat(1.0f)); + vv = select(flags,u1,v); + vu = select(flags,v1,u); + vNg = Vec3vf(flip*tri_Ng.x,flip*tri_Ng.y,flip*tri_Ng.z); +#endif + } + + __forceinline Vec2f uv(const size_t i) + { + const float u = vu[i]; + const float v = vv[i]; + return Vec2f(u,v); + } + + __forceinline float t(const size_t i) { return vt[i]; } + __forceinline Vec3fa Ng(const size_t i) { return Vec3fa(vNg.x[i],vNg.y[i],vNg.z[i]); } + + private: + vfloat U; + vfloat V; + vfloat UVW; + Vec3vf tri_Ng; + + public: + vbool valid; + vfloat vu; + vfloat vv; + vfloat vt; + Vec3vf vNg; + + public: + const vbool flags; + }; + + template + struct QuadHitPlueckerK + { + __forceinline QuadHitPlueckerK(const vfloat& U, + const vfloat& V, + const vfloat& UVW, + const vfloat& t, + const Vec3vf& Ng, + const vbool& flags) + : U(U), V(V), UVW(UVW), t(t), flags(flags), tri_Ng(Ng) {} + + __forceinline std::tuple,vfloat,vfloat,Vec3vf> operator() () const + { + const vbool invalid = abs(UVW) < min_rcp_input; + const vfloat rcpUVW = select(invalid,vfloat(0.0f),rcp(UVW)); + const vfloat u0 = min(U * rcpUVW,1.0f); + const vfloat v0 = min(V * rcpUVW,1.0f); + const vfloat u1 = vfloat(1.0f) - u0; + const vfloat v1 = vfloat(1.0f) - v0; + const vfloat u = select(flags,u1,u0); + const vfloat v = select(flags,v1,v0); + const Vec3vf Ng(tri_Ng.x,tri_Ng.y,tri_Ng.z); + return std::make_tuple(u,v,t,Ng); + } + + private: + const vfloat U; + const vfloat V; + const vfloat UVW; + const vfloat t; + const vbool flags; + const Vec3vf tri_Ng; + }; + + struct PlueckerIntersectorTriangle1 + { + template + static __forceinline bool intersect(Ray& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + const vbool& flags, + const Epilog& epilog) + { + /* calculate vertices relative to ray origin */ + const Vec3vf O = Vec3vf((Vec3fa)ray.org); + const Vec3vf D = Vec3vf((Vec3fa)ray.dir); + const Vec3vf v0 = tri_v0-O; + const Vec3vf v1 = tri_v1-O; + const Vec3vf v2 = tri_v2-O; + + /* calculate triangle edges */ + const Vec3vf e0 = v2-v0; + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v1-v2; + + /* perform edge tests */ + const vfloat U = dot(cross(e0,v2+v0),D); + const vfloat V = dot(cross(e1,v0+v1),D); + const vfloat W = dot(cross(e2,v1+v2),D); + const vfloat UVW = U+V+W; + const vfloat eps = float(ulp)*abs(UVW); +#if defined(EMBREE_BACKFACE_CULLING) + vbool valid = max(U,V,W) <= eps; +#else + vbool valid = (min(U,V,W) >= -eps) | (max(U,V,W) <= eps); +#endif + if (unlikely(none(valid))) return false; + + /* calculate geometry normal and denominator */ + const Vec3vf Ng = stable_triangle_normal(e0,e1,e2); + const vfloat den = twice(dot(Ng,D)); + + /* perform depth test */ + const vfloat T = twice(dot(v0,Ng)); + const vfloat t = rcp(den)*T; + valid &= vfloat(ray.tnear()) <= t & t <= vfloat(ray.tfar); + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; + + /* update hit information */ + QuadHitPlueckerM hit(valid,U,V,UVW,t,Ng,flags); + return epilog(valid,hit); + } + }; + + /*! Intersects M quads with 1 ray */ + template + struct QuadMIntersector1Pluecker + { + __forceinline QuadMIntersector1Pluecker() {} + + __forceinline QuadMIntersector1Pluecker(const Ray& ray, const void* ptr) {} + + __forceinline void intersect(RayHit& ray, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, + const vuint& geomID, const vuint& primID) const + { + Intersect1EpilogM epilog(ray,context,geomID,primID); + PlueckerIntersectorTriangle1::intersect(ray,v0,v1,v3,vbool(false),epilog); + PlueckerIntersectorTriangle1::intersect(ray,v2,v3,v1,vbool(true),epilog); + } + + __forceinline bool occluded(Ray& ray, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, + const vuint& geomID, const vuint& primID) const + { + Occluded1EpilogM epilog(ray,context,geomID,primID); + if (PlueckerIntersectorTriangle1::intersect(ray,v0,v1,v3,vbool(false),epilog)) return true; + if (PlueckerIntersectorTriangle1::intersect(ray,v2,v3,v1,vbool(true ),epilog)) return true; + return false; + } + }; + +#if defined(__AVX512ER__) // KNL + + /*! Intersects 4 quads with 1 ray using AVX512 */ + template + struct QuadMIntersector1Pluecker<4,filter> + { + __forceinline QuadMIntersector1Pluecker() {} + + __forceinline QuadMIntersector1Pluecker(const Ray& ray, const void* ptr) {} + + template + __forceinline bool intersect(Ray& ray, const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, const Epilog& epilog) const + { + const Vec3vf16 vtx0(select(0x0f0f,vfloat16(v0.x),vfloat16(v2.x)), + select(0x0f0f,vfloat16(v0.y),vfloat16(v2.y)), + select(0x0f0f,vfloat16(v0.z),vfloat16(v2.z))); +#if !defined(EMBREE_BACKFACE_CULLING) + const Vec3vf16 vtx1(vfloat16(v1.x),vfloat16(v1.y),vfloat16(v1.z)); + const Vec3vf16 vtx2(vfloat16(v3.x),vfloat16(v3.y),vfloat16(v3.z)); +#else + const Vec3vf16 vtx1(select(0x0f0f,vfloat16(v1.x),vfloat16(v3.x)), + select(0x0f0f,vfloat16(v1.y),vfloat16(v3.y)), + select(0x0f0f,vfloat16(v1.z),vfloat16(v3.z))); + const Vec3vf16 vtx2(select(0x0f0f,vfloat16(v3.x),vfloat16(v1.x)), + select(0x0f0f,vfloat16(v3.y),vfloat16(v1.y)), + select(0x0f0f,vfloat16(v3.z),vfloat16(v1.z))); +#endif + const vbool16 flags(0xf0f0); + return PlueckerIntersectorTriangle1::intersect(ray,vtx0,vtx1,vtx2,flags,epilog); + } + + __forceinline bool intersect(RayHit& ray, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect(ray,v0,v1,v2,v3,Intersect1EpilogM<8,16,filter>(ray,context,vuint8(geomID),vuint8(primID))); + } + + __forceinline bool occluded(Ray& ray, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect(ray,v0,v1,v2,v3,Occluded1EpilogM<8,16,filter>(ray,context,vuint8(geomID),vuint8(primID))); + } + }; + +#elif defined(__AVX__) + + /*! Intersects 4 quads with 1 ray using AVX */ + template + struct QuadMIntersector1Pluecker<4,filter> + { + __forceinline QuadMIntersector1Pluecker() {} + + __forceinline QuadMIntersector1Pluecker(const Ray& ray, const void* ptr) {} + + template + __forceinline bool intersect(Ray& ray, const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, const Epilog& epilog) const + { + const Vec3vf8 vtx0(vfloat8(v0.x,v2.x),vfloat8(v0.y,v2.y),vfloat8(v0.z,v2.z)); +#if !defined(EMBREE_BACKFACE_CULLING) + const Vec3vf8 vtx1(vfloat8(v1.x),vfloat8(v1.y),vfloat8(v1.z)); + const Vec3vf8 vtx2(vfloat8(v3.x),vfloat8(v3.y),vfloat8(v3.z)); +#else + const Vec3vf8 vtx1(vfloat8(v1.x,v3.x),vfloat8(v1.y,v3.y),vfloat8(v1.z,v3.z)); + const Vec3vf8 vtx2(vfloat8(v3.x,v1.x),vfloat8(v3.y,v1.y),vfloat8(v3.z,v1.z)); +#endif + const vbool8 flags(0,0,0,0,1,1,1,1); + return PlueckerIntersectorTriangle1::intersect(ray,vtx0,vtx1,vtx2,flags,epilog); + } + + __forceinline bool intersect(RayHit& ray, IntersectContext* context, const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect(ray,v0,v1,v2,v3,Intersect1EpilogM<8,8,filter>(ray,context,vuint8(geomID),vuint8(primID))); + } + + __forceinline bool occluded(Ray& ray, IntersectContext* context, const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect(ray,v0,v1,v2,v3,Occluded1EpilogM<8,8,filter>(ray,context,vuint8(geomID),vuint8(primID))); + } + }; + +#endif + + + /* ----------------------------- */ + /* -- ray packet intersectors -- */ + /* ----------------------------- */ + + struct PlueckerIntersector1KTriangleM + { + /*! Intersect k'th ray from ray packet of size K with M triangles. */ + template + static __forceinline bool intersect1(RayK& ray, + size_t k, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + const vbool& flags, + const Epilog& epilog) + { + /* calculate vertices relative to ray origin */ + const Vec3vf O = broadcast>(ray.org,k); + const Vec3vf D = broadcast>(ray.dir,k); + const Vec3vf v0 = tri_v0-O; + const Vec3vf v1 = tri_v1-O; + const Vec3vf v2 = tri_v2-O; + + /* calculate triangle edges */ + const Vec3vf e0 = v2-v0; + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v1-v2; + + /* perform edge tests */ + const vfloat U = dot(cross(e0,v2+v0),D); + const vfloat V = dot(cross(e1,v0+v1),D); + const vfloat W = dot(cross(e2,v1+v2),D); + const vfloat UVW = U+V+W; + const vfloat eps = float(ulp)*abs(UVW); +#if defined(EMBREE_BACKFACE_CULLING) + vbool valid = max(U,V,W) <= eps; +#else + vbool valid = (min(U,V,W) >= -eps) | (max(U,V,W) <= eps); +#endif + if (unlikely(none(valid))) return false; + + /* calculate geometry normal and denominator */ + const Vec3vf Ng = stable_triangle_normal(e0,e1,e2); + const vfloat den = twice(dot(Ng,D)); + + /* perform depth test */ + const vfloat T = twice(dot(v0,Ng)); + const vfloat t = rcp(den)*T; + valid &= vfloat(ray.tnear()[k]) <= t & t <= vfloat(ray.tfar[k]); + if (unlikely(none(valid))) return false; + + /* avoid division by 0 */ + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; + + /* update hit information */ + QuadHitPlueckerM hit(valid,U,V,UVW,t,Ng,flags); + return epilog(valid,hit); + } + }; + + template + struct QuadMIntersectorKPlueckerBase + { + __forceinline QuadMIntersectorKPlueckerBase(const vbool& valid, const RayK& ray) {} + + /*! Intersects K rays with one of M triangles. */ + template + __forceinline vbool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + const vbool& flags, + const Epilog& epilog) const + { + /* calculate vertices relative to ray origin */ + vbool valid = valid0; + const Vec3vf O = ray.org; + const Vec3vf D = ray.dir; + const Vec3vf v0 = tri_v0-O; + const Vec3vf v1 = tri_v1-O; + const Vec3vf v2 = tri_v2-O; + + /* calculate triangle edges */ + const Vec3vf e0 = v2-v0; + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v1-v2; + + /* perform edge tests */ + const vfloat U = dot(Vec3vf(cross(e0,v2+v0)),D); + const vfloat V = dot(Vec3vf(cross(e1,v0+v1)),D); + const vfloat W = dot(Vec3vf(cross(e2,v1+v2)),D); + const vfloat UVW = U+V+W; + const vfloat eps = float(ulp)*abs(UVW); +#if defined(EMBREE_BACKFACE_CULLING) + valid &= max(U,V,W) <= eps; +#else + valid &= (min(U,V,W) >= -eps) | (max(U,V,W) <= eps); +#endif + if (unlikely(none(valid))) return false; + + /* calculate geometry normal and denominator */ + const Vec3vf Ng = stable_triangle_normal(e0,e1,e2); + const vfloat den = twice(dot(Vec3vf(Ng),D)); + + /* perform depth test */ + const vfloat T = twice(dot(v0,Vec3vf(Ng))); + const vfloat t = rcp(den)*T; + valid &= ray.tnear() <= t & t <= ray.tfar; + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; + + /* calculate hit information */ + QuadHitPlueckerK hit(U,V,UVW,t,Ng,flags); + return epilog(valid,hit); + } + + /*! Intersects K rays with one of M quads. */ + template + __forceinline bool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const Vec3vf& v3, + const Epilog& epilog) const + { + intersectK(valid0,ray,v0,v1,v3,vbool(false),epilog); + if (none(valid0)) return true; + intersectK(valid0,ray,v2,v3,v1,vbool(true ),epilog); + return none(valid0); + } + }; + + template + struct QuadMIntersectorKPluecker : public QuadMIntersectorKPlueckerBase + { + __forceinline QuadMIntersectorKPluecker(const vbool& valid, const RayK& ray) + : QuadMIntersectorKPlueckerBase(valid,ray) {} + + __forceinline void intersect1(RayHitK& ray, size_t k, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, + const vuint& geomID, const vuint& primID) const + { + Intersect1KEpilogM epilog(ray,k,context,geomID,primID); + PlueckerIntersector1KTriangleM::intersect1(ray,k,v0,v1,v3,vbool(false),epilog); + PlueckerIntersector1KTriangleM::intersect1(ray,k,v2,v3,v1,vbool(true ),epilog); + } + + __forceinline bool occluded1(RayK& ray, size_t k, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, + const vuint& geomID, const vuint& primID) const + { + Occluded1KEpilogM epilog(ray,k,context,geomID,primID); + if (PlueckerIntersector1KTriangleM::intersect1(ray,k,v0,v1,v3,vbool(false),epilog)) return true; + if (PlueckerIntersector1KTriangleM::intersect1(ray,k,v2,v3,v1,vbool(true ),epilog)) return true; + return false; + } + }; + +#if defined(__AVX512ER__) // KNL + + /*! Intersects 4 quads with 1 ray using AVX512 */ + template + struct QuadMIntersectorKPluecker<4,K,filter> : public QuadMIntersectorKPlueckerBase<4,K,filter> + { + __forceinline QuadMIntersectorKPluecker(const vbool& valid, const RayK& ray) + : QuadMIntersectorKPlueckerBase<4,K,filter>(valid,ray) {} + + template + __forceinline bool intersect1(RayK& ray, size_t k, const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, const Epilog& epilog) const + { + const Vec3vf16 vtx0(select(0x0f0f,vfloat16(v0.x),vfloat16(v2.x)), + select(0x0f0f,vfloat16(v0.y),vfloat16(v2.y)), + select(0x0f0f,vfloat16(v0.z),vfloat16(v2.z))); +#if !defined(EMBREE_BACKFACE_CULLING) + const Vec3vf16 vtx1(vfloat16(v1.x),vfloat16(v1.y),vfloat16(v1.z)); + const Vec3vf16 vtx2(vfloat16(v3.x),vfloat16(v3.y),vfloat16(v3.z)); +#else + const Vec3vf16 vtx1(select(0x0f0f,vfloat16(v1.x),vfloat16(v3.x)), + select(0x0f0f,vfloat16(v1.y),vfloat16(v3.y)), + select(0x0f0f,vfloat16(v1.z),vfloat16(v3.z))); + const Vec3vf16 vtx2(select(0x0f0f,vfloat16(v3.x),vfloat16(v1.x)), + select(0x0f0f,vfloat16(v3.y),vfloat16(v1.y)), + select(0x0f0f,vfloat16(v3.z),vfloat16(v1.z))); +#endif + + const vbool16 flags(0xf0f0); + return PlueckerIntersector1KTriangleM::intersect1(ray,k,vtx0,vtx1,vtx2,flags,epilog); + } + + __forceinline bool intersect1(RayHitK& ray, size_t k, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect1(ray,k,v0,v1,v2,v3,Intersect1KEpilogM<8,16,K,filter>(ray,k,context,vuint8(geomID),vuint8(primID))); + } + + __forceinline bool occluded1(RayK& ray, size_t k, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect1(ray,k,v0,v1,v2,v3,Occluded1KEpilogM<8,16,K,filter>(ray,k,context,vuint8(geomID),vuint8(primID))); + } + }; + +#elif defined(__AVX__) + + /*! Intersects 4 quads with 1 ray using AVX */ + template + struct QuadMIntersectorKPluecker<4,K,filter> : public QuadMIntersectorKPlueckerBase<4,K,filter> + { + __forceinline QuadMIntersectorKPluecker(const vbool& valid, const RayK& ray) + : QuadMIntersectorKPlueckerBase<4,K,filter>(valid,ray) {} + + template + __forceinline bool intersect1(RayK& ray, size_t k, const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, const Epilog& epilog) const + { + const Vec3vf8 vtx0(vfloat8(v0.x,v2.x),vfloat8(v0.y,v2.y),vfloat8(v0.z,v2.z)); + const vbool8 flags(0,0,0,0,1,1,1,1); +#if !defined(EMBREE_BACKFACE_CULLING) + const Vec3vf8 vtx1(vfloat8(v1.x),vfloat8(v1.y),vfloat8(v1.z)); + const Vec3vf8 vtx2(vfloat8(v3.x),vfloat8(v3.y),vfloat8(v3.z)); +#else + const Vec3vf8 vtx1(vfloat8(v1.x,v3.x),vfloat8(v1.y,v3.y),vfloat8(v1.z,v3.z)); + const Vec3vf8 vtx2(vfloat8(v3.x,v1.x),vfloat8(v3.y,v1.y),vfloat8(v3.z,v1.z)); +#endif + return PlueckerIntersector1KTriangleM::intersect1(ray,k,vtx0,vtx1,vtx2,flags,epilog); + } + + __forceinline bool intersect1(RayHitK& ray, size_t k, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect1(ray,k,v0,v1,v2,v3,Intersect1KEpilogM<8,8,K,filter>(ray,k,context,vuint8(geomID),vuint8(primID))); + } + + __forceinline bool occluded1(RayK& ray, size_t k, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const vuint4& geomID, const vuint4& primID) const + { + return intersect1(ray,k,v0,v1,v2,v3,Occluded1KEpilogM<8,8,K,filter>(ray,k,context,vuint8(geomID),vuint8(primID))); + } + }; + +#endif + } +} diff --git a/thirdparty/embree/kernels/geometry/quadi.h b/thirdparty/embree/kernels/geometry/quadi.h new file mode 100644 index 000000000000..741ec519ab4c --- /dev/null +++ b/thirdparty/embree/kernels/geometry/quadi.h @@ -0,0 +1,483 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" +#include "../common/scene.h" + +namespace embree +{ + /* Stores M quads from an indexed face set */ + template + struct QuadMi + { + /* Virtual interface to query information about the quad type */ + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + /* primitive supports multiple time segments */ + static const bool singleTimeSegment = false; + + /* Returns maximum number of stored quads */ + static __forceinline size_t max_size() { return M; } + + /* Returns required number of primitive blocks for N primitives */ + static __forceinline size_t blocks(size_t N) { return (N+max_size()-1)/max_size(); } + + public: + + /* Default constructor */ + __forceinline QuadMi() { } + + /* Construction from vertices and IDs */ + __forceinline QuadMi(const vuint& v0, + const vuint& v1, + const vuint& v2, + const vuint& v3, + const vuint& geomIDs, + const vuint& primIDs) +#if defined(EMBREE_COMPACT_POLYS) + : geomIDs(geomIDs), primIDs(primIDs) {} +#else + : v0_(v0),v1_(v1), v2_(v2), v3_(v3), geomIDs(geomIDs), primIDs(primIDs) {} +#endif + + /* Returns a mask that tells which quads are valid */ + __forceinline vbool valid() const { return primIDs != vuint(-1); } + + /* Returns if the specified quad is valid */ + __forceinline bool valid(const size_t i) const { assert(i& geomID() { return geomIDs; } + __forceinline const vuint& geomID() const { return geomIDs; } + __forceinline unsigned int geomID(const size_t i) const { assert(i& primID() { return primIDs; } + __forceinline const vuint& primID() const { return primIDs; } + __forceinline unsigned int primID(const size_t i) const { assert(iget(geomID(i)); + bounds.extend(mesh->bounds(primID(i),itime)); + } + return bounds; + } + + /* Calculate the linear bounds of the primitive */ + __forceinline LBBox3fa linearBounds(const Scene* const scene, const size_t itime) { + return LBBox3fa(bounds(scene,itime+0),bounds(scene,itime+1)); + } + + __forceinline LBBox3fa linearBounds(const Scene *const scene, size_t itime, size_t numTimeSteps) + { + LBBox3fa allBounds = empty; + for (size_t i=0; iget(geomID(i)); + allBounds.extend(mesh->linearBounds(primID(i), itime, numTimeSteps)); + } + return allBounds; + } + + __forceinline LBBox3fa linearBounds(const Scene *const scene, const BBox1f time_range) + { + LBBox3fa allBounds = empty; + for (size_t i=0; iget(geomID(i)); + allBounds.extend(mesh->linearBounds(primID(i), time_range)); + } + return allBounds; + } + + /* Fill quad from quad list */ + template + __forceinline void fill(const PrimRefT* prims, size_t& begin, size_t end, Scene* scene) + { + vuint geomID = -1, primID = -1; + const PrimRefT* prim = &prims[begin]; + vuint v0 = zero, v1 = zero, v2 = zero, v3 = zero; + + for (size_t i=0; igeomID(); + primID[i] = prim->primID(); +#if !defined(EMBREE_COMPACT_POLYS) + const QuadMesh* mesh = scene->get(prim->geomID()); + const QuadMesh::Quad& q = mesh->quad(prim->primID()); + unsigned int_stride = mesh->vertices0.getStride()/4; + v0[i] = q.v[0] * int_stride; + v1[i] = q.v[1] * int_stride; + v2[i] = q.v[2] * int_stride; + v3[i] = q.v[3] * int_stride; +#endif + begin++; + } else { + assert(i); + if (likely(i > 0)) { + geomID[i] = geomID[0]; // always valid geomIDs + primID[i] = -1; // indicates invalid data + v0[i] = v0[0]; + v1[i] = v0[0]; + v2[i] = v0[0]; + v3[i] = v0[0]; + } + } + if (begin( " +#if !defined(EMBREE_COMPACT_POLYS) + << "v0 = " << quad.v0_ << ", v1 = " << quad.v1_ << ", v2 = " << quad.v2_ << ", v3 = " << quad.v3_ << ", " +#endif + << "geomID = " << quad.geomIDs << ", primID = " << quad.primIDs << " )"; + } + + protected: +#if !defined(EMBREE_COMPACT_POLYS) + vuint v0_; // 4 byte offset of 1st vertex + vuint v1_; // 4 byte offset of 2nd vertex + vuint v2_; // 4 byte offset of 3rd vertex + vuint v3_; // 4 byte offset of 4th vertex +#endif + vuint geomIDs; // geometry ID of mesh + vuint primIDs; // primitive ID of primitive inside mesh + }; + + namespace isa + { + + template + struct QuadMi : public embree::QuadMi + { +#if !defined(EMBREE_COMPACT_POLYS) + using embree::QuadMi::v0_; + using embree::QuadMi::v1_; + using embree::QuadMi::v2_; + using embree::QuadMi::v3_; +#endif + using embree::QuadMi::geomIDs; + using embree::QuadMi::primIDs; + using embree::QuadMi::geomID; + using embree::QuadMi::primID; + using embree::QuadMi::valid; + + template + __forceinline Vec3f getVertex(const size_t index, const Scene *const scene) const + { +#if defined(EMBREE_COMPACT_POLYS) + const QuadMesh* mesh = scene->get(geomID(index)); + const QuadMesh::Quad& quad = mesh->quad(primID(index)); + return (Vec3f) mesh->vertices[0][quad.v[vid]]; +#else + const vuint& v = getVertexOffset(); + const float* vertices = scene->vertices[geomID(index)]; + return (Vec3f&) vertices[v[index]]; +#endif + } + + template + __forceinline Vec3 getVertex(const size_t index, const Scene *const scene, const size_t itime, const T& ftime) const + { +#if defined(EMBREE_COMPACT_POLYS) + const QuadMesh* mesh = scene->get(geomID(index)); + const QuadMesh::Quad& quad = mesh->quad(primID(index)); + const Vec3fa v0 = mesh->vertices[itime+0][quad.v[vid]]; + const Vec3fa v1 = mesh->vertices[itime+1][quad.v[vid]]; +#else + const vuint& v = getVertexOffset(); + const QuadMesh* mesh = scene->get(geomID(index)); + const float* vertices0 = (const float*) mesh->vertexPtr(0,itime+0); + const float* vertices1 = (const float*) mesh->vertexPtr(0,itime+1); + const Vec3fa v0 = Vec3fa::loadu(vertices0+v[index]); + const Vec3fa v1 = Vec3fa::loadu(vertices1+v[index]); +#endif + const Vec3 p0(v0.x,v0.y,v0.z); + const Vec3 p1(v1.x,v1.y,v1.z); + return lerp(p0,p1,ftime); + } + + template + __forceinline Vec3 getVertex(const vbool& valid, const size_t index, const Scene *const scene, const vint& itime, const T& ftime) const + { + Vec3 p0, p1; + const QuadMesh* mesh = scene->get(geomID(index)); + + for (size_t mask=movemask(valid), i=bsf(mask); mask; mask=btc(mask,i), i=bsf(mask)) + { +#if defined(EMBREE_COMPACT_POLYS) + const QuadMesh::Quad& quad = mesh->quad(primID(index)); + const Vec3fa v0 = mesh->vertices[itime[i]+0][quad.v[vid]]; + const Vec3fa v1 = mesh->vertices[itime[i]+1][quad.v[vid]]; +#else + const vuint& v = getVertexOffset(); + const float* vertices0 = (const float*) mesh->vertexPtr(0,itime[i]+0); + const float* vertices1 = (const float*) mesh->vertexPtr(0,itime[i]+1); + const Vec3fa v0 = Vec3fa::loadu(vertices0+v[index]); + const Vec3fa v1 = Vec3fa::loadu(vertices1+v[index]); +#endif + p0.x[i] = v0.x; p0.y[i] = v0.y; p0.z[i] = v0.z; + p1.x[i] = v1.x; p1.y[i] = v1.y; p1.z[i] = v1.z; + } + return (T(one)-ftime)*p0 + ftime*p1; + } + + struct Quad { + vfloat4 v0,v1,v2,v3; + }; + +#if defined(EMBREE_COMPACT_POLYS) + + __forceinline Quad loadQuad(const int i, const Scene* const scene) const + { + const unsigned int geomID = geomIDs[i]; + const unsigned int primID = primIDs[i]; + if (unlikely(primID == -1)) return { zero, zero, zero, zero }; + const QuadMesh* mesh = scene->get(geomID); + const QuadMesh::Quad& quad = mesh->quad(primID); + const vfloat4 v0 = (vfloat4) mesh->vertices0[quad.v[0]]; + const vfloat4 v1 = (vfloat4) mesh->vertices0[quad.v[1]]; + const vfloat4 v2 = (vfloat4) mesh->vertices0[quad.v[2]]; + const vfloat4 v3 = (vfloat4) mesh->vertices0[quad.v[3]]; + return { v0, v1, v2, v3 }; + } + + __forceinline Quad loadQuad(const int i, const int itime, const Scene* const scene) const + { + const unsigned int geomID = geomIDs[i]; + const unsigned int primID = primIDs[i]; + if (unlikely(primID == -1)) return { zero, zero, zero, zero }; + const QuadMesh* mesh = scene->get(geomID); + const QuadMesh::Quad& quad = mesh->quad(primID); + const vfloat4 v0 = (vfloat4) mesh->vertices[itime][quad.v[0]]; + const vfloat4 v1 = (vfloat4) mesh->vertices[itime][quad.v[1]]; + const vfloat4 v2 = (vfloat4) mesh->vertices[itime][quad.v[2]]; + const vfloat4 v3 = (vfloat4) mesh->vertices[itime][quad.v[3]]; + return { v0, v1, v2, v3 }; + } + +#else + + __forceinline Quad loadQuad(const int i, const Scene* const scene) const + { + const float* vertices = scene->vertices[geomID(i)]; + const vfloat4 v0 = vfloat4::loadu(vertices + v0_[i]); + const vfloat4 v1 = vfloat4::loadu(vertices + v1_[i]); + const vfloat4 v2 = vfloat4::loadu(vertices + v2_[i]); + const vfloat4 v3 = vfloat4::loadu(vertices + v3_[i]); + return { v0, v1, v2, v3 }; + } + + __forceinline Quad loadQuad(const int i, const int itime, const Scene* const scene) const + { + const unsigned int geomID = geomIDs[i]; + const QuadMesh* mesh = scene->get(geomID); + const float* vertices = (const float*) mesh->vertexPtr(0,itime); + const vfloat4 v0 = vfloat4::loadu(vertices + v0_[i]); + const vfloat4 v1 = vfloat4::loadu(vertices + v1_[i]); + const vfloat4 v2 = vfloat4::loadu(vertices + v2_[i]); + const vfloat4 v3 = vfloat4::loadu(vertices + v3_[i]); + return { v0, v1, v2, v3 }; + } + +#endif + + /* Gather the quads */ + __forceinline void gather(Vec3vf& p0, + Vec3vf& p1, + Vec3vf& p2, + Vec3vf& p3, + const Scene *const scene) const; + +#if defined(__AVX512F__) + __forceinline void gather(Vec3vf16& p0, + Vec3vf16& p1, + Vec3vf16& p2, + Vec3vf16& p3, + const Scene *const scene) const; +#endif + + template +#if defined(__INTEL_COMPILER) && (__INTEL_COMPILER < 2000) // workaround for compiler bug in ICC 2019 + __noinline +#else + __forceinline +#endif + void gather(const vbool& valid, + Vec3vf& p0, + Vec3vf& p1, + Vec3vf& p2, + Vec3vf& p3, + const size_t index, + const Scene* const scene, + const vfloat& time) const + { + const QuadMesh* mesh = scene->get(geomID(index)); + + vfloat ftime; + const vint itime = mesh->timeSegment(time, ftime); + + const size_t first = bsf(movemask(valid)); + if (likely(all(valid,itime[first] == itime))) + { + p0 = getVertex<0>(index, scene, itime[first], ftime); + p1 = getVertex<1>(index, scene, itime[first], ftime); + p2 = getVertex<2>(index, scene, itime[first], ftime); + p3 = getVertex<3>(index, scene, itime[first], ftime); + } + else + { + p0 = getVertex<0>(valid, index, scene, itime, ftime); + p1 = getVertex<1>(valid, index, scene, itime, ftime); + p2 = getVertex<2>(valid, index, scene, itime, ftime); + p3 = getVertex<3>(valid, index, scene, itime, ftime); + } + } + + __forceinline void gather(Vec3vf& p0, + Vec3vf& p1, + Vec3vf& p2, + Vec3vf& p3, + const QuadMesh* mesh, + const Scene *const scene, + const int itime) const; + + __forceinline void gather(Vec3vf& p0, + Vec3vf& p1, + Vec3vf& p2, + Vec3vf& p3, + const Scene *const scene, + const float time) const; + + /* Updates the primitive */ + __forceinline BBox3fa update(QuadMesh* mesh) + { + BBox3fa bounds = empty; + for (size_t i=0; iquad(primId); + const Vec3fa p0 = mesh->vertex(q.v[0]); + const Vec3fa p1 = mesh->vertex(q.v[1]); + const Vec3fa p2 = mesh->vertex(q.v[2]); + const Vec3fa p3 = mesh->vertex(q.v[3]); + bounds.extend(merge(BBox3fa(p0),BBox3fa(p1),BBox3fa(p2),BBox3fa(p3))); + } + return bounds; + } + + private: +#if !defined(EMBREE_COMPACT_POLYS) + template const vuint& getVertexOffset() const; +#endif + }; + +#if !defined(EMBREE_COMPACT_POLYS) + template<> template<> __forceinline const vuint<4>& QuadMi<4>::getVertexOffset<0>() const { return v0_; } + template<> template<> __forceinline const vuint<4>& QuadMi<4>::getVertexOffset<1>() const { return v1_; } + template<> template<> __forceinline const vuint<4>& QuadMi<4>::getVertexOffset<2>() const { return v2_; } + template<> template<> __forceinline const vuint<4>& QuadMi<4>::getVertexOffset<3>() const { return v3_; } +#endif + + template<> + __forceinline void QuadMi<4>::gather(Vec3vf4& p0, + Vec3vf4& p1, + Vec3vf4& p2, + Vec3vf4& p3, + const Scene *const scene) const + { + prefetchL1(((char*)this)+0*64); + prefetchL1(((char*)this)+1*64); + const Quad tri0 = loadQuad(0,scene); + const Quad tri1 = loadQuad(1,scene); + const Quad tri2 = loadQuad(2,scene); + const Quad tri3 = loadQuad(3,scene); + transpose(tri0.v0,tri1.v0,tri2.v0,tri3.v0,p0.x,p0.y,p0.z); + transpose(tri0.v1,tri1.v1,tri2.v1,tri3.v1,p1.x,p1.y,p1.z); + transpose(tri0.v2,tri1.v2,tri2.v2,tri3.v2,p2.x,p2.y,p2.z); + transpose(tri0.v3,tri1.v3,tri2.v3,tri3.v3,p3.x,p3.y,p3.z); + } + + template<> + __forceinline void QuadMi<4>::gather(Vec3vf4& p0, + Vec3vf4& p1, + Vec3vf4& p2, + Vec3vf4& p3, + const QuadMesh* mesh, + const Scene *const scene, + const int itime) const + { + // FIXME: for trianglei there all geometries are identical, is this the case here too? + + const Quad tri0 = loadQuad(0,itime,scene); + const Quad tri1 = loadQuad(1,itime,scene); + const Quad tri2 = loadQuad(2,itime,scene); + const Quad tri3 = loadQuad(3,itime,scene); + transpose(tri0.v0,tri1.v0,tri2.v0,tri3.v0,p0.x,p0.y,p0.z); + transpose(tri0.v1,tri1.v1,tri2.v1,tri3.v1,p1.x,p1.y,p1.z); + transpose(tri0.v2,tri1.v2,tri2.v2,tri3.v2,p2.x,p2.y,p2.z); + transpose(tri0.v3,tri1.v3,tri2.v3,tri3.v3,p3.x,p3.y,p3.z); + } + + template<> + __forceinline void QuadMi<4>::gather(Vec3vf4& p0, + Vec3vf4& p1, + Vec3vf4& p2, + Vec3vf4& p3, + const Scene *const scene, + const float time) const + { + const QuadMesh* mesh = scene->get(geomID(0)); // in mblur mode all geometries are identical + + float ftime; + const int itime = mesh->timeSegment(time, ftime); + + Vec3vf4 a0,a1,a2,a3; gather(a0,a1,a2,a3,mesh,scene,itime); + Vec3vf4 b0,b1,b2,b3; gather(b0,b1,b2,b3,mesh,scene,itime+1); + p0 = lerp(a0,b0,vfloat4(ftime)); + p1 = lerp(a1,b1,vfloat4(ftime)); + p2 = lerp(a2,b2,vfloat4(ftime)); + p3 = lerp(a3,b3,vfloat4(ftime)); + } + } + + template + typename QuadMi::Type QuadMi::type; + + typedef QuadMi<4> Quad4i; +} diff --git a/thirdparty/embree/kernels/geometry/quadi_intersector.h b/thirdparty/embree/kernels/geometry/quadi_intersector.h new file mode 100644 index 000000000000..96cf7f1ca2d3 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/quadi_intersector.h @@ -0,0 +1,350 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "quadi.h" +#include "quad_intersector_moeller.h" +#include "quad_intersector_pluecker.h" + +namespace embree +{ + namespace isa + { + /*! Intersects M quads with 1 ray */ + template + struct QuadMiIntersector1Moeller + { + typedef QuadMi Primitive; + typedef QuadMIntersector1MoellerTrumbore Precalculations; + + /*! Intersect a ray with the M quads and updates the hit. */ + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& quad) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene); + pre.intersect(ray,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + + /*! Test if the ray is occluded by one of M quads. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& quad) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene); + return pre.occluded(ray,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& quad) + { + return PrimitivePointQuery1::pointQuery(query, context, quad); + } + }; + + /*! Intersects M triangles with K rays. */ + template + struct QuadMiIntersectorKMoeller + { + typedef QuadMi Primitive; + typedef QuadMIntersectorKMoellerTrumbore Precalculations; + + /*! Intersects K rays with M triangles. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const QuadMi& quad) + { + Scene* scene = context->scene; + for (size_t i=0; i::max_size(); i++) + { + if (!quad.valid(i)) break; + STAT3(normal.trav_prims,1,popcnt(valid_i),K); + const Vec3vf p0 = quad.template getVertex<0>(i,scene); + const Vec3vf p1 = quad.template getVertex<1>(i,scene); + const Vec3vf p2 = quad.template getVertex<2>(i,scene); + const Vec3vf p3 = quad.template getVertex<3>(i,scene); + pre.intersectK(valid_i,ray,p0,p1,p2,p3,IntersectKEpilogM(ray,context,quad.geomID(),quad.primID(),i)); + } + } + + /*! Test for K rays if they are occluded by any of the M triangles. */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const QuadMi& quad) + { + Scene* scene = context->scene; + vbool valid0 = valid_i; + for (size_t i=0; i::max_size(); i++) + { + if (!quad.valid(i)) break; + STAT3(shadow.trav_prims,1,popcnt(valid0),K); + const Vec3vf p0 = quad.template getVertex<0>(i,scene); + const Vec3vf p1 = quad.template getVertex<1>(i,scene); + const Vec3vf p2 = quad.template getVertex<2>(i,scene); + const Vec3vf p3 = quad.template getVertex<3>(i,scene); + if (pre.intersectK(valid0,ray,p0,p1,p2,p3,OccludedKEpilogM(valid0,ray,context,quad.geomID(),quad.primID(),i))) + break; + } + return !valid0; + } + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const QuadMi& quad) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf4 v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene); + pre.intersect1(ray,k,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const QuadMi& quad) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf4 v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene); + return pre.occluded1(ray,k,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + }; + + /*! Intersects M quads with 1 ray */ + template + struct QuadMiIntersector1Pluecker + { + typedef QuadMi Primitive; + typedef QuadMIntersector1Pluecker Precalculations; + + /*! Intersect a ray with the M quads and updates the hit. */ + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& quad) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene); + pre.intersect(ray,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + + /*! Test if the ray is occluded by one of M quads. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& quad) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene); + return pre.occluded(ray,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& quad) + { + return PrimitivePointQuery1::pointQuery(query, context, quad); + } + }; + + /*! Intersects M triangles with K rays. */ + template + struct QuadMiIntersectorKPluecker + { + typedef QuadMi Primitive; + typedef QuadMIntersectorKPluecker Precalculations; + + /*! Intersects K rays with M triangles. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const QuadMi& quad) + { + Scene* scene = context->scene; + for (size_t i=0; i::max_size(); i++) + { + if (!quad.valid(i)) break; + STAT3(normal.trav_prims,1,popcnt(valid_i),K); + const Vec3vf p0 = quad.template getVertex<0>(i,scene); + const Vec3vf p1 = quad.template getVertex<1>(i,scene); + const Vec3vf p2 = quad.template getVertex<2>(i,scene); + const Vec3vf p3 = quad.template getVertex<3>(i,scene); + pre.intersectK(valid_i,ray,p0,p1,p2,p3,IntersectKEpilogM(ray,context,quad.geomID(),quad.primID(),i)); + } + } + + /*! Test for K rays if they are occluded by any of the M triangles. */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const QuadMi& quad) + { + Scene* scene = context->scene; + vbool valid0 = valid_i; + for (size_t i=0; i::max_size(); i++) + { + if (!quad.valid(i)) break; + STAT3(shadow.trav_prims,1,popcnt(valid0),K); + const Vec3vf p0 = quad.template getVertex<0>(i,scene); + const Vec3vf p1 = quad.template getVertex<1>(i,scene); + const Vec3vf p2 = quad.template getVertex<2>(i,scene); + const Vec3vf p3 = quad.template getVertex<3>(i,scene); + if (pre.intersectK(valid0,ray,p0,p1,p2,p3,OccludedKEpilogM(valid0,ray,context,quad.geomID(),quad.primID(),i))) + break; + } + return !valid0; + } + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const QuadMi& quad) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf4 v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene); + pre.intersect1(ray,k,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const QuadMi& quad) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf4 v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene); + return pre.occluded1(ray,k,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + }; + + /*! Intersects M motion blur quads with 1 ray */ + template + struct QuadMiMBIntersector1Moeller + { + typedef QuadMi Primitive; + typedef QuadMIntersector1MoellerTrumbore Precalculations; + + /*! Intersect a ray with the M quads and updates the hit. */ + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& quad) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene,ray.time()); + pre.intersect(ray,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + + /*! Test if the ray is occluded by one of M quads. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& quad) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene,ray.time()); + return pre.occluded(ray,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& quad) + { + return PrimitivePointQuery1::pointQuery(query, context, quad); + } + }; + + /*! Intersects M motion blur quads with K rays. */ + template + struct QuadMiMBIntersectorKMoeller + { + typedef QuadMi Primitive; + typedef QuadMIntersectorKMoellerTrumbore Precalculations; + + /*! Intersects K rays with M quads. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const QuadMi& quad) + { + for (size_t i=0; i::max_size(); i++) + { + if (!quad.valid(i)) break; + STAT3(normal.trav_prims,1,popcnt(valid_i),K); + Vec3vf v0,v1,v2,v3; quad.gather(valid_i,v0,v1,v2,v3,i,context->scene,ray.time()); + pre.intersectK(valid_i,ray,v0,v1,v2,v3,IntersectKEpilogM(ray,context,quad.geomID(),quad.primID(),i)); + } + } + + /*! Test for K rays if they are occluded by any of the M quads. */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const QuadMi& quad) + { + vbool valid0 = valid_i; + for (size_t i=0; i::max_size(); i++) + { + if (!quad.valid(i)) break; + STAT3(shadow.trav_prims,1,popcnt(valid0),K); + Vec3vf v0,v1,v2,v3; quad.gather(valid_i,v0,v1,v2,v3,i,context->scene,ray.time()); + if (pre.intersectK(valid0,ray,v0,v1,v2,v3,OccludedKEpilogM(valid0,ray,context,quad.geomID(),quad.primID(),i))) + break; + } + return !valid0; + } + + /*! Intersect a ray with M quads and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const QuadMi& quad) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene,ray.time()[k]); + pre.intersect1(ray,k,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + + /*! Test if the ray is occluded by one of the M quads. */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const QuadMi& quad) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene,ray.time()[k]); + return pre.occluded1(ray,k,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + }; + + /*! Intersects M motion blur quads with 1 ray */ + template + struct QuadMiMBIntersector1Pluecker + { + typedef QuadMi Primitive; + typedef QuadMIntersector1Pluecker Precalculations; + + /*! Intersect a ray with the M quads and updates the hit. */ + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& quad) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene,ray.time()); + pre.intersect(ray,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + + /*! Test if the ray is occluded by one of M quads. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& quad) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene,ray.time()); + return pre.occluded(ray,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& quad) + { + return PrimitivePointQuery1::pointQuery(query, context, quad); + } + }; + + /*! Intersects M motion blur quads with K rays. */ + template + struct QuadMiMBIntersectorKPluecker + { + typedef QuadMi Primitive; + typedef QuadMIntersectorKPluecker Precalculations; + + /*! Intersects K rays with M quads. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const QuadMi& quad) + { + for (size_t i=0; i::max_size(); i++) + { + if (!quad.valid(i)) break; + STAT3(normal.trav_prims,1,popcnt(valid_i),K); + Vec3vf v0,v1,v2,v3; quad.gather(valid_i,v0,v1,v2,v3,i,context->scene,ray.time()); + pre.intersectK(valid_i,ray,v0,v1,v2,v3,IntersectKEpilogM(ray,context,quad.geomID(),quad.primID(),i)); + } + } + + /*! Test for K rays if they are occluded by any of the M quads. */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const QuadMi& quad) + { + vbool valid0 = valid_i; + for (size_t i=0; i::max_size(); i++) + { + if (!quad.valid(i)) break; + STAT3(shadow.trav_prims,1,popcnt(valid0),K); + Vec3vf v0,v1,v2,v3; quad.gather(valid_i,v0,v1,v2,v3,i,context->scene,ray.time()); + if (pre.intersectK(valid0,ray,v0,v1,v2,v3,OccludedKEpilogM(valid0,ray,context,quad.geomID(),quad.primID(),i))) + break; + } + return !valid0; + } + + /*! Intersect a ray with M quads and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const QuadMi& quad) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene,ray.time()[k]); + pre.intersect1(ray,k,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + + /*! Test if the ray is occluded by one of the M quads. */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const QuadMi& quad) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0,v1,v2,v3; quad.gather(v0,v1,v2,v3,context->scene,ray.time()[k]); + return pre.occluded1(ray,k,context,v0,v1,v2,v3,quad.geomID(),quad.primID()); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/quadv.h b/thirdparty/embree/kernels/geometry/quadv.h new file mode 100644 index 000000000000..0a1fe4d128a7 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/quadv.h @@ -0,0 +1,165 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" + +namespace embree +{ + /* Stores the vertices of M quads in struct of array layout */ + template + struct QuadMv + { + public: + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + /* Returns maximum number of stored quads */ + static __forceinline size_t max_size() { return M; } + + /* Returns required number of primitive blocks for N primitives */ + static __forceinline size_t blocks(size_t N) { return (N+max_size()-1)/max_size(); } + + public: + + /* Default constructor */ + __forceinline QuadMv() {} + + /* Construction from vertices and IDs */ + __forceinline QuadMv(const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, const vuint& geomIDs, const vuint& primIDs) + : v0(v0), v1(v1), v2(v2), v3(v3), geomIDs(geomIDs), primIDs(primIDs) {} + + /* Returns a mask that tells which quads are valid */ + __forceinline vbool valid() const { return geomIDs != vuint(-1); } + + /* Returns true if the specified quad is valid */ + __forceinline bool valid(const size_t i) const { assert(i& geomID() { return geomIDs; } + __forceinline const vuint& geomID() const { return geomIDs; } + __forceinline unsigned int geomID(const size_t i) const { assert(i primID() { return primIDs; } + __forceinline const vuint primID() const { return primIDs; } + __forceinline unsigned int primID(const size_t i) const { assert(i lower = min(v0,v1,v2,v3); + Vec3vf upper = max(v0,v1,v2,v3); + vbool mask = valid(); + lower.x = select(mask,lower.x,vfloat(pos_inf)); + lower.y = select(mask,lower.y,vfloat(pos_inf)); + lower.z = select(mask,lower.z,vfloat(pos_inf)); + upper.x = select(mask,upper.x,vfloat(neg_inf)); + upper.y = select(mask,upper.y,vfloat(neg_inf)); + upper.z = select(mask,upper.z,vfloat(neg_inf)); + return BBox3fa(Vec3fa(reduce_min(lower.x),reduce_min(lower.y),reduce_min(lower.z)), + Vec3fa(reduce_max(upper.x),reduce_max(upper.y),reduce_max(upper.z))); + } + + /* Non temporal store */ + __forceinline static void store_nt(QuadMv* dst, const QuadMv& src) + { + vfloat::store_nt(&dst->v0.x,src.v0.x); + vfloat::store_nt(&dst->v0.y,src.v0.y); + vfloat::store_nt(&dst->v0.z,src.v0.z); + vfloat::store_nt(&dst->v1.x,src.v1.x); + vfloat::store_nt(&dst->v1.y,src.v1.y); + vfloat::store_nt(&dst->v1.z,src.v1.z); + vfloat::store_nt(&dst->v2.x,src.v2.x); + vfloat::store_nt(&dst->v2.y,src.v2.y); + vfloat::store_nt(&dst->v2.z,src.v2.z); + vfloat::store_nt(&dst->v3.x,src.v3.x); + vfloat::store_nt(&dst->v3.y,src.v3.y); + vfloat::store_nt(&dst->v3.z,src.v3.z); + vuint::store_nt(&dst->geomIDs,src.geomIDs); + vuint::store_nt(&dst->primIDs,src.primIDs); + } + + /* Fill quad from quad list */ + __forceinline void fill(const PrimRef* prims, size_t& begin, size_t end, Scene* scene) + { + vuint vgeomID = -1, vprimID = -1; + Vec3vf v0 = zero, v1 = zero, v2 = zero, v3 = zero; + + for (size_t i=0; iget(geomID); + const QuadMesh::Quad& quad = mesh->quad(primID); + const Vec3fa& p0 = mesh->vertex(quad.v[0]); + const Vec3fa& p1 = mesh->vertex(quad.v[1]); + const Vec3fa& p2 = mesh->vertex(quad.v[2]); + const Vec3fa& p3 = mesh->vertex(quad.v[3]); + vgeomID [i] = geomID; + vprimID [i] = primID; + v0.x[i] = p0.x; v0.y[i] = p0.y; v0.z[i] = p0.z; + v1.x[i] = p1.x; v1.y[i] = p1.y; v1.z[i] = p1.z; + v2.x[i] = p2.x; v2.y[i] = p2.y; v2.z[i] = p2.z; + v3.x[i] = p3.x; v3.y[i] = p3.y; v3.z[i] = p3.z; + } + QuadMv::store_nt(this,QuadMv(v0,v1,v2,v3,vgeomID,vprimID)); + } + + /* Updates the primitive */ + __forceinline BBox3fa update(QuadMesh* mesh) + { + BBox3fa bounds = empty; + vuint vgeomID = -1, vprimID = -1; + Vec3vf v0 = zero, v1 = zero, v2 = zero; + + for (size_t i=0; iquad(primId); + const Vec3fa p0 = mesh->vertex(quad.v[0]); + const Vec3fa p1 = mesh->vertex(quad.v[1]); + const Vec3fa p2 = mesh->vertex(quad.v[2]); + const Vec3fa p3 = mesh->vertex(quad.v[3]); + bounds.extend(merge(BBox3fa(p0),BBox3fa(p1),BBox3fa(p2),BBox3fa(p3))); + vgeomID [i] = geomId; + vprimID [i] = primId; + v0.x[i] = p0.x; v0.y[i] = p0.y; v0.z[i] = p0.z; + v1.x[i] = p1.x; v1.y[i] = p1.y; v1.z[i] = p1.z; + v2.x[i] = p2.x; v2.y[i] = p2.y; v2.z[i] = p2.z; + v3.x[i] = p3.x; v3.y[i] = p3.y; v3.z[i] = p3.z; + } + new (this) QuadMv(v0,v1,v2,v3,vgeomID,vprimID); + return bounds; + } + + public: + Vec3vf v0; // 1st vertex of the quads + Vec3vf v1; // 2nd vertex of the quads + Vec3vf v2; // 3rd vertex of the quads + Vec3vf v3; // 4rd vertex of the quads + private: + vuint geomIDs; // geometry ID + vuint primIDs; // primitive ID + }; + + template + typename QuadMv::Type QuadMv::type; + + typedef QuadMv<4> Quad4v; +} diff --git a/thirdparty/embree/kernels/geometry/quadv_intersector.h b/thirdparty/embree/kernels/geometry/quadv_intersector.h new file mode 100644 index 000000000000..30a24b291a7c --- /dev/null +++ b/thirdparty/embree/kernels/geometry/quadv_intersector.h @@ -0,0 +1,181 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "quadv.h" +#include "quad_intersector_moeller.h" +#include "quad_intersector_pluecker.h" + +namespace embree +{ + namespace isa + { + /*! Intersects M quads with 1 ray */ + template + struct QuadMvIntersector1Moeller + { + typedef QuadMv Primitive; + typedef QuadMIntersector1MoellerTrumbore Precalculations; + + /*! Intersect a ray with the M quads and updates the hit. */ + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& quad) + { + STAT3(normal.trav_prims,1,1,1); + pre.intersect(ray,context,quad.v0,quad.v1,quad.v2,quad.v3,quad.geomID(),quad.primID()); + } + + /*! Test if the ray is occluded by one of M quads. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& quad) + { + STAT3(shadow.trav_prims,1,1,1); + return pre.occluded(ray,context, quad.v0,quad.v1,quad.v2,quad.v3,quad.geomID(),quad.primID()); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& quad) + { + return PrimitivePointQuery1::pointQuery(query, context, quad); + } + }; + + /*! Intersects M triangles with K rays. */ + template + struct QuadMvIntersectorKMoeller + { + typedef QuadMv Primitive; + typedef QuadMIntersectorKMoellerTrumbore Precalculations; + + /*! Intersects K rays with M triangles. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const QuadMv& quad) + { + for (size_t i=0; i::max_size(); i++) + { + if (!quad.valid(i)) break; + STAT3(normal.trav_prims,1,popcnt(valid_i),K); + const Vec3vf p0 = broadcast>(quad.v0,i); + const Vec3vf p1 = broadcast>(quad.v1,i); + const Vec3vf p2 = broadcast>(quad.v2,i); + const Vec3vf p3 = broadcast>(quad.v3,i); + pre.intersectK(valid_i,ray,p0,p1,p2,p3,IntersectKEpilogM(ray,context,quad.geomID(),quad.primID(),i)); + } + } + + /*! Test for K rays if they are occluded by any of the M triangles. */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const QuadMv& quad) + { + vbool valid0 = valid_i; + + for (size_t i=0; i::max_size(); i++) + { + if (!quad.valid(i)) break; + STAT3(shadow.trav_prims,1,popcnt(valid0),K); + const Vec3vf p0 = broadcast>(quad.v0,i); + const Vec3vf p1 = broadcast>(quad.v1,i); + const Vec3vf p2 = broadcast>(quad.v2,i); + const Vec3vf p3 = broadcast>(quad.v3,i); + if (pre.intersectK(valid0,ray,p0,p1,p2,p3,OccludedKEpilogM(valid0,ray,context,quad.geomID(),quad.primID(),i))) + break; + } + return !valid0; + } + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const QuadMv& quad) + { + STAT3(normal.trav_prims,1,1,1); + pre.intersect1(ray,k,context,quad.v0,quad.v1,quad.v2,quad.v3,quad.geomID(),quad.primID()); + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const QuadMv& quad) + { + STAT3(shadow.trav_prims,1,1,1); + return pre.occluded1(ray,k,context,quad.v0,quad.v1,quad.v2,quad.v3,quad.geomID(),quad.primID()); + } + }; + + /*! Intersects M quads with 1 ray */ + template + struct QuadMvIntersector1Pluecker + { + typedef QuadMv Primitive; + typedef QuadMIntersector1Pluecker Precalculations; + + /*! Intersect a ray with the M quads and updates the hit. */ + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& quad) + { + STAT3(normal.trav_prims,1,1,1); + pre.intersect(ray,context,quad.v0,quad.v1,quad.v2,quad.v3,quad.geomID(),quad.primID()); + } + + /*! Test if the ray is occluded by one of M quads. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& quad) + { + STAT3(shadow.trav_prims,1,1,1); + return pre.occluded(ray,context, quad.v0,quad.v1,quad.v2,quad.v3,quad.geomID(),quad.primID()); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& quad) + { + return PrimitivePointQuery1::pointQuery(query, context, quad); + } + }; + + /*! Intersects M triangles with K rays. */ + template + struct QuadMvIntersectorKPluecker + { + typedef QuadMv Primitive; + typedef QuadMIntersectorKPluecker Precalculations; + + /*! Intersects K rays with M triangles. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const QuadMv& quad) + { + for (size_t i=0; i::max_size(); i++) + { + if (!quad.valid(i)) break; + STAT3(normal.trav_prims,1,popcnt(valid_i),K); + const Vec3vf p0 = broadcast>(quad.v0,i); + const Vec3vf p1 = broadcast>(quad.v1,i); + const Vec3vf p2 = broadcast>(quad.v2,i); + const Vec3vf p3 = broadcast>(quad.v3,i); + pre.intersectK(valid_i,ray,p0,p1,p2,p3,IntersectKEpilogM(ray,context,quad.geomID(),quad.primID(),i)); + } + } + + /*! Test for K rays if they are occluded by any of the M triangles. */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const QuadMv& quad) + { + vbool valid0 = valid_i; + + for (size_t i=0; i::max_size(); i++) + { + if (!quad.valid(i)) break; + STAT3(shadow.trav_prims,1,popcnt(valid0),K); + const Vec3vf p0 = broadcast>(quad.v0,i); + const Vec3vf p1 = broadcast>(quad.v1,i); + const Vec3vf p2 = broadcast>(quad.v2,i); + const Vec3vf p3 = broadcast>(quad.v3,i); + if (pre.intersectK(valid0,ray,p0,p1,p2,p3,OccludedKEpilogM(valid0,ray,context,quad.geomID(),quad.primID(),i))) + break; + } + return !valid0; + } + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const QuadMv& quad) + { + STAT3(normal.trav_prims,1,1,1); + pre.intersect1(ray,k,context,quad.v0,quad.v1,quad.v2,quad.v3,quad.geomID(),quad.primID()); + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const QuadMv& quad) + { + STAT3(shadow.trav_prims,1,1,1); + return pre.occluded1(ray,k,context,quad.v0,quad.v1,quad.v2,quad.v3,quad.geomID(),quad.primID()); + } + }; + } +} + diff --git a/thirdparty/embree/kernels/geometry/roundline_intersector.h b/thirdparty/embree/kernels/geometry/roundline_intersector.h new file mode 100644 index 000000000000..cdf68f486bb7 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/roundline_intersector.h @@ -0,0 +1,710 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" +#include "curve_intersector_precalculations.h" + + +/* + + This file implements the intersection of a ray with a round linear + curve segment. We define the geometry of such a round linear curve + segment from point p0 with radius r0 to point p1 with radius r1 + using the cone that touches spheres p0/r0 and p1/r1 tangentially + plus the sphere p1/r1. We denote the tangentially touching cone from + p0/r0 to p1/r1 with cone(p0,r0,p1,r1) and the cone plus the ending + sphere with cone_sphere(p0,r0,p1,r1). + + For multiple connected round linear curve segments this construction + yield a proper shape when viewed from the outside. Using the + following CSG we can also handle the interiour in most common cases: + + round_linear_curve(pl,rl,p0,r0,p1,r1,pr,rr) = + cone_sphere(p0,r0,p1,r1) - cone(pl,rl,p0,r0) - cone(p1,r1,pr,rr) + + Thus by subtracting the neighboring cone geometries, we cut away + parts of the center cone_sphere surface which lie inside the + combined curve. This approach works as long as geometry of the + current cone_sphere penetrates into direct neighbor segments only, + and not into segments further away. + + To construct a cone that touches two spheres at p0 and p1 with r0 + and r1, one has to increase the cone radius at r0 and r1 to obtain + larger radii w0 and w1, such that the infinite cone properly touches + the spheres. From the paper "Ray Tracing Generalized Tube + Primitives: Method and Applications" + (https://www.researchgate.net/publication/334378683_Ray_Tracing_Generalized_Tube_Primitives_Method_and_Applications) + one can derive the following equations for these increased + radii: + + sr = 1.0f / sqrt(1-sqr(dr)/sqr(p1-p0)) + w0 = sr*r0 + w1 = sr*r1 + + Further, we want the cone to start where it touches the sphere at p0 + and to end where it touches sphere at p1. Therefore, we need to + construct clipping locations y0 and y1 for the start and end of the + cone. These start and end clipping location of the cone can get + calculated as: + + Y0 = - r0 * (r1-r0) / length(p1-p0) + Y1 = length(p1-p0) - r1 * (r1-r0) / length(p1-p0) + + Where the cone starts a distance Y0 and ends a distance Y1 away of + point p0 along the cone center. The distance between Y1-Y0 can get + calculated as: + + dY = length(p1-p0) - (r1-r0)^2 / length(p1-p0) + + In the code below, Y will always be scaled by length(p1-p0) to + obtain y and you will find the terms r0*(r1-r0) and + (p1-p0)^2-(r1-r0)^2. + + */ + +namespace embree +{ + namespace isa + { + template + struct RoundLineIntersectorHitM + { + __forceinline RoundLineIntersectorHitM() {} + + __forceinline RoundLineIntersectorHitM(const vfloat& u, const vfloat& v, const vfloat& t, const Vec3vf& Ng) + : vu(u), vv(v), vt(t), vNg(Ng) {} + + __forceinline void finalize() {} + + __forceinline Vec2f uv (const size_t i) const { return Vec2f(vu[i],vv[i]); } + __forceinline float t (const size_t i) const { return vt[i]; } + __forceinline Vec3fa Ng(const size_t i) const { return Vec3fa(vNg.x[i],vNg.y[i],vNg.z[i]); } + + public: + vfloat vu; + vfloat vv; + vfloat vt; + Vec3vf vNg; + }; + + namespace __roundline_internal + { + template + struct ConeGeometry + { + ConeGeometry (const Vec4vf& a, const Vec4vf& b) + : p0(a.xyz()), p1(b.xyz()), dP(p1-p0), dPdP(dot(dP,dP)), r0(a.w), sqr_r0(sqr(r0)), r1(b.w), dr(r1-r0), drdr(dr*dr), r0dr (r0*dr), g(dPdP - drdr) {} + + /* + + This function tests if a point is accepted by first cone + clipping plane. + + First, we need to project the point onto the line p0->p1: + + Y = (p-p0)*(p1-p0)/length(p1-p0) + + This value y is the distance to the projection point from + p0. The clip distances are calculated as: + + Y0 = - r0 * (r1-r0) / length(p1-p0) + Y1 = length(p1-p0) - r1 * (r1-r0) / length(p1-p0) + + Thus to test if the point p is accepted by the first + clipping plane we need to test Y > Y0 and to test if it + is accepted by the second clipping plane we need to test + Y < Y1. + + By multiplying the calculations with length(p1-p0) these + calculation can get simplied to: + + y = (p-p0)*(p1-p0) + y0 = - r0 * (r1-r0) + y1 = (p1-p0)^2 - r1 * (r1-r0) + + and the test y > y0 and y < y1. + + */ + + __forceinline vbool isClippedByPlane (const vbool& valid_i, const Vec3vf& p) const + { + const Vec3vf p0p = p - p0; + const vfloat y = dot(p0p,dP); + const vfloat cap0 = -r0dr; + const vbool inside_cone = y > cap0; + return valid_i & (p0.x != vfloat(inf)) & (p1.x != vfloat(inf)) & inside_cone; + } + + /* + + This function tests whether a point lies inside the capped cone + tangential to its ending spheres. + + Therefore one has to check if the point is inside the + region defined by the cone clipping planes, which is + performed similar as in the previous function. + + To perform the inside cone test we need to project the + point onto the line p0->p1: + + dP = p1-p0 + Y = (p-p0)*dP/length(dP) + + This value Y is the distance to the projection point from + p0. To obtain a parameter value u going from 0 to 1 along + the line p0->p1 we calculate: + + U = Y/length(dP) + + The radii to use at points p0 and p1 are: + + w0 = sr * r0 + w1 = sr * r1 + dw = w1-w0 + + Using these radii and u one can directly test if the point + lies inside the cone using the formula dP*dP < wy*wy with: + + wy = w0 + u*dw + py = p0 + u*dP - p + + By multiplying the calculations with length(p1-p0) and + inserting the definition of w can obtain simpler equations: + + y = (p-p0)*dP + ry = r0 + y/dP^2 * dr + wy = sr*ry + py = p0 + y/dP^2*dP - p + y0 = - r0 * dr + y1 = dP^2 - r1 * dr + + Thus for the in-cone test we get: + + py^2 < wy^2 + <=> py^2 < sr^2 * ry^2 + <=> py^2 * ( dP^2 - dr^2 ) < dP^2 * ry^2 + + This can further get simplified to: + + (p0-p)^2 * (dP^2 - dr^2) - y^2 < dP^2 * r0^2 + 2.0f*r0*dr*y; + + */ + + __forceinline vbool isInsideCappedCone (const vbool& valid_i, const Vec3vf& p) const + { + const Vec3vf p0p = p - p0; + const vfloat y = dot(p0p,dP); + const vfloat cap0 = -r0dr+vfloat(ulp); + const vfloat cap1 = -r1*dr + dPdP; + + vbool inside_cone = valid_i & (p0.x != vfloat(inf)) & (p1.x != vfloat(inf)); + inside_cone &= y > cap0; // start clipping plane + inside_cone &= y < cap1; // end clipping plane + inside_cone &= sqr(p0p)*g - sqr(y) < dPdP * sqr_r0 + 2.0f*r0dr*y; // in cone test + return inside_cone; + } + + protected: + Vec3vf p0; + Vec3vf p1; + Vec3vf dP; + vfloat dPdP; + vfloat r0; + vfloat sqr_r0; + vfloat r1; + vfloat dr; + vfloat drdr; + vfloat r0dr; + vfloat g; + }; + + template + struct ConeGeometryIntersector : public ConeGeometry + { + using ConeGeometry::p0; + using ConeGeometry::p1; + using ConeGeometry::dP; + using ConeGeometry::dPdP; + using ConeGeometry::r0; + using ConeGeometry::sqr_r0; + using ConeGeometry::r1; + using ConeGeometry::dr; + using ConeGeometry::r0dr; + using ConeGeometry::g; + + ConeGeometryIntersector (const Vec3vf& ray_org, const Vec3vf& ray_dir, const vfloat& dOdO, const vfloat& rcp_dOdO, const Vec4vf& a, const Vec4vf& b) + : ConeGeometry(a,b), org(ray_org), O(ray_org-p0), dO(ray_dir), dOdO(dOdO), rcp_dOdO(rcp_dOdO), OdP(dot(dP,O)), dOdP(dot(dP,dO)), yp(OdP + r0dr) {} + + /* + + This function intersects a ray with a cone that touches a + start sphere p0/r0 and end sphere p1/r1. + + To find this ray/cone intersections one could just + calculate radii w0 and w1 as described above and use a + standard ray/cone intersection routine with these + radii. However, it turns out that calculations can get + simplified when deriving a specialized ray/cone + intersection for this special case. We perform + calculations relative to the cone origin p0 and define: + + O = ray_org - p0 + dO = ray_dir + dP = p1-p0 + dr = r1-r0 + dw = w1-w0 + + For some t we can compute the potential hit point h = O + t*dO and + project it onto the cone vector dP to obtain u = (h*dP)/(dP*dP). In + case of an intersection, the squared distance from the hit point + projected onto the cone center line to the hit point should be equal + to the squared cone radius at u: + + (u*dP - h)^2 = (w0 + u*dw)^2 + + Inserting the definition of h, u, w0, and dw into this formula, then + factoring out all terms, and sorting by t^2, t^1, and t^0 terms + yields a quadratic equation to solve. + + Inserting u: + ( (h*dP)*dP/dP^2 - h )^2 = ( w0 + (h*dP)*dw/dP^2 )^2 + + Multiplying by dP^4: + ( (h*dP)*dP - h*dP^2 )^2 = ( w0*dP^2 + (h*dP)*dw )^2 + + Inserting w0 and dw: + ( (h*dP)*dP - h*dP^2 )^2 = ( r0*dP^2 + (h*dP)*dr )^2 / (1-dr^2/dP^2) + ( (h*dP)*dP - h*dP^2 )^2 *(dP^2 - dr^2) = dP^2 * ( r0*dP^2 + (h*dP)*dr )^2 + + Now one can insert the definition of h, factor out, and presort by t: + ( ((O + t*dO)*dP)*dP - (O + t*dO)*dP^2 )^2 *(dP^2 - dr^2) = dP^2 * ( r0*dP^2 + ((O + t*dO)*dP)*dr )^2 + ( (O*dP)*dP-O*dP^2 + t*( (dO*dP)*dP - dO*dP^2 ) )^2 *(dP^2 - dr^2) = dP^2 * ( r0*dP^2 + (O*dP)*dr + t*(dO*dP)*dr )^2 + + Factoring out further and sorting by t^2, t^1 and t^0 yields: + + 0 = t^2 * [ ((dO*dP)*dP - dO-dP^2)^2 * (dP^2 - dr^2) - dP^2*(dO*dP)^2*dr^2 ] + + 2*t^1 * [ ((O*dP)*dP - O*dP^2) * ((dO*dP)*dP - dO*dP^2) * (dP^2 - dr^2) - dP^2*(r0*dP^2 + (O*dP)*dr)*(dO*dP)*dr ] + + t^0 * [ ( (O*dP)*dP - O*dP^2)^2 * (dP^2-dr^2) - dP^2*(r0*dP^2 + (O*dP)*dr)^2 ] + + This can be simplified to: + + 0 = t^2 * [ (dP^2 - dr^2)*dO^2 - (dO*dP)^2 ] + + 2*t^1 * [ (dP^2 - dr^2)*(O*dO) - (dO*dP)*(O*dP + r0*dr) ] + + t^0 * [ (dP^2 - dr^2)*O^2 - (O*dP)^2 - r0^2*dP^2 - 2.0f*r0*dr*(O*dP) ] + + Solving this quadratic equation yields the values for t at which the + ray intersects the cone. + + */ + + __forceinline bool intersectCone(vbool& valid, vfloat& lower, vfloat& upper) + { + /* return no hit by default */ + lower = pos_inf; + upper = neg_inf; + + /* compute quadratic equation A*t^2 + B*t + C = 0 */ + const vfloat OO = dot(O,O); + const vfloat OdO = dot(dO,O); + const vfloat A = g * dOdO - sqr(dOdP); + const vfloat B = 2.0f * (g*OdO - dOdP*yp); + const vfloat C = g*OO - sqr(OdP) - sqr_r0*dPdP - 2.0f*r0dr*OdP; + + /* we miss the cone if determinant is smaller than zero */ + const vfloat D = B*B - 4.0f*A*C; + valid &= (D >= 0.0f & g > 0.0f); // if g <= 0 then the cone is inside a sphere end + + /* When rays are parallel to the cone surface, then the + * ray may be inside or outside the cone. We just assume a + * miss in that case, which is fine as rays inside the + * cone would anyway hit the ending spheres in that + * case. */ + valid &= abs(A) > min_rcp_input; + if (unlikely(none(valid))) { + return false; + } + + /* compute distance to front and back hit */ + const vfloat Q = sqrt(D); + const vfloat rcp_2A = rcp(2.0f*A); + t_cone_front = (-B-Q)*rcp_2A; + y_cone_front = yp + t_cone_front*dOdP; + lower = select( (y_cone_front > -(float)ulp) & (y_cone_front <= g) & (g > 0.0f), t_cone_front, vfloat(pos_inf)); +#if !defined (EMBREE_BACKFACE_CULLING_CURVES) + t_cone_back = (-B+Q)*rcp_2A; + y_cone_back = yp + t_cone_back *dOdP; + upper = select( (y_cone_back > -(float)ulp) & (y_cone_back <= g) & (g > 0.0f), t_cone_back , vfloat(neg_inf)); +#endif + return true; + } + + /* + This function intersects the ray with the end sphere at + p1. We already clip away hits that are inside the + neighboring cone segment. + + */ + + __forceinline void intersectEndSphere(vbool& valid, + const ConeGeometry& coneR, + vfloat& lower, vfloat& upper) + { + /* calculate front and back hit with end sphere */ + const Vec3vf O1 = org - p1; + const vfloat O1dO = dot(O1,dO); + const vfloat h2 = sqr(O1dO) - dOdO*(sqr(O1) - sqr(r1)); + const vfloat rhs1 = select( h2 >= 0.0f, sqrt(h2), vfloat(neg_inf) ); + + /* clip away front hit if it is inside next cone segment */ + t_sph1_front = (-O1dO - rhs1)*rcp_dOdO; + const Vec3vf hit_front = org + t_sph1_front*dO; + vbool valid_sph1_front = h2 >= 0.0f & yp + t_sph1_front*dOdP > g & !coneR.isClippedByPlane (valid, hit_front); + lower = select(valid_sph1_front, t_sph1_front, vfloat(pos_inf)); + +#if !defined(EMBREE_BACKFACE_CULLING_CURVES) + /* clip away back hit if it is inside next cone segment */ + t_sph1_back = (-O1dO + rhs1)*rcp_dOdO; + const Vec3vf hit_back = org + t_sph1_back*dO; + vbool valid_sph1_back = h2 >= 0.0f & yp + t_sph1_back*dOdP > g & !coneR.isClippedByPlane (valid, hit_back); + upper = select(valid_sph1_back, t_sph1_back, vfloat(neg_inf)); +#else + upper = vfloat(neg_inf); +#endif + } + + __forceinline void intersectBeginSphere(const vbool& valid, + vfloat& lower, vfloat& upper) + { + /* calculate front and back hit with end sphere */ + const Vec3vf O1 = org - p0; + const vfloat O1dO = dot(O1,dO); + const vfloat h2 = sqr(O1dO) - dOdO*(sqr(O1) - sqr(r0)); + const vfloat rhs1 = select( h2 >= 0.0f, sqrt(h2), vfloat(neg_inf) ); + + /* clip away front hit if it is inside next cone segment */ + t_sph0_front = (-O1dO - rhs1)*rcp_dOdO; + vbool valid_sph1_front = valid & h2 >= 0.0f & yp + t_sph0_front*dOdP < 0; + lower = select(valid_sph1_front, t_sph0_front, vfloat(pos_inf)); + +#if !defined(EMBREE_BACKFACE_CULLING_CURVES) + /* clip away back hit if it is inside next cone segment */ + t_sph0_back = (-O1dO + rhs1)*rcp_dOdO; + vbool valid_sph1_back = valid & h2 >= 0.0f & yp + t_sph0_back*dOdP < 0; + upper = select(valid_sph1_back, t_sph0_back, vfloat(neg_inf)); +#else + upper = vfloat(neg_inf); +#endif + } + + /* + + This function calculates the geometry normal of some cone hit. + + For a given hit point h (relative to p0) with a cone + starting at p0 with radius w0 and ending at p1 with + radius w1 one normally calculates the geometry normal by + first calculating the parmetric u hit location along the + cone: + + u = dot(h,dP)/dP^2 + + Using this value one can now directly calculate the + geometry normal by bending the connection vector (h-u*dP) + from hit to projected hit with some cone dependent value + dw/sqrt(dP^2) * normalize(dP): + + Ng = normalize(h-u*dP) - dw/length(dP) * normalize(dP) + + The length of the vector (h-u*dP) can also get calculated + by interpolating the radii as w0+u*dw which yields: + + Ng = (h-u*dP)/(w0+u*dw) - dw/dP^2 * dP + + Multiplying with (w0+u*dw) yield a scaled Ng': + + Ng' = (h-u*dP) - (w0+u*dw)*dw/dP^2*dP + + Inserting the definition of w0 and dw and refactoring + yield a furhter scaled Ng'': + + Ng'' = (dP^2 - dr^2) (h-q) - (r0+u*dr)*dr*dP + + Now inserting the definition of u gives and multiplying + with the denominator yields: + + Ng''' = (dP^2-dr^2)*(dP^2*h-dot(h,dP)*dP) - (dP^2*r0+dot(h,dP)*dr)*dr*dP + + Factoring out, cancelling terms, dividing by dP^2, and + factoring again yields finally: + + Ng'''' = (dP^2-dr^2)*h - dP*(dot(h,dP) + r0*dr) + + */ + + __forceinline Vec3vf Ng_cone(const vbool& front_hit) const + { +#if !defined(EMBREE_BACKFACE_CULLING_CURVES) + const vfloat y = select(front_hit, y_cone_front, y_cone_back); + const vfloat t = select(front_hit, t_cone_front, t_cone_back); + const Vec3vf h = O + t*dO; + return g*h-dP*y; +#else + const Vec3vf h = O + t_cone_front*dO; + return g*h-dP*y_cone_front; +#endif + } + + /* compute geometry normal of sphere hit as the difference + * vector from hit point to sphere center */ + + __forceinline Vec3vf Ng_sphere1(const vbool& front_hit) const + { +#if !defined(EMBREE_BACKFACE_CULLING_CURVES) + const vfloat t_sph1 = select(front_hit, t_sph1_front, t_sph1_back); + return org+t_sph1*dO-p1; +#else + return org+t_sph1_front*dO-p1; +#endif + } + + __forceinline Vec3vf Ng_sphere0(const vbool& front_hit) const + { +#if !defined(EMBREE_BACKFACE_CULLING_CURVES) + const vfloat t_sph0 = select(front_hit, t_sph0_front, t_sph0_back); + return org+t_sph0*dO-p0; +#else + return org+t_sph0_front*dO-p0; +#endif + } + + /* + This function calculates the u coordinate of a + hit. Therefore we use the hit distance y (which is zero + at the first cone clipping plane) and divide by distance + g between the clipping planes. + + */ + + __forceinline vfloat u_cone(const vbool& front_hit) const + { +#if !defined(EMBREE_BACKFACE_CULLING_CURVES) + const vfloat y = select(front_hit, y_cone_front, y_cone_back); + return clamp(y*rcp(g)); +#else + return clamp(y_cone_front*rcp(g)); +#endif + } + + private: + Vec3vf org; + Vec3vf O; + Vec3vf dO; + vfloat dOdO; + vfloat rcp_dOdO; + vfloat OdP; + vfloat dOdP; + + /* for ray/cone intersection */ + private: + vfloat yp; + vfloat y_cone_front; + vfloat t_cone_front; +#if !defined (EMBREE_BACKFACE_CULLING_CURVES) + vfloat y_cone_back; + vfloat t_cone_back; +#endif + + /* for ray/sphere intersection */ + private: + vfloat t_sph1_front; + vfloat t_sph0_front; +#if !defined (EMBREE_BACKFACE_CULLING_CURVES) + vfloat t_sph1_back; + vfloat t_sph0_back; +#endif + }; + + + template + static __forceinline bool intersectConeSphere(const vbool& valid_i, + const Vec3vf& ray_org_in, const Vec3vf& ray_dir, + const vfloat& ray_tnear, const ray_tfar_func& ray_tfar, + const Vec4vf& v0, const Vec4vf& v1, + const Vec4vf& vL, const Vec4vf& vR, + const Epilog& epilog) + { + vbool valid = valid_i; + + /* move ray origin closer to make calculations numerically stable */ + const vfloat dOdO = sqr(ray_dir); + const vfloat rcp_dOdO = rcp(dOdO); + const Vec3vf center = vfloat(0.5f)*(v0.xyz()+v1.xyz()); + const vfloat dt = dot(center-ray_org_in,ray_dir)*rcp_dOdO; + const Vec3vf ray_org = ray_org_in + dt*ray_dir; + + /* intersect with cone from v0 to v1 */ + vfloat t_cone_lower, t_cone_upper; + ConeGeometryIntersector cone (ray_org, ray_dir, dOdO, rcp_dOdO, v0, v1); + vbool validCone = valid; + cone.intersectCone(validCone, t_cone_lower, t_cone_upper); + + valid &= (validCone | (cone.g <= 0.0f)); // if cone is entirely in sphere end - check sphere + if (unlikely(none(valid))) + return false; + + /* cone hits inside the neighboring capped cones are inside the geometry and thus ignored */ + const ConeGeometry coneL (v0, vL); + const ConeGeometry coneR (v1, vR); +#if !defined(EMBREE_BACKFACE_CULLING_CURVES) + const Vec3vf hit_lower = ray_org + t_cone_lower*ray_dir; + const Vec3vf hit_upper = ray_org + t_cone_upper*ray_dir; + t_cone_lower = select (!coneL.isInsideCappedCone (validCone, hit_lower) & !coneR.isInsideCappedCone (validCone, hit_lower), t_cone_lower, vfloat(pos_inf)); + t_cone_upper = select (!coneL.isInsideCappedCone (validCone, hit_upper) & !coneR.isInsideCappedCone (validCone, hit_upper), t_cone_upper, vfloat(neg_inf)); +#endif + + /* intersect ending sphere */ + vfloat t_sph1_lower, t_sph1_upper; + vfloat t_sph0_lower = vfloat(pos_inf); + vfloat t_sph0_upper = vfloat(neg_inf); + cone.intersectEndSphere(valid, coneR, t_sph1_lower, t_sph1_upper); + + const vbool isBeginPoint = valid & (vL[0] == vfloat(pos_inf)); + if (unlikely(any(isBeginPoint))) { + cone.intersectBeginSphere (isBeginPoint, t_sph0_lower, t_sph0_upper); + } + + /* CSG union of cone and end sphere */ + vfloat t_sph_lower = min(t_sph0_lower, t_sph1_lower); + vfloat t_cone_sphere_lower = min(t_cone_lower, t_sph_lower); +#if !defined (EMBREE_BACKFACE_CULLING_CURVES) + vfloat t_sph_upper = max(t_sph0_upper, t_sph1_upper); + vfloat t_cone_sphere_upper = max(t_cone_upper, t_sph_upper); + + /* filter out hits that are not in tnear/tfar range */ + const vbool valid_lower = valid & ray_tnear <= dt+t_cone_sphere_lower & dt+t_cone_sphere_lower <= ray_tfar() & t_cone_sphere_lower != vfloat(pos_inf); + const vbool valid_upper = valid & ray_tnear <= dt+t_cone_sphere_upper & dt+t_cone_sphere_upper <= ray_tfar() & t_cone_sphere_upper != vfloat(neg_inf); + + /* check if there is a first hit */ + const vbool valid_first = valid_lower | valid_upper; + if (unlikely(none(valid_first))) + return false; + + /* construct first hit */ + const vfloat t_first = select(valid_lower, t_cone_sphere_lower, t_cone_sphere_upper); + const vbool cone_hit_first = t_first == t_cone_lower | t_first == t_cone_upper; + const vbool sph0_hit_first = t_first == t_sph0_lower | t_first == t_sph0_upper; + const Vec3vf Ng_first = select(cone_hit_first, cone.Ng_cone(valid_lower), select (sph0_hit_first, cone.Ng_sphere0(valid_lower), cone.Ng_sphere1(valid_lower))); + const vfloat u_first = select(cone_hit_first, cone.u_cone(valid_lower), select (sph0_hit_first, vfloat(zero), vfloat(one))); + + /* invoke intersection filter for first hit */ + RoundLineIntersectorHitM hit(u_first,zero,dt+t_first,Ng_first); + const bool is_hit_first = epilog(valid_first, hit); + + /* check for possible second hits before potentially accepted hit */ + const vfloat t_second = t_cone_sphere_upper; + const vbool valid_second = valid_lower & valid_upper & (dt+t_cone_sphere_upper <= ray_tfar()); + if (unlikely(none(valid_second))) + return is_hit_first; + + /* invoke intersection filter for second hit */ + const vbool cone_hit_second = t_second == t_cone_lower | t_second == t_cone_upper; + const vbool sph0_hit_second = t_second == t_sph0_lower | t_second == t_sph0_upper; + const Vec3vf Ng_second = select(cone_hit_second, cone.Ng_cone(false), select (sph0_hit_second, cone.Ng_sphere0(false), cone.Ng_sphere1(false))); + const vfloat u_second = select(cone_hit_second, cone.u_cone(false), select (sph0_hit_second, vfloat(zero), vfloat(one))); + + hit = RoundLineIntersectorHitM(u_second,zero,dt+t_second,Ng_second); + const bool is_hit_second = epilog(valid_second, hit); + + return is_hit_first | is_hit_second; +#else + /* filter out hits that are not in tnear/tfar range */ + const vbool valid_lower = valid & ray_tnear <= dt+t_cone_sphere_lower & dt+t_cone_sphere_lower <= ray_tfar() & t_cone_sphere_lower != vfloat(pos_inf); + + /* check if there is a valid hit */ + if (unlikely(none(valid_lower))) + return false; + + /* construct first hit */ + const vbool cone_hit_first = t_cone_sphere_lower == t_cone_lower | t_cone_sphere_lower == t_cone_upper; + const vbool sph0_hit_first = t_cone_sphere_lower == t_sph0_lower | t_cone_sphere_lower == t_sph0_upper; + const Vec3vf Ng_first = select(cone_hit_first, cone.Ng_cone(valid_lower), select (sph0_hit_first, cone.Ng_sphere0(valid_lower), cone.Ng_sphere1(valid_lower))); + const vfloat u_first = select(cone_hit_first, cone.u_cone(valid_lower), select (sph0_hit_first, vfloat(zero), vfloat(one))); + + /* invoke intersection filter for first hit */ + RoundLineIntersectorHitM hit(u_first,zero,dt+t_cone_sphere_lower,Ng_first); + const bool is_hit_first = epilog(valid_lower, hit); + + return is_hit_first; +#endif + } + + } // end namespace __roundline_internal + + template + struct RoundLinearCurveIntersector1 + { + typedef CurvePrecalculations1 Precalculations; + + struct ray_tfar { + Ray& ray; + __forceinline ray_tfar(Ray& ray) : ray(ray) {} + __forceinline vfloat operator() () const { return ray.tfar; }; + }; + + template + static __forceinline bool intersect(const vbool& valid_i, + Ray& ray, + IntersectContext* context, + const LineSegments* geom, + const Precalculations& pre, + const Vec4vf& v0i, const Vec4vf& v1i, + const Vec4vf& vLi, const Vec4vf& vRi, + const Epilog& epilog) + { + const Vec3vf ray_org(ray.org.x, ray.org.y, ray.org.z); + const Vec3vf ray_dir(ray.dir.x, ray.dir.y, ray.dir.z); + const vfloat ray_tnear(ray.tnear()); + const Vec4vf v0 = enlargeRadiusToMinWidth(context,geom,ray_org,v0i); + const Vec4vf v1 = enlargeRadiusToMinWidth(context,geom,ray_org,v1i); + const Vec4vf vL = enlargeRadiusToMinWidth(context,geom,ray_org,vLi); + const Vec4vf vR = enlargeRadiusToMinWidth(context,geom,ray_org,vRi); + return __roundline_internal::intersectConeSphere(valid_i,ray_org,ray_dir,ray_tnear,ray_tfar(ray),v0,v1,vL,vR,epilog); + } + }; + + template + struct RoundLinearCurveIntersectorK + { + typedef CurvePrecalculationsK Precalculations; + + struct ray_tfar { + RayK& ray; + size_t k; + __forceinline ray_tfar(RayK& ray, size_t k) : ray(ray), k(k) {} + __forceinline vfloat operator() () const { return ray.tfar[k]; }; + }; + + template + static __forceinline bool intersect(const vbool& valid_i, + RayK& ray, size_t k, + IntersectContext* context, + const LineSegments* geom, + const Precalculations& pre, + const Vec4vf& v0i, const Vec4vf& v1i, + const Vec4vf& vLi, const Vec4vf& vRi, + const Epilog& epilog) + { + const Vec3vf ray_org(ray.org.x[k], ray.org.y[k], ray.org.z[k]); + const Vec3vf ray_dir(ray.dir.x[k], ray.dir.y[k], ray.dir.z[k]); + const vfloat ray_tnear = ray.tnear()[k]; + const Vec4vf v0 = enlargeRadiusToMinWidth(context,geom,ray_org,v0i); + const Vec4vf v1 = enlargeRadiusToMinWidth(context,geom,ray_org,v1i); + const Vec4vf vL = enlargeRadiusToMinWidth(context,geom,ray_org,vLi); + const Vec4vf vR = enlargeRadiusToMinWidth(context,geom,ray_org,vRi); + return __roundline_internal::intersectConeSphere(valid_i,ray_org,ray_dir,ray_tnear,ray_tfar(ray,k),v0,v1,vL,vR,epilog); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/roundlinei_intersector.h b/thirdparty/embree/kernels/geometry/roundlinei_intersector.h new file mode 100644 index 000000000000..079817335ebf --- /dev/null +++ b/thirdparty/embree/kernels/geometry/roundlinei_intersector.h @@ -0,0 +1,136 @@ +// ======================================================================== // +// Copyright 2009-2020 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "roundline_intersector.h" +#include "intersector_epilog.h" + +namespace embree +{ + namespace isa + { + template + struct RoundLinearCurveMiIntersector1 + { + typedef LineMi Primitive; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& line) + { + STAT3(normal.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1,vL,vR; line.gather(v0,v1,vL,vR,geom); + const vbool valid = line.template valid(); + RoundLinearCurveIntersector1::intersect(valid,ray,context,geom,pre,v0,v1,vL,vR,Intersect1EpilogM(ray,context,line.geomID(),line.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& line) + { + STAT3(shadow.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1,vL,vR; line.gather(v0,v1,vL,vR,geom); + const vbool valid = line.template valid(); + return RoundLinearCurveIntersector1::intersect(valid,ray,context,geom,pre,v0,v1,vL,vR,Occluded1EpilogM(ray,context,line.geomID(),line.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& line) + { + return PrimitivePointQuery1::pointQuery(query, context, line); + } + }; + + template + struct RoundLinearCurveMiMBIntersector1 + { + typedef LineMi Primitive; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& line) + { + STAT3(normal.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1,vL,vR; line.gather(v0,v1,vL,vR,geom,ray.time()); + const vbool valid = line.template valid(); + RoundLinearCurveIntersector1::intersect(valid,ray,context,geom,pre,v0,v1,vL,vR,Intersect1EpilogM(ray,context,line.geomID(),line.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& line) + { + STAT3(shadow.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1,vL,vR; line.gather(v0,v1,vL,vR,geom,ray.time()); + const vbool valid = line.template valid(); + return RoundLinearCurveIntersector1::intersect(valid,ray,context,geom,pre,v0,v1,vL,vR,Occluded1EpilogM(ray,context,line.geomID(),line.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& line) + { + return PrimitivePointQuery1::pointQuery(query, context, line); + } + }; + + template + struct RoundLinearCurveMiIntersectorK + { + typedef LineMi Primitive; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& line) + { + STAT3(normal.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1,vL,vR; line.gather(v0,v1,vL,vR,geom); + const vbool valid = line.template valid(); + RoundLinearCurveIntersectorK::intersect(valid,ray,k,context,geom,pre,v0,v1,vL,vR,Intersect1KEpilogM(ray,k,context,line.geomID(),line.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& line) + { + STAT3(shadow.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1,vL,vR; line.gather(v0,v1,vL,vR,geom); + const vbool valid = line.template valid(); + return RoundLinearCurveIntersectorK::intersect(valid,ray,k,context,geom,pre,v0,v1,vL,vR,Occluded1KEpilogM(ray,k,context,line.geomID(),line.primID())); + } + }; + + template + struct RoundLinearCurveMiMBIntersectorK + { + typedef LineMi Primitive; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& line) + { + STAT3(normal.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1,vL,vR; line.gather(v0,v1,vL,vR,geom,ray.time()[k]); + const vbool valid = line.template valid(); + RoundLinearCurveIntersectorK::intersect(valid,ray,k,context,geom,pre,v0,v1,vL,vR,Intersect1KEpilogM(ray,k,context,line.geomID(),line.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& line) + { + STAT3(shadow.trav_prims,1,1,1); + const LineSegments* geom = context->scene->get(line.geomID()); + Vec4vf v0,v1,vL,vR; line.gather(v0,v1,vL,vR,geom,ray.time()[k]); + const vbool valid = line.template valid(); + return RoundLinearCurveIntersectorK::intersect(valid,ray,k,context,geom,pre,v0,v1,vL,vR,Occluded1KEpilogM(ray,k,context,line.geomID(),line.primID())); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/sphere_intersector.h b/thirdparty/embree/kernels/geometry/sphere_intersector.h new file mode 100644 index 000000000000..3ab90c29ef71 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/sphere_intersector.h @@ -0,0 +1,183 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" +#include "../common/scene_points.h" +#include "curve_intersector_precalculations.h" + +namespace embree +{ + namespace isa + { + template + struct SphereIntersectorHitM + { + __forceinline SphereIntersectorHitM() {} + + __forceinline SphereIntersectorHitM(const vfloat& t, const Vec3vf& Ng) + : vt(t), vNg(Ng) {} + + __forceinline void finalize() {} + + __forceinline Vec2f uv(const size_t i) const { + return Vec2f(0.0f, 0.0f); + } + __forceinline float t(const size_t i) const { + return vt[i]; + } + __forceinline Vec3fa Ng(const size_t i) const { + return Vec3fa(vNg.x[i], vNg.y[i], vNg.z[i]); + } + + public: + vfloat vt; + Vec3vf vNg; + }; + + template + struct SphereIntersector1 + { + typedef CurvePrecalculations1 Precalculations; + + template + static __forceinline bool intersect( + const vbool& valid_i, Ray& ray, + const Precalculations& pre, const Vec4vf& v0, const Epilog& epilog) + { + vbool valid = valid_i; + + const vfloat rd2 = rcp(dot(ray.dir, ray.dir)); + const Vec3vf ray_org(ray.org.x, ray.org.y, ray.org.z); + const Vec3vf ray_dir(ray.dir.x, ray.dir.y, ray.dir.z); + const Vec3vf center = v0.xyz(); + const vfloat radius = v0.w; + + const Vec3vf c0 = center - ray_org; + const vfloat projC0 = dot(c0, ray_dir) * rd2; + const Vec3vf perp = c0 - projC0 * ray_dir; + const vfloat l2 = dot(perp, perp); + const vfloat r2 = radius * radius; + valid &= (l2 <= r2); + if (unlikely(none(valid))) + return false; + + const vfloat td = sqrt((r2 - l2) * rd2); + const vfloat t_front = projC0 - td; + const vfloat t_back = projC0 + td; + + const vbool valid_front = valid & (ray.tnear() <= t_front) & (t_front <= ray.tfar); + const vbool valid_back = valid & (ray.tnear() <= t_back ) & (t_back <= ray.tfar); + + /* check if there is a first hit */ + const vbool valid_first = valid_front | valid_back; + if (unlikely(none(valid_first))) + return false; + + /* construct first hit */ + const vfloat td_front = -td; + const vfloat td_back = +td; + const vfloat t_first = select(valid_front, t_front, t_back); + const Vec3vf Ng_first = select(valid_front, td_front, td_back) * ray_dir - perp; + SphereIntersectorHitM hit(t_first, Ng_first); + + /* invoke intersection filter for first hit */ + const bool is_hit_first = epilog(valid_first, hit); + + /* check for possible second hits before potentially accepted hit */ + const vfloat t_second = t_back; + const vbool valid_second = valid_front & valid_back & (t_second <= ray.tfar); + if (unlikely(none(valid_second))) + return is_hit_first; + + /* invoke intersection filter for second hit */ + const Vec3vf Ng_second = td_back * ray_dir - perp; + hit = SphereIntersectorHitM (t_second, Ng_second); + const bool is_hit_second = epilog(valid_second, hit); + + return is_hit_first | is_hit_second; + } + + template + static __forceinline bool intersect( + const vbool& valid_i, Ray& ray, IntersectContext* context, const Points* geom, + const Precalculations& pre, const Vec4vf& v0i, const Epilog& epilog) + { + const Vec3vf ray_org(ray.org.x, ray.org.y, ray.org.z); + const Vec4vf v0 = enlargeRadiusToMinWidth(context,geom,ray_org,v0i); + return intersect(valid_i,ray,pre,v0,epilog); + } + }; + + template + struct SphereIntersectorK + { + typedef CurvePrecalculationsK Precalculations; + + template + static __forceinline bool intersect(const vbool& valid_i, + RayK& ray, size_t k, + IntersectContext* context, + const Points* geom, + const Precalculations& pre, + const Vec4vf& v0i, + const Epilog& epilog) + { + vbool valid = valid_i; + + const Vec3vf ray_org(ray.org.x[k], ray.org.y[k], ray.org.z[k]); + const Vec3vf ray_dir(ray.dir.x[k], ray.dir.y[k], ray.dir.z[k]); + const vfloat rd2 = rcp(dot(ray_dir, ray_dir)); + + const Vec4vf v0 = enlargeRadiusToMinWidth(context,geom,ray_org,v0i); + const Vec3vf center = v0.xyz(); + const vfloat radius = v0.w; + + const Vec3vf c0 = center - ray_org; + const vfloat projC0 = dot(c0, ray_dir) * rd2; + const Vec3vf perp = c0 - projC0 * ray_dir; + const vfloat l2 = dot(perp, perp); + const vfloat r2 = radius * radius; + valid &= (l2 <= r2); + if (unlikely(none(valid))) + return false; + + const vfloat td = sqrt((r2 - l2) * rd2); + const vfloat t_front = projC0 - td; + const vfloat t_back = projC0 + td; + + const vbool valid_front = valid & (ray.tnear()[k] <= t_front) & (t_front <= ray.tfar[k]); + const vbool valid_back = valid & (ray.tnear()[k] <= t_back ) & (t_back <= ray.tfar[k]); + + /* check if there is a first hit */ + const vbool valid_first = valid_front | valid_back; + if (unlikely(none(valid_first))) + return false; + + /* construct first hit */ + const vfloat td_front = -td; + const vfloat td_back = +td; + const vfloat t_first = select(valid_front, t_front, t_back); + const Vec3vf Ng_first = select(valid_front, td_front, td_back) * ray_dir - perp; + SphereIntersectorHitM hit(t_first, Ng_first); + + /* invoke intersection filter for first hit */ + const bool is_hit_first = epilog(valid_first, hit); + + /* check for possible second hits before potentially accepted hit */ + const vfloat t_second = t_back; + const vbool valid_second = valid_front & valid_back & (t_second <= ray.tfar[k]); + if (unlikely(none(valid_second))) + return is_hit_first; + + /* invoke intersection filter for second hit */ + const Vec3vf Ng_second = td_back * ray_dir - perp; + hit = SphereIntersectorHitM (t_second, Ng_second); + const bool is_hit_second = epilog(valid_second, hit); + + return is_hit_first | is_hit_second; + } + }; + } // namespace isa +} // namespace embree diff --git a/thirdparty/embree/kernels/geometry/spherei_intersector.h b/thirdparty/embree/kernels/geometry/spherei_intersector.h new file mode 100644 index 000000000000..11468476022d --- /dev/null +++ b/thirdparty/embree/kernels/geometry/spherei_intersector.h @@ -0,0 +1,156 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "intersector_epilog.h" +#include "pointi.h" +#include "sphere_intersector.h" + +namespace embree +{ + namespace isa + { + template + struct SphereMiIntersector1 + { + typedef PointMi Primitive; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, + RayHit& ray, + IntersectContext* context, + const Primitive& sphere) + { + STAT3(normal.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(sphere.geomID()); + Vec4vf v0; sphere.gather(v0, geom); + const vbool valid = sphere.template valid(); + SphereIntersector1::intersect( + valid, ray, context, geom, pre, v0, Intersect1EpilogM(ray, context, sphere.geomID(), sphere.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, + Ray& ray, + IntersectContext* context, + const Primitive& sphere) + { + STAT3(shadow.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(sphere.geomID()); + Vec4vf v0; sphere.gather(v0, geom); + const vbool valid = sphere.template valid(); + return SphereIntersector1::intersect( + valid, ray, context, geom, pre, v0, Occluded1EpilogM(ray, context, sphere.geomID(), sphere.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, + PointQueryContext* context, + const Primitive& sphere) + { + return PrimitivePointQuery1::pointQuery(query, context, sphere); + } + }; + + template + struct SphereMiMBIntersector1 + { + typedef PointMi Primitive; + typedef CurvePrecalculations1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, + RayHit& ray, + IntersectContext* context, + const Primitive& sphere) + { + STAT3(normal.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(sphere.geomID()); + Vec4vf v0; sphere.gather(v0, geom, ray.time()); + const vbool valid = sphere.template valid(); + SphereIntersector1::intersect( + valid, ray, context, geom, pre, v0, Intersect1EpilogM(ray, context, sphere.geomID(), sphere.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, + Ray& ray, + IntersectContext* context, + const Primitive& sphere) + { + STAT3(shadow.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(sphere.geomID()); + Vec4vf v0; sphere.gather(v0, geom, ray.time()); + const vbool valid = sphere.template valid(); + return SphereIntersector1::intersect( + valid, ray, context, geom, pre, v0, Occluded1EpilogM(ray, context, sphere.geomID(), sphere.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, + PointQueryContext* context, + const Primitive& sphere) + { + return PrimitivePointQuery1::pointQuery(query, context, sphere); + } + }; + + template + struct SphereMiIntersectorK + { + typedef PointMi Primitive; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline void intersect( + const Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& sphere) + { + STAT3(normal.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(sphere.geomID()); + Vec4vf v0; sphere.gather(v0, geom); + const vbool valid = sphere.template valid(); + SphereIntersectorK::intersect( + valid, ray, k, context, geom, pre, v0, + Intersect1KEpilogM(ray, k, context, sphere.geomID(), sphere.primID())); + } + + static __forceinline bool occluded( + const Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& sphere) + { + STAT3(shadow.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(sphere.geomID()); + Vec4vf v0; sphere.gather(v0, geom); + const vbool valid = sphere.template valid(); + return SphereIntersectorK::intersect( + valid, ray, k, context, geom, pre, v0, + Occluded1KEpilogM(ray, k, context, sphere.geomID(), sphere.primID())); + } + }; + + template + struct SphereMiMBIntersectorK + { + typedef PointMi Primitive; + typedef CurvePrecalculationsK Precalculations; + + static __forceinline void intersect( + const Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& sphere) + { + STAT3(normal.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(sphere.geomID()); + Vec4vf v0; sphere.gather(v0, geom, ray.time()[k]); + const vbool valid = sphere.template valid(); + SphereIntersectorK::intersect( + valid, ray, k, context, geom, pre, v0, + Intersect1KEpilogM(ray, k, context, sphere.geomID(), sphere.primID())); + } + + static __forceinline bool occluded( + const Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& sphere) + { + STAT3(shadow.trav_prims, 1, 1, 1); + const Points* geom = context->scene->get(sphere.geomID()); + Vec4vf v0; sphere.gather(v0, geom, ray.time()[k]); + const vbool valid = sphere.template valid(); + return SphereIntersectorK::intersect( + valid, ray, k, context, geom, pre, v0, + Occluded1KEpilogM(ray, k, context, sphere.geomID(), sphere.primID())); + } + }; + } // namespace isa +} // namespace embree diff --git a/thirdparty/embree/kernels/geometry/subdivpatch1.h b/thirdparty/embree/kernels/geometry/subdivpatch1.h new file mode 100644 index 000000000000..94ad46ad87ad --- /dev/null +++ b/thirdparty/embree/kernels/geometry/subdivpatch1.h @@ -0,0 +1,38 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../geometry/primitive.h" +#include "../subdiv/subdivpatch1base.h" + +namespace embree +{ + + struct __aligned(64) SubdivPatch1 : public SubdivPatch1Base + { + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + + static Type type; + + public: + + /*! constructor for cached subdiv patch */ + SubdivPatch1 (const unsigned int gID, + const unsigned int pID, + const unsigned int subPatch, + const SubdivMesh *const mesh, + const size_t time, + const Vec2f uv[4], + const float edge_level[4], + const int subdiv[4], + const int simd_width) + : SubdivPatch1Base(gID,pID,subPatch,mesh,time,uv,edge_level,subdiv,simd_width) {} + }; +} diff --git a/thirdparty/embree/kernels/geometry/subdivpatch1_intersector.h b/thirdparty/embree/kernels/geometry/subdivpatch1_intersector.h new file mode 100644 index 000000000000..74ec1de25868 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/subdivpatch1_intersector.h @@ -0,0 +1,237 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "subdivpatch1.h" +#include "grid_soa.h" +#include "grid_soa_intersector1.h" +#include "grid_soa_intersector_packet.h" +#include "../common/ray.h" + +namespace embree +{ + namespace isa + { + template + class SubdivPatch1Precalculations : public T + { + public: + __forceinline SubdivPatch1Precalculations (const Ray& ray, const void* ptr) + : T(ray,ptr) {} + }; + + template + class SubdivPatch1PrecalculationsK : public T + { + public: + __forceinline SubdivPatch1PrecalculationsK (const vbool& valid, RayK& ray) + : T(valid,ray) {} + }; + + class SubdivPatch1Intersector1 + { + public: + typedef GridSOA Primitive; + typedef SubdivPatch1Precalculations Precalculations; + + static __forceinline bool processLazyNode(Precalculations& pre, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + lazy_node = prim->root(0); + pre.grid = (Primitive*)prim; + return false; + } + + /*! Intersect a ray with the primitive. */ + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive* prim, size_t ty, const TravRay &tray, size_t& lazy_node) + { + if (likely(ty == 0)) GridSOAIntersector1::intersect(pre,ray,context,prim,lazy_node); + else processLazyNode(pre,context,prim,lazy_node); + } + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHit& ray, IntersectContext* context, size_t ty0, const Primitive* prim, size_t ty, const TravRay &tray, size_t& lazy_node) { + intersect(This,pre,ray,context,prim,ty,tray,lazy_node); + } + + /*! Test if the ray is occluded by the primitive */ + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive* prim, size_t ty, const TravRay &tray, size_t& lazy_node) + { + if (likely(ty == 0)) return GridSOAIntersector1::occluded(pre,ray,context,prim,lazy_node); + else return processLazyNode(pre,context,prim,lazy_node); + } + + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, Ray& ray, IntersectContext* context, size_t ty0, const Primitive* prim, size_t ty, const TravRay &tray, size_t& lazy_node) { + return occluded(This,pre,ray,context,prim,ty,tray,lazy_node); + } + + template + static __forceinline bool pointQuery(const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context, const Primitive* prim, size_t ty, const TravPointQuery &tquery, size_t& lazy_node) + { + // TODO: PointQuery implement + assert(false && "not implemented"); + return false; + } + + template + static __forceinline bool pointQuery(const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context, size_t ty0, const Primitive* prim, size_t ty, const TravPointQuery &tquery, size_t& lazy_node) { + return pointQuery(This,query,context,prim,ty,tquery,lazy_node); + } + }; + + class SubdivPatch1MBIntersector1 + { + public: + typedef SubdivPatch1 Primitive; + typedef GridSOAMBIntersector1::Precalculations Precalculations; + + static __forceinline bool processLazyNode(Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive* prim_i, size_t& lazy_node) + { + Primitive* prim = (Primitive*) prim_i; + GridSOA* grid = nullptr; + grid = (GridSOA*) prim->root_ref.get(); + pre.itime = getTimeSegment(ray.time(), float(grid->time_steps-1), pre.ftime); + lazy_node = grid->root(pre.itime); + pre.grid = grid; + return false; + } + + /*! Intersect a ray with the primitive. */ + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive* prim, size_t ty, const TravRay &tray, size_t& lazy_node) + { + if (likely(ty == 0)) GridSOAMBIntersector1::intersect(pre,ray,context,prim,lazy_node); + else processLazyNode(pre,ray,context,prim,lazy_node); + } + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHit& ray, IntersectContext* context, size_t ty0, const Primitive* prim, size_t ty, const TravRay &tray, size_t& lazy_node) { + intersect(This,pre,ray,context,prim,ty,tray,lazy_node); + } + + /*! Test if the ray is occluded by the primitive */ + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive* prim, size_t ty, const TravRay &tray, size_t& lazy_node) + { + if (likely(ty == 0)) return GridSOAMBIntersector1::occluded(pre,ray,context,prim,lazy_node); + else return processLazyNode(pre,ray,context,prim,lazy_node); + } + + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, Ray& ray, IntersectContext* context, size_t ty0, const Primitive* prim, size_t ty, const TravRay &tray, size_t& lazy_node) { + return occluded(This,pre,ray,context,prim,ty,tray,lazy_node); + } + + template + static __forceinline bool pointQuery(const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context, const Primitive* prim, size_t ty, const TravPointQuery &tquery, size_t& lazy_node) + { + // TODO: PointQuery implement + assert(false && "not implemented"); + return false; + } + + template + static __forceinline bool pointQuery(const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context, size_t ty0, const Primitive* prim, size_t ty, const TravPointQuery &tquery, size_t& lazy_node) { + return pointQuery(This,query,context,prim,ty,tquery,lazy_node); + } + }; + + template + struct SubdivPatch1IntersectorK + { + typedef GridSOA Primitive; + typedef SubdivPatch1PrecalculationsK::Precalculations> Precalculations; + + static __forceinline bool processLazyNode(Precalculations& pre, IntersectContext* context, const Primitive* prim, size_t& lazy_node) + { + lazy_node = prim->root(0); + pre.grid = (Primitive*)prim; + return false; + } + + template + static __forceinline void intersect(const vbool& valid, const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive* prim, size_t ty, const TravRayK &tray, size_t& lazy_node) + { + if (likely(ty == 0)) GridSOAIntersectorK::intersect(valid,pre,ray,context,prim,lazy_node); + else processLazyNode(pre,context,prim,lazy_node); + } + + template + static __forceinline vbool occluded(const vbool& valid, const Accel::Intersectors* This, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive* prim, size_t ty, const TravRayK &tray, size_t& lazy_node) + { + if (likely(ty == 0)) return GridSOAIntersectorK::occluded(valid,pre,ray,context,prim,lazy_node); + else return processLazyNode(pre,context,prim,lazy_node); + } + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t ty, const TravRay &tray, size_t& lazy_node) + { + if (likely(ty == 0)) GridSOAIntersectorK::intersect(pre,ray,k,context,prim,lazy_node); + else processLazyNode(pre,context,prim,lazy_node); + } + + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t ty, const TravRay &tray, size_t& lazy_node) + { + if (likely(ty == 0)) return GridSOAIntersectorK::occluded(pre,ray,k,context,prim,lazy_node); + else return processLazyNode(pre,context,prim,lazy_node); + } + }; + + typedef SubdivPatch1IntersectorK<4> SubdivPatch1Intersector4; + typedef SubdivPatch1IntersectorK<8> SubdivPatch1Intersector8; + typedef SubdivPatch1IntersectorK<16> SubdivPatch1Intersector16; + + template + struct SubdivPatch1MBIntersectorK + { + typedef SubdivPatch1 Primitive; + //typedef GridSOAMBIntersectorK::Precalculations Precalculations; + typedef SubdivPatch1PrecalculationsK::Precalculations> Precalculations; + + static __forceinline bool processLazyNode(Precalculations& pre, IntersectContext* context, const Primitive* prim_i, size_t& lazy_node) + { + Primitive* prim = (Primitive*) prim_i; + GridSOA* grid = (GridSOA*) prim->root_ref.get(); + lazy_node = grid->troot; + pre.grid = grid; + return false; + } + + template + static __forceinline void intersect(const vbool& valid, const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive* prim, size_t ty, const TravRayK &tray, size_t& lazy_node) + { + if (likely(ty == 0)) GridSOAMBIntersectorK::intersect(valid,pre,ray,context,prim,lazy_node); + else processLazyNode(pre,context,prim,lazy_node); + } + + template + static __forceinline vbool occluded(const vbool& valid, const Accel::Intersectors* This, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive* prim, size_t ty, const TravRayK &tray, size_t& lazy_node) + { + if (likely(ty == 0)) return GridSOAMBIntersectorK::occluded(valid,pre,ray,context,prim,lazy_node); + else return processLazyNode(pre,context,prim,lazy_node); + } + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t ty, const TravRay &tray, size_t& lazy_node) + { + if (likely(ty == 0)) GridSOAMBIntersectorK::intersect(pre,ray,k,context,prim,lazy_node); + else processLazyNode(pre,context,prim,lazy_node); + } + + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t ty, const TravRay &tray, size_t& lazy_node) + { + if (likely(ty == 0)) return GridSOAMBIntersectorK::occluded(pre,ray,k,context,prim,lazy_node); + else return processLazyNode(pre,context,prim,lazy_node); + } + }; + + typedef SubdivPatch1MBIntersectorK<4> SubdivPatch1MBIntersector4; + typedef SubdivPatch1MBIntersectorK<8> SubdivPatch1MBIntersector8; + typedef SubdivPatch1MBIntersectorK<16> SubdivPatch1MBIntersector16; + } +} diff --git a/thirdparty/embree/kernels/geometry/subgrid.h b/thirdparty/embree/kernels/geometry/subgrid.h new file mode 100644 index 000000000000..39fa6fb0f048 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/subgrid.h @@ -0,0 +1,517 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/ray.h" +#include "../common/scene_grid_mesh.h" +#include "../bvh/bvh.h" + +namespace embree +{ + /* Stores M quads from an indexed face set */ + struct SubGrid + { + /* Virtual interface to query information about the quad type */ + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + /* primitive supports multiple time segments */ + static const bool singleTimeSegment = false; + + /* Returns maximum number of stored quads */ + static __forceinline size_t max_size() { return 1; } + + /* Returns required number of primitive blocks for N primitives */ + static __forceinline size_t blocks(size_t N) { return (N+max_size()-1)/max_size(); } + + public: + + /* Default constructor */ + __forceinline SubGrid() { } + + /* Construction from vertices and IDs */ + __forceinline SubGrid(const unsigned int x, + const unsigned int y, + const unsigned int geomID, + const unsigned int primID) + : _x(x), _y(y), _geomID(geomID), _primID(primID) + { + } + + __forceinline bool invalid3x3X() const { return (unsigned int)_x & (1<<15); } + __forceinline bool invalid3x3Y() const { return (unsigned int)_y & (1<<15); } + + /* Gather the quads */ + __forceinline void gather(Vec3vf4& p0, + Vec3vf4& p1, + Vec3vf4& p2, + Vec3vf4& p3, + const GridMesh* const mesh, + const GridMesh::Grid &g) const + { + /* first quad always valid */ + const size_t vtxID00 = g.startVtxID + x() + y() * g.lineVtxOffset; + const size_t vtxID01 = vtxID00 + 1; + const vfloat4 vtx00 = vfloat4::loadu(mesh->vertexPtr(vtxID00)); + const vfloat4 vtx01 = vfloat4::loadu(mesh->vertexPtr(vtxID01)); + const size_t vtxID10 = vtxID00 + g.lineVtxOffset; + const size_t vtxID11 = vtxID01 + g.lineVtxOffset; + const vfloat4 vtx10 = vfloat4::loadu(mesh->vertexPtr(vtxID10)); + const vfloat4 vtx11 = vfloat4::loadu(mesh->vertexPtr(vtxID11)); + + /* deltaX => vtx02, vtx12 */ + const size_t deltaX = invalid3x3X() ? 0 : 1; + const size_t vtxID02 = vtxID01 + deltaX; + const vfloat4 vtx02 = vfloat4::loadu(mesh->vertexPtr(vtxID02)); + const size_t vtxID12 = vtxID11 + deltaX; + const vfloat4 vtx12 = vfloat4::loadu(mesh->vertexPtr(vtxID12)); + + /* deltaY => vtx20, vtx21 */ + const size_t deltaY = invalid3x3Y() ? 0 : g.lineVtxOffset; + const size_t vtxID20 = vtxID10 + deltaY; + const size_t vtxID21 = vtxID11 + deltaY; + const vfloat4 vtx20 = vfloat4::loadu(mesh->vertexPtr(vtxID20)); + const vfloat4 vtx21 = vfloat4::loadu(mesh->vertexPtr(vtxID21)); + + /* deltaX/deltaY => vtx22 */ + const size_t vtxID22 = vtxID11 + deltaX + deltaY; + const vfloat4 vtx22 = vfloat4::loadu(mesh->vertexPtr(vtxID22)); + + transpose(vtx00,vtx01,vtx11,vtx10,p0.x,p0.y,p0.z); + transpose(vtx01,vtx02,vtx12,vtx11,p1.x,p1.y,p1.z); + transpose(vtx11,vtx12,vtx22,vtx21,p2.x,p2.y,p2.z); + transpose(vtx10,vtx11,vtx21,vtx20,p3.x,p3.y,p3.z); + } + + template + __forceinline vfloat4 getVertexMB(const GridMesh* const mesh, const size_t offset, const size_t itime, const float ftime) const + { + const T v0 = T::loadu(mesh->vertexPtr(offset,itime+0)); + const T v1 = T::loadu(mesh->vertexPtr(offset,itime+1)); + return lerp(v0,v1,ftime); + } + + /* Gather the quads */ + __forceinline void gatherMB(Vec3vf4& p0, + Vec3vf4& p1, + Vec3vf4& p2, + Vec3vf4& p3, + const GridMesh* const mesh, + const GridMesh::Grid &g, + const size_t itime, + const float ftime) const + { + /* first quad always valid */ + const size_t vtxID00 = g.startVtxID + x() + y() * g.lineVtxOffset; + const size_t vtxID01 = vtxID00 + 1; + const vfloat4 vtx00 = getVertexMB(mesh,vtxID00,itime,ftime); + const vfloat4 vtx01 = getVertexMB(mesh,vtxID01,itime,ftime); + const size_t vtxID10 = vtxID00 + g.lineVtxOffset; + const size_t vtxID11 = vtxID01 + g.lineVtxOffset; + const vfloat4 vtx10 = getVertexMB(mesh,vtxID10,itime,ftime); + const vfloat4 vtx11 = getVertexMB(mesh,vtxID11,itime,ftime); + + /* deltaX => vtx02, vtx12 */ + const size_t deltaX = invalid3x3X() ? 0 : 1; + const size_t vtxID02 = vtxID01 + deltaX; + const vfloat4 vtx02 = getVertexMB(mesh,vtxID02,itime,ftime); + const size_t vtxID12 = vtxID11 + deltaX; + const vfloat4 vtx12 = getVertexMB(mesh,vtxID12,itime,ftime); + + /* deltaY => vtx20, vtx21 */ + const size_t deltaY = invalid3x3Y() ? 0 : g.lineVtxOffset; + const size_t vtxID20 = vtxID10 + deltaY; + const size_t vtxID21 = vtxID11 + deltaY; + const vfloat4 vtx20 = getVertexMB(mesh,vtxID20,itime,ftime); + const vfloat4 vtx21 = getVertexMB(mesh,vtxID21,itime,ftime); + + /* deltaX/deltaY => vtx22 */ + const size_t vtxID22 = vtxID11 + deltaX + deltaY; + const vfloat4 vtx22 = getVertexMB(mesh,vtxID22,itime,ftime); + + transpose(vtx00,vtx01,vtx11,vtx10,p0.x,p0.y,p0.z); + transpose(vtx01,vtx02,vtx12,vtx11,p1.x,p1.y,p1.z); + transpose(vtx11,vtx12,vtx22,vtx21,p2.x,p2.y,p2.z); + transpose(vtx10,vtx11,vtx21,vtx20,p3.x,p3.y,p3.z); + } + + + + /* Gather the quads */ + __forceinline void gather(Vec3vf4& p0, + Vec3vf4& p1, + Vec3vf4& p2, + Vec3vf4& p3, + const Scene *const scene) const + { + const GridMesh* const mesh = scene->get(geomID()); + const GridMesh::Grid &g = mesh->grid(primID()); + gather(p0,p1,p2,p3,mesh,g); + } + + /* Gather the quads in the motion blur case */ + __forceinline void gatherMB(Vec3vf4& p0, + Vec3vf4& p1, + Vec3vf4& p2, + Vec3vf4& p3, + const Scene *const scene, + const size_t itime, + const float ftime) const + { + const GridMesh* const mesh = scene->get(geomID()); + const GridMesh::Grid &g = mesh->grid(primID()); + gatherMB(p0,p1,p2,p3,mesh,g,itime,ftime); + } + + /* Gather the quads */ + __forceinline void gather(Vec3fa vtx[16], const Scene *const scene) const + { + const GridMesh* mesh = scene->get(geomID()); + const GridMesh::Grid &g = mesh->grid(primID()); + + /* first quad always valid */ + const size_t vtxID00 = g.startVtxID + x() + y() * g.lineVtxOffset; + const size_t vtxID01 = vtxID00 + 1; + const Vec3fa vtx00 = Vec3fa::loadu(mesh->vertexPtr(vtxID00)); + const Vec3fa vtx01 = Vec3fa::loadu(mesh->vertexPtr(vtxID01)); + const size_t vtxID10 = vtxID00 + g.lineVtxOffset; + const size_t vtxID11 = vtxID01 + g.lineVtxOffset; + const Vec3fa vtx10 = Vec3fa::loadu(mesh->vertexPtr(vtxID10)); + const Vec3fa vtx11 = Vec3fa::loadu(mesh->vertexPtr(vtxID11)); + + /* deltaX => vtx02, vtx12 */ + const size_t deltaX = invalid3x3X() ? 0 : 1; + const size_t vtxID02 = vtxID01 + deltaX; + const Vec3fa vtx02 = Vec3fa::loadu(mesh->vertexPtr(vtxID02)); + const size_t vtxID12 = vtxID11 + deltaX; + const Vec3fa vtx12 = Vec3fa::loadu(mesh->vertexPtr(vtxID12)); + + /* deltaY => vtx20, vtx21 */ + const size_t deltaY = invalid3x3Y() ? 0 : g.lineVtxOffset; + const size_t vtxID20 = vtxID10 + deltaY; + const size_t vtxID21 = vtxID11 + deltaY; + const Vec3fa vtx20 = Vec3fa::loadu(mesh->vertexPtr(vtxID20)); + const Vec3fa vtx21 = Vec3fa::loadu(mesh->vertexPtr(vtxID21)); + + /* deltaX/deltaY => vtx22 */ + const size_t vtxID22 = vtxID11 + deltaX + deltaY; + const Vec3fa vtx22 = Vec3fa::loadu(mesh->vertexPtr(vtxID22)); + + vtx[ 0] = vtx00; vtx[ 1] = vtx01; vtx[ 2] = vtx11; vtx[ 3] = vtx10; + vtx[ 4] = vtx01; vtx[ 5] = vtx02; vtx[ 6] = vtx12; vtx[ 7] = vtx11; + vtx[ 8] = vtx10; vtx[ 9] = vtx11; vtx[10] = vtx21; vtx[11] = vtx20; + vtx[12] = vtx11; vtx[13] = vtx12; vtx[14] = vtx22; vtx[15] = vtx21; + } + + /* Gather the quads */ + __forceinline void gatherMB(vfloat4 vtx[16], const Scene *const scene, const size_t itime, const float ftime) const + { + const GridMesh* mesh = scene->get(geomID()); + const GridMesh::Grid &g = mesh->grid(primID()); + + /* first quad always valid */ + const size_t vtxID00 = g.startVtxID + x() + y() * g.lineVtxOffset; + const size_t vtxID01 = vtxID00 + 1; + const vfloat4 vtx00 = getVertexMB(mesh,vtxID00,itime,ftime); + const vfloat4 vtx01 = getVertexMB(mesh,vtxID01,itime,ftime); + const size_t vtxID10 = vtxID00 + g.lineVtxOffset; + const size_t vtxID11 = vtxID01 + g.lineVtxOffset; + const vfloat4 vtx10 = getVertexMB(mesh,vtxID10,itime,ftime); + const vfloat4 vtx11 = getVertexMB(mesh,vtxID11,itime,ftime); + + /* deltaX => vtx02, vtx12 */ + const size_t deltaX = invalid3x3X() ? 0 : 1; + const size_t vtxID02 = vtxID01 + deltaX; + const vfloat4 vtx02 = getVertexMB(mesh,vtxID02,itime,ftime); + const size_t vtxID12 = vtxID11 + deltaX; + const vfloat4 vtx12 = getVertexMB(mesh,vtxID12,itime,ftime); + + /* deltaY => vtx20, vtx21 */ + const size_t deltaY = invalid3x3Y() ? 0 : g.lineVtxOffset; + const size_t vtxID20 = vtxID10 + deltaY; + const size_t vtxID21 = vtxID11 + deltaY; + const vfloat4 vtx20 = getVertexMB(mesh,vtxID20,itime,ftime); + const vfloat4 vtx21 = getVertexMB(mesh,vtxID21,itime,ftime); + + /* deltaX/deltaY => vtx22 */ + const size_t vtxID22 = vtxID11 + deltaX + deltaY; + const vfloat4 vtx22 = getVertexMB(mesh,vtxID22,itime,ftime); + + vtx[ 0] = vtx00; vtx[ 1] = vtx01; vtx[ 2] = vtx11; vtx[ 3] = vtx10; + vtx[ 4] = vtx01; vtx[ 5] = vtx02; vtx[ 6] = vtx12; vtx[ 7] = vtx11; + vtx[ 8] = vtx10; vtx[ 9] = vtx11; vtx[10] = vtx21; vtx[11] = vtx20; + vtx[12] = vtx11; vtx[13] = vtx12; vtx[14] = vtx22; vtx[15] = vtx21; + } + + + /* Calculate the bounds of the subgrid */ + __forceinline const BBox3fa bounds(const Scene *const scene, const size_t itime=0) const + { + BBox3fa bounds = empty; + FATAL("not implemented yet"); + return bounds; + } + + /* Calculate the linear bounds of the primitive */ + __forceinline LBBox3fa linearBounds(const Scene* const scene, const size_t itime) + { + return LBBox3fa(bounds(scene,itime+0),bounds(scene,itime+1)); + } + + __forceinline LBBox3fa linearBounds(const Scene *const scene, size_t itime, size_t numTimeSteps) + { + LBBox3fa allBounds = empty; + FATAL("not implemented yet"); + return allBounds; + } + + __forceinline LBBox3fa linearBounds(const Scene *const scene, const BBox1f time_range) + { + LBBox3fa allBounds = empty; + FATAL("not implemented yet"); + return allBounds; + } + + + friend embree_ostream operator<<(embree_ostream cout, const SubGrid& sg) { + return cout << "SubGrid " << " ( x " << sg.x() << ", y = " << sg.y() << ", geomID = " << sg.geomID() << ", primID = " << sg.primID() << " )"; + } + + __forceinline unsigned int geomID() const { return _geomID; } + __forceinline unsigned int primID() const { return _primID; } + __forceinline unsigned int x() const { return (unsigned int)_x & 0x7fff; } + __forceinline unsigned int y() const { return (unsigned int)_y & 0x7fff; } + + private: + unsigned short _x; + unsigned short _y; + unsigned int _geomID; // geometry ID of mesh + unsigned int _primID; // primitive ID of primitive inside mesh + }; + + struct SubGridID { + unsigned short x; + unsigned short y; + unsigned int primID; + + __forceinline SubGridID() {} + __forceinline SubGridID(const unsigned int x, const unsigned int y, const unsigned int primID) : + x(x), y(y), primID(primID) {} + }; + + /* QuantizedBaseNode as large subgrid leaf */ + template + struct SubGridQBVHN + { + /* Virtual interface to query information about the quad type */ + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + __forceinline size_t size() const + { + for (size_t i=0;i::AABBNode node; + node.clear(); + for (size_t i=0;i::QuantizedBaseNode qnode; + + unsigned int _geomID; // geometry ID of mesh + + + friend embree_ostream operator<<(embree_ostream cout, const SubGridQBVHN& sg) { + cout << "SubGridQBVHN " << embree_endl; + for (size_t i=0;i + typename SubGridQBVHN::Type SubGridQBVHN::type; + + typedef SubGridQBVHN<4> SubGridQBVH4; + typedef SubGridQBVHN<8> SubGridQBVH8; + + + /* QuantizedBaseNode as large subgrid leaf */ + template + struct SubGridMBQBVHN + { + /* Virtual interface to query information about the quad type */ + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + __forceinline size_t size() const + { + for (size_t i=0;i::AABBNode node0,node1; + node0.clear(); + node1.clear(); + for (size_t i=0;i + __forceinline vfloat adjustTime(const vfloat &t) const { return time_scale * (t-time_offset); } + + public: + SubGridID subgridIDs[N]; + + typename BVHN::QuantizedBaseNodeMB qnode; + + float time_offset; + float time_scale; + unsigned int _geomID; // geometry ID of mesh + + + friend embree_ostream operator<<(embree_ostream cout, const SubGridMBQBVHN& sg) { + cout << "SubGridMBQBVHN " << embree_endl; + for (size_t i=0;i + struct SubGridIntersector1Moeller + { + typedef SubGridQBVHN Primitive; + typedef SubGridQuadMIntersector1MoellerTrumbore<4,filter> Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const SubGrid& subgrid) + { + STAT3(normal.trav_prims,1,1,1); + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + Vec3vf4 v0,v1,v2,v3; subgrid.gather(v0,v1,v2,v3,context->scene); + pre.intersect(ray,context,v0,v1,v2,v3,g,subgrid); + } + + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const SubGrid& subgrid) + { + STAT3(shadow.trav_prims,1,1,1); + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + Vec3vf4 v0,v1,v2,v3; subgrid.gather(v0,v1,v2,v3,context->scene); + return pre.occluded(ray,context,v0,v1,v2,v3,g,subgrid); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const SubGrid& subgrid) + { + STAT3(point_query.trav_prims,1,1,1); + AccelSet* accel = (AccelSet*)context->scene->get(subgrid.geomID()); + assert(accel); + context->geomID = subgrid.geomID(); + context->primID = subgrid.primID(); + return accel->pointQuery(query, context); + } + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersector1 isec1; + + for (size_t i=0;i dist; + size_t mask = isec1.intersect(&prim[i].qnode,tray,dist); +#if defined(__AVX__) + STAT3(normal.trav_hit_boxes[popcnt(mask)],1,1,1); +#endif + while(mask != 0) + { + const size_t ID = bscf(mask); + assert(((size_t)1 << ID) & movemask(prim[i].qnode.validMask())); + + if (unlikely(dist[ID] > ray.tfar)) continue; + intersect(pre,ray,context,prim[i].subgrid(ID)); + } + } + } + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + + { + BVHNQuantizedBaseNodeIntersector1 isec1; + + for (size_t i=0;i dist; + size_t mask = isec1.intersect(&prim[i].qnode,tray,dist); + while(mask != 0) + { + const size_t ID = bscf(mask); + assert(((size_t)1 << ID) & movemask(prim[i].qnode.validMask())); + + if (occluded(pre,ray,context,prim[i].subgrid(ID))) + return true; + } + } + return false; + } + + static __forceinline bool pointQuery(const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context, const Primitive* prim, size_t num, const TravPointQuery &tquery, size_t& lazy_node) + { + bool changed = false; + for (size_t i=0;i dist; + size_t mask; + if (likely(context->query_type == POINT_QUERY_TYPE_SPHERE)) { + mask = BVHNQuantizedBaseNodePointQuerySphere1::pointQuery(&prim[i].qnode,tquery,dist); + } else { + mask = BVHNQuantizedBaseNodePointQueryAABB1::pointQuery(&prim[i].qnode,tquery,dist); + } + while(mask != 0) + { + const size_t ID = bscf(mask); + assert(((size_t)1 << ID) & movemask(prim[i].qnode.validMask())); + changed |= pointQuery(query, context, prim[i].subgrid(ID)); + } + } + return changed; + } + }; + + template + struct SubGridIntersector1Pluecker + { + typedef SubGridQBVHN Primitive; + typedef SubGridQuadMIntersector1Pluecker<4,filter> Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const SubGrid& subgrid) + { + STAT3(normal.trav_prims,1,1,1); + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + Vec3vf4 v0,v1,v2,v3; subgrid.gather(v0,v1,v2,v3,context->scene); + pre.intersect(ray,context,v0,v1,v2,v3,g,subgrid); + } + + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const SubGrid& subgrid) + { + STAT3(shadow.trav_prims,1,1,1); + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + Vec3vf4 v0,v1,v2,v3; subgrid.gather(v0,v1,v2,v3,context->scene); + return pre.occluded(ray,context,v0,v1,v2,v3,g,subgrid); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const SubGrid& subgrid) + { + STAT3(point_query.trav_prims,1,1,1); + AccelSet* accel = (AccelSet*)context->scene->get(subgrid.geomID()); + context->geomID = subgrid.geomID(); + context->primID = subgrid.primID(); + return accel->pointQuery(query, context); + } + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersector1 isec1; + + for (size_t i=0;i dist; + size_t mask = isec1.intersect(&prim[i].qnode,tray,dist); +#if defined(__AVX__) + STAT3(normal.trav_hit_boxes[popcnt(mask)],1,1,1); +#endif + while(mask != 0) + { + const size_t ID = bscf(mask); + assert(((size_t)1 << ID) & movemask(prim[i].qnode.validMask())); + + if (unlikely(dist[ID] > ray.tfar)) continue; + intersect(pre,ray,context,prim[i].subgrid(ID)); + } + } + } + + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersector1 isec1; + + for (size_t i=0;i dist; + size_t mask = isec1.intersect(&prim[i].qnode,tray,dist); + while(mask != 0) + { + const size_t ID = bscf(mask); + assert(((size_t)1 << ID) & movemask(prim[i].qnode.validMask())); + + if (occluded(pre,ray,context,prim[i].subgrid(ID))) + return true; + } + } + return false; + } + + static __forceinline bool pointQuery(const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context, const Primitive* prim, size_t num, const TravPointQuery &tquery, size_t& lazy_node) + { + bool changed = false; + for (size_t i=0;i dist; + size_t mask; + if (likely(context->query_type == POINT_QUERY_TYPE_SPHERE)) { + mask = BVHNQuantizedBaseNodePointQuerySphere1::pointQuery(&prim[i].qnode,tquery,dist); + } else { + mask = BVHNQuantizedBaseNodePointQueryAABB1::pointQuery(&prim[i].qnode,tquery,dist); + } +#if defined(__AVX__) + STAT3(point_query.trav_hit_boxes[popcnt(mask)],1,1,1); +#endif + while(mask != 0) + { + const size_t ID = bscf(mask); + assert(((size_t)1 << ID) & movemask(prim[i].qnode.validMask())); + changed |= pointQuery(query, context, prim[i].subgrid(ID)); + } + } + return changed; + } + }; + + template + struct SubGridIntersectorKMoeller + { + typedef SubGridQBVHN Primitive; + typedef SubGridQuadMIntersectorKMoellerTrumbore<4,K,filter> Precalculations; + + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const SubGrid& subgrid) + { + Vec3fa vtx[16]; + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + subgrid.gather(vtx,context->scene); + for (unsigned int i=0; i<4; i++) + { + const Vec3vf p0 = vtx[i*4+0]; + const Vec3vf p1 = vtx[i*4+1]; + const Vec3vf p2 = vtx[i*4+2]; + const Vec3vf p3 = vtx[i*4+3]; + STAT3(normal.trav_prims,1,popcnt(valid_i),K); + pre.intersectK(valid_i,ray,p0,p1,p2,p3,g,subgrid,i,IntersectKEpilogM<4,K,filter>(ray,context,subgrid.geomID(),subgrid.primID(),i)); + } + } + + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const SubGrid& subgrid) + { + vbool valid0 = valid_i; + Vec3fa vtx[16]; + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + subgrid.gather(vtx,context->scene); + for (unsigned int i=0; i<4; i++) + { + const Vec3vf p0 = vtx[i*4+0]; + const Vec3vf p1 = vtx[i*4+1]; + const Vec3vf p2 = vtx[i*4+2]; + const Vec3vf p3 = vtx[i*4+3]; + STAT3(shadow.trav_prims,1,popcnt(valid0),K); + if (pre.intersectK(valid0,ray,p0,p1,p2,p3,g,subgrid,i,OccludedKEpilogM<4,K,filter>(valid0,ray,context,subgrid.geomID(),subgrid.primID(),i))) + break; + } + return !valid0; + } + + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const SubGrid& subgrid) + { + STAT3(normal.trav_prims,1,1,1); + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + Vec3vf4 v0,v1,v2,v3; subgrid.gather(v0,v1,v2,v3,context->scene); + pre.intersect1(ray,k,context,v0,v1,v2,v3,g,subgrid); + } + + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const SubGrid& subgrid) + { + STAT3(shadow.trav_prims,1,1,1); + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + Vec3vf4 v0,v1,v2,v3; subgrid.gather(v0,v1,v2,v3,context->scene); + return pre.occluded1(ray,k,context,v0,v1,v2,v3,g,subgrid); + } + + template + static __forceinline void intersect(const vbool& valid, const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRayK &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersectorK isecK; + for (size_t j=0;j dist; + while(m_valid) + { + const size_t i = bscf(m_valid); + if (none(valid & isecK.intersectK(&prim[j].qnode,i,tray,dist))) continue; + intersect(valid,pre,ray,context,prim[j].subgrid(i)); + } + } + } + + template + static __forceinline vbool occluded(const vbool& valid, const Accel::Intersectors* This, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRayK &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersectorK isecK; + vbool valid0 = valid; + for (size_t j=0;j dist; + while(m_valid) + { + const size_t i = bscf(m_valid); + if (none(valid0 & isecK.intersectK(&prim[j].qnode,i,tray,dist))) continue; + valid0 &= !occluded(valid0,pre,ray,context,prim[j].subgrid(i)); + if (none(valid0)) break; + } + } + return !valid0; + } + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersector1 isec1; + + for (size_t i=0;i dist; + size_t mask = isec1.intersect(&prim[i].qnode,tray,dist); + while(mask != 0) + { + const size_t ID = bscf(mask); + assert(((size_t)1 << ID) & movemask(prim[i].qnode.validMask())); + + if (unlikely(dist[ID] > ray.tfar[k])) continue; + intersect(pre,ray,k,context,prim[i].subgrid(ID)); + } + } + } + + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersector1 isec1; + + for (size_t i=0;i dist; + size_t mask = isec1.intersect(&prim[i].qnode,tray,dist); + while(mask != 0) + { + const size_t ID = bscf(mask); + assert(((size_t)1 << ID) & movemask(prim[i].qnode.validMask())); + + if (occluded(pre,ray,k,context,prim[i].subgrid(ID))) + return true; + } + } + return false; + } + }; + + + template + struct SubGridIntersectorKPluecker + { + typedef SubGridQBVHN Primitive; + typedef SubGridQuadMIntersectorKPluecker<4,K,filter> Precalculations; + + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const SubGrid& subgrid) + { + Vec3fa vtx[16]; + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + subgrid.gather(vtx,context->scene); + for (unsigned int i=0; i<4; i++) + { + const Vec3vf p0 = vtx[i*4+0]; + const Vec3vf p1 = vtx[i*4+1]; + const Vec3vf p2 = vtx[i*4+2]; + const Vec3vf p3 = vtx[i*4+3]; + STAT3(normal.trav_prims,1,popcnt(valid_i),K); + pre.intersectK(valid_i,ray,p0,p1,p2,p3,g,subgrid,i,IntersectKEpilogM<4,K,filter>(ray,context,subgrid.geomID(),subgrid.primID(),i)); + } + } + + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const SubGrid& subgrid) + { + vbool valid0 = valid_i; + Vec3fa vtx[16]; + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + subgrid.gather(vtx,context->scene); + for (unsigned int i=0; i<4; i++) + { + const Vec3vf p0 = vtx[i*4+0]; + const Vec3vf p1 = vtx[i*4+1]; + const Vec3vf p2 = vtx[i*4+2]; + const Vec3vf p3 = vtx[i*4+3]; + STAT3(shadow.trav_prims,1,popcnt(valid0),K); + if (pre.intersectK(valid0,ray,p0,p1,p2,p3,g,subgrid,i,OccludedKEpilogM<4,K,filter>(valid0,ray,context,subgrid.geomID(),subgrid.primID(),i))) + break; + } + return !valid0; + } + + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const SubGrid& subgrid) + { + STAT3(normal.trav_prims,1,1,1); + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + Vec3vf4 v0,v1,v2,v3; subgrid.gather(v0,v1,v2,v3,context->scene); + pre.intersect1(ray,k,context,v0,v1,v2,v3,g,subgrid); + } + + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const SubGrid& subgrid) + { + STAT3(shadow.trav_prims,1,1,1); + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + Vec3vf4 v0,v1,v2,v3; subgrid.gather(v0,v1,v2,v3,context->scene); + return pre.occluded1(ray,k,context,v0,v1,v2,v3,g,subgrid); + } + + template + static __forceinline void intersect(const vbool& valid, const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRayK &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersectorK isecK; + for (size_t j=0;j dist; + while(m_valid) + { + const size_t i = bscf(m_valid); + if (none(valid & isecK.intersectK(&prim[j].qnode,i,tray,dist))) continue; + intersect(valid,pre,ray,context,prim[j].subgrid(i)); + } + } + } + + template + static __forceinline vbool occluded(const vbool& valid, const Accel::Intersectors* This, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRayK &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersectorK isecK; + vbool valid0 = valid; + for (size_t j=0;j dist; + while(m_valid) + { + const size_t i = bscf(m_valid); + if (none(valid0 & isecK.intersectK(&prim[j].qnode,i,tray,dist))) continue; + valid0 &= !occluded(valid0,pre,ray,context,prim[j].subgrid(i)); + if (none(valid0)) break; + } + } + return !valid0; + } + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersector1 isec1; + + for (size_t i=0;i dist; + size_t mask = isec1.intersect(&prim[i].qnode,tray,dist); + while(mask != 0) + { + const size_t ID = bscf(mask); + assert(((size_t)1 << ID) & movemask(prim[i].qnode.validMask())); + + if (unlikely(dist[ID] > ray.tfar[k])) continue; + intersect(pre,ray,k,context,prim[i].subgrid(ID)); + } + } + } + + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersector1 isec1; + + for (size_t i=0;i dist; + size_t mask = isec1.intersect(&prim[i].qnode,tray,dist); + while(mask != 0) + { + const size_t ID = bscf(mask); + assert(((size_t)1 << ID) & movemask(prim[i].qnode.validMask())); + + if (occluded(pre,ray,k,context,prim[i].subgrid(ID))) + return true; + } + } + return false; + } + }; + + + + } +} diff --git a/thirdparty/embree/kernels/geometry/subgrid_intersector_moeller.h b/thirdparty/embree/kernels/geometry/subgrid_intersector_moeller.h new file mode 100644 index 000000000000..f65b4abf61c9 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/subgrid_intersector_moeller.h @@ -0,0 +1,493 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "subgrid.h" +#include "quad_intersector_moeller.h" + +namespace embree +{ + namespace isa + { + + /* ----------------------------- */ + /* -- single ray intersectors -- */ + /* ----------------------------- */ + + template + __forceinline void interpolateUV(MoellerTrumboreHitM &hit,const GridMesh::Grid &g, const SubGrid& subgrid) + { + /* correct U,V interpolation across the entire grid */ + const vint sx((int)subgrid.x()); + const vint sy((int)subgrid.y()); + const vint sxM(sx + vint(0,1,1,0)); + const vint syM(sy + vint(0,0,1,1)); + const float inv_resX = rcp((float)((int)g.resX-1)); + const float inv_resY = rcp((float)((int)g.resY-1)); + hit.U = (hit.U + (vfloat)sxM * hit.absDen) * inv_resX; + hit.V = (hit.V + (vfloat)syM * hit.absDen) * inv_resY; + } + + template + struct SubGridQuadMIntersector1MoellerTrumbore; + + template + struct SubGridQuadMIntersector1MoellerTrumbore + { + __forceinline SubGridQuadMIntersector1MoellerTrumbore() {} + + __forceinline SubGridQuadMIntersector1MoellerTrumbore(const Ray& ray, const void* ptr) {} + + __forceinline void intersect(RayHit& ray, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, + const GridMesh::Grid &g, const SubGrid& subgrid) const + { + MoellerTrumboreHitM hit; + MoellerTrumboreIntersector1 intersector(ray,nullptr); + Intersect1EpilogMU epilog(ray,context,subgrid.geomID(),subgrid.primID()); + + /* intersect first triangle */ + if (intersector.intersect(ray,v0,v1,v3,hit)) + { + interpolateUV(hit,g,subgrid); + epilog(hit.valid,hit); + } + + /* intersect second triangle */ + if (intersector.intersect(ray,v2,v3,v1,hit)) + { + hit.U = hit.absDen - hit.U; + hit.V = hit.absDen - hit.V; + interpolateUV(hit,g,subgrid); + epilog(hit.valid,hit); + } + } + + __forceinline bool occluded(Ray& ray, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, + const GridMesh::Grid &g, const SubGrid& subgrid) const + { + MoellerTrumboreHitM hit; + MoellerTrumboreIntersector1 intersector(ray,nullptr); + Occluded1EpilogMU epilog(ray,context,subgrid.geomID(),subgrid.primID()); + + /* intersect first triangle */ + if (intersector.intersect(ray,v0,v1,v3,hit)) + { + interpolateUV(hit,g,subgrid); + if (epilog(hit.valid,hit)) + return true; + } + + /* intersect second triangle */ + if (intersector.intersect(ray,v2,v3,v1,hit)) + { + hit.U = hit.absDen - hit.U; + hit.V = hit.absDen - hit.V; + interpolateUV(hit,g,subgrid); + if (epilog(hit.valid,hit)) + return true; + } + return false; + } + }; + +#if defined (__AVX__) + + /*! Intersects 4 quads with 1 ray using AVX */ + template + struct SubGridQuadMIntersector1MoellerTrumbore<4,filter> + { + __forceinline SubGridQuadMIntersector1MoellerTrumbore() {} + + __forceinline SubGridQuadMIntersector1MoellerTrumbore(const Ray& ray, const void* ptr) {} + + template + __forceinline bool intersect(Ray& ray, const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, const GridMesh::Grid &g, const SubGrid& subgrid, const Epilog& epilog) const + { + const Vec3vf8 vtx0(vfloat8(v0.x,v2.x),vfloat8(v0.y,v2.y),vfloat8(v0.z,v2.z)); +#if !defined(EMBREE_BACKFACE_CULLING) + const Vec3vf8 vtx1(vfloat8(v1.x),vfloat8(v1.y),vfloat8(v1.z)); + const Vec3vf8 vtx2(vfloat8(v3.x),vfloat8(v3.y),vfloat8(v3.z)); +#else + const Vec3vf8 vtx1(vfloat8(v1.x,v3.x),vfloat8(v1.y,v3.y),vfloat8(v1.z,v3.z)); + const Vec3vf8 vtx2(vfloat8(v3.x,v1.x),vfloat8(v3.y,v1.y),vfloat8(v3.z,v1.z)); +#endif + MoellerTrumboreHitM<8> hit; + MoellerTrumboreIntersector1<8> intersector(ray,nullptr); + const vbool8 flags(0,0,0,0,1,1,1,1); + if (unlikely(intersector.intersect(ray,vtx0,vtx1,vtx2,hit))) + { + vfloat8 U = hit.U, V = hit.V, absDen = hit.absDen; + +#if !defined(EMBREE_BACKFACE_CULLING) + hit.U = select(flags,absDen-V,U); + hit.V = select(flags,absDen-U,V); + hit.vNg *= select(flags,vfloat8(-1.0f),vfloat8(1.0f)); +#else + hit.U = select(flags,absDen-U,U); + hit.V = select(flags,absDen-V,V); +#endif + /* correct U,V interpolation across the entire grid */ + const vint8 sx((int)subgrid.x()); + const vint8 sy((int)subgrid.y()); + const vint8 sx8(sx + vint8(0,1,1,0,0,1,1,0)); + const vint8 sy8(sy + vint8(0,0,1,1,0,0,1,1)); + const float inv_resX = rcp((float)((int)g.resX-1)); + const float inv_resY = rcp((float)((int)g.resY-1)); + hit.U = (hit.U + (vfloat8)sx8 * absDen) * inv_resX; + hit.V = (hit.V + (vfloat8)sy8 * absDen) * inv_resY; + + if (unlikely(epilog(hit.valid,hit))) + return true; + } + return false; + } + + __forceinline bool intersect(RayHit& ray, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const GridMesh::Grid &g, const SubGrid& subgrid) const + { + return intersect(ray,v0,v1,v2,v3,g,subgrid,Intersect1EpilogMU<8,filter>(ray,context,subgrid.geomID(),subgrid.primID())); + } + + __forceinline bool occluded(Ray& ray, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const GridMesh::Grid &g, const SubGrid& subgrid) const + { + return intersect(ray,v0,v1,v2,v3,g,subgrid,Occluded1EpilogMU<8,filter>(ray,context,subgrid.geomID(),subgrid.primID())); + } + }; + +#endif + + // ============================================================================================================================ + // ============================================================================================================================ + // ============================================================================================================================ + + + /* ----------------------------- */ + /* -- ray packet intersectors -- */ + /* ----------------------------- */ + + template + struct SubGridQuadHitK + { + __forceinline SubGridQuadHitK(const vfloat& U, + const vfloat& V, + const vfloat& T, + const vfloat& absDen, + const Vec3vf& Ng, + const vbool& flags, + const GridMesh::Grid &g, + const SubGrid& subgrid, + const unsigned int i) + : U(U), V(V), T(T), absDen(absDen), flags(flags), tri_Ng(Ng), g(g), subgrid(subgrid), i(i) {} + + __forceinline std::tuple,vfloat,vfloat,Vec3vf> operator() () const + { + const vfloat rcpAbsDen = rcp(absDen); + const vfloat t = T * rcpAbsDen; + const vfloat u0 = min(U * rcpAbsDen,1.0f); + const vfloat v0 = min(V * rcpAbsDen,1.0f); + const vfloat u1 = vfloat(1.0f) - u0; + const vfloat v1 = vfloat(1.0f) - v0; + const vfloat uu = select(flags,u1,u0); + const vfloat vv = select(flags,v1,v0); + const unsigned int sx = subgrid.x() + (unsigned int)(i % 2); + const unsigned int sy = subgrid.y() + (unsigned int)(i >>1); + const float inv_resX = rcp((float)(int)(g.resX-1)); + const float inv_resY = rcp((float)(int)(g.resY-1)); + const vfloat u = (uu + (float)(int)sx) * inv_resX; + const vfloat v = (vv + (float)(int)sy) * inv_resY; + const Vec3vf Ng(tri_Ng.x,tri_Ng.y,tri_Ng.z); + return std::make_tuple(u,v,t,Ng); + } + + private: + const vfloat U; + const vfloat V; + const vfloat T; + const vfloat absDen; + const vbool flags; + const Vec3vf tri_Ng; + + const GridMesh::Grid &g; + const SubGrid& subgrid; + const size_t i; + }; + + template + struct SubGridQuadMIntersectorKMoellerTrumboreBase + { + __forceinline SubGridQuadMIntersectorKMoellerTrumboreBase(const vbool& valid, const RayK& ray) {} + + template + __forceinline vbool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + const Vec3vf& tri_Ng, + const vbool& flags, + const GridMesh::Grid &g, + const SubGrid &subgrid, + const unsigned int i, + const Epilog& epilog) const + { + /* calculate denominator */ + vbool valid = valid0; + const Vec3vf C = tri_v0 - ray.org; + const Vec3vf R = cross(C,ray.dir); + const vfloat den = dot(tri_Ng,ray.dir); + const vfloat absDen = abs(den); + const vfloat sgnDen = signmsk(den); + + /* test against edge p2 p0 */ + const vfloat U = dot(R,tri_e2) ^ sgnDen; + valid &= U >= 0.0f; + if (likely(none(valid))) return false; + + /* test against edge p0 p1 */ + const vfloat V = dot(R,tri_e1) ^ sgnDen; + valid &= V >= 0.0f; + if (likely(none(valid))) return false; + + /* test against edge p1 p2 */ + const vfloat W = absDen-U-V; + valid &= W >= 0.0f; + if (likely(none(valid))) return false; + + /* perform depth test */ + const vfloat T = dot(tri_Ng,C) ^ sgnDen; + valid &= (absDen*ray.tnear() < T) & (T <= absDen*ray.tfar); + if (unlikely(none(valid))) return false; + + /* perform backface culling */ +#if defined(EMBREE_BACKFACE_CULLING) + valid &= den < vfloat(zero); + if (unlikely(none(valid))) return false; +#else + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; +#endif + + /* calculate hit information */ + SubGridQuadHitK hit(U,V,T,absDen,tri_Ng,flags,g,subgrid,i); + return epilog(valid,hit); + } + + template + __forceinline vbool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + const vbool& flags, + const GridMesh::Grid &g, + const SubGrid &subgrid, + const unsigned int i, + const Epilog& epilog) const + { + const Vec3vf e1 = tri_v0-tri_v1; + const Vec3vf e2 = tri_v2-tri_v0; + const Vec3vf Ng = cross(e2,e1); + return intersectK(valid0,ray,tri_v0,e1,e2,Ng,flags,g,subgrid,i,epilog); + } + + template + __forceinline bool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const Vec3vf& v3, + const GridMesh::Grid &g, + const SubGrid &subgrid, + const unsigned int i, + const Epilog& epilog) const + { + intersectK(valid0,ray,v0,v1,v3,vbool(false),g,subgrid,i,epilog); + if (none(valid0)) return true; + intersectK(valid0,ray,v2,v3,v1,vbool(true ),g,subgrid,i,epilog); + return none(valid0); + } + + static __forceinline bool intersect1(RayK& ray, + size_t k, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + const Vec3vf& tri_Ng, + MoellerTrumboreHitM &hit) + { + /* calculate denominator */ + const Vec3vf O = broadcast>(ray.org,k); + const Vec3vf D = broadcast>(ray.dir,k); + const Vec3vf C = Vec3vf(tri_v0) - O; + const Vec3vf R = cross(C,D); + const vfloat den = dot(Vec3vf(tri_Ng),D); + const vfloat absDen = abs(den); + const vfloat sgnDen = signmsk(den); + + /* perform edge tests */ + const vfloat U = dot(R,Vec3vf(tri_e2)) ^ sgnDen; + const vfloat V = dot(R,Vec3vf(tri_e1)) ^ sgnDen; + + /* perform backface culling */ +#if defined(EMBREE_BACKFACE_CULLING) + vbool valid = (den < vfloat(zero)) & (U >= 0.0f) & (V >= 0.0f) & (U+V<=absDen); +#else + vbool valid = (den != vfloat(zero)) & (U >= 0.0f) & (V >= 0.0f) & (U+V<=absDen); +#endif + if (likely(none(valid))) return false; + + /* perform depth test */ + const vfloat T = dot(Vec3vf(tri_Ng),C) ^ sgnDen; + valid &= (absDen*vfloat(ray.tnear()[k]) < T) & (T <= absDen*vfloat(ray.tfar[k])); + if (likely(none(valid))) return false; + + /* calculate hit information */ + new (&hit) MoellerTrumboreHitM(valid,U,V,T,absDen,tri_Ng); + return true; + } + + static __forceinline bool intersect1(RayK& ray, + size_t k, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + MoellerTrumboreHitM &hit) + { + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v2-v0; + const Vec3vf Ng = cross(e2,e1); + return intersect1(ray,k,v0,e1,e2,Ng,hit); + } + + }; + + template + struct SubGridQuadMIntersectorKMoellerTrumbore : public SubGridQuadMIntersectorKMoellerTrumboreBase + { + __forceinline SubGridQuadMIntersectorKMoellerTrumbore(const vbool& valid, const RayK& ray) + : SubGridQuadMIntersectorKMoellerTrumboreBase(valid,ray) {} + + __forceinline void intersect1(RayHitK& ray, size_t k, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, const GridMesh::Grid &g, const SubGrid &subgrid) const + { + Intersect1KEpilogMU epilog(ray,k,context,subgrid.geomID(),subgrid.primID()); + + MoellerTrumboreHitM<4> hit; + if (SubGridQuadMIntersectorKMoellerTrumboreBase<4,K,filter>::intersect1(ray,k,v0,v1,v3,hit)) + { + interpolateUV(hit,g,subgrid); + epilog(hit.valid,hit); + } + + if (SubGridQuadMIntersectorKMoellerTrumboreBase<4,K,filter>::intersect1(ray,k,v2,v3,v1,hit)) + { + hit.U = hit.absDen - hit.U; + hit.V = hit.absDen - hit.V; + interpolateUV(hit,g,subgrid); + epilog(hit.valid,hit); + } + + } + + __forceinline bool occluded1(RayK& ray, size_t k, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, const GridMesh::Grid &g, const SubGrid &subgrid) const + { + Occluded1KEpilogMU epilog(ray,k,context,subgrid.geomID(),subgrid.primID()); + + MoellerTrumboreHitM<4> hit; + if (SubGridQuadMIntersectorKMoellerTrumboreBase<4,K,filter>::intersect1(ray,k,v0,v1,v3,hit)) + { + interpolateUV(hit,g,subgrid); + if (epilog(hit.valid,hit)) return true; + } + + if (SubGridQuadMIntersectorKMoellerTrumboreBase<4,K,filter>::intersect1(ray,k,v2,v3,v1,hit)) + { + hit.U = hit.absDen - hit.U; + hit.V = hit.absDen - hit.V; + interpolateUV(hit,g,subgrid); + if (epilog(hit.valid,hit)) return true; + } + return false; + } + }; + + +#if defined (__AVX__) + + /*! Intersects 4 quads with 1 ray using AVX */ + template + struct SubGridQuadMIntersectorKMoellerTrumbore<4,K,filter> : public SubGridQuadMIntersectorKMoellerTrumboreBase<4,K,filter> + { + __forceinline SubGridQuadMIntersectorKMoellerTrumbore(const vbool& valid, const RayK& ray) + : SubGridQuadMIntersectorKMoellerTrumboreBase<4,K,filter>(valid,ray) {} + + template + __forceinline bool intersect1(RayK& ray, size_t k,const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const GridMesh::Grid &g, const SubGrid &subgrid, const Epilog& epilog) const + { + const Vec3vf8 vtx0(vfloat8(v0.x,v2.x),vfloat8(v0.y,v2.y),vfloat8(v0.z,v2.z)); +#if !defined(EMBREE_BACKFACE_CULLING) + const Vec3vf8 vtx1(vfloat8(v1.x),vfloat8(v1.y),vfloat8(v1.z)); + const Vec3vf8 vtx2(vfloat8(v3.x),vfloat8(v3.y),vfloat8(v3.z)); +#else + const Vec3vf8 vtx1(vfloat8(v1.x,v3.x),vfloat8(v1.y,v3.y),vfloat8(v1.z,v3.z)); + const Vec3vf8 vtx2(vfloat8(v3.x,v1.x),vfloat8(v3.y,v1.y),vfloat8(v3.z,v1.z)); +#endif + const vbool8 flags(0,0,0,0,1,1,1,1); + + MoellerTrumboreHitM<8> hit; + if (SubGridQuadMIntersectorKMoellerTrumboreBase<8,K,filter>::intersect1(ray,k,vtx0,vtx1,vtx2,hit)) + { + vfloat8 U = hit.U, V = hit.V, absDen = hit.absDen; +#if !defined(EMBREE_BACKFACE_CULLING) + hit.U = select(flags,absDen-V,U); + hit.V = select(flags,absDen-U,V); + hit.vNg *= select(flags,vfloat8(-1.0f),vfloat8(1.0f)); +#else + hit.U = select(flags,absDen-U,U); + hit.V = select(flags,absDen-V,V); +#endif + + /* correct U,V interpolation across the entire grid */ + const vint8 sx((int)subgrid.x()); + const vint8 sy((int)subgrid.y()); + const vint8 sx8(sx + vint8(0,1,1,0,0,1,1,0)); + const vint8 sy8(sy + vint8(0,0,1,1,0,0,1,1)); + const float inv_resX = rcp((float)((int)g.resX-1)); + const float inv_resY = rcp((float)((int)g.resY-1)); + hit.U = (hit.U + (vfloat8)sx8 * absDen) * inv_resX; + hit.V = (hit.V + (vfloat8)sy8 * absDen) * inv_resY; + if (unlikely(epilog(hit.valid,hit))) + return true; + + } + return false; + } + + __forceinline bool intersect1(RayHitK& ray, size_t k, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, const GridMesh::Grid &g, const SubGrid &subgrid) const + { + return intersect1(ray,k,v0,v1,v2,v3,g,subgrid,Intersect1KEpilogMU<8,K,filter>(ray,k,context,subgrid.geomID(),subgrid.primID())); + } + + __forceinline bool occluded1(RayK& ray, size_t k, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, const GridMesh::Grid &g, const SubGrid &subgrid) const + { + return intersect1(ray,k,v0,v1,v2,v3,g,subgrid,Occluded1KEpilogMU<8,K,filter>(ray,k,context,subgrid.geomID(),subgrid.primID())); + } + }; + +#endif + + + + } +} diff --git a/thirdparty/embree/kernels/geometry/subgrid_intersector_pluecker.h b/thirdparty/embree/kernels/geometry/subgrid_intersector_pluecker.h new file mode 100644 index 000000000000..1cd88aa799e7 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/subgrid_intersector_pluecker.h @@ -0,0 +1,508 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "subgrid.h" +#include "quad_intersector_moeller.h" +#include "quad_intersector_pluecker.h" + +namespace embree +{ + namespace isa + { + + template + struct SubGridQuadHitPlueckerM + { + __forceinline SubGridQuadHitPlueckerM() {} + + __forceinline SubGridQuadHitPlueckerM(const vbool& valid, + const vfloat& U, + const vfloat& V, + const vfloat& UVW, + const vfloat& t, + const Vec3vf& Ng, + const vbool& flags) : valid(valid), vt(t) + { + const vbool invalid = abs(UVW) < min_rcp_input; + const vfloat rcpUVW = select(invalid,vfloat(0.0f),rcp(UVW)); + const vfloat u = min(U * rcpUVW,1.0f); + const vfloat v = min(V * rcpUVW,1.0f); + const vfloat u1 = vfloat(1.0f) - u; + const vfloat v1 = vfloat(1.0f) - v; +#if !defined(__AVX__) || defined(EMBREE_BACKFACE_CULLING) + vu = select(flags,u1,u); + vv = select(flags,v1,v); + vNg = Vec3vf(Ng.x,Ng.y,Ng.z); +#else + const vfloat flip = select(flags,vfloat(-1.0f),vfloat(1.0f)); + vv = select(flags,u1,v); + vu = select(flags,v1,u); + vNg = Vec3vf(flip*Ng.x,flip*Ng.y,flip*Ng.z); +#endif + } + + __forceinline void finalize() + { + } + + __forceinline Vec2f uv(const size_t i) + { + const float u = vu[i]; + const float v = vv[i]; + return Vec2f(u,v); + } + + __forceinline float t(const size_t i) { return vt[i]; } + __forceinline Vec3fa Ng(const size_t i) { return Vec3fa(vNg.x[i],vNg.y[i],vNg.z[i]); } + + public: + vbool valid; + vfloat vu; + vfloat vv; + vfloat vt; + Vec3vf vNg; + }; + + template + __forceinline void interpolateUV(SubGridQuadHitPlueckerM &hit,const GridMesh::Grid &g, const SubGrid& subgrid, const vint &stepX, const vint &stepY) + { + /* correct U,V interpolation across the entire grid */ + const vint sx((int)subgrid.x()); + const vint sy((int)subgrid.y()); + const vint sxM(sx + stepX); + const vint syM(sy + stepY); + const float inv_resX = rcp((float)((int)g.resX-1)); + const float inv_resY = rcp((float)((int)g.resY-1)); + hit.vu = (hit.vu + vfloat(sxM)) * inv_resX; + hit.vv = (hit.vv + vfloat(syM)) * inv_resY; + } + + template + __forceinline static bool intersectPluecker(Ray& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + const vbool& flags, + SubGridQuadHitPlueckerM &hit) + { + /* calculate vertices relative to ray origin */ + const Vec3vf O = Vec3vf((Vec3fa)ray.org); + const Vec3vf D = Vec3vf((Vec3fa)ray.dir); + const Vec3vf v0 = tri_v0-O; + const Vec3vf v1 = tri_v1-O; + const Vec3vf v2 = tri_v2-O; + + /* calculate triangle edges */ + const Vec3vf e0 = v2-v0; + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v1-v2; + + /* perform edge tests */ + const vfloat U = dot(cross(e0,v2+v0),D); + const vfloat V = dot(cross(e1,v0+v1),D); + const vfloat W = dot(cross(e2,v1+v2),D); + const vfloat UVW = U+V+W; + const vfloat eps = float(ulp)*abs(UVW); +#if defined(EMBREE_BACKFACE_CULLING) + vbool valid = max(U,V,W) <= eps; +#else + vbool valid = (min(U,V,W) >= -eps) | (max(U,V,W) <= eps); +#endif + if (unlikely(none(valid))) return false; + + /* calculate geometry normal and denominator */ + const Vec3vf Ng = stable_triangle_normal(e0,e1,e2); + const vfloat den = twice(dot(Ng,D)); + + /* perform depth test */ + const vfloat T = twice(dot(v0,Ng)); + const vfloat t = rcp(den)*T; + valid &= vfloat(ray.tnear()) <= t & t <= vfloat(ray.tfar); + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; + + /* update hit information */ + new (&hit) SubGridQuadHitPlueckerM(valid,U,V,UVW,t,Ng,flags); + return true; + } + + template + struct SubGridQuadMIntersector1Pluecker; + + template + struct SubGridQuadMIntersector1Pluecker + { + __forceinline SubGridQuadMIntersector1Pluecker() {} + + __forceinline SubGridQuadMIntersector1Pluecker(const Ray& ray, const void* ptr) {} + + __forceinline void intersect(RayHit& ray, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, + const GridMesh::Grid &g, const SubGrid& subgrid) const + { + SubGridQuadHitPlueckerM hit; + Intersect1EpilogMU epilog(ray,context,subgrid.geomID(),subgrid.primID()); + + /* intersect first triangle */ + if (intersectPluecker(ray,v0,v1,v3,vbool(false),hit)) + { + interpolateUV(hit,g,subgrid,vint(0,1,1,0),vint(0,0,1,1)); + epilog(hit.valid,hit); + } + + /* intersect second triangle */ + if (intersectPluecker(ray,v2,v3,v1,vbool(true),hit)) + { + interpolateUV(hit,g,subgrid,vint(0,1,1,0),vint(0,0,1,1)); + epilog(hit.valid,hit); + } + } + + __forceinline bool occluded(Ray& ray, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, + const GridMesh::Grid &g, const SubGrid& subgrid) const + { + SubGridQuadHitPlueckerM hit; + Occluded1EpilogMU epilog(ray,context,subgrid.geomID(),subgrid.primID()); + + /* intersect first triangle */ + if (intersectPluecker(ray,v0,v1,v3,vbool(false),hit)) + { + interpolateUV(hit,g,subgrid,vint(0,1,1,0),vint(0,0,1,1)); + if (epilog(hit.valid,hit)) + return true; + } + + /* intersect second triangle */ + if (intersectPluecker(ray,v2,v3,v1,vbool(true),hit)) + { + interpolateUV(hit,g,subgrid,vint(0,1,1,0),vint(0,0,1,1)); + if (epilog(hit.valid,hit)) + return true; + } + + return false; + } + }; + +#if defined (__AVX__) + + /*! Intersects 4 quads with 1 ray using AVX */ + template + struct SubGridQuadMIntersector1Pluecker<4,filter> + { + __forceinline SubGridQuadMIntersector1Pluecker() {} + + __forceinline SubGridQuadMIntersector1Pluecker(const Ray& ray, const void* ptr) {} + + template + __forceinline bool intersect(Ray& ray, const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, const GridMesh::Grid &g, const SubGrid& subgrid, const Epilog& epilog) const + { + const Vec3vf8 vtx0(vfloat8(v0.x,v2.x),vfloat8(v0.y,v2.y),vfloat8(v0.z,v2.z)); +#if !defined(EMBREE_BACKFACE_CULLING) + const Vec3vf8 vtx1(vfloat8(v1.x),vfloat8(v1.y),vfloat8(v1.z)); + const Vec3vf8 vtx2(vfloat8(v3.x),vfloat8(v3.y),vfloat8(v3.z)); +#else + const Vec3vf8 vtx1(vfloat8(v1.x,v3.x),vfloat8(v1.y,v3.y),vfloat8(v1.z,v3.z)); + const Vec3vf8 vtx2(vfloat8(v3.x,v1.x),vfloat8(v3.y,v1.y),vfloat8(v3.z,v1.z)); +#endif + SubGridQuadHitPlueckerM<8> hit; + const vbool8 flags(0,0,0,0,1,1,1,1); + if (unlikely(intersectPluecker(ray,vtx0,vtx1,vtx2,flags,hit))) + { + /* correct U,V interpolation across the entire grid */ + interpolateUV<8>(hit,g,subgrid,vint<8>(0,1,1,0,0,1,1,0),vint<8>(0,0,1,1,0,0,1,1)); + if (unlikely(epilog(hit.valid,hit))) + return true; + } + return false; + } + + __forceinline bool intersect(RayHit& ray, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const GridMesh::Grid &g, const SubGrid& subgrid) const + { + return intersect(ray,v0,v1,v2,v3,g,subgrid,Intersect1EpilogMU<8,filter>(ray,context,subgrid.geomID(),subgrid.primID())); + } + + __forceinline bool occluded(Ray& ray, IntersectContext* context, + const Vec3vf4& v0, const Vec3vf4& v1, const Vec3vf4& v2, const Vec3vf4& v3, + const GridMesh::Grid &g, const SubGrid& subgrid) const + { + return intersect(ray,v0,v1,v2,v3,g,subgrid,Occluded1EpilogMU<8,filter>(ray,context,subgrid.geomID(),subgrid.primID())); + } + }; + +#endif + + + /* ----------------------------- */ + /* -- ray packet intersectors -- */ + /* ----------------------------- */ + + template + struct SubGridQuadHitPlueckerK + { + __forceinline SubGridQuadHitPlueckerK(const vfloat& U, + const vfloat& V, + const vfloat& UVW, + const vfloat& t, + const Vec3vf& Ng, + const vbool& flags, + const GridMesh::Grid &g, + const SubGrid& subgrid, + const unsigned int i) + : U(U), V(V), UVW(UVW), t(t), flags(flags), tri_Ng(Ng), g(g), subgrid(subgrid), i(i) {} + + __forceinline std::tuple,vfloat,vfloat,Vec3vf> operator() () const + { + const vbool invalid = abs(UVW) < min_rcp_input; + const vfloat rcpUVW = select(invalid,vfloat(0.0f),rcp(UVW)); + const vfloat u0 = min(U * rcpUVW,1.0f); + const vfloat v0 = min(V * rcpUVW,1.0f); + const vfloat u1 = vfloat(1.0f) - u0; + const vfloat v1 = vfloat(1.0f) - v0; + const vfloat uu = select(flags,u1,u0); + const vfloat vv = select(flags,v1,v0); + const unsigned int sx = subgrid.x() + (unsigned int)(i % 2); + const unsigned int sy = subgrid.y() + (unsigned int)(i >>1); + const float inv_resX = rcp((float)(int)(g.resX-1)); + const float inv_resY = rcp((float)(int)(g.resY-1)); + const vfloat u = (uu + (float)(int)sx) * inv_resX; + const vfloat v = (vv + (float)(int)sy) * inv_resY; + const Vec3vf Ng(tri_Ng.x,tri_Ng.y,tri_Ng.z); + return std::make_tuple(u,v,t,Ng); + } + + private: + const vfloat U; + const vfloat V; + const vfloat UVW; + const vfloat t; + const vfloat absDen; + const vbool flags; + const Vec3vf tri_Ng; + + const GridMesh::Grid &g; + const SubGrid& subgrid; + const size_t i; + }; + + + template + struct SubGridQuadMIntersectorKPlueckerBase + { + __forceinline SubGridQuadMIntersectorKPlueckerBase(const vbool& valid, const RayK& ray) {} + + template + __forceinline vbool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + const Vec3vf& tri_Ng, + const vbool& flags, + const GridMesh::Grid &g, + const SubGrid &subgrid, + const unsigned int i, + const Epilog& epilog) const + { + /* calculate denominator */ + /* calculate vertices relative to ray origin */ + vbool valid = valid0; + const Vec3vf O = ray.org; + const Vec3vf D = ray.dir; + const Vec3vf v0 = tri_v0-O; + const Vec3vf v1 = tri_v1-O; + const Vec3vf v2 = tri_v2-O; + + /* calculate triangle edges */ + const Vec3vf e0 = v2-v0; + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v1-v2; + + /* perform edge tests */ + const vfloat U = dot(Vec3vf(cross(e0,v2+v0)),D); + const vfloat V = dot(Vec3vf(cross(e1,v0+v1)),D); + const vfloat W = dot(Vec3vf(cross(e2,v1+v2)),D); + const vfloat UVW = U+V+W; + const vfloat eps = float(ulp)*abs(UVW); +#if defined(EMBREE_BACKFACE_CULLING) + valid &= max(U,V,W) <= eps; +#else + valid &= (min(U,V,W) >= -eps) | (max(U,V,W) <= eps); +#endif + if (unlikely(none(valid))) return false; + + /* calculate geometry normal and denominator */ + const Vec3vf Ng = stable_triangle_normal(e0,e1,e2); + const vfloat den = twice(dot(Vec3vf(Ng),D)); + + /* perform depth test */ + const vfloat T = twice(dot(v0,Vec3vf(Ng))); + const vfloat t = rcp(den)*T; + valid &= ray.tnear() <= t & t <= ray.tfar; + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; + + /* calculate hit information */ + SubGridQuadHitPlueckerK hit(U,V,UVW,t,tri_Ng,flags,g,subgrid,i); + return epilog(valid,hit); + } + + template + __forceinline vbool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const vbool& flags, + const GridMesh::Grid &g, + const SubGrid &subgrid, + const unsigned int i, + const Epilog& epilog) const + { + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v2-v0; + const Vec3vf Ng = cross(e2,e1); + return intersectK(valid0,ray,v0,v1,v2,Ng,flags,g,subgrid,i,epilog); + } + + template + __forceinline bool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const Vec3vf& v3, + const GridMesh::Grid &g, + const SubGrid &subgrid, + const unsigned int i, + const Epilog& epilog) const + { + intersectK(valid0,ray,v0,v1,v3,vbool(false),g,subgrid,i,epilog); + if (none(valid0)) return true; + intersectK(valid0,ray,v2,v3,v1,vbool(true ),g,subgrid,i,epilog); + return none(valid0); + } + + static __forceinline bool intersect1(RayK& ray, + size_t k, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + const Vec3vf& tri_Ng, + const vbool& flags, + SubGridQuadHitPlueckerM &hit) + { + /* calculate vertices relative to ray origin */ + const Vec3vf O = broadcast>(ray.org,k); + const Vec3vf D = broadcast>(ray.dir,k); + const Vec3vf v0 = tri_v0-O; + const Vec3vf v1 = tri_v1-O; + const Vec3vf v2 = tri_v2-O; + + /* calculate triangle edges */ + const Vec3vf e0 = v2-v0; + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v1-v2; + + /* perform edge tests */ + const vfloat U = dot(cross(e0,v2+v0),D); + const vfloat V = dot(cross(e1,v0+v1),D); + const vfloat W = dot(cross(e2,v1+v2),D); + const vfloat UVW = U+V+W; + const vfloat eps = float(ulp)*abs(UVW); +#if defined(EMBREE_BACKFACE_CULLING) + vbool valid = max(U,V,W) <= eps ; +#else + vbool valid = (min(U,V,W) >= -eps) | (max(U,V,W) <= eps); +#endif + if (unlikely(none(valid))) return false; + + /* calculate geometry normal and denominator */ + const Vec3vf Ng = stable_triangle_normal(e0,e1,e2); + const vfloat den = twice(dot(Ng,D)); + + /* perform depth test */ + const vfloat T = twice(dot(v0,Ng)); + const vfloat t = rcp(den)*T; + valid &= vfloat(ray.tnear()[k]) <= t & t <= vfloat(ray.tfar[k]); + if (unlikely(none(valid))) return false; + + /* avoid division by 0 */ + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; + + /* update hit information */ + new (&hit) SubGridQuadHitPlueckerM(valid,U,V,UVW,t,tri_Ng,flags); + return true; + } + + static __forceinline bool intersect1(RayK& ray, + size_t k, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const vbool& flags, + SubGridQuadHitPlueckerM &hit) + { + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v2-v0; + const Vec3vf Ng = cross(e2,e1); // FIXME: optimize!!! + return intersect1(ray,k,v0,v1,v2,Ng,flags,hit); + } + + }; + + template + struct SubGridQuadMIntersectorKPluecker : public SubGridQuadMIntersectorKPlueckerBase + { + __forceinline SubGridQuadMIntersectorKPluecker(const vbool& valid, const RayK& ray) + : SubGridQuadMIntersectorKPlueckerBase(valid,ray) {} + + __forceinline void intersect1(RayHitK& ray, size_t k, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, const GridMesh::Grid &g, const SubGrid &subgrid) const + { + Intersect1KEpilogMU epilog(ray,k,context,subgrid.geomID(),subgrid.primID()); + + SubGridQuadHitPlueckerM<4> hit; + if (SubGridQuadMIntersectorKPlueckerBase<4,K,filter>::intersect1(ray,k,v0,v1,v3,vboolf4(false),hit)) + { + interpolateUV(hit,g,subgrid,vint(0,1,1,0),vint(0,0,1,1)); + epilog(hit.valid,hit); + } + + if (SubGridQuadMIntersectorKPlueckerBase<4,K,filter>::intersect1(ray,k,v2,v3,v1,vboolf4(true),hit)) + { + interpolateUV(hit,g,subgrid,vint(0,1,1,0),vint(0,0,1,1)); + epilog(hit.valid,hit); + } + + } + + __forceinline bool occluded1(RayK& ray, size_t k, IntersectContext* context, + const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const Vec3vf& v3, const GridMesh::Grid &g, const SubGrid &subgrid) const + { + Occluded1KEpilogMU epilog(ray,k,context,subgrid.geomID(),subgrid.primID()); + + SubGridQuadHitPlueckerM<4> hit; + if (SubGridQuadMIntersectorKPlueckerBase<4,K,filter>::intersect1(ray,k,v0,v1,v3,vboolf4(false),hit)) + { + interpolateUV(hit,g,subgrid,vint(0,1,1,0),vint(0,0,1,1)); + if (epilog(hit.valid,hit)) return true; + } + + if (SubGridQuadMIntersectorKPlueckerBase<4,K,filter>::intersect1(ray,k,v2,v3,v1,vboolf4(true),hit)) + { + interpolateUV(hit,g,subgrid,vint(0,1,1,0),vint(0,0,1,1)); + if (epilog(hit.valid,hit)) return true; + } + return false; + } + }; + + } +} diff --git a/thirdparty/embree/kernels/geometry/subgrid_mb_intersector.h b/thirdparty/embree/kernels/geometry/subgrid_mb_intersector.h new file mode 100644 index 000000000000..400a88b985a7 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/subgrid_mb_intersector.h @@ -0,0 +1,236 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "subgrid_intersector.h" + +namespace embree +{ + namespace isa + { + template + struct SubGridMBIntersector1Pluecker + { + typedef SubGridMBQBVHN Primitive; + typedef SubGridQuadMIntersector1Pluecker<4,filter> Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const SubGrid& subgrid) + { + STAT3(normal.trav_prims,1,1,1); + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + float ftime; + const int itime = mesh->timeSegment(ray.time(), ftime); + Vec3vf4 v0,v1,v2,v3; subgrid.gatherMB(v0,v1,v2,v3,context->scene,itime,ftime); + pre.intersect(ray,context,v0,v1,v2,v3,g,subgrid); + } + + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const SubGrid& subgrid) + { + STAT3(shadow.trav_prims,1,1,1); + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + float ftime; + const int itime = mesh->timeSegment(ray.time(), ftime); + + Vec3vf4 v0,v1,v2,v3; subgrid.gatherMB(v0,v1,v2,v3,context->scene,itime,ftime); + return pre.occluded(ray,context,v0,v1,v2,v3,g,subgrid); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const SubGrid& subgrid) + { + return PrimitivePointQuery1::pointQuery(query, context, subgrid); + } + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersector1 isec1; + for (size_t i=0;i dist; + const float time = prim[i].adjustTime(ray.time()); + + assert(time <= 1.0f); + size_t mask = isec1.intersect(&prim[i].qnode,tray,time,dist); +#if defined(__AVX__) + STAT3(normal.trav_hit_boxes[popcnt(mask)],1,1,1); +#endif + while(mask != 0) + { + const size_t ID = bscf(mask); + if (unlikely(dist[ID] > ray.tfar)) continue; + intersect(pre,ray,context,prim[i].subgrid(ID)); + } + } + } + + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersector1 isec1; + for (size_t i=0;i dist; + size_t mask = isec1.intersect(&prim[i].qnode,tray,time,dist); + while(mask != 0) + { + const size_t ID = bscf(mask); + if (occluded(pre,ray,context,prim[i].subgrid(ID))) + return true; + } + } + return false; + } + + static __forceinline bool pointQuery(const Accel::Intersectors* This, PointQuery* query, PointQueryContext* context, const Primitive* prim, size_t num, const TravPointQuery &tquery, size_t& lazy_node) + { + assert(false && "not implemented"); + return false; + } + }; + + + template + struct SubGridMBIntersectorKPluecker + { + typedef SubGridMBQBVHN Primitive; + typedef SubGridQuadMIntersectorKPluecker<4,K,filter> Precalculations; + + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const SubGrid& subgrid) + { + size_t m_valid = movemask(valid_i); + while(m_valid) + { + size_t ID = bscf(m_valid); + intersect(pre,ray,ID,context,subgrid); + } + } + + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const SubGrid& subgrid) + { + vbool valid0 = valid_i; + size_t m_valid = movemask(valid_i); + while(m_valid) + { + size_t ID = bscf(m_valid); + if (occluded(pre,ray,ID,context,subgrid)) + clear(valid0,ID); + } + return !valid0; + } + + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const SubGrid& subgrid) + { + STAT3(normal.trav_prims,1,1,1); + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + vfloat ftime; + const vint itime = mesh->timeSegment(ray.time(), ftime); + Vec3vf4 v0,v1,v2,v3; subgrid.gatherMB(v0,v1,v2,v3,context->scene,itime[k],ftime[k]); + pre.intersect1(ray,k,context,v0,v1,v2,v3,g,subgrid); + } + + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const SubGrid& subgrid) + { + STAT3(shadow.trav_prims,1,1,1); + const GridMesh* mesh = context->scene->get(subgrid.geomID()); + const GridMesh::Grid &g = mesh->grid(subgrid.primID()); + + vfloat ftime; + const vint itime = mesh->timeSegment(ray.time(), ftime); + Vec3vf4 v0,v1,v2,v3; subgrid.gatherMB(v0,v1,v2,v3,context->scene,itime[k],ftime[k]); + return pre.occluded1(ray,k,context,v0,v1,v2,v3,g,subgrid); + } + + template + static __forceinline void intersect(const vbool& valid, const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRayK &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersectorK isecK; + for (size_t j=0;j time = prim[j].adjustTime(ray.time()); + + vfloat dist; + while(m_valid) + { + const size_t i = bscf(m_valid); + if (none(valid & isecK.intersectK(&prim[j].qnode,i,tray,time,dist))) continue; + intersect(valid,pre,ray,context,prim[j].subgrid(i)); + } + } + } + + template + static __forceinline vbool occluded(const vbool& valid, const Accel::Intersectors* This, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive* prim, size_t num, const TravRayK &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersectorK isecK; + + vbool valid0 = valid; + for (size_t j=0;j time = prim[j].adjustTime(ray.time()); + vfloat dist; + while(m_valid) + { + const size_t i = bscf(m_valid); + if (none(valid0 & isecK.intersectK(&prim[j].qnode,i,tray,time,dist))) continue; + valid0 &= !occluded(valid0,pre,ray,context,prim[j].subgrid(i)); + if (none(valid0)) break; + } + } + return !valid0; + } + + template + static __forceinline void intersect(const Accel::Intersectors* This, Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersector1 isec1; + for (size_t i=0;i dist; + const float time = prim[i].adjustTime(ray.time()[k]); + assert(time <= 1.0f); + + size_t mask = isec1.intersect(&prim[i].qnode,tray,time,dist); + while(mask != 0) + { + const size_t ID = bscf(mask); + if (unlikely(dist[ID] > ray.tfar[k])) continue; + intersect(pre,ray,k,context,prim[i].subgrid(ID)); + } + } + } + + template + static __forceinline bool occluded(const Accel::Intersectors* This, Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive* prim, size_t num, const TravRay &tray, size_t& lazy_node) + { + BVHNQuantizedBaseNodeIntersector1 isec1; + + for (size_t i=0;i dist; + const float time = prim[i].adjustTime(ray.time()[k]); + assert(time <= 1.0f); + + size_t mask = isec1.intersect(&prim[i].qnode,tray,time,dist); + while(mask != 0) + { + const size_t ID = bscf(mask); + if (occluded(pre,ray,k,context,prim[i].subgrid(ID))) + return true; + } + } + return false; + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/triangle.h b/thirdparty/embree/kernels/geometry/triangle.h new file mode 100644 index 000000000000..0dedf6dc4c13 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/triangle.h @@ -0,0 +1,162 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" + +namespace embree +{ + /* Precalculated representation for M triangles. Stores for each + triangle a base vertex, two edges, and the geometry normal to + speed up intersection calculations */ + template + struct TriangleM + { + public: + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + /* Returns maximum number of stored triangles */ + static __forceinline size_t max_size() { return M; } + + /* Returns required number of primitive blocks for N primitives */ + static __forceinline size_t blocks(size_t N) { return (N+max_size()-1)/max_size(); } + + public: + + /* Default constructor */ + __forceinline TriangleM() {} + + /* Construction from vertices and IDs */ + __forceinline TriangleM(const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const vuint& geomIDs, const vuint& primIDs) + : v0(v0), e1(v0-v1), e2(v2-v0), geomIDs(geomIDs), primIDs(primIDs) {} + + /* Returns a mask that tells which triangles are valid */ + __forceinline vbool valid() const { return geomIDs != vuint(-1); } + + /* Returns true if the specified triangle is valid */ + __forceinline bool valid(const size_t i) const { assert(i& geomID() { return geomIDs; } + __forceinline const vuint& geomID() const { return geomIDs; } + __forceinline unsigned int geomID(const size_t i) const { assert(i& primID() { return primIDs; } + __forceinline const vuint& primID() const { return primIDs; } + __forceinline unsigned int primID(const size_t i) const { assert(i p0 = v0; + Vec3vf p1 = v0-e1; + Vec3vf p2 = v0+e2; + Vec3vf lower = min(p0,p1,p2); + Vec3vf upper = max(p0,p1,p2); + vbool mask = valid(); + lower.x = select(mask,lower.x,vfloat(pos_inf)); + lower.y = select(mask,lower.y,vfloat(pos_inf)); + lower.z = select(mask,lower.z,vfloat(pos_inf)); + upper.x = select(mask,upper.x,vfloat(neg_inf)); + upper.y = select(mask,upper.y,vfloat(neg_inf)); + upper.z = select(mask,upper.z,vfloat(neg_inf)); + return BBox3fa(Vec3fa(reduce_min(lower.x),reduce_min(lower.y),reduce_min(lower.z)), + Vec3fa(reduce_max(upper.x),reduce_max(upper.y),reduce_max(upper.z))); + } + + /* Non temporal store */ + __forceinline static void store_nt(TriangleM* dst, const TriangleM& src) + { + vfloat::store_nt(&dst->v0.x,src.v0.x); + vfloat::store_nt(&dst->v0.y,src.v0.y); + vfloat::store_nt(&dst->v0.z,src.v0.z); + vfloat::store_nt(&dst->e1.x,src.e1.x); + vfloat::store_nt(&dst->e1.y,src.e1.y); + vfloat::store_nt(&dst->e1.z,src.e1.z); + vfloat::store_nt(&dst->e2.x,src.e2.x); + vfloat::store_nt(&dst->e2.y,src.e2.y); + vfloat::store_nt(&dst->e2.z,src.e2.z); + vuint::store_nt(&dst->geomIDs,src.geomIDs); + vuint::store_nt(&dst->primIDs,src.primIDs); + } + + /* Fill triangle from triangle list */ + __forceinline void fill(const PrimRef* prims, size_t& begin, size_t end, Scene* scene) + { + vuint vgeomID = -1, vprimID = -1; + Vec3vf v0 = zero, v1 = zero, v2 = zero; + + for (size_t i=0; iget(geomID); + const TriangleMesh::Triangle& tri = mesh->triangle(primID); + const Vec3fa& p0 = mesh->vertex(tri.v[0]); + const Vec3fa& p1 = mesh->vertex(tri.v[1]); + const Vec3fa& p2 = mesh->vertex(tri.v[2]); + vgeomID [i] = geomID; + vprimID [i] = primID; + v0.x[i] = p0.x; v0.y[i] = p0.y; v0.z[i] = p0.z; + v1.x[i] = p1.x; v1.y[i] = p1.y; v1.z[i] = p1.z; + v2.x[i] = p2.x; v2.y[i] = p2.y; v2.z[i] = p2.z; + } + TriangleM::store_nt(this,TriangleM(v0,v1,v2,vgeomID,vprimID)); + } + + /* Updates the primitive */ + __forceinline BBox3fa update(TriangleMesh* mesh) + { + BBox3fa bounds = empty; + vuint vgeomID = -1, vprimID = -1; + Vec3vf v0 = zero, v1 = zero, v2 = zero; + + for (size_t i=0; itriangle(primId); + const Vec3fa p0 = mesh->vertex(tri.v[0]); + const Vec3fa p1 = mesh->vertex(tri.v[1]); + const Vec3fa p2 = mesh->vertex(tri.v[2]); + bounds.extend(merge(BBox3fa(p0),BBox3fa(p1),BBox3fa(p2))); + vgeomID [i] = geomId; + vprimID [i] = primId; + v0.x[i] = p0.x; v0.y[i] = p0.y; v0.z[i] = p0.z; + v1.x[i] = p1.x; v1.y[i] = p1.y; v1.z[i] = p1.z; + v2.x[i] = p2.x; v2.y[i] = p2.y; v2.z[i] = p2.z; + } + TriangleM::store_nt(this,TriangleM(v0,v1,v2,vgeomID,vprimID)); + return bounds; + } + + public: + Vec3vf v0; // base vertex of the triangles + Vec3vf e1; // 1st edge of the triangles (v0-v1) + Vec3vf e2; // 2nd edge of the triangles (v2-v0) + private: + vuint geomIDs; // geometry IDs + vuint primIDs; // primitive IDs + }; + + template + typename TriangleM::Type TriangleM::type; + + typedef TriangleM<4> Triangle4; +} diff --git a/thirdparty/embree/kernels/geometry/triangle_intersector.h b/thirdparty/embree/kernels/geometry/triangle_intersector.h new file mode 100644 index 000000000000..125a42c5fee0 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/triangle_intersector.h @@ -0,0 +1,96 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "triangle.h" +#include "triangle_intersector_moeller.h" + +namespace embree +{ + namespace isa + { + /*! Intersects M triangles with 1 ray */ + template + struct TriangleMIntersector1Moeller + { + typedef TriangleM Primitive; + typedef MoellerTrumboreIntersector1 Precalculations; + + /*! Intersect a ray with the M triangles and updates the hit. */ + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const TriangleM& tri) + { + STAT3(normal.trav_prims,1,1,1); + pre.intersectEdge(ray,tri.v0,tri.e1,tri.e2,Intersect1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + /*! Test if the ray is occluded by one of M triangles. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const TriangleM& tri) + { + STAT3(shadow.trav_prims,1,1,1); + return pre.intersectEdge(ray,tri.v0,tri.e1,tri.e2,Occluded1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& tri) + { + return PrimitivePointQuery1::pointQuery(query, context, tri); + } + + }; + + /*! Intersects M triangles with K rays. */ + template + struct TriangleMIntersectorKMoeller + { + typedef TriangleM Primitive; + typedef MoellerTrumboreIntersectorK Precalculations; + + /*! Intersects K rays with M triangles. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const TriangleM& tri) + { + STAT_USER(0,TriangleM::max_size()); + for (size_t i=0; i::max_size(); i++) + { + if (!tri.valid(i)) break; + STAT3(normal.trav_prims,1,popcnt(valid_i),K); + const Vec3vf p0 = broadcast>(tri.v0,i); + const Vec3vf e1 = broadcast>(tri.e1,i); + const Vec3vf e2 = broadcast>(tri.e2,i); + pre.intersectEdgeK(valid_i,ray,p0,e1,e2,IntersectKEpilogM(ray,context,tri.geomID(),tri.primID(),i)); + } + } + + /*! Test for K rays if they are occluded by any of the M triangles. */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const TriangleM& tri) + { + vbool valid0 = valid_i; + + for (size_t i=0; i::max_size(); i++) + { + if (!tri.valid(i)) break; + STAT3(shadow.trav_prims,1,popcnt(valid0),K); + const Vec3vf p0 = broadcast>(tri.v0,i); + const Vec3vf e1 = broadcast>(tri.e1,i); + const Vec3vf e2 = broadcast>(tri.e2,i); + pre.intersectEdgeK(valid0,ray,p0,e1,e2,OccludedKEpilogM(valid0,ray,context,tri.geomID(),tri.primID(),i)); + if (none(valid0)) break; + } + return !valid0; + } + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const TriangleM& tri) + { + STAT3(normal.trav_prims,1,1,1); + pre.intersectEdge(ray,k,tri.v0,tri.e1,tri.e2,Intersect1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const TriangleM& tri) + { + STAT3(shadow.trav_prims,1,1,1); + return pre.intersectEdge(ray,k,tri.v0,tri.e1,tri.e2,Occluded1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/triangle_intersector_moeller.h b/thirdparty/embree/kernels/geometry/triangle_intersector_moeller.h new file mode 100644 index 000000000000..b5a8519236da --- /dev/null +++ b/thirdparty/embree/kernels/geometry/triangle_intersector_moeller.h @@ -0,0 +1,403 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "triangle.h" +#include "intersector_epilog.h" + +/*! This intersector implements a modified version of the Moeller + * Trumbore intersector from the paper "Fast, Minimum Storage + * Ray-Triangle Intersection". In contrast to the paper we + * precalculate some factors and factor the calculations differently + * to allow precalculating the cross product e1 x e2. The resulting + * algorithm is similar to the fastest one of the paper "Optimizing + * Ray-Triangle Intersection via Automated Search". */ + +namespace embree +{ + namespace isa + { + template + struct MoellerTrumboreHitM + { + __forceinline MoellerTrumboreHitM() {} + + __forceinline MoellerTrumboreHitM(const vbool& valid, const vfloat& U, const vfloat& V, const vfloat& T, const vfloat& absDen, const Vec3vf& Ng) + : U(U), V(V), T(T), absDen(absDen), valid(valid), vNg(Ng) {} + + __forceinline void finalize() + { + const vfloat rcpAbsDen = rcp(absDen); + vt = T * rcpAbsDen; + vu = U * rcpAbsDen; + vv = V * rcpAbsDen; + } + + __forceinline Vec2f uv (const size_t i) const { return Vec2f(vu[i],vv[i]); } + __forceinline float t (const size_t i) const { return vt[i]; } + __forceinline Vec3fa Ng(const size_t i) const { return Vec3fa(vNg.x[i],vNg.y[i],vNg.z[i]); } + + public: + vfloat U; + vfloat V; + vfloat T; + vfloat absDen; + + public: + vbool valid; + vfloat vu; + vfloat vv; + vfloat vt; + Vec3vf vNg; + }; + + template + struct MoellerTrumboreIntersector1 + { + __forceinline MoellerTrumboreIntersector1() {} + + __forceinline MoellerTrumboreIntersector1(const Ray& ray, const void* ptr) {} + + __forceinline bool intersect(const vbool& valid0, + Ray& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + const Vec3vf& tri_Ng, + MoellerTrumboreHitM& hit) const + { + /* calculate denominator */ + vbool valid = valid0; + const Vec3vf O = Vec3vf((Vec3fa)ray.org); + const Vec3vf D = Vec3vf((Vec3fa)ray.dir); + const Vec3vf C = Vec3vf(tri_v0) - O; + const Vec3vf R = cross(C,D); + const vfloat den = dot(Vec3vf(tri_Ng),D); + + const vfloat absDen = abs(den); + const vfloat sgnDen = signmsk(den); + + /* perform edge tests */ + const vfloat U = dot(R,Vec3vf(tri_e2)) ^ sgnDen; + const vfloat V = dot(R,Vec3vf(tri_e1)) ^ sgnDen; + + /* perform backface culling */ +#if defined(EMBREE_BACKFACE_CULLING) + valid &= (den < vfloat(zero)) & (U >= 0.0f) & (V >= 0.0f) & (U+V<=absDen); +#else + valid &= (den != vfloat(zero)) & (U >= 0.0f) & (V >= 0.0f) & (U+V<=absDen); +#endif + if (likely(none(valid))) return false; + + /* perform depth test */ + const vfloat T = dot(Vec3vf(tri_Ng),C) ^ sgnDen; + valid &= (absDen*vfloat(ray.tnear()) < T) & (T <= absDen*vfloat(ray.tfar)); + if (likely(none(valid))) return false; + + + /* update hit information */ + new (&hit) MoellerTrumboreHitM(valid,U,V,T,absDen,tri_Ng); + + return true; + } + + __forceinline bool intersectEdge(Ray& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + MoellerTrumboreHitM& hit) const + { + vbool valid = true; + const Vec3> tri_Ng = cross(tri_e2,tri_e1); + return intersect(valid,ray,tri_v0,tri_e1,tri_e2,tri_Ng,hit); + } + + __forceinline bool intersect(Ray& ray, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + MoellerTrumboreHitM& hit) const + { + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v2-v0; + return intersectEdge(ray,v0,e1,e2,hit); + } + + __forceinline bool intersect(const vbool& valid, + Ray& ray, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + MoellerTrumboreHitM& hit) const + { + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v2-v0; + return intersectEdge(valid,ray,v0,e1,e2,hit); + } + + template + __forceinline bool intersectEdge(Ray& ray, + const Vec3vf& v0, + const Vec3vf& e1, + const Vec3vf& e2, + const Epilog& epilog) const + { + MoellerTrumboreHitM hit; + if (likely(intersectEdge(ray,v0,e1,e2,hit))) return epilog(hit.valid,hit); + return false; + } + + template + __forceinline bool intersect(Ray& ray, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const Epilog& epilog) const + { + MoellerTrumboreHitM hit; + if (likely(intersect(ray,v0,v1,v2,hit))) return epilog(hit.valid,hit); + return false; + } + + template + __forceinline bool intersect(const vbool& valid, + Ray& ray, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const Epilog& epilog) const + { + MoellerTrumboreHitM hit; + if (likely(intersect(valid,ray,v0,v1,v2,hit))) return epilog(hit.valid,hit); + return false; + } + }; + + template + struct MoellerTrumboreHitK + { + __forceinline MoellerTrumboreHitK(const vfloat& U, const vfloat& V, const vfloat& T, const vfloat& absDen, const Vec3vf& Ng) + : U(U), V(V), T(T), absDen(absDen), Ng(Ng) {} + + __forceinline std::tuple,vfloat,vfloat,Vec3vf> operator() () const + { + const vfloat rcpAbsDen = rcp(absDen); + const vfloat t = T * rcpAbsDen; + const vfloat u = U * rcpAbsDen; + const vfloat v = V * rcpAbsDen; + return std::make_tuple(u,v,t,Ng); + } + + private: + const vfloat U; + const vfloat V; + const vfloat T; + const vfloat absDen; + const Vec3vf Ng; + }; + + template + struct MoellerTrumboreIntersectorK + { + __forceinline MoellerTrumboreIntersectorK(const vbool& valid, const RayK& ray) {} + + /*! Intersects K rays with one of M triangles. */ + template + __forceinline vbool intersectK(const vbool& valid0, + //RayK& ray, + const Vec3vf& ray_org, + const Vec3vf& ray_dir, + const vfloat& ray_tnear, + const vfloat& ray_tfar, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + const Vec3vf& tri_Ng, + const Epilog& epilog) const + { + /* calculate denominator */ + vbool valid = valid0; + const Vec3vf C = tri_v0 - ray_org; + const Vec3vf R = cross(C,ray_dir); + const vfloat den = dot(tri_Ng,ray_dir); + const vfloat absDen = abs(den); + const vfloat sgnDen = signmsk(den); + + /* test against edge p2 p0 */ + const vfloat U = dot(tri_e2,R) ^ sgnDen; + valid &= U >= 0.0f; + if (likely(none(valid))) return false; + + /* test against edge p0 p1 */ + const vfloat V = dot(tri_e1,R) ^ sgnDen; + valid &= V >= 0.0f; + if (likely(none(valid))) return false; + + /* test against edge p1 p2 */ + const vfloat W = absDen-U-V; + valid &= W >= 0.0f; + if (likely(none(valid))) return false; + + /* perform depth test */ + const vfloat T = dot(tri_Ng,C) ^ sgnDen; + valid &= (absDen*ray_tnear < T) & (T <= absDen*ray_tfar); + if (unlikely(none(valid))) return false; + + /* perform backface culling */ +#if defined(EMBREE_BACKFACE_CULLING) + valid &= den < vfloat(zero); + if (unlikely(none(valid))) return false; +#else + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; +#endif + + /* calculate hit information */ + MoellerTrumboreHitK hit(U,V,T,absDen,tri_Ng); + return epilog(valid,hit); + } + + /*! Intersects K rays with one of M triangles. */ + template + __forceinline vbool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + const Epilog& epilog) const + { + const Vec3vf e1 = tri_v0-tri_v1; + const Vec3vf e2 = tri_v2-tri_v0; + const Vec3vf Ng = cross(e2,e1); + return intersectK(valid0,ray.org,ray.dir,ray.tnear(),ray.tfar,tri_v0,e1,e2,Ng,epilog); + } + + /*! Intersects K rays with one of M triangles. */ + template + __forceinline vbool intersectEdgeK(const vbool& valid0, + RayK& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + const Epilog& epilog) const + { + const Vec3vf tri_Ng = cross(tri_e2,tri_e1); + return intersectK(valid0,ray.org,ray.dir,ray.tnear(),ray.tfar,tri_v0,tri_e1,tri_e2,tri_Ng,epilog); + } + + /*! Intersect k'th ray from ray packet of size K with M triangles. */ + __forceinline bool intersectEdge(RayK& ray, + size_t k, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + MoellerTrumboreHitM& hit) const + { + /* calculate denominator */ + typedef Vec3vf Vec3vfM; + const Vec3vf tri_Ng = cross(tri_e2,tri_e1); + + const Vec3vfM O = broadcast>(ray.org,k); + const Vec3vfM D = broadcast>(ray.dir,k); + const Vec3vfM C = Vec3vfM(tri_v0) - O; + const Vec3vfM R = cross(C,D); + const vfloat den = dot(Vec3vfM(tri_Ng),D); + const vfloat absDen = abs(den); + const vfloat sgnDen = signmsk(den); + + /* perform edge tests */ + const vfloat U = dot(Vec3vf(tri_e2),R) ^ sgnDen; + const vfloat V = dot(Vec3vf(tri_e1),R) ^ sgnDen; + + /* perform backface culling */ +#if defined(EMBREE_BACKFACE_CULLING) + vbool valid = (den < vfloat(zero)) & (U >= 0.0f) & (V >= 0.0f) & (U+V<=absDen); +#else + vbool valid = (den != vfloat(zero)) & (U >= 0.0f) & (V >= 0.0f) & (U+V<=absDen); +#endif + if (likely(none(valid))) return false; + + /* perform depth test */ + const vfloat T = dot(Vec3vf(tri_Ng),C) ^ sgnDen; + valid &= (absDen*vfloat(ray.tnear()[k]) < T) & (T <= absDen*vfloat(ray.tfar[k])); + if (likely(none(valid))) return false; + + /* calculate hit information */ + new (&hit) MoellerTrumboreHitM(valid,U,V,T,absDen,tri_Ng); + return true; + } + + __forceinline bool intersectEdge(RayK& ray, + size_t k, + const BBox>& time_range, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + MoellerTrumboreHitM& hit) const + { + if (likely(intersect(ray,k,tri_v0,tri_e1,tri_e2,hit))) + { + hit.valid &= time_range.lower <= vfloat(ray.time[k]); + hit.valid &= vfloat(ray.time[k]) < time_range.upper; + return any(hit.valid); + } + return false; + } + + template + __forceinline bool intersectEdge(RayK& ray, + size_t k, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + const Epilog& epilog) const + { + MoellerTrumboreHitM hit; + if (likely(intersectEdge(ray,k,tri_v0,tri_e1,tri_e2,hit))) return epilog(hit.valid,hit); + return false; + } + + template + __forceinline bool intersectEdge(RayK& ray, + size_t k, + const BBox>& time_range, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + const Epilog& epilog) const + { + MoellerTrumboreHitM hit; + if (likely(intersectEdge(ray,k,time_range,tri_v0,tri_e1,tri_e2,hit))) return epilog(hit.valid,hit); + return false; + } + + template + __forceinline bool intersect(RayK& ray, + size_t k, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const Epilog& epilog) const + { + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v2-v0; + return intersectEdge(ray,k,v0,e1,e2,epilog); + } + + template + __forceinline bool intersect(RayK& ray, + size_t k, + const BBox>& time_range, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const Epilog& epilog) const + { + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v2-v0; + return intersectEdge(ray,k,time_range,v0,e1,e2,epilog); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/triangle_intersector_pluecker.h b/thirdparty/embree/kernels/geometry/triangle_intersector_pluecker.h new file mode 100644 index 000000000000..f1de99d2087c --- /dev/null +++ b/thirdparty/embree/kernels/geometry/triangle_intersector_pluecker.h @@ -0,0 +1,247 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "triangle.h" +#include "trianglev.h" +#include "trianglev_mb.h" +#include "intersector_epilog.h" + +/*! Modified Pluecker ray/triangle intersector. The test first shifts + * the ray origin into the origin of the coordinate system and then + * uses Pluecker coordinates for the intersection. Due to the shift, + * the Pluecker coordinate calculation simplifies and the tests get + * numerically stable. The edge equations are watertight along the + * edge for neighboring triangles. */ + +namespace embree +{ + namespace isa + { + template + struct PlueckerHitM + { + __forceinline PlueckerHitM(const vfloat& U, const vfloat& V, const vfloat& UVW, const vfloat& t, const Vec3vf& Ng, const UVMapper& mapUV) + : U(U), V(V), UVW(UVW), mapUV(mapUV), vt(t), vNg(Ng) {} + + __forceinline void finalize() + { + const vbool invalid = abs(UVW) < min_rcp_input; + const vfloat rcpUVW = select(invalid,vfloat(0.0f),rcp(UVW)); + vu = U * rcpUVW; + vv = V * rcpUVW; + mapUV(vu,vv); + } + + __forceinline Vec2f uv (const size_t i) const { return Vec2f(vu[i],vv[i]); } + __forceinline float t (const size_t i) const { return vt[i]; } + __forceinline Vec3fa Ng(const size_t i) const { return Vec3fa(vNg.x[i],vNg.y[i],vNg.z[i]); } + + private: + const vfloat U; + const vfloat V; + const vfloat UVW; + const UVMapper& mapUV; + + public: + vfloat vu; + vfloat vv; + vfloat vt; + Vec3vf vNg; + }; + + template + struct PlueckerIntersector1 + { + __forceinline PlueckerIntersector1() {} + + __forceinline PlueckerIntersector1(const Ray& ray, const void* ptr) {} + + template + __forceinline bool intersect(Ray& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + const UVMapper& mapUV, + const Epilog& epilog) const + { + /* calculate vertices relative to ray origin */ + const Vec3vf O = Vec3vf((Vec3fa)ray.org); + const Vec3vf D = Vec3vf((Vec3fa)ray.dir); + const Vec3vf v0 = tri_v0-O; + const Vec3vf v1 = tri_v1-O; + const Vec3vf v2 = tri_v2-O; + + /* calculate triangle edges */ + const Vec3vf e0 = v2-v0; + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v1-v2; + + /* perform edge tests */ + const vfloat U = dot(cross(e0,v2+v0),D); + const vfloat V = dot(cross(e1,v0+v1),D); + const vfloat W = dot(cross(e2,v1+v2),D); + const vfloat UVW = U+V+W; + const vfloat eps = float(ulp)*abs(UVW); +#if defined(EMBREE_BACKFACE_CULLING) + vbool valid = max(U,V,W) <= eps; +#else + vbool valid = (min(U,V,W) >= -eps) | (max(U,V,W) <= eps); +#endif + if (unlikely(none(valid))) return false; + + /* calculate geometry normal and denominator */ + const Vec3vf Ng = stable_triangle_normal(e0,e1,e2); + const vfloat den = twice(dot(Ng,D)); + + /* perform depth test */ + const vfloat T = twice(dot(v0,Ng)); + const vfloat t = rcp(den)*T; + valid &= vfloat(ray.tnear()) <= t & t <= vfloat(ray.tfar); + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; + + /* update hit information */ + PlueckerHitM hit(U,V,UVW,t,Ng,mapUV); + return epilog(valid,hit); + } + }; + + template + struct PlueckerHitK + { + __forceinline PlueckerHitK(const vfloat& U, const vfloat& V, const vfloat& UVW, const vfloat& t, const Vec3vf& Ng, const UVMapper& mapUV) + : U(U), V(V), UVW(UVW), t(t), Ng(Ng), mapUV(mapUV) {} + + __forceinline std::tuple,vfloat,vfloat,Vec3vf> operator() () const + { + const vbool invalid = abs(UVW) < min_rcp_input; + const vfloat rcpUVW = select(invalid,vfloat(0.0f),rcp(UVW)); + vfloat u = U * rcpUVW; + vfloat v = V * rcpUVW; + mapUV(u,v); + return std::make_tuple(u,v,t,Ng); + } + + private: + const vfloat U; + const vfloat V; + const vfloat UVW; + const vfloat t; + const Vec3vf Ng; + const UVMapper& mapUV; + }; + + template + struct PlueckerIntersectorK + { + __forceinline PlueckerIntersectorK(const vbool& valid, const RayK& ray) {} + + /*! Intersects K rays with one of M triangles. */ + template + __forceinline vbool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + const UVMapper& mapUV, + const Epilog& epilog) const + { + /* calculate vertices relative to ray origin */ + vbool valid = valid0; + const Vec3vf O = ray.org; + const Vec3vf D = ray.dir; + const Vec3vf v0 = tri_v0-O; + const Vec3vf v1 = tri_v1-O; + const Vec3vf v2 = tri_v2-O; + + /* calculate triangle edges */ + const Vec3vf e0 = v2-v0; + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v1-v2; + + /* perform edge tests */ + const vfloat U = dot(Vec3vf(cross(e0,v2+v0)),D); + const vfloat V = dot(Vec3vf(cross(e1,v0+v1)),D); + const vfloat W = dot(Vec3vf(cross(e2,v1+v2)),D); + const vfloat UVW = U+V+W; + const vfloat eps = float(ulp)*abs(UVW); +#if defined(EMBREE_BACKFACE_CULLING) + valid &= max(U,V,W) <= eps; +#else + valid &= (min(U,V,W) >= -eps) | (max(U,V,W) <= eps); +#endif + if (unlikely(none(valid))) return false; + + /* calculate geometry normal and denominator */ + const Vec3vf Ng = stable_triangle_normal(e0,e1,e2); + const vfloat den = twice(dot(Vec3vf(Ng),D)); + + /* perform depth test */ + const vfloat T = twice(dot(v0,Vec3vf(Ng))); + const vfloat t = rcp(den)*T; + valid &= ray.tnear() <= t & t <= ray.tfar; + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; + + /* calculate hit information */ + PlueckerHitK hit(U,V,UVW,t,Ng,mapUV); + return epilog(valid,hit); + } + + /*! Intersect k'th ray from ray packet of size K with M triangles. */ + template + __forceinline bool intersect(RayK& ray, size_t k, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + const UVMapper& mapUV, + const Epilog& epilog) const + { + /* calculate vertices relative to ray origin */ + const Vec3vf O = broadcast>(ray.org,k); + const Vec3vf D = broadcast>(ray.dir,k); + const Vec3vf v0 = tri_v0-O; + const Vec3vf v1 = tri_v1-O; + const Vec3vf v2 = tri_v2-O; + + /* calculate triangle edges */ + const Vec3vf e0 = v2-v0; + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v1-v2; + + /* perform edge tests */ + const vfloat U = dot(cross(e0,v2+v0),D); + const vfloat V = dot(cross(e1,v0+v1),D); + const vfloat W = dot(cross(e2,v1+v2),D); + const vfloat UVW = U+V+W; + const vfloat eps = float(ulp)*abs(UVW); +#if defined(EMBREE_BACKFACE_CULLING) + vbool valid = max(U,V,W) <= eps; +#else + vbool valid = (min(U,V,W) >= -eps) | (max(U,V,W) <= eps); +#endif + if (unlikely(none(valid))) return false; + + /* calculate geometry normal and denominator */ + const Vec3vf Ng = stable_triangle_normal(e0,e1,e2); + const vfloat den = twice(dot(Ng,D)); + + /* perform depth test */ + const vfloat T = twice(dot(v0,Ng)); + const vfloat t = rcp(den)*T; + valid &= vfloat(ray.tnear()[k]) <= t & t <= vfloat(ray.tfar[k]); + if (unlikely(none(valid))) return false; + + /* avoid division by 0 */ + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; + + /* update hit information */ + PlueckerHitM hit(U,V,UVW,t,Ng,mapUV); + return epilog(valid,hit); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/triangle_intersector_woop.h b/thirdparty/embree/kernels/geometry/triangle_intersector_woop.h new file mode 100644 index 000000000000..63e649d8fb6e --- /dev/null +++ b/thirdparty/embree/kernels/geometry/triangle_intersector_woop.h @@ -0,0 +1,418 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "triangle.h" +#include "intersector_epilog.h" + +/*! This intersector implements a modified version of the Woop's ray-triangle intersection test */ + +namespace embree +{ + namespace isa + { + template + struct WoopHitM + { + __forceinline WoopHitM() {} + + __forceinline WoopHitM(const vbool& valid, + const vfloat& U, + const vfloat& V, + const vfloat& T, + const vfloat& inv_det, + const Vec3vf& Ng) + : U(U), V(V), T(T), inv_det(inv_det), valid(valid), vNg(Ng) {} + + __forceinline void finalize() + { + vt = T; + vu = U*inv_det; + vv = V*inv_det; + } + + __forceinline Vec2f uv (const size_t i) const { return Vec2f(vu[i],vv[i]); } + __forceinline float t (const size_t i) const { return vt[i]; } + __forceinline Vec3fa Ng(const size_t i) const { return Vec3fa(vNg.x[i],vNg.y[i],vNg.z[i]); } + + private: + const vfloat U; + const vfloat V; + const vfloat T; + const vfloat inv_det; + + public: + const vbool valid; + vfloat vu; + vfloat vv; + vfloat vt; + Vec3vf vNg; + }; + + template + struct WoopPrecalculations1 + { + unsigned int kx,ky,kz; + Vec3vf org; + Vec3fa S; + __forceinline WoopPrecalculations1() {} + + __forceinline WoopPrecalculations1(const Ray& ray, const void* ptr) + { + kz = maxDim(abs(ray.dir)); + kx = (kz+1) % 3; + ky = (kx+1) % 3; + const float inv_dir_kz = rcp(ray.dir[kz]); + if (ray.dir[kz]) std::swap(kx,ky); + S.x = ray.dir[kx] * inv_dir_kz; + S.y = ray.dir[ky] * inv_dir_kz; + S.z = inv_dir_kz; + org = Vec3vf(ray.org[kx],ray.org[ky],ray.org[kz]); + } + }; + + + template + struct WoopIntersector1 + { + + typedef WoopPrecalculations1 Precalculations; + + __forceinline WoopIntersector1() {} + + __forceinline WoopIntersector1(const Ray& ray, const void* ptr) {} + + static __forceinline bool intersect(const vbool& valid0, + Ray& ray, + const Precalculations& pre, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + WoopHitM& hit) + { + vbool valid = valid0; + + /* vertices relative to ray origin */ + const Vec3vf org = Vec3vf(pre.org.x,pre.org.y,pre.org.z); + const Vec3vf A = Vec3vf(tri_v0[pre.kx],tri_v0[pre.ky],tri_v0[pre.kz]) - org; + const Vec3vf B = Vec3vf(tri_v1[pre.kx],tri_v1[pre.ky],tri_v1[pre.kz]) - org; + const Vec3vf C = Vec3vf(tri_v2[pre.kx],tri_v2[pre.ky],tri_v2[pre.kz]) - org; + + /* shear and scale vertices */ + const vfloat Ax = nmadd(A.z,pre.S.x,A.x); + const vfloat Ay = nmadd(A.z,pre.S.y,A.y); + const vfloat Bx = nmadd(B.z,pre.S.x,B.x); + const vfloat By = nmadd(B.z,pre.S.y,B.y); + const vfloat Cx = nmadd(C.z,pre.S.x,C.x); + const vfloat Cy = nmadd(C.z,pre.S.y,C.y); + + /* scaled barycentric */ + const vfloat U0 = Cx*By; + const vfloat U1 = Cy*Bx; + const vfloat V0 = Ax*Cy; + const vfloat V1 = Ay*Cx; + const vfloat W0 = Bx*Ay; + const vfloat W1 = By*Ax; +#if !defined(__AVX512F__) + valid &= (U0 >= U1) & (V0 >= V1) & (W0 >= W1) | + (U0 <= U1) & (V0 <= V1) & (W0 <= W1); +#else + valid &= ge(ge(U0 >= U1,V0,V1),W0,W1) | le(le(U0 <= U1,V0,V1),W0,W1); +#endif + + if (likely(none(valid))) return false; + const vfloat U = U0-U1; + const vfloat V = V0-V1; + const vfloat W = W0-W1; + + const vfloat det = U+V+W; + + valid &= det != 0.0f; + const vfloat inv_det = rcp(det); + + const vfloat Az = pre.S.z * A.z; + const vfloat Bz = pre.S.z * B.z; + const vfloat Cz = pre.S.z * C.z; + const vfloat T = madd(U,Az,madd(V,Bz,W*Cz)); + const vfloat t = T * inv_det; + /* perform depth test */ + valid &= (vfloat(ray.tnear()) < t) & (t <= vfloat(ray.tfar)); + if (likely(none(valid))) return false; + + const Vec3vf tri_Ng = cross(tri_v2-tri_v0,tri_v0-tri_v1); + + /* update hit information */ + new (&hit) WoopHitM(valid,U,V,t,inv_det,tri_Ng); + return true; + } + + static __forceinline bool intersect(Ray& ray, + const Precalculations& pre, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + WoopHitM& hit) + { + vbool valid = true; + return intersect(valid,ray,pre,v0,v1,v2,hit); + } + + + template + static __forceinline bool intersect(Ray& ray, + const Precalculations& pre, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const Epilog& epilog) + { + WoopHitM hit; + if (likely(intersect(ray,pre,v0,v1,v2,hit))) return epilog(hit.valid,hit); + return false; + } + + template + static __forceinline bool intersect(const vbool& valid, + Ray& ray, + const Precalculations& pre, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const Epilog& epilog) + { + WoopHitM hit; + if (likely(intersect(valid,ray,pre,v0,v1,v2,hit))) return epilog(hit.valid,hit); + return false; + } + }; + +#if 0 + template + struct WoopHitK + { + __forceinline WoopHitK(const vfloat& U, const vfloat& V, const vfloat& T, const vfloat& absDen, const Vec3vf& Ng) + : U(U), V(V), T(T), absDen(absDen), Ng(Ng) {} + + __forceinline std::tuple,vfloat,vfloat,Vec3vf> operator() () const + { + const vfloat rcpAbsDen = rcp(absDen); + const vfloat t = T * rcpAbsDen; + const vfloat u = U * rcpAbsDen; + const vfloat v = V * rcpAbsDen; + return std::make_tuple(u,v,t,Ng); + } + + private: + const vfloat U; + const vfloat V; + const vfloat T; + const vfloat absDen; + const Vec3vf Ng; + }; + + template + struct WoopIntersectorK + { + __forceinline WoopIntersectorK(const vbool& valid, const RayK& ray) {} + + /*! Intersects K rays with one of M triangles. */ + template + __forceinline vbool intersectK(const vbool& valid0, + //RayK& ray, + const Vec3vf& ray_org, + const Vec3vf& ray_dir, + const vfloat& ray_tnear, + const vfloat& ray_tfar, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + const Vec3vf& tri_Ng, + const Epilog& epilog) const + { + /* calculate denominator */ + vbool valid = valid0; + const Vec3vf C = tri_v0 - ray_org; + const Vec3vf R = cross(C,ray_dir); + const vfloat den = dot(tri_Ng,ray_dir); + const vfloat absDen = abs(den); + const vfloat sgnDen = signmsk(den); + + /* test against edge p2 p0 */ + const vfloat U = dot(tri_e2,R) ^ sgnDen; + valid &= U >= 0.0f; + if (likely(none(valid))) return false; + + /* test against edge p0 p1 */ + const vfloat V = dot(tri_e1,R) ^ sgnDen; + valid &= V >= 0.0f; + if (likely(none(valid))) return false; + + /* test against edge p1 p2 */ + const vfloat W = absDen-U-V; + valid &= W >= 0.0f; + if (likely(none(valid))) return false; + + /* perform depth test */ + const vfloat T = dot(tri_Ng,C) ^ sgnDen; + valid &= (absDen*ray_tnear < T) & (T <= absDen*ray_tfar); + if (unlikely(none(valid))) return false; + + /* perform backface culling */ +#if defined(EMBREE_BACKFACE_CULLING) + valid &= den < vfloat(zero); + if (unlikely(none(valid))) return false; +#else + valid &= den != vfloat(zero); + if (unlikely(none(valid))) return false; +#endif + + /* calculate hit information */ + WoopHitK hit(U,V,T,absDen,tri_Ng); + return epilog(valid,hit); + } + + /*! Intersects K rays with one of M triangles. */ + template + __forceinline vbool intersectK(const vbool& valid0, + RayK& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_v1, + const Vec3vf& tri_v2, + const Epilog& epilog) const + { + const Vec3vf e1 = tri_v0-tri_v1; + const Vec3vf e2 = tri_v2-tri_v0; + const Vec3vf Ng = cross(e2,e1); + return intersectK(valid0,ray.org,ray.dir,ray.tnear(),ray.tfar,tri_v0,e1,e2,Ng,epilog); + } + + /*! Intersects K rays with one of M triangles. */ + template + __forceinline vbool intersectEdgeK(const vbool& valid0, + RayK& ray, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + const Epilog& epilog) const + { + const Vec3vf tri_Ng = cross(tri_e2,tri_e1); + return intersectK(valid0,ray.org,ray.dir,ray.tnear(),ray.tfar,tri_v0,tri_e1,tri_e2,tri_Ng,epilog); + } + + /*! Intersect k'th ray from ray packet of size K with M triangles. */ + __forceinline bool intersectEdge(RayK& ray, + size_t k, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + WoopHitM& hit) const + { + /* calculate denominator */ + typedef Vec3vf Vec3vfM; + const Vec3vf tri_Ng = cross(tri_e2,tri_e1); + + const Vec3vfM O = broadcast>(ray.org,k); + const Vec3vfM D = broadcast>(ray.dir,k); + const Vec3vfM C = Vec3vfM(tri_v0) - O; + const Vec3vfM R = cross(C,D); + const vfloat den = dot(Vec3vfM(tri_Ng),D); + const vfloat absDen = abs(den); + const vfloat sgnDen = signmsk(den); + + /* perform edge tests */ + const vfloat U = dot(Vec3vf(tri_e2),R) ^ sgnDen; + const vfloat V = dot(Vec3vf(tri_e1),R) ^ sgnDen; + + /* perform backface culling */ +#if defined(EMBREE_BACKFACE_CULLING) + vbool valid = (den < vfloat(zero)) & (U >= 0.0f) & (V >= 0.0f) & (U+V<=absDen); +#else + vbool valid = (den != vfloat(zero)) & (U >= 0.0f) & (V >= 0.0f) & (U+V<=absDen); +#endif + if (likely(none(valid))) return false; + + /* perform depth test */ + const vfloat T = dot(Vec3vf(tri_Ng),C) ^ sgnDen; + valid &= (absDen*vfloat(ray.tnear()[k]) < T) & (T <= absDen*vfloat(ray.tfar[k])); + if (likely(none(valid))) return false; + + /* calculate hit information */ + new (&hit) WoopHitM(valid,U,V,T,absDen,tri_Ng); + return true; + } + + __forceinline bool intersectEdge(RayK& ray, + size_t k, + const BBox>& time_range, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + WoopHitM& hit) const + { + if (likely(intersect(ray,k,tri_v0,tri_e1,tri_e2,hit))) + { + hit.valid &= time_range.lower <= vfloat(ray.time[k]); + hit.valid &= vfloat(ray.time[k]) < time_range.upper; + return any(hit.valid); + } + return false; + } + + template + __forceinline bool intersectEdge(RayK& ray, + size_t k, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + const Epilog& epilog) const + { + WoopHitM hit; + if (likely(intersectEdge(ray,k,tri_v0,tri_e1,tri_e2,hit))) return epilog(hit.valid,hit); + return false; + } + + template + __forceinline bool intersectEdge(RayK& ray, + size_t k, + const BBox>& time_range, + const Vec3vf& tri_v0, + const Vec3vf& tri_e1, + const Vec3vf& tri_e2, + const Epilog& epilog) const + { + WoopHitM hit; + if (likely(intersectEdge(ray,k,time_range,tri_v0,tri_e1,tri_e2,hit))) return epilog(hit.valid,hit); + return false; + } + + template + __forceinline bool intersect(RayK& ray, + size_t k, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const Epilog& epilog) const + { + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v2-v0; + return intersectEdge(ray,k,v0,e1,e2,epilog); + } + + template + __forceinline bool intersect(RayK& ray, + size_t k, + const BBox>& time_range, + const Vec3vf& v0, + const Vec3vf& v1, + const Vec3vf& v2, + const Epilog& epilog) const + { + const Vec3vf e1 = v0-v1; + const Vec3vf e2 = v2-v0; + return intersectEdge(ray,k,time_range,v0,e1,e2,epilog); + } + }; +#endif + } +} diff --git a/thirdparty/embree/kernels/geometry/triangle_triangle_intersector.h b/thirdparty/embree/kernels/geometry/triangle_triangle_intersector.h new file mode 100644 index 000000000000..91b35c36f396 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/triangle_triangle_intersector.h @@ -0,0 +1,132 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "primitive.h" + +namespace embree +{ + namespace isa + { + struct TriangleTriangleIntersector + { + __forceinline static float T(float pa0, float pa1, float da0, float da1) { + return pa0 + (pa1-pa0)*da0/(da0-da1); + } + + __forceinline static bool point_line_side(const Vec2f& p, const Vec2f& a0, const Vec2f& a1) { + return det(p-a0,a0-a1) >= 0.0f; + } + + __forceinline static bool point_inside_triangle(const Vec2f& p, const Vec2f& a, const Vec2f& b, const Vec2f& c) + { + const bool pab = point_line_side(p,a,b); + const bool pbc = point_line_side(p,b,c); + const bool pca = point_line_side(p,c,a); + return pab == pbc && pab == pca; + } + + __forceinline static bool intersect_line_line(const Vec2f& a0, const Vec2f& a1, const Vec2f& b0, const Vec2f& b1) + { + const bool different_sides0 = point_line_side(b0,a0,a1) != point_line_side(b1,a0,a1); + const bool different_sides1 = point_line_side(a0,b0,b1) != point_line_side(a1,b0,b1); + return different_sides0 && different_sides1; + } + + __forceinline static bool intersect_triangle_triangle (const Vec2f& a0, const Vec2f& a1, const Vec2f& a2, + const Vec2f& b0, const Vec2f& b1, const Vec2f& b2) + { + const bool a01_b01 = intersect_line_line(a0,a1,b0,b1); + if (a01_b01) return true; + const bool a01_b12 = intersect_line_line(a0,a1,b1,b2); + if (a01_b12) return true; + const bool a01_b20 = intersect_line_line(a0,a1,b2,b0); + if (a01_b20) return true; + const bool a12_b01 = intersect_line_line(a1,a2,b0,b1); + if (a12_b01) return true; + const bool a12_b12 = intersect_line_line(a1,a2,b1,b2); + if (a12_b12) return true; + const bool a12_b20 = intersect_line_line(a1,a2,b2,b0); + if (a12_b20) return true; + const bool a20_b01 = intersect_line_line(a2,a0,b0,b1); + if (a20_b01) return true; + const bool a20_b12 = intersect_line_line(a2,a0,b1,b2); + if (a20_b12) return true; + const bool a20_b20 = intersect_line_line(a2,a0,b2,b0); + if (a20_b20) return true; + + bool a_in_b = point_inside_triangle(a0,b0,b1,b2) && point_inside_triangle(a1,b0,b1,b2) && point_inside_triangle(a2,b0,b1,b2); + if (a_in_b) return true; + + bool b_in_a = point_inside_triangle(b0,a0,a1,a2) && point_inside_triangle(b1,a0,a1,a2) && point_inside_triangle(b2,a0,a1,a2); + if (b_in_a) return true; + + return false; + } + + static bool intersect_triangle_triangle (const Vec3fa& a0, const Vec3fa& a1, const Vec3fa& a2, + const Vec3fa& b0, const Vec3fa& b1, const Vec3fa& b2) + { + const float eps = 1E-5f; + + /* calculate triangle planes */ + const Vec3fa Na = cross(a1-a0,a2-a0); + const float Ca = dot(Na,a0); + const Vec3fa Nb = cross(b1-b0,b2-b0); + const float Cb = dot(Nb,b0); + + /* project triangle A onto plane B */ + const float da0 = dot(Nb,a0)-Cb; + const float da1 = dot(Nb,a1)-Cb; + const float da2 = dot(Nb,a2)-Cb; + if (max(da0,da1,da2) < -eps) return false; + if (min(da0,da1,da2) > +eps) return false; + //CSTAT(bvh_collide_prim_intersections4++); + + /* project triangle B onto plane A */ + const float db0 = dot(Na,b0)-Ca; + const float db1 = dot(Na,b1)-Ca; + const float db2 = dot(Na,b2)-Ca; + if (max(db0,db1,db2) < -eps) return false; + if (min(db0,db1,db2) > +eps) return false; + //CSTAT(bvh_collide_prim_intersections5++); + + if (unlikely((std::fabs(da0) < eps && std::fabs(da1) < eps && std::fabs(da2) < eps) || + (std::fabs(db0) < eps && std::fabs(db1) < eps && std::fabs(db2) < eps))) + { + const size_t dz = maxDim(Na); + const size_t dx = (dz+1)%3; + const size_t dy = (dx+1)%3; + const Vec2f A0(a0[dx],a0[dy]); + const Vec2f A1(a1[dx],a1[dy]); + const Vec2f A2(a2[dx],a2[dy]); + const Vec2f B0(b0[dx],b0[dy]); + const Vec2f B1(b1[dx],b1[dy]); + const Vec2f B2(b2[dx],b2[dy]); + return intersect_triangle_triangle(A0,A1,A2,B0,B1,B2); + } + + const Vec3fa D = cross(Na,Nb); + const float pa0 = dot(D,a0); + const float pa1 = dot(D,a1); + const float pa2 = dot(D,a2); + const float pb0 = dot(D,b0); + const float pb1 = dot(D,b1); + const float pb2 = dot(D,b2); + + BBox1f ba = empty; + if (min(da0,da1) <= 0.0f && max(da0,da1) >= 0.0f && abs(da0-da1) > 0.0f) ba.extend(T(pa0,pa1,da0,da1)); + if (min(da1,da2) <= 0.0f && max(da1,da2) >= 0.0f && abs(da1-da2) > 0.0f) ba.extend(T(pa1,pa2,da1,da2)); + if (min(da2,da0) <= 0.0f && max(da2,da0) >= 0.0f && abs(da2-da0) > 0.0f) ba.extend(T(pa2,pa0,da2,da0)); + + BBox1f bb = empty; + if (min(db0,db1) <= 0.0f && max(db0,db1) >= 0.0f && abs(db0-db1) > 0.0f) bb.extend(T(pb0,pb1,db0,db1)); + if (min(db1,db2) <= 0.0f && max(db1,db2) >= 0.0f && abs(db1-db2) > 0.0f) bb.extend(T(pb1,pb2,db1,db2)); + if (min(db2,db0) <= 0.0f && max(db2,db0) >= 0.0f && abs(db2-db0) > 0.0f) bb.extend(T(pb2,pb0,db2,db0)); + + return conjoint(ba,bb); + } + }; + } +} + + diff --git a/thirdparty/embree/kernels/geometry/trianglei.h b/thirdparty/embree/kernels/geometry/trianglei.h new file mode 100644 index 000000000000..4f3118cc0ccc --- /dev/null +++ b/thirdparty/embree/kernels/geometry/trianglei.h @@ -0,0 +1,442 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" +#include "../common/scene.h" + +namespace embree +{ + /* Stores M triangles from an indexed face set */ + template + struct TriangleMi + { + /* Virtual interface to query information about the triangle type */ + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + /* primitive supports multiple time segments */ + static const bool singleTimeSegment = false; + + /* Returns maximum number of stored triangles */ + static __forceinline size_t max_size() { return M; } + + /* Returns required number of primitive blocks for N primitives */ + static __forceinline size_t blocks(size_t N) { return (N+max_size()-1)/max_size(); } + + public: + + /* Default constructor */ + __forceinline TriangleMi() { } + + /* Construction from vertices and IDs */ + __forceinline TriangleMi(const vuint& v0, + const vuint& v1, + const vuint& v2, + const vuint& geomIDs, + const vuint& primIDs) +#if defined(EMBREE_COMPACT_POLYS) + : geomIDs(geomIDs), primIDs(primIDs) {} +#else + : v0_(v0), v1_(v1), v2_(v2), geomIDs(geomIDs), primIDs(primIDs) {} +#endif + + /* Returns a mask that tells which triangles are valid */ + __forceinline vbool valid() const { return primIDs != vuint(-1); } + + /* Returns if the specified triangle is valid */ + __forceinline bool valid(const size_t i) const { assert(i geomID() const { return geomIDs; } + __forceinline unsigned int geomID(const size_t i) const { assert(i primID() const { return primIDs; } + __forceinline unsigned int primID(const size_t i) const { assert(iget(geomID(i)); + bounds.extend(mesh->bounds(primID(i),itime)); + } + return bounds; + } + + /* Calculate the linear bounds of the primitive */ + __forceinline LBBox3fa linearBounds(const Scene *const scene, size_t itime) { + return LBBox3fa(bounds(scene,itime+0),bounds(scene,itime+1)); + } + + __forceinline LBBox3fa linearBounds(const Scene *const scene, size_t itime, size_t numTimeSteps) + { + LBBox3fa allBounds = empty; + for (size_t i=0; iget(geomID(i)); + allBounds.extend(mesh->linearBounds(primID(i), itime, numTimeSteps)); + } + return allBounds; + } + + __forceinline LBBox3fa linearBounds(const Scene *const scene, const BBox1f time_range) + { + LBBox3fa allBounds = empty; + for (size_t i=0; iget(geomID(i)); + allBounds.extend(mesh->linearBounds(primID(i), time_range)); + } + return allBounds; + } + + /* Non-temporal store */ + __forceinline static void store_nt(TriangleMi* dst, const TriangleMi& src) + { +#if !defined(EMBREE_COMPACT_POLYS) + vuint::store_nt(&dst->v0_,src.v0_); + vuint::store_nt(&dst->v1_,src.v1_); + vuint::store_nt(&dst->v2_,src.v2_); +#endif + vuint::store_nt(&dst->geomIDs,src.geomIDs); + vuint::store_nt(&dst->primIDs,src.primIDs); + } + + /* Fill triangle from triangle list */ + template + __forceinline void fill(const PrimRefT* prims, size_t& begin, size_t end, Scene* scene) + { + vuint v0 = zero, v1 = zero, v2 = zero; + vuint geomID = -1, primID = -1; + const PrimRefT* prim = &prims[begin]; + + for (size_t i=0; igeomID(); + primID[i] = prim->primID(); +#if !defined(EMBREE_COMPACT_POLYS) + const TriangleMesh* mesh = scene->get(prim->geomID()); + const TriangleMesh::Triangle& tri = mesh->triangle(prim->primID()); + unsigned int int_stride = mesh->vertices0.getStride()/4; + v0[i] = tri.v[0] * int_stride; + v1[i] = tri.v[1] * int_stride; + v2[i] = tri.v[2] * int_stride; +#endif + begin++; + } else { + assert(i); + if (likely(i > 0)) { + geomID[i] = geomID[0]; + primID[i] = -1; + v0[i] = v0[0]; + v1[i] = v0[0]; + v2[i] = v0[0]; + } + } + if (begintriangle(primId); + const Vec3fa p0 = mesh->vertex(tri.v[0]); + const Vec3fa p1 = mesh->vertex(tri.v[1]); + const Vec3fa p2 = mesh->vertex(tri.v[2]); + bounds.extend(merge(BBox3fa(p0),BBox3fa(p1),BBox3fa(p2))); + } + return bounds; + } + + protected: +#if !defined(EMBREE_COMPACT_POLYS) + vuint v0_; // 4 byte offset of 1st vertex + vuint v1_; // 4 byte offset of 2nd vertex + vuint v2_; // 4 byte offset of 3rd vertex +#endif + vuint geomIDs; // geometry ID of mesh + vuint primIDs; // primitive ID of primitive inside mesh + }; + + namespace isa + { + + template + struct TriangleMi : public embree::TriangleMi + { +#if !defined(EMBREE_COMPACT_POLYS) + using embree::TriangleMi::v0_; + using embree::TriangleMi::v1_; + using embree::TriangleMi::v2_; +#endif + using embree::TriangleMi::geomIDs; + using embree::TriangleMi::primIDs; + using embree::TriangleMi::geomID; + using embree::TriangleMi::primID; + using embree::TriangleMi::valid; + + /* loads a single vertex */ + template + __forceinline Vec3f getVertex(const size_t index, const Scene *const scene) const + { +#if defined(EMBREE_COMPACT_POLYS) + const TriangleMesh* mesh = scene->get(geomID(index)); + const TriangleMesh::Triangle& tri = mesh->triangle(primID(index)); + return (Vec3f) mesh->vertices[0][tri.v[vid]]; +#else + const vuint& v = getVertexOffset(); + const float* vertices = scene->vertices[geomID(index)]; + return (Vec3f&) vertices[v[index]]; +#endif + } + + template + __forceinline Vec3 getVertex(const size_t index, const Scene *const scene, const size_t itime, const T& ftime) const + { +#if defined(EMBREE_COMPACT_POLYS) + const TriangleMesh* mesh = scene->get(geomID(index)); + const TriangleMesh::Triangle& tri = mesh->triangle(primID(index)); + const Vec3fa v0 = mesh->vertices[itime+0][tri.v[vid]]; + const Vec3fa v1 = mesh->vertices[itime+1][tri.v[vid]]; +#else + const vuint& v = getVertexOffset(); + const TriangleMesh* mesh = scene->get(geomID(index)); + const float* vertices0 = (const float*) mesh->vertexPtr(0,itime+0); + const float* vertices1 = (const float*) mesh->vertexPtr(0,itime+1); + const Vec3fa v0 = Vec3fa::loadu(vertices0+v[index]); + const Vec3fa v1 = Vec3fa::loadu(vertices1+v[index]); +#endif + const Vec3 p0(v0.x,v0.y,v0.z); + const Vec3 p1(v1.x,v1.y,v1.z); + return lerp(p0,p1,ftime); + } + + template + __forceinline Vec3 getVertex(const vbool& valid, const size_t index, const Scene *const scene, const vint& itime, const T& ftime) const + { + Vec3 p0, p1; + const TriangleMesh* mesh = scene->get(geomID(index)); + + for (size_t mask=movemask(valid), i=bsf(mask); mask; mask=btc(mask,i), i=bsf(mask)) + { +#if defined(EMBREE_COMPACT_POLYS) + const TriangleMesh::Triangle& tri = mesh->triangle(primID(index)); + const Vec3fa v0 = mesh->vertices[itime[i]+0][tri.v[vid]]; + const Vec3fa v1 = mesh->vertices[itime[i]+1][tri.v[vid]]; +#else + const vuint& v = getVertexOffset(); + const float* vertices0 = (const float*) mesh->vertexPtr(0,itime[i]+0); + const float* vertices1 = (const float*) mesh->vertexPtr(0,itime[i]+1); + const Vec3fa v0 = Vec3fa::loadu(vertices0+v[index]); + const Vec3fa v1 = Vec3fa::loadu(vertices1+v[index]); +#endif + p0.x[i] = v0.x; p0.y[i] = v0.y; p0.z[i] = v0.z; + p1.x[i] = v1.x; p1.y[i] = v1.y; p1.z[i] = v1.z; + } + return (T(one)-ftime)*p0 + ftime*p1; + } + + struct Triangle { + vfloat4 v0,v1,v2; + }; + +#if defined(EMBREE_COMPACT_POLYS) + + __forceinline Triangle loadTriangle(const int i, const Scene* const scene) const + { + const unsigned int geomID = geomIDs[i]; + const unsigned int primID = primIDs[i]; + if (unlikely(primID == -1)) return { zero, zero, zero }; + const TriangleMesh* mesh = scene->get(geomID); + const TriangleMesh::Triangle& tri = mesh->triangle(primID); + const vfloat4 v0 = (vfloat4) mesh->vertices0[tri.v[0]]; + const vfloat4 v1 = (vfloat4) mesh->vertices0[tri.v[1]]; + const vfloat4 v2 = (vfloat4) mesh->vertices0[tri.v[2]]; + return { v0, v1, v2 }; + } + + __forceinline Triangle loadTriangle(const int i, const int itime, const TriangleMesh* const mesh) const + { + const unsigned int primID = primIDs[i]; + if (unlikely(primID == -1)) return { zero, zero, zero }; + const TriangleMesh::Triangle& tri = mesh->triangle(primID); + const vfloat4 v0 = (vfloat4) mesh->vertices[itime][tri.v[0]]; + const vfloat4 v1 = (vfloat4) mesh->vertices[itime][tri.v[1]]; + const vfloat4 v2 = (vfloat4) mesh->vertices[itime][tri.v[2]]; + return { v0, v1, v2 }; + } + +#else + + __forceinline Triangle loadTriangle(const int i, const Scene* const scene) const + { + const float* vertices = scene->vertices[geomID(i)]; + const vfloat4 v0 = vfloat4::loadu(vertices + v0_[i]); + const vfloat4 v1 = vfloat4::loadu(vertices + v1_[i]); + const vfloat4 v2 = vfloat4::loadu(vertices + v2_[i]); + return { v0, v1, v2 }; + } + + __forceinline Triangle loadTriangle(const int i, const int itime, const TriangleMesh* const mesh) const + { + const float* vertices = (const float*) mesh->vertexPtr(0,itime); + const vfloat4 v0 = vfloat4::loadu(vertices + v0_[i]); + const vfloat4 v1 = vfloat4::loadu(vertices + v1_[i]); + const vfloat4 v2 = vfloat4::loadu(vertices + v2_[i]); + return { v0, v1, v2 }; + } + +#endif + + /* Gather the triangles */ + __forceinline void gather(Vec3vf& p0, Vec3vf& p1, Vec3vf& p2, const Scene* const scene) const; + + template +#if defined(__INTEL_COMPILER) && (__INTEL_COMPILER < 2000) // workaround for compiler bug in ICC 2019 + __noinline +#else + __forceinline +#endif + void gather(const vbool& valid, + Vec3vf& p0, + Vec3vf& p1, + Vec3vf& p2, + const size_t index, + const Scene* const scene, + const vfloat& time) const + { + const TriangleMesh* mesh = scene->get(geomID(index)); + + vfloat ftime; + const vint itime = mesh->timeSegment(time, ftime); + + const size_t first = bsf(movemask(valid)); + if (likely(all(valid,itime[first] == itime))) + { + p0 = getVertex<0>(index, scene, itime[first], ftime); + p1 = getVertex<1>(index, scene, itime[first], ftime); + p2 = getVertex<2>(index, scene, itime[first], ftime); + } else { + p0 = getVertex<0>(valid, index, scene, itime, ftime); + p1 = getVertex<1>(valid, index, scene, itime, ftime); + p2 = getVertex<2>(valid, index, scene, itime, ftime); + } + } + + __forceinline void gather(Vec3vf& p0, + Vec3vf& p1, + Vec3vf& p2, + const TriangleMesh* mesh, + const Scene *const scene, + const int itime) const; + + __forceinline void gather(Vec3vf& p0, + Vec3vf& p1, + Vec3vf& p2, + const Scene *const scene, + const float time) const; + + +#if !defined(EMBREE_COMPACT_POLYS) + template const vuint& getVertexOffset() const; +#endif + }; + +#if !defined(EMBREE_COMPACT_POLYS) + template<> template<> __forceinline const vuint<4>& TriangleMi<4>::getVertexOffset<0>() const { return v0_; } + template<> template<> __forceinline const vuint<4>& TriangleMi<4>::getVertexOffset<1>() const { return v1_; } + template<> template<> __forceinline const vuint<4>& TriangleMi<4>::getVertexOffset<2>() const { return v2_; } +#endif + + template<> + __forceinline void TriangleMi<4>::gather(Vec3vf4& p0, + Vec3vf4& p1, + Vec3vf4& p2, + const Scene* const scene) const + { + const Triangle tri0 = loadTriangle(0,scene); + const Triangle tri1 = loadTriangle(1,scene); + const Triangle tri2 = loadTriangle(2,scene); + const Triangle tri3 = loadTriangle(3,scene); + transpose(tri0.v0,tri1.v0,tri2.v0,tri3.v0,p0.x,p0.y,p0.z); + transpose(tri0.v1,tri1.v1,tri2.v1,tri3.v1,p1.x,p1.y,p1.z); + transpose(tri0.v2,tri1.v2,tri2.v2,tri3.v2,p2.x,p2.y,p2.z); + } + + template<> + __forceinline void TriangleMi<4>::gather(Vec3vf4& p0, + Vec3vf4& p1, + Vec3vf4& p2, + const TriangleMesh* mesh, + const Scene *const scene, + const int itime) const + { + const Triangle tri0 = loadTriangle(0,itime,mesh); + const Triangle tri1 = loadTriangle(1,itime,mesh); + const Triangle tri2 = loadTriangle(2,itime,mesh); + const Triangle tri3 = loadTriangle(3,itime,mesh); + transpose(tri0.v0,tri1.v0,tri2.v0,tri3.v0,p0.x,p0.y,p0.z); + transpose(tri0.v1,tri1.v1,tri2.v1,tri3.v1,p1.x,p1.y,p1.z); + transpose(tri0.v2,tri1.v2,tri2.v2,tri3.v2,p2.x,p2.y,p2.z); + } + + template<> + __forceinline void TriangleMi<4>::gather(Vec3vf4& p0, + Vec3vf4& p1, + Vec3vf4& p2, + const Scene *const scene, + const float time) const + { + const TriangleMesh* mesh = scene->get(geomID(0)); // in mblur mode all geometries are identical + + float ftime; + const int itime = mesh->timeSegment(time, ftime); + + Vec3vf4 a0,a1,a2; gather(a0,a1,a2,mesh,scene,itime); + Vec3vf4 b0,b1,b2; gather(b0,b1,b2,mesh,scene,itime+1); + p0 = lerp(a0,b0,vfloat4(ftime)); + p1 = lerp(a1,b1,vfloat4(ftime)); + p2 = lerp(a2,b2,vfloat4(ftime)); + } + } + + template + typename TriangleMi::Type TriangleMi::type; + + typedef TriangleMi<4> Triangle4i; +} diff --git a/thirdparty/embree/kernels/geometry/trianglei_intersector.h b/thirdparty/embree/kernels/geometry/trianglei_intersector.h new file mode 100644 index 000000000000..e2f106a62cd8 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/trianglei_intersector.h @@ -0,0 +1,336 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "trianglei.h" +#include "triangle_intersector_moeller.h" +#include "triangle_intersector_pluecker.h" + +namespace embree +{ + namespace isa + { + /*! Intersects M triangles with 1 ray */ + template + struct TriangleMiIntersector1Moeller + { + typedef TriangleMi Primitive; + typedef MoellerTrumboreIntersector1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0, v1, v2; tri.gather(v0,v1,v2,context->scene); + pre.intersect(ray,v0,v1,v2,/*UVIdentity(),*/Intersect1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0, v1, v2; tri.gather(v0,v1,v2,context->scene); + return pre.intersect(ray,v0,v1,v2,/*UVIdentity(),*/Occluded1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& tri) + { + return PrimitivePointQuery1::pointQuery(query, context, tri); + } + }; + + /*! Intersects M triangles with K rays */ + template + struct TriangleMiIntersectorKMoeller + { + typedef TriangleMi Primitive; + typedef MoellerTrumboreIntersectorK Precalculations; + + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive& tri) + { + const Scene* scene = context->scene; + for (size_t i=0; i::size()); + const Vec3vf v0 = tri.template getVertex<0>(i,scene); + const Vec3vf v1 = tri.template getVertex<1>(i,scene); + const Vec3vf v2 = tri.template getVertex<2>(i,scene); + pre.intersectK(valid_i,ray,v0,v1,v2,/*UVIdentity(),*/IntersectKEpilogM(ray,context,tri.geomID(),tri.primID(),i)); + } + } + + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive& tri) + { + vbool valid0 = valid_i; + const Scene* scene = context->scene; + + for (size_t i=0; i::size()); + const Vec3vf v0 = tri.template getVertex<0>(i,scene); + const Vec3vf v1 = tri.template getVertex<1>(i,scene); + const Vec3vf v2 = tri.template getVertex<2>(i,scene); + pre.intersectK(valid0,ray,v0,v1,v2,/*UVIdentity(),*/OccludedKEpilogM(valid0,ray,context,tri.geomID(),tri.primID(),i)); + if (none(valid0)) break; + } + return !valid0; + } + + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& tri) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0, v1, v2; tri.gather(v0,v1,v2,context->scene); + pre.intersect(ray,k,v0,v1,v2,/*UVIdentity(),*/Intersect1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& tri) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0, v1, v2; tri.gather(v0,v1,v2,context->scene); + return pre.intersect(ray,k,v0,v1,v2,/*UVIdentity(),*/Occluded1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + }; + + /*! Intersects M triangles with 1 ray */ + template + struct TriangleMiIntersector1Pluecker + { + typedef TriangleMi Primitive; + typedef PlueckerIntersector1 Precalculations; + + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0, v1, v2; tri.gather(v0,v1,v2,context->scene); + pre.intersect(ray,v0,v1,v2,UVIdentity(),Intersect1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0, v1, v2; tri.gather(v0,v1,v2,context->scene); + return pre.intersect(ray,v0,v1,v2,UVIdentity(),Occluded1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& tri) + { + return PrimitivePointQuery1::pointQuery(query, context, tri); + } + }; + + /*! Intersects M triangles with K rays */ + template + struct TriangleMiIntersectorKPluecker + { + typedef TriangleMi Primitive; + typedef PlueckerIntersectorK Precalculations; + + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive& tri) + { + const Scene* scene = context->scene; + for (size_t i=0; i::size()); + const Vec3vf v0 = tri.template getVertex<0>(i,scene); + const Vec3vf v1 = tri.template getVertex<1>(i,scene); + const Vec3vf v2 = tri.template getVertex<2>(i,scene); + pre.intersectK(valid_i,ray,v0,v1,v2,UVIdentity(),IntersectKEpilogM(ray,context,tri.geomID(),tri.primID(),i)); + } + } + + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive& tri) + { + vbool valid0 = valid_i; + const Scene* scene = context->scene; + + for (size_t i=0; i::size()); + const Vec3vf v0 = tri.template getVertex<0>(i,scene); + const Vec3vf v1 = tri.template getVertex<1>(i,scene); + const Vec3vf v2 = tri.template getVertex<2>(i,scene); + pre.intersectK(valid0,ray,v0,v1,v2,UVIdentity(),OccludedKEpilogM(valid0,ray,context,tri.geomID(),tri.primID(),i)); + if (none(valid0)) break; + } + return !valid0; + } + + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& tri) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0, v1, v2; tri.gather(v0,v1,v2,context->scene); + pre.intersect(ray,k,v0,v1,v2,UVIdentity(),Intersect1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& tri) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0, v1, v2; tri.gather(v0,v1,v2,context->scene); + return pre.intersect(ray,k,v0,v1,v2,UVIdentity(),Occluded1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + }; + + /*! Intersects M motion blur triangles with 1 ray */ + template + struct TriangleMiMBIntersector1Moeller + { + typedef TriangleMi Primitive; + typedef MoellerTrumboreIntersector1 Precalculations; + + /*! Intersect a ray with the M triangles and updates the hit. */ + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0,v1,v2; tri.gather(v0,v1,v2,context->scene,ray.time()); + pre.intersect(ray,v0,v1,v2,/*UVIdentity(),*/Intersect1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + /*! Test if the ray is occluded by one of M triangles. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0,v1,v2; tri.gather(v0,v1,v2,context->scene,ray.time()); + return pre.intersect(ray,v0,v1,v2,/*UVIdentity(),*/Occluded1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& tri) + { + return PrimitivePointQuery1::pointQuery(query, context, tri); + } + }; + + /*! Intersects M motion blur triangles with K rays. */ + template + struct TriangleMiMBIntersectorKMoeller + { + typedef TriangleMi Primitive; + typedef MoellerTrumboreIntersectorK Precalculations; + + /*! Intersects K rays with M triangles. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const TriangleMi& tri) + { + for (size_t i=0; i::max_size(); i++) + { + if (!tri.valid(i)) break; + STAT3(normal.trav_prims,1,popcnt(valid_i),K); + Vec3vf v0,v1,v2; tri.gather(valid_i,v0,v1,v2,i,context->scene,ray.time()); + pre.intersectK(valid_i,ray,v0,v1,v2,/*UVIdentity(),*/IntersectKEpilogM(ray,context,tri.geomID(),tri.primID(),i)); + } + } + + /*! Test for K rays if they are occluded by any of the M triangles. */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const TriangleMi& tri) + { + vbool valid0 = valid_i; + for (size_t i=0; i::max_size(); i++) + { + if (!tri.valid(i)) break; + STAT3(shadow.trav_prims,1,popcnt(valid0),K); + Vec3vf v0,v1,v2; tri.gather(valid_i,v0,v1,v2,i,context->scene,ray.time()); + pre.intersectK(valid0,ray,v0,v1,v2,/*UVIdentity(),*/OccludedKEpilogM(valid0,ray,context,tri.geomID(),tri.primID(),i)); + if (none(valid0)) break; + } + return !valid0; + } + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const TriangleMi& tri) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0,v1,v2; tri.gather(v0,v1,v2,context->scene,ray.time()[k]); + pre.intersect(ray,k,v0,v1,v2,/*UVIdentity(),*/Intersect1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const TriangleMi& tri) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0,v1,v2; tri.gather(v0,v1,v2,context->scene,ray.time()[k]); + return pre.intersect(ray,k,v0,v1,v2,/*UVIdentity(),*/Occluded1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + }; + + /*! Intersects M motion blur triangles with 1 ray */ + template + struct TriangleMiMBIntersector1Pluecker + { + typedef TriangleMi Primitive; + typedef PlueckerIntersector1 Precalculations; + + /*! Intersect a ray with the M triangles and updates the hit. */ + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0,v1,v2; tri.gather(v0,v1,v2,context->scene,ray.time()); + pre.intersect(ray,v0,v1,v2,UVIdentity(),Intersect1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + /*! Test if the ray is occluded by one of M triangles. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0,v1,v2; tri.gather(v0,v1,v2,context->scene,ray.time()); + return pre.intersect(ray,v0,v1,v2,UVIdentity(),Occluded1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& tri) + { + return PrimitivePointQuery1::pointQuery(query, context, tri); + } + }; + + /*! Intersects M motion blur triangles with K rays. */ + template + struct TriangleMiMBIntersectorKPluecker + { + typedef TriangleMi Primitive; + typedef PlueckerIntersectorK Precalculations; + + /*! Intersects K rays with M triangles. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const TriangleMi& tri) + { + for (size_t i=0; i::max_size(); i++) + { + if (!tri.valid(i)) break; + STAT3(normal.trav_prims,1,popcnt(valid_i),K); + Vec3vf v0,v1,v2; tri.gather(valid_i,v0,v1,v2,i,context->scene,ray.time()); + pre.intersectK(valid_i,ray,v0,v1,v2,UVIdentity(),IntersectKEpilogM(ray,context,tri.geomID(),tri.primID(),i)); + } + } + + /*! Test for K rays if they are occluded by any of the M triangles. */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const TriangleMi& tri) + { + vbool valid0 = valid_i; + for (size_t i=0; i::max_size(); i++) + { + if (!tri.valid(i)) break; + STAT3(shadow.trav_prims,1,popcnt(valid0),K); + Vec3vf v0,v1,v2; tri.gather(valid_i,v0,v1,v2,i,context->scene,ray.time()); + pre.intersectK(valid0,ray,v0,v1,v2,UVIdentity(),OccludedKEpilogM(valid0,ray,context,tri.geomID(),tri.primID(),i)); + if (none(valid0)) break; + } + return !valid0; + } + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const TriangleMi& tri) + { + STAT3(normal.trav_prims,1,1,1); + Vec3vf v0,v1,v2; tri.gather(v0,v1,v2,context->scene,ray.time()[k]); + pre.intersect(ray,k,v0,v1,v2,UVIdentity(),Intersect1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const TriangleMi& tri) + { + STAT3(shadow.trav_prims,1,1,1); + Vec3vf v0,v1,v2; tri.gather(v0,v1,v2,context->scene,ray.time()[k]); + return pre.intersect(ray,k,v0,v1,v2,UVIdentity(),Occluded1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/trianglev.h b/thirdparty/embree/kernels/geometry/trianglev.h new file mode 100644 index 000000000000..19af389e73e4 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/trianglev.h @@ -0,0 +1,157 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" + +namespace embree +{ + /* Stores the vertices of M triangles in struct of array layout */ + template + struct TriangleMv + { + public: + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + static Type type; + + public: + + /* Returns maximum number of stored triangles */ + static __forceinline size_t max_size() { return M; } + + /* Returns required number of primitive blocks for N primitives */ + static __forceinline size_t blocks(size_t N) { return (N+max_size()-1)/max_size(); } + + public: + + /* Default constructor */ + __forceinline TriangleMv() {} + + /* Construction from vertices and IDs */ + __forceinline TriangleMv(const Vec3vf& v0, const Vec3vf& v1, const Vec3vf& v2, const vuint& geomIDs, const vuint& primIDs) + : v0(v0), v1(v1), v2(v2), geomIDs(geomIDs), primIDs(primIDs) {} + + /* Returns a mask that tells which triangles are valid */ + __forceinline vbool valid() const { return geomIDs != vuint(-1); } + + /* Returns true if the specified triangle is valid */ + __forceinline bool valid(const size_t i) const { assert(i& geomID() { return geomIDs; } + __forceinline const vuint& geomID() const { return geomIDs; } + __forceinline unsigned int geomID(const size_t i) const { assert(i& primID() { return primIDs; } + __forceinline const vuint& primID() const { return primIDs; } + __forceinline unsigned int primID(const size_t i) const { assert(i lower = min(v0,v1,v2); + Vec3vf upper = max(v0,v1,v2); + vbool mask = valid(); + lower.x = select(mask,lower.x,vfloat(pos_inf)); + lower.y = select(mask,lower.y,vfloat(pos_inf)); + lower.z = select(mask,lower.z,vfloat(pos_inf)); + upper.x = select(mask,upper.x,vfloat(neg_inf)); + upper.y = select(mask,upper.y,vfloat(neg_inf)); + upper.z = select(mask,upper.z,vfloat(neg_inf)); + return BBox3fa(Vec3fa(reduce_min(lower.x),reduce_min(lower.y),reduce_min(lower.z)), + Vec3fa(reduce_max(upper.x),reduce_max(upper.y),reduce_max(upper.z))); + } + + /* Non temporal store */ + __forceinline static void store_nt(TriangleMv* dst, const TriangleMv& src) + { + vfloat::store_nt(&dst->v0.x,src.v0.x); + vfloat::store_nt(&dst->v0.y,src.v0.y); + vfloat::store_nt(&dst->v0.z,src.v0.z); + vfloat::store_nt(&dst->v1.x,src.v1.x); + vfloat::store_nt(&dst->v1.y,src.v1.y); + vfloat::store_nt(&dst->v1.z,src.v1.z); + vfloat::store_nt(&dst->v2.x,src.v2.x); + vfloat::store_nt(&dst->v2.y,src.v2.y); + vfloat::store_nt(&dst->v2.z,src.v2.z); + vuint::store_nt(&dst->geomIDs,src.geomIDs); + vuint::store_nt(&dst->primIDs,src.primIDs); + } + + /* Fill triangle from triangle list */ + __forceinline void fill(const PrimRef* prims, size_t& begin, size_t end, Scene* scene) + { + vuint vgeomID = -1, vprimID = -1; + Vec3vf v0 = zero, v1 = zero, v2 = zero; + + for (size_t i=0; iget(geomID); + const TriangleMesh::Triangle& tri = mesh->triangle(primID); + const Vec3fa& p0 = mesh->vertex(tri.v[0]); + const Vec3fa& p1 = mesh->vertex(tri.v[1]); + const Vec3fa& p2 = mesh->vertex(tri.v[2]); + vgeomID [i] = geomID; + vprimID [i] = primID; + v0.x[i] = p0.x; v0.y[i] = p0.y; v0.z[i] = p0.z; + v1.x[i] = p1.x; v1.y[i] = p1.y; v1.z[i] = p1.z; + v2.x[i] = p2.x; v2.y[i] = p2.y; v2.z[i] = p2.z; + } + TriangleMv::store_nt(this,TriangleMv(v0,v1,v2,vgeomID,vprimID)); + } + + /* Updates the primitive */ + __forceinline BBox3fa update(TriangleMesh* mesh) + { + BBox3fa bounds = empty; + vuint vgeomID = -1, vprimID = -1; + Vec3vf v0 = zero, v1 = zero, v2 = zero; + + for (size_t i=0; itriangle(primId); + const Vec3fa p0 = mesh->vertex(tri.v[0]); + const Vec3fa p1 = mesh->vertex(tri.v[1]); + const Vec3fa p2 = mesh->vertex(tri.v[2]); + bounds.extend(merge(BBox3fa(p0),BBox3fa(p1),BBox3fa(p2))); + vgeomID [i] = geomId; + vprimID [i] = primId; + v0.x[i] = p0.x; v0.y[i] = p0.y; v0.z[i] = p0.z; + v1.x[i] = p1.x; v1.y[i] = p1.y; v1.z[i] = p1.z; + v2.x[i] = p2.x; v2.y[i] = p2.y; v2.z[i] = p2.z; + } + new (this) TriangleMv(v0,v1,v2,vgeomID,vprimID); + return bounds; + } + + public: + Vec3vf v0; // 1st vertex of the triangles + Vec3vf v1; // 2nd vertex of the triangles + Vec3vf v2; // 3rd vertex of the triangles + private: + vuint geomIDs; // geometry ID + vuint primIDs; // primitive ID + }; + + template + typename TriangleMv::Type TriangleMv::type; + + typedef TriangleMv<4> Triangle4v; +} diff --git a/thirdparty/embree/kernels/geometry/trianglev_intersector.h b/thirdparty/embree/kernels/geometry/trianglev_intersector.h new file mode 100644 index 000000000000..6af0d5a11c09 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/trianglev_intersector.h @@ -0,0 +1,206 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "triangle.h" +#include "triangle_intersector_pluecker.h" +#include "triangle_intersector_moeller.h" +#include "triangle_intersector_woop.h" + +namespace embree +{ + namespace isa + { + /*! Intersects M triangles with 1 ray */ + template + struct TriangleMvIntersector1Moeller + { + typedef TriangleMv Primitive; + typedef MoellerTrumboreIntersector1 Precalculations; + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(normal.trav_prims,1,1,1); + pre.intersect(ray,tri.v0,tri.v1,tri.v2,/*UVIdentity(),*/Intersect1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(shadow.trav_prims,1,1,1); + return pre.intersect(ray,tri.v0,tri.v1,tri.v2,/*UVIdentity(),*/Occluded1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& tri) + { + return PrimitivePointQuery1::pointQuery(query, context, tri); + } + }; + + + template + struct TriangleMvIntersector1Woop + { + typedef TriangleMv Primitive; + typedef WoopIntersector1 intersec; + typedef WoopPrecalculations1 Precalculations; + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(normal.trav_prims,1,1,1); + intersec::intersect(ray,pre,tri.v0,tri.v1,tri.v2,Intersect1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(shadow.trav_prims,1,1,1); + return intersec::intersect(ray,pre,tri.v0,tri.v1,tri.v2,Occluded1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& tri) + { + return PrimitivePointQuery1::pointQuery(query, context, tri); + } + }; + + + /*! Intersects M triangles with K rays */ + template + struct TriangleMvIntersectorKMoeller + { + typedef TriangleMv Primitive; + typedef MoellerTrumboreIntersectorK Precalculations; + + /*! Intersects K rays with M triangles. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive& tri) + { + for (size_t i=0; i v0 = broadcast>(tri.v0,i); + const Vec3vf v1 = broadcast>(tri.v1,i); + const Vec3vf v2 = broadcast>(tri.v2,i); + pre.intersectK(valid_i,ray,v0,v1,v2,/*UVIdentity(),*/IntersectKEpilogM(ray,context,tri.geomID(),tri.primID(),i)); + } + } + + /*! Test for K rays if they are occluded by any of the M triangles. */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive& tri) + { + vbool valid0 = valid_i; + + for (size_t i=0; i v0 = broadcast>(tri.v0,i); + const Vec3vf v1 = broadcast>(tri.v1,i); + const Vec3vf v2 = broadcast>(tri.v2,i); + pre.intersectK(valid0,ray,v0,v1,v2,/*UVIdentity(),*/OccludedKEpilogM(valid0,ray,context,tri.geomID(),tri.primID(),i)); + if (none(valid0)) break; + } + return !valid0; + } + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& tri) + { + STAT3(normal.trav_prims,1,1,1); + pre.intersect(ray,k,tri.v0,tri.v1,tri.v2,/*UVIdentity(),*/Intersect1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); //FIXME: M,Mx + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& tri) + { + STAT3(shadow.trav_prims,1,1,1); + return pre.intersect(ray,k,tri.v0,tri.v1,tri.v2,/*UVIdentity(),*/Occluded1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); //FIXME: M,Mx + } + }; + + /*! Intersects M triangles with 1 ray */ + template + struct TriangleMvIntersector1Pluecker + { + typedef TriangleMv Primitive; + typedef PlueckerIntersector1 Precalculations; + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHit& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(normal.trav_prims,1,1,1); + pre.intersect(ray,tri.v0,tri.v1,tri.v2,UVIdentity(),Intersect1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const Primitive& tri) + { + STAT3(shadow.trav_prims,1,1,1); + return pre.intersect(ray,tri.v0,tri.v1,tri.v2,UVIdentity(),Occluded1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& tri) + { + return PrimitivePointQuery1::pointQuery(query, context, tri); + } + }; + + /*! Intersects M triangles with K rays */ + template + struct TriangleMvIntersectorKPluecker + { + typedef TriangleMv Primitive; + typedef PlueckerIntersectorK Precalculations; + + /*! Intersects K rays with M triangles. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const Primitive& tri) + { + for (size_t i=0; i v0 = broadcast>(tri.v0,i); + const Vec3vf v1 = broadcast>(tri.v1,i); + const Vec3vf v2 = broadcast>(tri.v2,i); + pre.intersectK(valid_i,ray,v0,v1,v2,UVIdentity(),IntersectKEpilogM(ray,context,tri.geomID(),tri.primID(),i)); + } + } + + /*! Test for K rays if they are occluded by any of the M triangles. */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const Primitive& tri) + { + vbool valid0 = valid_i; + + for (size_t i=0; i v0 = broadcast>(tri.v0,i); + const Vec3vf v1 = broadcast>(tri.v1,i); + const Vec3vf v2 = broadcast>(tri.v2,i); + pre.intersectK(valid0,ray,v0,v1,v2,UVIdentity(),OccludedKEpilogM(valid0,ray,context,tri.geomID(),tri.primID(),i)); + if (none(valid0)) break; + } + return !valid0; + } + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const Primitive& tri) + { + STAT3(normal.trav_prims,1,1,1); + pre.intersect(ray,k,tri.v0,tri.v1,tri.v2,UVIdentity(),Intersect1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); //FIXME: M,Mx + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const Primitive& tri) + { + STAT3(shadow.trav_prims,1,1,1); + return pre.intersect(ray,k,tri.v0,tri.v1,tri.v2,UVIdentity(),Occluded1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); //FIXME: M,Mx + } + }; + } +} diff --git a/thirdparty/embree/kernels/geometry/trianglev_mb.h b/thirdparty/embree/kernels/geometry/trianglev_mb.h new file mode 100644 index 000000000000..63137aee166c --- /dev/null +++ b/thirdparty/embree/kernels/geometry/trianglev_mb.h @@ -0,0 +1,201 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "primitive.h" + +namespace embree +{ + /* Stores the vertices of M triangles in struct of array layout */ + template + struct TriangleMvMB + { + public: + struct Type : public PrimitiveType + { + const char* name() const; + size_t sizeActive(const char* This) const; + size_t sizeTotal(const char* This) const; + size_t getBytes(const char* This) const; + }; + + static Type type; + + public: + + /* primitive supports single time segments */ + static const bool singleTimeSegment = true; + + /* Returns maximum number of stored triangles */ + static __forceinline size_t max_size() { return M; } + + /* Returns required number of primitive blocks for N primitives */ + static __forceinline size_t blocks(size_t N) { return (N+max_size()-1)/max_size(); } + + public: + + /* Default constructor */ + __forceinline TriangleMvMB() {} + + /* Construction from vertices and IDs */ + __forceinline TriangleMvMB(const Vec3vf& a0, const Vec3vf& a1, + const Vec3vf& b0, const Vec3vf& b1, + const Vec3vf& c0, const Vec3vf& c1, + const vuint& geomIDs, const vuint& primIDs) + : v0(a0), v1(b0), v2(c0), dv0(a1-a0), dv1(b1-b0), dv2(c1-c0), geomIDs(geomIDs), primIDs(primIDs) {} + + /* Returns a mask that tells which triangles are valid */ + __forceinline vbool valid() const { return geomIDs != vuint(-1); } + + /* Returns if the specified triangle is valid */ + __forceinline bool valid(const size_t i) const { assert(i& geomID() { return geomIDs; } + __forceinline const vuint& geomID() const { return geomIDs; } + __forceinline unsigned int geomID(const size_t i) const { assert(i& primID() { return primIDs; } + __forceinline const vuint& primID() const { return primIDs; } + __forceinline unsigned int primID(const size_t i) const { assert(i lower = min(v0,v1,v2); + Vec3vf upper = max(v0,v1,v2); + const vbool mask = valid(); + lower.x = select(mask,lower.x,vfloat(pos_inf)); + lower.y = select(mask,lower.y,vfloat(pos_inf)); + lower.z = select(mask,lower.z,vfloat(pos_inf)); + upper.x = select(mask,upper.x,vfloat(neg_inf)); + upper.y = select(mask,upper.y,vfloat(neg_inf)); + upper.z = select(mask,upper.z,vfloat(neg_inf)); + return BBox3fa(Vec3fa(reduce_min(lower.x),reduce_min(lower.y),reduce_min(lower.z)), + Vec3fa(reduce_max(upper.x),reduce_max(upper.y),reduce_max(upper.z))); + } + + /* Calculate the bounds of the triangles at t1 */ + __forceinline BBox3fa bounds1() const + { + const Vec3vf p0 = v0+dv0; + const Vec3vf p1 = v1+dv1; + const Vec3vf p2 = v2+dv2; + Vec3vf lower = min(p0,p1,p2); + Vec3vf upper = max(p0,p1,p2); + const vbool mask = valid(); + lower.x = select(mask,lower.x,vfloat(pos_inf)); + lower.y = select(mask,lower.y,vfloat(pos_inf)); + lower.z = select(mask,lower.z,vfloat(pos_inf)); + upper.x = select(mask,upper.x,vfloat(neg_inf)); + upper.y = select(mask,upper.y,vfloat(neg_inf)); + upper.z = select(mask,upper.z,vfloat(neg_inf)); + return BBox3fa(Vec3fa(reduce_min(lower.x),reduce_min(lower.y),reduce_min(lower.z)), + Vec3fa(reduce_max(upper.x),reduce_max(upper.y),reduce_max(upper.z))); + } + + /* Calculate the linear bounds of the primitive */ + __forceinline LBBox3fa linearBounds() const { + return LBBox3fa(bounds0(),bounds1()); + } + + /* Fill triangle from triangle list */ + __forceinline LBBox3fa fillMB(const PrimRef* prims, size_t& begin, size_t end, Scene* scene, size_t itime) + { + vuint vgeomID = -1, vprimID = -1; + Vec3vf va0 = zero, vb0 = zero, vc0 = zero; + Vec3vf va1 = zero, vb1 = zero, vc1 = zero; + + BBox3fa bounds0 = empty; + BBox3fa bounds1 = empty; + + for (size_t i=0; iget(geomID); + const TriangleMesh::Triangle& tri = mesh->triangle(primID); + const Vec3fa& a0 = mesh->vertex(tri.v[0],itime+0); bounds0.extend(a0); + const Vec3fa& a1 = mesh->vertex(tri.v[0],itime+1); bounds1.extend(a1); + const Vec3fa& b0 = mesh->vertex(tri.v[1],itime+0); bounds0.extend(b0); + const Vec3fa& b1 = mesh->vertex(tri.v[1],itime+1); bounds1.extend(b1); + const Vec3fa& c0 = mesh->vertex(tri.v[2],itime+0); bounds0.extend(c0); + const Vec3fa& c1 = mesh->vertex(tri.v[2],itime+1); bounds1.extend(c1); + vgeomID [i] = geomID; + vprimID [i] = primID; + va0.x[i] = a0.x; va0.y[i] = a0.y; va0.z[i] = a0.z; + va1.x[i] = a1.x; va1.y[i] = a1.y; va1.z[i] = a1.z; + vb0.x[i] = b0.x; vb0.y[i] = b0.y; vb0.z[i] = b0.z; + vb1.x[i] = b1.x; vb1.y[i] = b1.y; vb1.z[i] = b1.z; + vc0.x[i] = c0.x; vc0.y[i] = c0.y; vc0.z[i] = c0.z; + vc1.x[i] = c1.x; vc1.y[i] = c1.y; vc1.z[i] = c1.z; + } + new (this) TriangleMvMB(va0,va1,vb0,vb1,vc0,vc1,vgeomID,vprimID); + return LBBox3fa(bounds0,bounds1); + } + + /* Fill triangle from triangle list */ + __forceinline LBBox3fa fillMB(const PrimRefMB* prims, size_t& begin, size_t end, Scene* scene, const BBox1f time_range) + { + vuint vgeomID = -1, vprimID = -1; + Vec3vf va0 = zero, vb0 = zero, vc0 = zero; + Vec3vf va1 = zero, vb1 = zero, vc1 = zero; + + LBBox3fa allBounds = empty; + for (size_t i=0; iget(geomID); + const range itime_range = mesh->timeSegmentRange(time_range); + assert(itime_range.size() == 1); + const int ilower = itime_range.begin(); + const TriangleMesh::Triangle& tri = mesh->triangle(primID); + allBounds.extend(mesh->linearBounds(primID, time_range)); + const Vec3fa& a0 = mesh->vertex(tri.v[0],ilower+0); + const Vec3fa& a1 = mesh->vertex(tri.v[0],ilower+1); + const Vec3fa& b0 = mesh->vertex(tri.v[1],ilower+0); + const Vec3fa& b1 = mesh->vertex(tri.v[1],ilower+1); + const Vec3fa& c0 = mesh->vertex(tri.v[2],ilower+0); + const Vec3fa& c1 = mesh->vertex(tri.v[2],ilower+1); + const BBox1f time_range_v(mesh->timeStep(ilower+0),mesh->timeStep(ilower+1)); + auto a01 = globalLinear(std::make_pair(a0,a1),time_range_v); + auto b01 = globalLinear(std::make_pair(b0,b1),time_range_v); + auto c01 = globalLinear(std::make_pair(c0,c1),time_range_v); + vgeomID [i] = geomID; + vprimID [i] = primID; + va0.x[i] = a01.first .x; va0.y[i] = a01.first .y; va0.z[i] = a01.first .z; + va1.x[i] = a01.second.x; va1.y[i] = a01.second.y; va1.z[i] = a01.second.z; + vb0.x[i] = b01.first .x; vb0.y[i] = b01.first .y; vb0.z[i] = b01.first .z; + vb1.x[i] = b01.second.x; vb1.y[i] = b01.second.y; vb1.z[i] = b01.second.z; + vc0.x[i] = c01.first .x; vc0.y[i] = c01.first .y; vc0.z[i] = c01.first .z; + vc1.x[i] = c01.second.x; vc1.y[i] = c01.second.y; vc1.z[i] = c01.second.z; + } + new (this) TriangleMvMB(va0,va1,vb0,vb1,vc0,vc1,vgeomID,vprimID); + return allBounds; + } + + public: + Vec3vf v0; // 1st vertex of the triangles + Vec3vf v1; // 2nd vertex of the triangles + Vec3vf v2; // 3rd vertex of the triangles + Vec3vf dv0; // difference vector between time steps t0 and t1 for first vertex + Vec3vf dv1; // difference vector between time steps t0 and t1 for second vertex + Vec3vf dv2; // difference vector between time steps t0 and t1 for third vertex + private: + vuint geomIDs; // geometry ID + vuint primIDs; // primitive ID + }; + + template + typename TriangleMvMB::Type TriangleMvMB::type; + + typedef TriangleMvMB<4> Triangle4vMB; +} diff --git a/thirdparty/embree/kernels/geometry/trianglev_mb_intersector.h b/thirdparty/embree/kernels/geometry/trianglev_mb_intersector.h new file mode 100644 index 000000000000..35a260d826e9 --- /dev/null +++ b/thirdparty/embree/kernels/geometry/trianglev_mb_intersector.h @@ -0,0 +1,211 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "triangle.h" +#include "intersector_epilog.h" + +namespace embree +{ + namespace isa + { + /*! Intersects M motion blur triangles with 1 ray */ + template + struct TriangleMvMBIntersector1Moeller + { + typedef TriangleMvMB Primitive; + typedef MoellerTrumboreIntersector1 Precalculations; + + /*! Intersect a ray with the M triangles and updates the hit. */ + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const TriangleMvMB& tri) + { + STAT3(normal.trav_prims,1,1,1); + const Vec3vf time(ray.time()); + const Vec3vf v0 = madd(time,Vec3vf(tri.dv0),Vec3vf(tri.v0)); + const Vec3vf v1 = madd(time,Vec3vf(tri.dv1),Vec3vf(tri.v1)); + const Vec3vf v2 = madd(time,Vec3vf(tri.dv2),Vec3vf(tri.v2)); + pre.intersect(ray,v0,v1,v2,Intersect1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + /*! Test if the ray is occluded by one of M triangles. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const TriangleMvMB& tri) + { + STAT3(shadow.trav_prims,1,1,1); + const Vec3vf time(ray.time()); + const Vec3vf v0 = madd(time,Vec3vf(tri.dv0),Vec3vf(tri.v0)); + const Vec3vf v1 = madd(time,Vec3vf(tri.dv1),Vec3vf(tri.v1)); + const Vec3vf v2 = madd(time,Vec3vf(tri.dv2),Vec3vf(tri.v2)); + return pre.intersect(ray,v0,v1,v2,Occluded1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& tri) + { + return PrimitivePointQuery1::pointQuery(query, context, tri); + } + }; + + /*! Intersects M motion blur triangles with K rays. */ + template + struct TriangleMvMBIntersectorKMoeller + { + typedef TriangleMvMB Primitive; + typedef MoellerTrumboreIntersectorK Precalculations; + + /*! Intersects K rays with M triangles. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const TriangleMvMB& tri) + { + for (size_t i=0; i::max_size(); i++) + { + if (!tri.valid(i)) break; + STAT3(normal.trav_prims,1,popcnt(valid_i),K); + const Vec3vf time(ray.time()); + const Vec3vf v0 = madd(time,broadcast>(tri.dv0,i),broadcast>(tri.v0,i)); + const Vec3vf v1 = madd(time,broadcast>(tri.dv1,i),broadcast>(tri.v1,i)); + const Vec3vf v2 = madd(time,broadcast>(tri.dv2,i),broadcast>(tri.v2,i)); + pre.intersectK(valid_i,ray,v0,v1,v2,IntersectKEpilogM(ray,context,tri.geomID(),tri.primID(),i)); + } + } + + /*! Test for K rays if they are occluded by any of the M triangles. */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const TriangleMvMB& tri) + { + vbool valid0 = valid_i; + + for (size_t i=0; i::max_size(); i++) + { + if (!tri.valid(i)) break; + STAT3(shadow.trav_prims,1,popcnt(valid0),K); + const Vec3vf time(ray.time()); + const Vec3vf v0 = madd(time,broadcast>(tri.dv0,i),broadcast>(tri.v0,i)); + const Vec3vf v1 = madd(time,broadcast>(tri.dv1,i),broadcast>(tri.v1,i)); + const Vec3vf v2 = madd(time,broadcast>(tri.dv2,i),broadcast>(tri.v2,i)); + pre.intersectK(valid0,ray,v0,v1,v2,OccludedKEpilogM(valid0,ray,context,tri.geomID(),tri.primID(),i)); + if (none(valid0)) break; + } + return !valid0; + } + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const TriangleMvMB& tri) + { + STAT3(normal.trav_prims,1,1,1); + const Vec3vf time(ray.time()[k]); + const Vec3vf v0 = madd(time,Vec3vf(tri.dv0),Vec3vf(tri.v0)); + const Vec3vf v1 = madd(time,Vec3vf(tri.dv1),Vec3vf(tri.v1)); + const Vec3vf v2 = madd(time,Vec3vf(tri.dv2),Vec3vf(tri.v2)); + pre.intersect(ray,k,v0,v1,v2,Intersect1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const TriangleMvMB& tri) + { + STAT3(shadow.trav_prims,1,1,1); + const Vec3vf time(ray.time()[k]); + const Vec3vf v0 = madd(time,Vec3vf(tri.dv0),Vec3vf(tri.v0)); + const Vec3vf v1 = madd(time,Vec3vf(tri.dv1),Vec3vf(tri.v1)); + const Vec3vf v2 = madd(time,Vec3vf(tri.dv2),Vec3vf(tri.v2)); + return pre.intersect(ray,k,v0,v1,v2,Occluded1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + }; + + /*! Intersects M motion blur triangles with 1 ray */ + template + struct TriangleMvMBIntersector1Pluecker + { + typedef TriangleMvMB Primitive; + typedef PlueckerIntersector1 Precalculations; + + /*! Intersect a ray with the M triangles and updates the hit. */ + static __forceinline void intersect(const Precalculations& pre, RayHit& ray, IntersectContext* context, const TriangleMvMB& tri) + { + STAT3(normal.trav_prims,1,1,1); + const Vec3vf time(ray.time()); + const Vec3vf v0 = madd(time,Vec3vf(tri.dv0),Vec3vf(tri.v0)); + const Vec3vf v1 = madd(time,Vec3vf(tri.dv1),Vec3vf(tri.v1)); + const Vec3vf v2 = madd(time,Vec3vf(tri.dv2),Vec3vf(tri.v2)); + pre.intersect(ray,v0,v1,v2,UVIdentity(),Intersect1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + /*! Test if the ray is occluded by one of M triangles. */ + static __forceinline bool occluded(const Precalculations& pre, Ray& ray, IntersectContext* context, const TriangleMvMB& tri) + { + STAT3(shadow.trav_prims,1,1,1); + const Vec3vf time(ray.time()); + const Vec3vf v0 = madd(time,Vec3vf(tri.dv0),Vec3vf(tri.v0)); + const Vec3vf v1 = madd(time,Vec3vf(tri.dv1),Vec3vf(tri.v1)); + const Vec3vf v2 = madd(time,Vec3vf(tri.dv2),Vec3vf(tri.v2)); + return pre.intersect(ray,v0,v1,v2,UVIdentity(),Occluded1EpilogM(ray,context,tri.geomID(),tri.primID())); + } + + static __forceinline bool pointQuery(PointQuery* query, PointQueryContext* context, const Primitive& tri) + { + return PrimitivePointQuery1::pointQuery(query, context, tri); + } + }; + + /*! Intersects M motion blur triangles with K rays. */ + template + struct TriangleMvMBIntersectorKPluecker + { + typedef TriangleMvMB Primitive; + typedef PlueckerIntersectorK Precalculations; + + /*! Intersects K rays with M triangles. */ + static __forceinline void intersect(const vbool& valid_i, Precalculations& pre, RayHitK& ray, IntersectContext* context, const TriangleMvMB& tri) + { + for (size_t i=0; i::max_size(); i++) + { + if (!tri.valid(i)) break; + STAT3(normal.trav_prims,1,popcnt(valid_i),K); + const Vec3vf time(ray.time()); + const Vec3vf v0 = madd(time,broadcast>(tri.dv0,i),broadcast>(tri.v0,i)); + const Vec3vf v1 = madd(time,broadcast>(tri.dv1,i),broadcast>(tri.v1,i)); + const Vec3vf v2 = madd(time,broadcast>(tri.dv2,i),broadcast>(tri.v2,i)); + pre.intersectK(valid_i,ray,v0,v1,v2,UVIdentity(),IntersectKEpilogM(ray,context,tri.geomID(),tri.primID(),i)); + } + } + + /*! Test for K rays if they are occluded by any of the M triangles. */ + static __forceinline vbool occluded(const vbool& valid_i, Precalculations& pre, RayK& ray, IntersectContext* context, const TriangleMvMB& tri) + { + vbool valid0 = valid_i; + + for (size_t i=0; i::max_size(); i++) + { + if (!tri.valid(i)) break; + STAT3(shadow.trav_prims,1,popcnt(valid0),K); + const Vec3vf time(ray.time()); + const Vec3vf v0 = madd(time,broadcast>(tri.dv0,i),broadcast>(tri.v0,i)); + const Vec3vf v1 = madd(time,broadcast>(tri.dv1,i),broadcast>(tri.v1,i)); + const Vec3vf v2 = madd(time,broadcast>(tri.dv2,i),broadcast>(tri.v2,i)); + pre.intersectK(valid0,ray,v0,v1,v2,UVIdentity(),OccludedKEpilogM(valid0,ray,context,tri.geomID(),tri.primID(),i)); + if (none(valid0)) break; + } + return !valid0; + } + + /*! Intersect a ray with M triangles and updates the hit. */ + static __forceinline void intersect(Precalculations& pre, RayHitK& ray, size_t k, IntersectContext* context, const TriangleMvMB& tri) + { + STAT3(normal.trav_prims,1,1,1); + const Vec3vf time(ray.time()[k]); + const Vec3vf v0 = madd(time,Vec3vf(tri.dv0),Vec3vf(tri.v0)); + const Vec3vf v1 = madd(time,Vec3vf(tri.dv1),Vec3vf(tri.v1)); + const Vec3vf v2 = madd(time,Vec3vf(tri.dv2),Vec3vf(tri.v2)); + pre.intersect(ray,k,v0,v1,v2,UVIdentity(),Intersect1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + + /*! Test if the ray is occluded by one of the M triangles. */ + static __forceinline bool occluded(Precalculations& pre, RayK& ray, size_t k, IntersectContext* context, const TriangleMvMB& tri) + { + STAT3(shadow.trav_prims,1,1,1); + const Vec3vf time(ray.time()[k]); + const Vec3vf v0 = madd(time,Vec3vf(tri.dv0),Vec3vf(tri.v0)); + const Vec3vf v1 = madd(time,Vec3vf(tri.dv1),Vec3vf(tri.v1)); + const Vec3vf v2 = madd(time,Vec3vf(tri.dv2),Vec3vf(tri.v2)); + return pre.intersect(ray,k,v0,v1,v2,UVIdentity(),Occluded1KEpilogM(ray,k,context,tri.geomID(),tri.primID())); + } + }; + } +} diff --git a/thirdparty/embree/kernels/hash.h b/thirdparty/embree/kernels/hash.h new file mode 100644 index 000000000000..f11598b7ab89 --- /dev/null +++ b/thirdparty/embree/kernels/hash.h @@ -0,0 +1,5 @@ + +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#define RTC_HASH "69bd4c272f1ed608494f233ecfff3feec516880b" diff --git a/thirdparty/embree/kernels/subdiv/bezier_curve.h b/thirdparty/embree/kernels/subdiv/bezier_curve.h new file mode 100644 index 000000000000..c0e78820f8ec --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/bezier_curve.h @@ -0,0 +1,669 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/default.h" +#include "../common/scene_curves.h" + +namespace embree +{ + class BezierBasis + { + public: + + template + static __forceinline Vec4 eval(const T& u) + { + const T t1 = u; + const T t0 = 1.0f-t1; + const T B0 = t0 * t0 * t0; + const T B1 = 3.0f * t1 * (t0 * t0); + const T B2 = 3.0f * (t1 * t1) * t0; + const T B3 = t1 * t1 * t1; + return Vec4(B0,B1,B2,B3); + } + + template + static __forceinline Vec4 derivative(const T& u) + { + const T t1 = u; + const T t0 = 1.0f-t1; + const T B0 = -(t0*t0); + const T B1 = madd(-2.0f,t0*t1,t0*t0); + const T B2 = msub(+2.0f,t0*t1,t1*t1); + const T B3 = +(t1*t1); + return T(3.0f)*Vec4(B0,B1,B2,B3); + } + + template + static __forceinline Vec4 derivative2(const T& u) + { + const T t1 = u; + const T t0 = 1.0f-t1; + const T B0 = t0; + const T B1 = madd(-2.0f,t0,t1); + const T B2 = madd(-2.0f,t1,t0); + const T B3 = t1; + return T(6.0f)*Vec4(B0,B1,B2,B3); + } + }; + + struct PrecomputedBezierBasis + { + enum { N = 16 }; + public: + PrecomputedBezierBasis() {} + PrecomputedBezierBasis(int shift); + + /* basis for bezier evaluation */ + public: + float c0[N+1][N+1]; + float c1[N+1][N+1]; + float c2[N+1][N+1]; + float c3[N+1][N+1]; + + /* basis for bezier derivative evaluation */ + public: + float d0[N+1][N+1]; + float d1[N+1][N+1]; + float d2[N+1][N+1]; + float d3[N+1][N+1]; + }; + extern PrecomputedBezierBasis bezier_basis0; + extern PrecomputedBezierBasis bezier_basis1; + + + template + struct LinearBezierCurve + { + V v0,v1; + + __forceinline LinearBezierCurve () {} + + __forceinline LinearBezierCurve (const LinearBezierCurve& other) + : v0(other.v0), v1(other.v1) {} + + __forceinline LinearBezierCurve& operator= (const LinearBezierCurve& other) { + v0 = other.v0; v1 = other.v1; return *this; + } + + __forceinline LinearBezierCurve (const V& v0, const V& v1) + : v0(v0), v1(v1) {} + + __forceinline V begin() const { return v0; } + __forceinline V end () const { return v1; } + + bool hasRoot() const; + + friend embree_ostream operator<<(embree_ostream cout, const LinearBezierCurve& a) { + return cout << "LinearBezierCurve (" << a.v0 << ", " << a.v1 << ")"; + } + }; + + template<> __forceinline bool LinearBezierCurve::hasRoot() const { + return numRoots(v0,v1); + } + + template + struct QuadraticBezierCurve + { + V v0,v1,v2; + + __forceinline QuadraticBezierCurve () {} + + __forceinline QuadraticBezierCurve (const QuadraticBezierCurve& other) + : v0(other.v0), v1(other.v1), v2(other.v2) {} + + __forceinline QuadraticBezierCurve& operator= (const QuadraticBezierCurve& other) { + v0 = other.v0; v1 = other.v1; v2 = other.v2; return *this; + } + + __forceinline QuadraticBezierCurve (const V& v0, const V& v1, const V& v2) + : v0(v0), v1(v1), v2(v2) {} + + __forceinline V begin() const { return v0; } + __forceinline V end () const { return v2; } + + __forceinline V interval() const { + return merge(v0,v1,v2); + } + + __forceinline BBox bounds() const { + return merge(BBox(v0),BBox(v1),BBox(v2)); + } + + friend embree_ostream operator<<(embree_ostream cout, const QuadraticBezierCurve& a) { + return cout << "QuadraticBezierCurve ( (" << a.u.lower << ", " << a.u.upper << "), " << a.v0 << ", " << a.v1 << ", " << a.v2 << ")"; + } + }; + + + typedef QuadraticBezierCurve QuadraticBezierCurve1f; + typedef QuadraticBezierCurve QuadraticBezierCurve2fa; + typedef QuadraticBezierCurve QuadraticBezierCurve3fa; + + template + struct CubicBezierCurve + { + Vertex v0,v1,v2,v3; + + __forceinline CubicBezierCurve() {} + + template + __forceinline CubicBezierCurve (const CubicBezierCurve& other) + : v0(other.v0), v1(other.v1), v2(other.v2), v3(other.v3) {} + + __forceinline CubicBezierCurve& operator= (const CubicBezierCurve& other) { + v0 = other.v0; v1 = other.v1; v2 = other.v2; v3 = other.v3; return *this; + } + + __forceinline CubicBezierCurve(const Vertex& v0, const Vertex& v1, const Vertex& v2, const Vertex& v3) + : v0(v0), v1(v1), v2(v2), v3(v3) {} + + __forceinline Vertex begin() const { + return v0; + } + + __forceinline Vertex end() const { + return v3; + } + + __forceinline Vertex center() const { + return 0.25f*(v0+v1+v2+v3); + } + + __forceinline Vertex begin_direction() const { + return v1-v0; + } + + __forceinline Vertex end_direction() const { + return v3-v2; + } + + __forceinline CubicBezierCurve xfm(const Vertex& dx) const { + return CubicBezierCurve(dot(v0,dx),dot(v1,dx),dot(v2,dx),dot(v3,dx)); + } + + __forceinline CubicBezierCurve vxfm(const Vertex& dx) const { + return CubicBezierCurve(dot(v0,dx),dot(v1,dx),dot(v2,dx),dot(v3,dx)); + } + + __forceinline CubicBezierCurve xfm(const Vertex& dx, const Vertex& p) const { + return CubicBezierCurve(dot(v0-p,dx),dot(v1-p,dx),dot(v2-p,dx),dot(v3-p,dx)); + } + + __forceinline CubicBezierCurve xfm(const LinearSpace3fa& space) const + { + const Vec3fa q0 = xfmVector(space,v0); + const Vec3fa q1 = xfmVector(space,v1); + const Vec3fa q2 = xfmVector(space,v2); + const Vec3fa q3 = xfmVector(space,v3); + return CubicBezierCurve(q0,q1,q2,q3); + } + + __forceinline CubicBezierCurve xfm(const LinearSpace3fa& space, const Vec3fa& p) const + { + const Vec3fa q0 = xfmVector(space,v0-p); + const Vec3fa q1 = xfmVector(space,v1-p); + const Vec3fa q2 = xfmVector(space,v2-p); + const Vec3fa q3 = xfmVector(space,v3-p); + return CubicBezierCurve(q0,q1,q2,q3); + } + + __forceinline CubicBezierCurve xfm_pr(const LinearSpace3fa& space, const Vec3fa& p) const + { + const Vec3ff q0(xfmVector(space,(Vec3fa)v0-p), v0.w); + const Vec3ff q1(xfmVector(space,(Vec3fa)v1-p), v1.w); + const Vec3ff q2(xfmVector(space,(Vec3fa)v2-p), v2.w); + const Vec3ff q3(xfmVector(space,(Vec3fa)v3-p), v3.w); + return CubicBezierCurve(q0,q1,q2,q3); + } + + __forceinline CubicBezierCurve xfm(const LinearSpace3fa& space, const Vec3fa& p, const float s) const + { + const Vec3fa q0 = xfmVector(space,s*(v0-p)); + const Vec3fa q1 = xfmVector(space,s*(v1-p)); + const Vec3fa q2 = xfmVector(space,s*(v2-p)); + const Vec3fa q3 = xfmVector(space,s*(v3-p)); + return CubicBezierCurve(q0,q1,q2,q3); + } + + __forceinline int maxRoots() const; + + __forceinline BBox bounds() const { + return merge(BBox(v0),BBox(v1),BBox(v2),BBox(v3)); + } + + __forceinline friend CubicBezierCurve operator +( const CubicBezierCurve& a, const CubicBezierCurve& b ) { + return CubicBezierCurve(a.v0+b.v0,a.v1+b.v1,a.v2+b.v2,a.v3+b.v3); + } + + __forceinline friend CubicBezierCurve operator -( const CubicBezierCurve& a, const CubicBezierCurve& b ) { + return CubicBezierCurve(a.v0-b.v0,a.v1-b.v1,a.v2-b.v2,a.v3-b.v3); + } + + __forceinline friend CubicBezierCurve operator -( const CubicBezierCurve& a, const Vertex& b ) { + return CubicBezierCurve(a.v0-b,a.v1-b,a.v2-b,a.v3-b); + } + + __forceinline friend CubicBezierCurve operator *( const Vertex& a, const CubicBezierCurve& b ) { + return CubicBezierCurve(a*b.v0,a*b.v1,a*b.v2,a*b.v3); + } + + __forceinline friend CubicBezierCurve cmadd( const Vertex& a, const CubicBezierCurve& b, const CubicBezierCurve& c) { + return CubicBezierCurve(madd(a,b.v0,c.v0),madd(a,b.v1,c.v1),madd(a,b.v2,c.v2),madd(a,b.v3,c.v3)); + } + + __forceinline friend CubicBezierCurve clerp ( const CubicBezierCurve& a, const CubicBezierCurve& b, const Vertex& t ) { + return cmadd((Vertex(1.0f)-t),a,t*b); + } + + __forceinline friend CubicBezierCurve merge ( const CubicBezierCurve& a, const CubicBezierCurve& b ) { + return CubicBezierCurve(merge(a.v0,b.v0),merge(a.v1,b.v1),merge(a.v2,b.v2),merge(a.v3,b.v3)); + } + + __forceinline void split(CubicBezierCurve& left, CubicBezierCurve& right, const float t = 0.5f) const + { + const Vertex p00 = v0; + const Vertex p01 = v1; + const Vertex p02 = v2; + const Vertex p03 = v3; + + const Vertex p10 = lerp(p00,p01,t); + const Vertex p11 = lerp(p01,p02,t); + const Vertex p12 = lerp(p02,p03,t); + const Vertex p20 = lerp(p10,p11,t); + const Vertex p21 = lerp(p11,p12,t); + const Vertex p30 = lerp(p20,p21,t); + + new (&left ) CubicBezierCurve(p00,p10,p20,p30); + new (&right) CubicBezierCurve(p30,p21,p12,p03); + } + + __forceinline CubicBezierCurve split() const + { + const float u0 = 0.0f, u1 = 1.0f; + const float dscale = (u1-u0)*(1.0f/(3.0f*(VSIZEX-1))); + const vfloatx vu0 = lerp(u0,u1,vfloatx(step)*(1.0f/(VSIZEX-1))); + Vec2vfx P0, dP0du; evalN(vu0,P0,dP0du); dP0du = dP0du * Vec2vfx(dscale); + const Vec2vfx P3 = shift_right_1(P0); + const Vec2vfx dP3du = shift_right_1(dP0du); + const Vec2vfx P1 = P0 + dP0du; + const Vec2vfx P2 = P3 - dP3du; + return CubicBezierCurve(P0,P1,P2,P3); + } + + __forceinline CubicBezierCurve split(const BBox1f& u) const + { + const float u0 = u.lower, u1 = u.upper; + const float dscale = (u1-u0)*(1.0f/(3.0f*(VSIZEX-1))); + const vfloatx vu0 = lerp(u0,u1,vfloatx(step)*(1.0f/(VSIZEX-1))); + Vec2vfx P0, dP0du; evalN(vu0,P0,dP0du); dP0du = dP0du * Vec2vfx(dscale); + const Vec2vfx P3 = shift_right_1(P0); + const Vec2vfx dP3du = shift_right_1(dP0du); + const Vec2vfx P1 = P0 + dP0du; + const Vec2vfx P2 = P3 - dP3du; + return CubicBezierCurve(P0,P1,P2,P3); + } + + __forceinline void eval(float t, Vertex& p, Vertex& dp) const + { + const Vertex p00 = v0; + const Vertex p01 = v1; + const Vertex p02 = v2; + const Vertex p03 = v3; + + const Vertex p10 = lerp(p00,p01,t); + const Vertex p11 = lerp(p01,p02,t); + const Vertex p12 = lerp(p02,p03,t); + const Vertex p20 = lerp(p10,p11,t); + const Vertex p21 = lerp(p11,p12,t); + const Vertex p30 = lerp(p20,p21,t); + + p = p30; + dp = Vertex(3.0f)*(p21-p20); + } + +#if 0 + __forceinline Vertex eval(float t) const + { + const Vertex p00 = v0; + const Vertex p01 = v1; + const Vertex p02 = v2; + const Vertex p03 = v3; + + const Vertex p10 = lerp(p00,p01,t); + const Vertex p11 = lerp(p01,p02,t); + const Vertex p12 = lerp(p02,p03,t); + const Vertex p20 = lerp(p10,p11,t); + const Vertex p21 = lerp(p11,p12,t); + const Vertex p30 = lerp(p20,p21,t); + + return p30; + } +#else + __forceinline Vertex eval(const float t) const + { + const Vec4 b = BezierBasis::eval(t); + return madd(b.x,v0,madd(b.y,v1,madd(b.z,v2,b.w*v3))); + } +#endif + + __forceinline Vertex eval_dt(float t) const + { + const Vertex p00 = v1-v0; + const Vertex p01 = v2-v1; + const Vertex p02 = v3-v2; + const Vertex p10 = lerp(p00,p01,t); + const Vertex p11 = lerp(p01,p02,t); + const Vertex p20 = lerp(p10,p11,t); + return Vertex(3.0f)*p20; + } + + __forceinline Vertex eval_du(const float t) const + { + const Vec4 b = BezierBasis::derivative(t); + return madd(b.x,v0,madd(b.y,v1,madd(b.z,v2,b.w*v3))); + } + + __forceinline Vertex eval_dudu(const float t) const + { + const Vec4 b = BezierBasis::derivative2(t); + return madd(b.x,v0,madd(b.y,v1,madd(b.z,v2,b.w*v3))); + } + + __forceinline void evalN(const vfloatx& t, Vec2vfx& p, Vec2vfx& dp) const + { + const Vec2vfx p00 = v0; + const Vec2vfx p01 = v1; + const Vec2vfx p02 = v2; + const Vec2vfx p03 = v3; + + const Vec2vfx p10 = lerp(p00,p01,t); + const Vec2vfx p11 = lerp(p01,p02,t); + const Vec2vfx p12 = lerp(p02,p03,t); + + const Vec2vfx p20 = lerp(p10,p11,t); + const Vec2vfx p21 = lerp(p11,p12,t); + + const Vec2vfx p30 = lerp(p20,p21,t); + + p = p30; + dp = vfloatx(3.0f)*(p21-p20); + } + + __forceinline void eval(const float t, Vertex& p, Vertex& dp, Vertex& ddp) const + { + const Vertex p00 = v0; + const Vertex p01 = v1; + const Vertex p02 = v2; + const Vertex p03 = v3; + const Vertex p10 = lerp(p00,p01,t); + const Vertex p11 = lerp(p01,p02,t); + const Vertex p12 = lerp(p02,p03,t); + const Vertex p20 = lerp(p10,p11,t); + const Vertex p21 = lerp(p11,p12,t); + const Vertex p30 = lerp(p20,p21,t); + p = p30; + dp = 3.0f*(p21-p20); + ddp = eval_dudu(t); + } + + __forceinline CubicBezierCurve clip(const Interval1f& u1) const + { + Vertex f0,df0; eval(u1.lower,f0,df0); + Vertex f1,df1; eval(u1.upper,f1,df1); + float s = u1.upper-u1.lower; + return CubicBezierCurve(f0,f0+s*(1.0f/3.0f)*df0,f1-s*(1.0f/3.0f)*df1,f1); + } + + __forceinline QuadraticBezierCurve derivative() const + { + const Vertex q0 = 3.0f*(v1-v0); + const Vertex q1 = 3.0f*(v2-v1); + const Vertex q2 = 3.0f*(v3-v2); + return QuadraticBezierCurve(q0,q1,q2); + } + + __forceinline BBox derivative_bounds(const Interval1f& u1) const + { + Vertex f0,df0; eval(u1.lower,f0,df0); + Vertex f3,df3; eval(u1.upper,f3,df3); + const float s = u1.upper-u1.lower; + const Vertex f1 = f0+s*(1.0f/3.0f)*df0; + const Vertex f2 = f3-s*(1.0f/3.0f)*df3; + const Vertex q0 = s*df0; + const Vertex q1 = 3.0f*(f2-f1); + const Vertex q2 = s*df3; + return merge(BBox(q0),BBox(q1),BBox(q2)); + } + + template + __forceinline Vec4vf veval(const vfloat& t) const + { + const Vec4vf b = BezierBasis::eval(t); + return madd(b.x, Vec4vf(v0), madd(b.y, Vec4vf(v1), madd(b.z, Vec4vf(v2), b.w * Vec4vf(v3)))); + } + + template + __forceinline Vec4vf veval_du(const vfloat& t) const + { + const Vec4vf b = BezierBasis::derivative(t); + return madd(b.x, Vec4vf(v0), madd(b.y, Vec4vf(v1), madd(b.z, Vec4vf(v2), b.w * Vec4vf(v3)))); + } + + template + __forceinline Vec4vf veval_dudu(const vfloat& t) const + { + const Vec4vf b = BezierBasis::derivative2(t); + return madd(b.x, Vec4vf(v0), madd(b.y, Vec4vf(v1), madd(b.z, Vec4vf(v2), b.w * Vec4vf(v3)))); + } + + template + __forceinline void veval(const vfloat& t, Vec4vf& p, Vec4vf& dp) const + { + const Vec4vf p00 = v0; + const Vec4vf p01 = v1; + const Vec4vf p02 = v2; + const Vec4vf p03 = v3; + + const Vec4vf p10 = lerp(p00,p01,t); + const Vec4vf p11 = lerp(p01,p02,t); + const Vec4vf p12 = lerp(p02,p03,t); + const Vec4vf p20 = lerp(p10,p11,t); + const Vec4vf p21 = lerp(p11,p12,t); + const Vec4vf p30 = lerp(p20,p21,t); + + p = p30; + dp = vfloat(3.0f)*(p21-p20); + } + + template> + __forceinline Vec eval0(const int ofs, const int size) const + { + assert(size <= PrecomputedBezierBasis::N); + assert(ofs <= size); + return madd(vfloat::loadu(&bezier_basis0.c0[size][ofs]), Vec(v0), + madd(vfloat::loadu(&bezier_basis0.c1[size][ofs]), Vec(v1), + madd(vfloat::loadu(&bezier_basis0.c2[size][ofs]), Vec(v2), + vfloat::loadu(&bezier_basis0.c3[size][ofs]) * Vec(v3)))); + } + + template> + __forceinline Vec eval1(const int ofs, const int size) const + { + assert(size <= PrecomputedBezierBasis::N); + assert(ofs <= size); + return madd(vfloat::loadu(&bezier_basis1.c0[size][ofs]), Vec(v0), + madd(vfloat::loadu(&bezier_basis1.c1[size][ofs]), Vec(v1), + madd(vfloat::loadu(&bezier_basis1.c2[size][ofs]), Vec(v2), + vfloat::loadu(&bezier_basis1.c3[size][ofs]) * Vec(v3)))); + } + + template> + __forceinline Vec derivative0(const int ofs, const int size) const + { + assert(size <= PrecomputedBezierBasis::N); + assert(ofs <= size); + return madd(vfloat::loadu(&bezier_basis0.d0[size][ofs]), Vec(v0), + madd(vfloat::loadu(&bezier_basis0.d1[size][ofs]), Vec(v1), + madd(vfloat::loadu(&bezier_basis0.d2[size][ofs]), Vec(v2), + vfloat::loadu(&bezier_basis0.d3[size][ofs]) * Vec(v3)))); + } + + template> + __forceinline Vec derivative1(const int ofs, const int size) const + { + assert(size <= PrecomputedBezierBasis::N); + assert(ofs <= size); + return madd(vfloat::loadu(&bezier_basis1.d0[size][ofs]), Vec(v0), + madd(vfloat::loadu(&bezier_basis1.d1[size][ofs]), Vec(v1), + madd(vfloat::loadu(&bezier_basis1.d2[size][ofs]), Vec(v2), + vfloat::loadu(&bezier_basis1.d3[size][ofs]) * Vec(v3)))); + } + + /* calculates bounds of bezier curve geometry */ + __forceinline BBox3fa accurateBounds() const + { + const int N = 7; + const float scale = 1.0f/(3.0f*(N-1)); + Vec3vfx pl(pos_inf), pu(neg_inf); + for (int i=0; i<=N; i+=VSIZEX) + { + vintx vi = vintx(i)+vintx(step); + vboolx valid = vi <= vintx(N); + const Vec3vfx p = eval0>(i,N); + const Vec3vfx dp = derivative0>(i,N); + const Vec3vfx pm = p-Vec3vfx(scale)*select(vi!=vintx(0),dp,Vec3vfx(zero)); + const Vec3vfx pp = p+Vec3vfx(scale)*select(vi!=vintx(N),dp,Vec3vfx(zero)); + pl = select(valid,min(pl,p,pm,pp),pl); // FIXME: use masked min + pu = select(valid,max(pu,p,pm,pp),pu); // FIXME: use masked min + } + const Vec3fa lower(reduce_min(pl.x),reduce_min(pl.y),reduce_min(pl.z)); + const Vec3fa upper(reduce_max(pu.x),reduce_max(pu.y),reduce_max(pu.z)); + return BBox3fa(lower,upper); + } + + /* calculates bounds of bezier curve geometry */ + __forceinline BBox3fa accurateRoundBounds() const + { + const int N = 7; + const float scale = 1.0f/(3.0f*(N-1)); + Vec4vfx pl(pos_inf), pu(neg_inf); + for (int i=0; i<=N; i+=VSIZEX) + { + vintx vi = vintx(i)+vintx(step); + vboolx valid = vi <= vintx(N); + const Vec4vfx p = eval0(i,N); + const Vec4vfx dp = derivative0(i,N); + const Vec4vfx pm = p-Vec4vfx(scale)*select(vi!=vintx(0),dp,Vec4vfx(zero)); + const Vec4vfx pp = p+Vec4vfx(scale)*select(vi!=vintx(N),dp,Vec4vfx(zero)); + pl = select(valid,min(pl,p,pm,pp),pl); // FIXME: use masked min + pu = select(valid,max(pu,p,pm,pp),pu); // FIXME: use masked min + } + const Vec3fa lower(reduce_min(pl.x),reduce_min(pl.y),reduce_min(pl.z)); + const Vec3fa upper(reduce_max(pu.x),reduce_max(pu.y),reduce_max(pu.z)); + const float r_min = reduce_min(pl.w); + const float r_max = reduce_max(pu.w); + const Vec3fa upper_r = Vec3fa(max(abs(r_min),abs(r_max))); + return enlarge(BBox3fa(lower,upper),upper_r); + } + + /* calculates bounds when tessellated into N line segments */ + __forceinline BBox3fa accurateFlatBounds(int N) const + { + if (likely(N == 4)) + { + const Vec4vf4 pi = eval0<4>(0,4); + const Vec3fa lower(reduce_min(pi.x),reduce_min(pi.y),reduce_min(pi.z)); + const Vec3fa upper(reduce_max(pi.x),reduce_max(pi.y),reduce_max(pi.z)); + const Vec3fa upper_r = Vec3fa(reduce_max(abs(pi.w))); + return enlarge(BBox3fa(min(lower,v3),max(upper,v3)),max(upper_r,Vec3fa(abs(v3.w)))); + } + else + { + Vec3vfx pl(pos_inf), pu(neg_inf); vfloatx ru(0.0f); + for (int i=0; i(i,N); + + pl.x = select(valid,min(pl.x,pi.x),pl.x); // FIXME: use masked min + pl.y = select(valid,min(pl.y,pi.y),pl.y); + pl.z = select(valid,min(pl.z,pi.z),pl.z); + + pu.x = select(valid,max(pu.x,pi.x),pu.x); // FIXME: use masked min + pu.y = select(valid,max(pu.y,pi.y),pu.y); + pu.z = select(valid,max(pu.z,pi.z),pu.z); + + ru = select(valid,max(ru,abs(pi.w)),ru); + } + const Vec3fa lower(reduce_min(pl.x),reduce_min(pl.y),reduce_min(pl.z)); + const Vec3fa upper(reduce_max(pu.x),reduce_max(pu.y),reduce_max(pu.z)); + const Vec3fa upper_r(reduce_max(ru)); + return enlarge(BBox3fa(min(lower,v3),max(upper,v3)),max(upper_r,Vec3fa(abs(v3.w)))); + } + } + + friend __forceinline embree_ostream operator<<(embree_ostream cout, const CubicBezierCurve& curve) { + return cout << "CubicBezierCurve { v0 = " << curve.v0 << ", v1 = " << curve.v1 << ", v2 = " << curve.v2 << ", v3 = " << curve.v3 << " }"; + } + }; + +#if defined(__AVX__) + template<> + __forceinline CubicBezierCurve CubicBezierCurve::clip(const Interval1f& u1) const + { + const vfloat8 p00 = vfloat8(v0); + const vfloat8 p01 = vfloat8(v1); + const vfloat8 p02 = vfloat8(v2); + const vfloat8 p03 = vfloat8(v3); + + const vfloat8 t(vfloat4(u1.lower),vfloat4(u1.upper)); + const vfloat8 p10 = lerp(p00,p01,t); + const vfloat8 p11 = lerp(p01,p02,t); + const vfloat8 p12 = lerp(p02,p03,t); + const vfloat8 p20 = lerp(p10,p11,t); + const vfloat8 p21 = lerp(p11,p12,t); + const vfloat8 p30 = lerp(p20,p21,t); + + const vfloat8 f01 = p30; + const vfloat8 df01 = vfloat8(3.0f)*(p21-p20); + + const vfloat4 f0 = extract4<0>(f01), f1 = extract4<1>(f01); + const vfloat4 df0 = extract4<0>(df01), df1 = extract4<1>(df01); + const float s = u1.upper-u1.lower; + return CubicBezierCurve(f0,f0+s*(1.0f/3.0f)*df0,f1-s*(1.0f/3.0f)*df1,f1); + } +#endif + + template using BezierCurveT = CubicBezierCurve; + + typedef CubicBezierCurve CubicBezierCurve1f; + typedef CubicBezierCurve CubicBezierCurve2fa; + typedef CubicBezierCurve CubicBezierCurve3fa; + typedef CubicBezierCurve BezierCurve3fa; + + template<> __forceinline int CubicBezierCurve::maxRoots() const + { + float eps = 1E-4f; + bool neg0 = v0 <= 0.0f; bool zero0 = fabs(v0) < eps; + bool neg1 = v1 <= 0.0f; bool zero1 = fabs(v1) < eps; + bool neg2 = v2 <= 0.0f; bool zero2 = fabs(v2) < eps; + bool neg3 = v3 <= 0.0f; bool zero3 = fabs(v3) < eps; + return (neg0 != neg1 || zero0) + (neg1 != neg2 || zero1) + (neg2 != neg3 || zero2 || zero3); + } + + template<> __forceinline int CubicBezierCurve::maxRoots() const { + return numRoots(v0,v1) + numRoots(v1,v2) + numRoots(v2,v3); + } + + __forceinline CubicBezierCurve enlargeRadiusToMinWidth(const IntersectContext* context, const CurveGeometry* geom, const Vec3fa& ray_org, const CubicBezierCurve& curve) + { + return CubicBezierCurve(enlargeRadiusToMinWidth(context,geom,ray_org,curve.v0), + enlargeRadiusToMinWidth(context,geom,ray_org,curve.v1), + enlargeRadiusToMinWidth(context,geom,ray_org,curve.v2), + enlargeRadiusToMinWidth(context,geom,ray_org,curve.v3)); + } +} diff --git a/thirdparty/embree/kernels/subdiv/bezier_patch.h b/thirdparty/embree/kernels/subdiv/bezier_patch.h new file mode 100644 index 000000000000..d87ed41ccb1f --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/bezier_patch.h @@ -0,0 +1,372 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "catmullclark_patch.h" +#include "bezier_curve.h" + +namespace embree +{ + template + static __forceinline T deCasteljau(const S& uu, const T& v0, const T& v1, const T& v2, const T& v3) + { + const T v0_1 = lerp(v0,v1,uu); + const T v1_1 = lerp(v1,v2,uu); + const T v2_1 = lerp(v2,v3,uu); + const T v0_2 = lerp(v0_1,v1_1,uu); + const T v1_2 = lerp(v1_1,v2_1,uu); + const T v0_3 = lerp(v0_2,v1_2,uu); + return v0_3; + } + + template + static __forceinline T deCasteljau_tangent(const S& uu, const T& v0, const T& v1, const T& v2, const T& v3) + { + const T v0_1 = lerp(v0,v1,uu); + const T v1_1 = lerp(v1,v2,uu); + const T v2_1 = lerp(v2,v3,uu); + const T v0_2 = lerp(v0_1,v1_1,uu); + const T v1_2 = lerp(v1_1,v2_1,uu); + return S(3.0f)*(v1_2-v0_2); + } + + template + __forceinline Vertex computeInnerBezierControlPoint(const Vertex v[4][4], const size_t y, const size_t x) { + return 1.0f / 36.0f * (16.0f * v[y][x] + 4.0f * (v[y-1][x] + v[y+1][x] + v[y][x-1] + v[y][x+1]) + (v[y-1][x-1] + v[y+1][x+1] + v[y-1][x+1] + v[y+1][x-1])); + } + + template + __forceinline Vertex computeTopEdgeBezierControlPoint(const Vertex v[4][4], const size_t y, const size_t x) { + return 1.0f / 18.0f * (8.0f * v[y][x] + 4.0f * v[y-1][x] + 2.0f * (v[y][x-1] + v[y][x+1]) + (v[y-1][x-1] + v[y-1][x+1])); + } + + template + __forceinline Vertex computeBottomEdgeBezierControlPoint(const Vertex v[4][4], const size_t y, const size_t x) { + return 1.0f / 18.0f * (8.0f * v[y][x] + 4.0f * v[y+1][x] + 2.0f * (v[y][x-1] + v[y][x+1]) + v[y+1][x-1] + v[y+1][x+1]); + } + + template + __forceinline Vertex computeLeftEdgeBezierControlPoint(const Vertex v[4][4], const size_t y, const size_t x) { + return 1.0f / 18.0f * (8.0f * v[y][x] + 4.0f * v[y][x-1] + 2.0f * (v[y-1][x] + v[y+1][x]) + v[y-1][x-1] + v[y+1][x-1]); + } + + template + __forceinline Vertex computeRightEdgeBezierControlPoint(const Vertex v[4][4], const size_t y, const size_t x) { + return 1.0f / 18.0f * (8.0f * v[y][x] + 4.0f * v[y][x+1] + 2.0f * (v[y-1][x] + v[y+1][x]) + v[y-1][x+1] + v[y+1][x+1]); + } + + template + __forceinline Vertex computeCornerBezierControlPoint(const Vertex v[4][4], const size_t y, const size_t x, const ssize_t delta_y, const ssize_t delta_x) + { + return 1.0f / 9.0f * (4.0f * v[y][x] + 2.0f * (v[y+delta_y][x] + v[y][x+delta_x]) + v[y+delta_y][x+delta_x]); + } + + template + class __aligned(64) BezierPatchT + { + public: + Vertex matrix[4][4]; + + public: + + __forceinline BezierPatchT() {} + + __forceinline BezierPatchT (const HalfEdge* edge, const char* vertices, size_t stride); + + __forceinline BezierPatchT(const CatmullClarkPatchT& patch); + + __forceinline BezierPatchT(const CatmullClarkPatchT& patch, + const BezierCurveT* border0, + const BezierCurveT* border1, + const BezierCurveT* border2, + const BezierCurveT* border3); + + __forceinline BezierPatchT(const BSplinePatchT& source) + { + /* compute inner bezier control points */ + matrix[0][0] = computeInnerBezierControlPoint(source.v,1,1); + matrix[0][3] = computeInnerBezierControlPoint(source.v,1,2); + matrix[3][3] = computeInnerBezierControlPoint(source.v,2,2); + matrix[3][0] = computeInnerBezierControlPoint(source.v,2,1); + + /* compute top edge control points */ + matrix[0][1] = computeRightEdgeBezierControlPoint(source.v,1,1); + matrix[0][2] = computeLeftEdgeBezierControlPoint(source.v,1,2); + + /* compute buttom edge control points */ + matrix[3][1] = computeRightEdgeBezierControlPoint(source.v,2,1); + matrix[3][2] = computeLeftEdgeBezierControlPoint(source.v,2,2); + + /* compute left edge control points */ + matrix[1][0] = computeBottomEdgeBezierControlPoint(source.v,1,1); + matrix[2][0] = computeTopEdgeBezierControlPoint(source.v,2,1); + + /* compute right edge control points */ + matrix[1][3] = computeBottomEdgeBezierControlPoint(source.v,1,2); + matrix[2][3] = computeTopEdgeBezierControlPoint(source.v,2,2); + + /* compute corner control points */ + matrix[1][1] = computeCornerBezierControlPoint(source.v,1,1, 1, 1); + matrix[1][2] = computeCornerBezierControlPoint(source.v,1,2, 1,-1); + matrix[2][2] = computeCornerBezierControlPoint(source.v,2,2,-1,-1); + matrix[2][1] = computeCornerBezierControlPoint(source.v,2,1,-1, 1); + } + + static __forceinline Vertex_t bilinear(const Vec4f Bu, const Vertex matrix[4][4], const Vec4f Bv) + { + const Vertex_t M0 = madd(Bu.x,matrix[0][0],madd(Bu.y,matrix[0][1],madd(Bu.z,matrix[0][2],Bu.w * matrix[0][3]))); + const Vertex_t M1 = madd(Bu.x,matrix[1][0],madd(Bu.y,matrix[1][1],madd(Bu.z,matrix[1][2],Bu.w * matrix[1][3]))); + const Vertex_t M2 = madd(Bu.x,matrix[2][0],madd(Bu.y,matrix[2][1],madd(Bu.z,matrix[2][2],Bu.w * matrix[2][3]))); + const Vertex_t M3 = madd(Bu.x,matrix[3][0],madd(Bu.y,matrix[3][1],madd(Bu.z,matrix[3][2],Bu.w * matrix[3][3]))); + return madd(Bv.x,M0,madd(Bv.y,M1,madd(Bv.z,M2,Bv.w*M3))); + } + + static __forceinline Vertex_t eval(const Vertex matrix[4][4], const float uu, const float vv) + { + const Vec4f Bu = BezierBasis::eval(uu); + const Vec4f Bv = BezierBasis::eval(vv); + return bilinear(Bu,matrix,Bv); + } + + static __forceinline Vertex_t eval_du(const Vertex matrix[4][4], const float uu, const float vv) + { + const Vec4f Bu = BezierBasis::derivative(uu); + const Vec4f Bv = BezierBasis::eval(vv); + return bilinear(Bu,matrix,Bv); + } + + static __forceinline Vertex_t eval_dv(const Vertex matrix[4][4], const float uu, const float vv) + { + const Vec4f Bu = BezierBasis::eval(uu); + const Vec4f Bv = BezierBasis::derivative(vv); + return bilinear(Bu,matrix,Bv); + } + + static __forceinline Vertex_t eval_dudu(const Vertex matrix[4][4], const float uu, const float vv) + { + const Vec4f Bu = BezierBasis::derivative2(uu); + const Vec4f Bv = BezierBasis::eval(vv); + return bilinear(Bu,matrix,Bv); + } + + static __forceinline Vertex_t eval_dvdv(const Vertex matrix[4][4], const float uu, const float vv) + { + const Vec4f Bu = BezierBasis::eval(uu); + const Vec4f Bv = BezierBasis::derivative2(vv); + return bilinear(Bu,matrix,Bv); + } + + static __forceinline Vertex_t eval_dudv(const Vertex matrix[4][4], const float uu, const float vv) + { + const Vec4f Bu = BezierBasis::derivative(uu); + const Vec4f Bv = BezierBasis::derivative(vv); + return bilinear(Bu,matrix,Bv); + } + + static __forceinline Vertex_t normal(const Vertex matrix[4][4], const float uu, const float vv) + { + const Vertex_t dPdu = eval_du(matrix,uu,vv); + const Vertex_t dPdv = eval_dv(matrix,uu,vv); + return cross(dPdu,dPdv); + } + + __forceinline Vertex_t normal(const float uu, const float vv) + { + const Vertex_t dPdu = eval_du(matrix,uu,vv); + const Vertex_t dPdv = eval_dv(matrix,uu,vv); + return cross(dPdu,dPdv); + } + + __forceinline Vertex_t eval(const float uu, const float vv) const { + return eval(matrix,uu,vv); + } + + __forceinline Vertex_t eval_du(const float uu, const float vv) const { + return eval_du(matrix,uu,vv); + } + + __forceinline Vertex_t eval_dv(const float uu, const float vv) const { + return eval_dv(matrix,uu,vv); + } + + __forceinline Vertex_t eval_dudu(const float uu, const float vv) const { + return eval_dudu(matrix,uu,vv); + } + + __forceinline Vertex_t eval_dvdv(const float uu, const float vv) const { + return eval_dvdv(matrix,uu,vv); + } + + __forceinline Vertex_t eval_dudv(const float uu, const float vv) const { + return eval_dudv(matrix,uu,vv); + } + + __forceinline void eval(const float u, const float v, Vertex* P, Vertex* dPdu, Vertex* dPdv, Vertex* ddPdudu, Vertex* ddPdvdv, Vertex* ddPdudv, const float dscale = 1.0f) const + { + if (P) { + *P = eval(u,v); + } + if (dPdu) { + assert(dPdu); *dPdu = eval_du(u,v)*dscale; + assert(dPdv); *dPdv = eval_dv(u,v)*dscale; + } + if (ddPdudu) { + assert(ddPdudu); *ddPdudu = eval_dudu(u,v)*sqr(dscale); + assert(ddPdvdv); *ddPdvdv = eval_dvdv(u,v)*sqr(dscale); + assert(ddPdudv); *ddPdudv = eval_dudv(u,v)*sqr(dscale); + } + } + + template + __forceinline vfloat eval(const size_t i, const vfloat& uu, const vfloat& vv, const Vec4& u_n, const Vec4& v_n) const + { + const vfloat curve0_x = v_n[0] * vfloat(matrix[0][0][i]) + v_n[1] * vfloat(matrix[1][0][i]) + v_n[2] * vfloat(matrix[2][0][i]) + v_n[3] * vfloat(matrix[3][0][i]); + const vfloat curve1_x = v_n[0] * vfloat(matrix[0][1][i]) + v_n[1] * vfloat(matrix[1][1][i]) + v_n[2] * vfloat(matrix[2][1][i]) + v_n[3] * vfloat(matrix[3][1][i]); + const vfloat curve2_x = v_n[0] * vfloat(matrix[0][2][i]) + v_n[1] * vfloat(matrix[1][2][i]) + v_n[2] * vfloat(matrix[2][2][i]) + v_n[3] * vfloat(matrix[3][2][i]); + const vfloat curve3_x = v_n[0] * vfloat(matrix[0][3][i]) + v_n[1] * vfloat(matrix[1][3][i]) + v_n[2] * vfloat(matrix[2][3][i]) + v_n[3] * vfloat(matrix[3][3][i]); + return u_n[0] * curve0_x + u_n[1] * curve1_x + u_n[2] * curve2_x + u_n[3] * curve3_x; + } + + template + __forceinline void eval(const vbool& valid, const vfloat& uu, const vfloat& vv, + float* P, float* dPdu, float* dPdv, float* ddPdudu, float* ddPdvdv, float* ddPdudv, + const float dscale, const size_t dstride, const size_t N) const + { + if (P) { + const Vec4 u_n = BezierBasis::eval(uu); + const Vec4 v_n = BezierBasis::eval(vv); + for (size_t i=0; i u_n = BezierBasis::derivative(uu); + const Vec4 v_n = BezierBasis::eval(vv); + for (size_t i=0; i u_n = BezierBasis::eval(uu); + const Vec4 v_n = BezierBasis::derivative(vv); + for (size_t i=0; i u_n = BezierBasis::derivative2(uu); + const Vec4 v_n = BezierBasis::eval(vv); + for (size_t i=0; i u_n = BezierBasis::eval(uu); + const Vec4 v_n = BezierBasis::derivative2(vv); + for (size_t i=0; i u_n = BezierBasis::derivative(uu); + const Vec4 v_n = BezierBasis::derivative(vv); + for (size_t i=0; i + static __forceinline Vec3 eval(const Vertex matrix[4][4], const T& uu, const T& vv) + { + const T one_minus_uu = 1.0f - uu; + const T one_minus_vv = 1.0f - vv; + + const T B0_u = one_minus_uu * one_minus_uu * one_minus_uu; + const T B0_v = one_minus_vv * one_minus_vv * one_minus_vv; + const T B1_u = 3.0f * (one_minus_uu * uu * one_minus_uu); + const T B1_v = 3.0f * (one_minus_vv * vv * one_minus_vv); + const T B2_u = 3.0f * (uu * one_minus_uu * uu); + const T B2_v = 3.0f * (vv * one_minus_vv * vv); + const T B3_u = uu * uu * uu; + const T B3_v = vv * vv * vv; + + const T x = + madd(B0_v,madd(B0_u,matrix[0][0].x,madd(B1_u,matrix[0][1].x,madd(B2_u,matrix[0][2].x,B3_u*matrix[0][3].x))), + madd(B1_v,madd(B0_u,matrix[1][0].x,madd(B1_u,matrix[1][1].x,madd(B2_u,matrix[1][2].x,B3_u*matrix[1][3].x))), + madd(B2_v,madd(B0_u,matrix[2][0].x,madd(B1_u,matrix[2][1].x,madd(B2_u,matrix[2][2].x,B3_u*matrix[2][3].x))), + B3_v*madd(B0_u,matrix[3][0].x,madd(B1_u,matrix[3][1].x,madd(B2_u,matrix[3][2].x,B3_u*matrix[3][3].x)))))); + + const T y = + madd(B0_v,madd(B0_u,matrix[0][0].y,madd(B1_u,matrix[0][1].y,madd(B2_u,matrix[0][2].y,B3_u*matrix[0][3].y))), + madd(B1_v,madd(B0_u,matrix[1][0].y,madd(B1_u,matrix[1][1].y,madd(B2_u,matrix[1][2].y,B3_u*matrix[1][3].y))), + madd(B2_v,madd(B0_u,matrix[2][0].y,madd(B1_u,matrix[2][1].y,madd(B2_u,matrix[2][2].y,B3_u*matrix[2][3].y))), + B3_v*madd(B0_u,matrix[3][0].y,madd(B1_u,matrix[3][1].y,madd(B2_u,matrix[3][2].y,B3_u*matrix[3][3].y)))))); + + const T z = + madd(B0_v,madd(B0_u,matrix[0][0].z,madd(B1_u,matrix[0][1].z,madd(B2_u,matrix[0][2].z,B3_u*matrix[0][3].z))), + madd(B1_v,madd(B0_u,matrix[1][0].z,madd(B1_u,matrix[1][1].z,madd(B2_u,matrix[1][2].z,B3_u*matrix[1][3].z))), + madd(B2_v,madd(B0_u,matrix[2][0].z,madd(B1_u,matrix[2][1].z,madd(B2_u,matrix[2][2].z,B3_u*matrix[2][3].z))), + B3_v*madd(B0_u,matrix[3][0].z,madd(B1_u,matrix[3][1].z,madd(B2_u,matrix[3][2].z,B3_u*matrix[3][3].z)))))); + + return Vec3(x,y,z); + } + + template + __forceinline Vec3 eval(const vfloat& uu, const vfloat& vv) const { + return eval(matrix,uu,vv); + } + + template + static __forceinline Vec3 normal(const Vertex matrix[4][4], const T& uu, const T& vv) + { + + const Vec3 matrix_00 = Vec3(matrix[0][0].x,matrix[0][0].y,matrix[0][0].z); + const Vec3 matrix_01 = Vec3(matrix[0][1].x,matrix[0][1].y,matrix[0][1].z); + const Vec3 matrix_02 = Vec3(matrix[0][2].x,matrix[0][2].y,matrix[0][2].z); + const Vec3 matrix_03 = Vec3(matrix[0][3].x,matrix[0][3].y,matrix[0][3].z); + + const Vec3 matrix_10 = Vec3(matrix[1][0].x,matrix[1][0].y,matrix[1][0].z); + const Vec3 matrix_11 = Vec3(matrix[1][1].x,matrix[1][1].y,matrix[1][1].z); + const Vec3 matrix_12 = Vec3(matrix[1][2].x,matrix[1][2].y,matrix[1][2].z); + const Vec3 matrix_13 = Vec3(matrix[1][3].x,matrix[1][3].y,matrix[1][3].z); + + const Vec3 matrix_20 = Vec3(matrix[2][0].x,matrix[2][0].y,matrix[2][0].z); + const Vec3 matrix_21 = Vec3(matrix[2][1].x,matrix[2][1].y,matrix[2][1].z); + const Vec3 matrix_22 = Vec3(matrix[2][2].x,matrix[2][2].y,matrix[2][2].z); + const Vec3 matrix_23 = Vec3(matrix[2][3].x,matrix[2][3].y,matrix[2][3].z); + + const Vec3 matrix_30 = Vec3(matrix[3][0].x,matrix[3][0].y,matrix[3][0].z); + const Vec3 matrix_31 = Vec3(matrix[3][1].x,matrix[3][1].y,matrix[3][1].z); + const Vec3 matrix_32 = Vec3(matrix[3][2].x,matrix[3][2].y,matrix[3][2].z); + const Vec3 matrix_33 = Vec3(matrix[3][3].x,matrix[3][3].y,matrix[3][3].z); + + /* tangentU */ + const Vec3 col0 = deCasteljau(vv, matrix_00, matrix_10, matrix_20, matrix_30); + const Vec3 col1 = deCasteljau(vv, matrix_01, matrix_11, matrix_21, matrix_31); + const Vec3 col2 = deCasteljau(vv, matrix_02, matrix_12, matrix_22, matrix_32); + const Vec3 col3 = deCasteljau(vv, matrix_03, matrix_13, matrix_23, matrix_33); + + const Vec3 tangentU = deCasteljau_tangent(uu, col0, col1, col2, col3); + + /* tangentV */ + const Vec3 row0 = deCasteljau(uu, matrix_00, matrix_01, matrix_02, matrix_03); + const Vec3 row1 = deCasteljau(uu, matrix_10, matrix_11, matrix_12, matrix_13); + const Vec3 row2 = deCasteljau(uu, matrix_20, matrix_21, matrix_22, matrix_23); + const Vec3 row3 = deCasteljau(uu, matrix_30, matrix_31, matrix_32, matrix_33); + + const Vec3 tangentV = deCasteljau_tangent(vv, row0, row1, row2, row3); + + /* normal = tangentU x tangentV */ + const Vec3 n = cross(tangentU,tangentV); + return n; + } + + template + __forceinline Vec3 normal(const vfloat& uu, const vfloat& vv) const { + return normal(matrix,uu,vv); + } + }; + + typedef BezierPatchT BezierPatch3fa; +} diff --git a/thirdparty/embree/kernels/subdiv/bilinear_patch.h b/thirdparty/embree/kernels/subdiv/bilinear_patch.h new file mode 100644 index 000000000000..35748754bd02 --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/bilinear_patch.h @@ -0,0 +1,191 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "catmullclark_patch.h" +#include "bezier_curve.h" + +namespace embree +{ + template + class __aligned(64) BilinearPatchT + { + typedef CatmullClark1RingT CatmullClarkRing; + typedef CatmullClarkPatchT CatmullClarkPatch; + + public: + Vertex v[4]; + + public: + + __forceinline BilinearPatchT () {} + + __forceinline BilinearPatchT (const HalfEdge* edge, const BufferView& vertices) { + init(edge,vertices.getPtr(),vertices.getStride()); + } + + __forceinline BilinearPatchT (const HalfEdge* edge, const char* vertices, size_t stride) { + init(edge,vertices,stride); + } + + __forceinline void init (const HalfEdge* edge, const char* vertices, size_t stride) + { + v[0] = Vertex::loadu(vertices+edge->getStartVertexIndex()*stride); edge = edge->next(); + v[1] = Vertex::loadu(vertices+edge->getStartVertexIndex()*stride); edge = edge->next(); + v[2] = Vertex::loadu(vertices+edge->getStartVertexIndex()*stride); edge = edge->next(); + v[3] = Vertex::loadu(vertices+edge->getStartVertexIndex()*stride); edge = edge->next(); + } + + __forceinline BilinearPatchT (const CatmullClarkPatch& patch) + { + v[0] = patch.ring[0].getLimitVertex(); + v[1] = patch.ring[1].getLimitVertex(); + v[2] = patch.ring[2].getLimitVertex(); + v[3] = patch.ring[3].getLimitVertex(); + } + + __forceinline BBox bounds() const + { + + BBox bounds (v[0]); + bounds.extend(v[1]); + bounds.extend(v[2]); + bounds.extend(v[3]); + return bounds; + } + + __forceinline Vertex eval(const float uu, const float vv) const { + return lerp(lerp(v[0],v[1],uu),lerp(v[3],v[2],uu),vv); + } + + __forceinline Vertex eval_du(const float uu, const float vv) const { + return lerp(v[1]-v[0],v[2]-v[3],vv); + } + + __forceinline Vertex eval_dv(const float uu, const float vv) const { + return lerp(v[3]-v[0],v[2]-v[1],uu); + } + + __forceinline Vertex eval_dudu(const float uu, const float vv) const { + return Vertex(zero); + } + + __forceinline Vertex eval_dvdv(const float uu, const float vv) const { + return Vertex(zero); + } + + __forceinline Vertex eval_dudv(const float uu, const float vv) const { + return (v[2]-v[3]) - (v[1]-v[0]); + } + + __forceinline Vertex normal(const float uu, const float vv) const { + return cross(eval_du(uu,vv),eval_dv(uu,vv)); + } + + __forceinline void eval(const float u, const float v, + Vertex* P, Vertex* dPdu, Vertex* dPdv, Vertex* ddPdudu, Vertex* ddPdvdv, Vertex* ddPdudv, + const float dscale = 1.0f) const + { + if (P) { + *P = eval(u,v); + } + if (dPdu) { + assert(dPdu); *dPdu = eval_du(u,v)*dscale; + assert(dPdv); *dPdv = eval_dv(u,v)*dscale; + } + if (ddPdudu) { + assert(ddPdudu); *ddPdudu = eval_dudu(u,v)*sqr(dscale); + assert(ddPdvdv); *ddPdvdv = eval_dvdv(u,v)*sqr(dscale); + assert(ddPdudv); *ddPdudv = eval_dudv(u,v)*sqr(dscale); + } + } + + template + __forceinline Vec3 eval(const vfloat& uu, const vfloat& vv) const + { + const vfloat x = lerp(lerp(v[0].x,v[1].x,uu),lerp(v[3].x,v[2].x,uu),vv); + const vfloat y = lerp(lerp(v[0].y,v[1].y,uu),lerp(v[3].y,v[2].y,uu),vv); + const vfloat z = lerp(lerp(v[0].z,v[1].z,uu),lerp(v[3].z,v[2].z,uu),vv); + return Vec3(x,y,z); + } + + template + __forceinline Vec3 eval_du(const vfloat& uu, const vfloat& vv) const + { + const vfloat x = lerp(v[1].x-v[0].x,v[2].x-v[3].x,vv); + const vfloat y = lerp(v[1].y-v[0].y,v[2].y-v[3].y,vv); + const vfloat z = lerp(v[1].z-v[0].z,v[2].z-v[3].z,vv); + return Vec3(x,y,z); + } + + template + __forceinline Vec3 eval_dv(const vfloat& uu, const vfloat& vv) const + { + const vfloat x = lerp(v[3].x-v[0].x,v[2].x-v[1].x,uu); + const vfloat y = lerp(v[3].y-v[0].y,v[2].y-v[1].y,uu); + const vfloat z = lerp(v[3].z-v[0].z,v[2].z-v[1].z,uu); + return Vec3(x,y,z); + } + + template + __forceinline Vec3 normal(const vfloat& uu, const vfloat& vv) const { + return cross(eval_du(uu,vv),eval_dv(uu,vv)); + } + + template + __forceinline vfloat eval(const size_t i, const vfloat& uu, const vfloat& vv) const { + return lerp(lerp(v[0][i],v[1][i],uu),lerp(v[3][i],v[2][i],uu),vv); + } + + template + __forceinline vfloat eval_du(const size_t i, const vfloat& uu, const vfloat& vv) const { + return lerp(v[1][i]-v[0][i],v[2][i]-v[3][i],vv); + } + + template + __forceinline vfloat eval_dv(const size_t i, const vfloat& uu, const vfloat& vv) const { + return lerp(v[3][i]-v[0][i],v[2][i]-v[1][i],uu); + } + + template + __forceinline vfloat eval_dudu(const size_t i, const vfloat& uu, const vfloat& vv) const { + return vfloat(zero); + } + + template + __forceinline vfloat eval_dvdv(const size_t i, const vfloat& uu, const vfloat& vv) const { + return vfloat(zero); + } + + template + __forceinline vfloat eval_dudv(const size_t i, const vfloat& uu, const vfloat& vv) const { + return (v[2][i]-v[3][i]) - (v[1][i]-v[0][i]); + } + + template + __forceinline void eval(const vbool& valid, const vfloat& uu, const vfloat& vv, + float* P, float* dPdu, float* dPdv, float* ddPdudu, float* ddPdvdv, float* ddPdudv, + const float dscale, const size_t dstride, const size_t N) const + { + if (P) { + for (size_t i=0; i BilinearPatch3fa; +} diff --git a/thirdparty/embree/kernels/subdiv/bspline_curve.h b/thirdparty/embree/kernels/subdiv/bspline_curve.h new file mode 100644 index 000000000000..a3256673282f --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/bspline_curve.h @@ -0,0 +1,319 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/default.h" +#include "bezier_curve.h" + +namespace embree +{ + class BSplineBasis + { + public: + + template + static __forceinline Vec4 eval(const T& u) + { + const T t = u; + const T s = T(1.0f) - u; + const T n0 = s*s*s; + const T n1 = (4.0f*(s*s*s)+(t*t*t)) + (12.0f*((s*t)*s) + 6.0f*((t*s)*t)); + const T n2 = (4.0f*(t*t*t)+(s*s*s)) + (12.0f*((t*s)*t) + 6.0f*((s*t)*s)); + const T n3 = t*t*t; + return T(1.0f/6.0f)*Vec4(n0,n1,n2,n3); + } + + template + static __forceinline Vec4 derivative(const T& u) + { + const T t = u; + const T s = 1.0f - u; + const T n0 = -s*s; + const T n1 = -t*t - 4.0f*(t*s); + const T n2 = s*s + 4.0f*(s*t); + const T n3 = t*t; + return T(0.5f)*Vec4(n0,n1,n2,n3); + } + + template + static __forceinline Vec4 derivative2(const T& u) + { + const T t = u; + const T s = 1.0f - u; + const T n0 = s; + const T n1 = t - 2.0f*s; + const T n2 = s - 2.0f*t; + const T n3 = t; + return Vec4(n0,n1,n2,n3); + } + }; + + struct PrecomputedBSplineBasis + { + enum { N = 16 }; + public: + PrecomputedBSplineBasis() {} + PrecomputedBSplineBasis(int shift); + + /* basis for bspline evaluation */ + public: + float c0[N+1][N+1]; + float c1[N+1][N+1]; + float c2[N+1][N+1]; + float c3[N+1][N+1]; + + /* basis for bspline derivative evaluation */ + public: + float d0[N+1][N+1]; + float d1[N+1][N+1]; + float d2[N+1][N+1]; + float d3[N+1][N+1]; + }; + extern PrecomputedBSplineBasis bspline_basis0; + extern PrecomputedBSplineBasis bspline_basis1; + + template + struct BSplineCurveT + { + Vertex v0,v1,v2,v3; + + __forceinline BSplineCurveT() {} + + __forceinline BSplineCurveT(const Vertex& v0, const Vertex& v1, const Vertex& v2, const Vertex& v3) + : v0(v0), v1(v1), v2(v2), v3(v3) {} + + __forceinline Vertex begin() const { + return madd(1.0f/6.0f,v0,madd(2.0f/3.0f,v1,1.0f/6.0f*v2)); + } + + __forceinline Vertex end() const { + return madd(1.0f/6.0f,v1,madd(2.0f/3.0f,v2,1.0f/6.0f*v3)); + } + + __forceinline Vertex center() const { + return 0.25f*(v0+v1+v2+v3); + } + + __forceinline BBox bounds() const { + return merge(BBox(v0),BBox(v1),BBox(v2),BBox(v3)); + } + + __forceinline friend BSplineCurveT operator -( const BSplineCurveT& a, const Vertex& b ) { + return BSplineCurveT(a.v0-b,a.v1-b,a.v2-b,a.v3-b); + } + + __forceinline BSplineCurveT xfm_pr(const LinearSpace3fa& space, const Vec3fa& p) const + { + const Vec3ff q0(xfmVector(space,(Vec3fa)v0-p), v0.w); + const Vec3ff q1(xfmVector(space,(Vec3fa)v1-p), v1.w); + const Vec3ff q2(xfmVector(space,(Vec3fa)v2-p), v2.w); + const Vec3ff q3(xfmVector(space,(Vec3fa)v3-p), v3.w); + return BSplineCurveT(q0,q1,q2,q3); + } + + __forceinline Vertex eval(const float t) const + { + const Vec4 b = BSplineBasis::eval(t); + return madd(b.x,v0,madd(b.y,v1,madd(b.z,v2,b.w*v3))); + } + + __forceinline Vertex eval_du(const float t) const + { + const Vec4 b = BSplineBasis::derivative(t); + return madd(b.x,v0,madd(b.y,v1,madd(b.z,v2,b.w*v3))); + } + + __forceinline Vertex eval_dudu(const float t) const + { + const Vec4 b = BSplineBasis::derivative2(t); + return madd(b.x,v0,madd(b.y,v1,madd(b.z,v2,b.w*v3))); + } + + __forceinline void eval(const float t, Vertex& p, Vertex& dp, Vertex& ddp) const + { + p = eval(t); + dp = eval_du(t); + ddp = eval_dudu(t); + } + + template + __forceinline Vec4vf veval(const vfloat& t) const + { + const Vec4vf b = BSplineBasis::eval(t); + return madd(b.x, Vec4vf(v0), madd(b.y, Vec4vf(v1), madd(b.z, Vec4vf(v2), b.w * Vec4vf(v3)))); + } + + template + __forceinline Vec4vf veval_du(const vfloat& t) const + { + const Vec4vf b = BSplineBasis::derivative(t); + return madd(b.x, Vec4vf(v0), madd(b.y, Vec4vf(v1), madd(b.z, Vec4vf(v2), b.w * Vec4vf(v3)))); + } + + template + __forceinline Vec4vf veval_dudu(const vfloat& t) const + { + const Vec4vf b = BSplineBasis::derivative2(t); + return madd(b.x, Vec4vf(v0), madd(b.y, Vec4vf(v1), madd(b.z, Vec4vf(v2), b.w * Vec4vf(v3)))); + } + + template + __forceinline void veval(const vfloat& t, Vec4vf& p, Vec4vf& dp) const + { + p = veval(t); + dp = veval_du(t); + } + + template + __forceinline Vec4vf eval0(const int ofs, const int size) const + { + assert(size <= PrecomputedBSplineBasis::N); + assert(ofs <= size); + return madd(vfloat::loadu(&bspline_basis0.c0[size][ofs]), Vec4vf(v0), + madd(vfloat::loadu(&bspline_basis0.c1[size][ofs]), Vec4vf(v1), + madd(vfloat::loadu(&bspline_basis0.c2[size][ofs]), Vec4vf(v2), + vfloat::loadu(&bspline_basis0.c3[size][ofs]) * Vec4vf(v3)))); + } + + template + __forceinline Vec4vf eval1(const int ofs, const int size) const + { + assert(size <= PrecomputedBSplineBasis::N); + assert(ofs <= size); + return madd(vfloat::loadu(&bspline_basis1.c0[size][ofs]), Vec4vf(v0), + madd(vfloat::loadu(&bspline_basis1.c1[size][ofs]), Vec4vf(v1), + madd(vfloat::loadu(&bspline_basis1.c2[size][ofs]), Vec4vf(v2), + vfloat::loadu(&bspline_basis1.c3[size][ofs]) * Vec4vf(v3)))); + } + + template + __forceinline Vec4vf derivative0(const int ofs, const int size) const + { + assert(size <= PrecomputedBSplineBasis::N); + assert(ofs <= size); + return madd(vfloat::loadu(&bspline_basis0.d0[size][ofs]), Vec4vf(v0), + madd(vfloat::loadu(&bspline_basis0.d1[size][ofs]), Vec4vf(v1), + madd(vfloat::loadu(&bspline_basis0.d2[size][ofs]), Vec4vf(v2), + vfloat::loadu(&bspline_basis0.d3[size][ofs]) * Vec4vf(v3)))); + } + + template + __forceinline Vec4vf derivative1(const int ofs, const int size) const + { + assert(size <= PrecomputedBSplineBasis::N); + assert(ofs <= size); + return madd(vfloat::loadu(&bspline_basis1.d0[size][ofs]), Vec4vf(v0), + madd(vfloat::loadu(&bspline_basis1.d1[size][ofs]), Vec4vf(v1), + madd(vfloat::loadu(&bspline_basis1.d2[size][ofs]), Vec4vf(v2), + vfloat::loadu(&bspline_basis1.d3[size][ofs]) * Vec4vf(v3)))); + } + + /* calculates bounds of bspline curve geometry */ + __forceinline BBox3fa accurateRoundBounds() const + { + const int N = 7; + const float scale = 1.0f/(3.0f*(N-1)); + Vec4vfx pl(pos_inf), pu(neg_inf); + for (int i=0; i<=N; i+=VSIZEX) + { + vintx vi = vintx(i)+vintx(step); + vboolx valid = vi <= vintx(N); + const Vec4vfx p = eval0(i,N); + const Vec4vfx dp = derivative0(i,N); + const Vec4vfx pm = p-Vec4vfx(scale)*select(vi!=vintx(0),dp,Vec4vfx(zero)); + const Vec4vfx pp = p+Vec4vfx(scale)*select(vi!=vintx(N),dp,Vec4vfx(zero)); + pl = select(valid,min(pl,p,pm,pp),pl); // FIXME: use masked min + pu = select(valid,max(pu,p,pm,pp),pu); // FIXME: use masked min + } + const Vec3fa lower(reduce_min(pl.x),reduce_min(pl.y),reduce_min(pl.z)); + const Vec3fa upper(reduce_max(pu.x),reduce_max(pu.y),reduce_max(pu.z)); + const float r_min = reduce_min(pl.w); + const float r_max = reduce_max(pu.w); + const Vec3fa upper_r = Vec3fa(max(abs(r_min),abs(r_max))); + return enlarge(BBox3fa(lower,upper),upper_r); + } + + /* calculates bounds when tessellated into N line segments */ + __forceinline BBox3fa accurateFlatBounds(int N) const + { + if (likely(N == 4)) + { + const Vec4vf4 pi = eval0<4>(0,4); + const Vec3fa lower(reduce_min(pi.x),reduce_min(pi.y),reduce_min(pi.z)); + const Vec3fa upper(reduce_max(pi.x),reduce_max(pi.y),reduce_max(pi.z)); + const Vec3fa upper_r = Vec3fa(reduce_max(abs(pi.w))); + const Vec3ff pe = end(); + return enlarge(BBox3fa(min(lower,pe),max(upper,pe)),max(upper_r,Vec3fa(abs(pe.w)))); + } + else + { + Vec3vfx pl(pos_inf), pu(neg_inf); vfloatx ru(0.0f); + for (int i=0; i<=N; i+=VSIZEX) + { + vboolx valid = vintx(i)+vintx(step) <= vintx(N); + const Vec4vfx pi = eval0(i,N); + + pl.x = select(valid,min(pl.x,pi.x),pl.x); // FIXME: use masked min + pl.y = select(valid,min(pl.y,pi.y),pl.y); + pl.z = select(valid,min(pl.z,pi.z),pl.z); + + pu.x = select(valid,max(pu.x,pi.x),pu.x); // FIXME: use masked min + pu.y = select(valid,max(pu.y,pi.y),pu.y); + pu.z = select(valid,max(pu.z,pi.z),pu.z); + + ru = select(valid,max(ru,abs(pi.w)),ru); + } + const Vec3fa lower(reduce_min(pl.x),reduce_min(pl.y),reduce_min(pl.z)); + const Vec3fa upper(reduce_max(pu.x),reduce_max(pu.y),reduce_max(pu.z)); + const Vec3fa upper_r(reduce_max(ru)); + return enlarge(BBox3fa(lower,upper),upper_r); + } + } + + friend __forceinline embree_ostream operator<<(embree_ostream cout, const BSplineCurveT& curve) { + return cout << "BSplineCurve { v0 = " << curve.v0 << ", v1 = " << curve.v1 << ", v2 = " << curve.v2 << ", v3 = " << curve.v3 << " }"; + } + }; + + template + __forceinline void convert(const BezierCurveT& icurve, BezierCurveT& ocurve) { + ocurve = icurve; + } + + template + __forceinline void convert(const BSplineCurveT& icurve, BSplineCurveT& ocurve) { + ocurve = icurve; + } + + template + __forceinline void convert(const BezierCurveT& icurve, BSplineCurveT& ocurve) + { + const Vertex v0 = madd(6.0f,icurve.v0,madd(-7.0f,icurve.v1,2.0f*icurve.v2)); + const Vertex v1 = msub(2.0f,icurve.v1,icurve.v2); + const Vertex v2 = msub(2.0f,icurve.v2,icurve.v1); + const Vertex v3 = madd(2.0f,icurve.v1,madd(-7.0f,icurve.v2,6.0f*icurve.v3)); + ocurve = BSplineCurveT(v0,v1,v2,v3); + } + + template + __forceinline void convert(const BSplineCurveT& icurve, BezierCurveT& ocurve) + { + const Vertex v0 = madd(1.0f/6.0f,icurve.v0,madd(2.0f/3.0f,icurve.v1,1.0f/6.0f*icurve.v2)); + const Vertex v1 = madd(2.0f/3.0f,icurve.v1,1.0f/3.0f*icurve.v2); + const Vertex v2 = madd(1.0f/3.0f,icurve.v1,2.0f/3.0f*icurve.v2); + const Vertex v3 = madd(1.0f/6.0f,icurve.v1,madd(2.0f/3.0f,icurve.v2,1.0f/6.0f*icurve.v3)); + ocurve = BezierCurveT(v0,v1,v2,v3); + } + + __forceinline BSplineCurveT enlargeRadiusToMinWidth(const IntersectContext* context, const CurveGeometry* geom, const Vec3fa& ray_org, const BSplineCurveT& curve) + { + return BSplineCurveT(enlargeRadiusToMinWidth(context,geom,ray_org,curve.v0), + enlargeRadiusToMinWidth(context,geom,ray_org,curve.v1), + enlargeRadiusToMinWidth(context,geom,ray_org,curve.v2), + enlargeRadiusToMinWidth(context,geom,ray_org,curve.v3)); + } + + typedef BSplineCurveT BSplineCurve3fa; +} + diff --git a/thirdparty/embree/kernels/subdiv/bspline_patch.h b/thirdparty/embree/kernels/subdiv/bspline_patch.h new file mode 100644 index 000000000000..9769bc17bdfe --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/bspline_patch.h @@ -0,0 +1,449 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "catmullclark_patch.h" +#include "bspline_curve.h" + +namespace embree +{ + template + class __aligned(64) BSplinePatchT + { + typedef CatmullClark1RingT CatmullClarkRing; + typedef CatmullClarkPatchT CatmullClarkPatch; + + public: + + __forceinline BSplinePatchT () {} + + __forceinline BSplinePatchT (const CatmullClarkPatch& patch) { + init(patch); + } + + __forceinline BSplinePatchT(const CatmullClarkPatch& patch, + const BezierCurveT* border0, + const BezierCurveT* border1, + const BezierCurveT* border2, + const BezierCurveT* border3) + { + init(patch); + } + + __forceinline BSplinePatchT (const HalfEdge* edge, const char* vertices, size_t stride) { + init(edge,vertices,stride); + } + + __forceinline Vertex hard_corner(const Vertex& v01, const Vertex& v02, + const Vertex& v10, const Vertex& v11, const Vertex& v12, + const Vertex& v20, const Vertex& v21, const Vertex& v22) + { + return 4.0f*v11 - 2.0f*(v12+v21) + v22; + } + + __forceinline Vertex soft_convex_corner( const Vertex& v01, const Vertex& v02, + const Vertex& v10, const Vertex& v11, const Vertex& v12, + const Vertex& v20, const Vertex& v21, const Vertex& v22) + { + return -8.0f*v11 + 4.0f*(v12+v21) + v22; + } + + __forceinline Vertex convex_corner(const float vertex_crease_weight, + const Vertex& v01, const Vertex& v02, + const Vertex& v10, const Vertex& v11, const Vertex& v12, + const Vertex& v20, const Vertex& v21, const Vertex& v22) + { + if (std::isinf(vertex_crease_weight)) return hard_corner(v01,v02,v10,v11,v12,v20,v21,v22); + else return soft_convex_corner(v01,v02,v10,v11,v12,v20,v21,v22); + } + + __forceinline Vertex load(const HalfEdge* edge, const char* vertices, size_t stride) { + return Vertex_t::loadu(vertices+edge->getStartVertexIndex()*stride); + } + + __forceinline void init_border(const CatmullClarkRing& edge0, + Vertex& v01, Vertex& v02, + const Vertex& v11, const Vertex& v12, + const Vertex& v21, const Vertex& v22) + { + if (likely(edge0.has_opposite_back(0))) + { + v01 = edge0.back(2); + v02 = edge0.back(1); + } else { + v01 = 2.0f*v11-v21; + v02 = 2.0f*v12-v22; + } + } + + __forceinline void init_corner(const CatmullClarkRing& edge0, + Vertex& v00, const Vertex& v01, const Vertex& v02, + const Vertex& v10, const Vertex& v11, const Vertex& v12, + const Vertex& v20, const Vertex& v21, const Vertex& v22) + { + const bool MAYBE_UNUSED has_back1 = edge0.has_opposite_back(1); + const bool has_back0 = edge0.has_opposite_back(0); + const bool has_front1 = edge0.has_opposite_front(1); + const bool MAYBE_UNUSED has_front2 = edge0.has_opposite_front(2); + + if (likely(has_back0)) { + if (likely(has_front1)) { assert(has_back1 && has_front2); v00 = edge0.back(3); } + else { assert(!has_back1); v00 = 2.0f*v01-v02; } + } + else { + if (likely(has_front1)) { assert(!has_front2); v00 = 2.0f*v10-v20; } + else v00 = convex_corner(edge0.vertex_crease_weight,v01,v02,v10,v11,v12,v20,v21,v22); + } + } + + void init(const CatmullClarkPatch& patch) + { + /* fill inner vertices */ + const Vertex v11 = v[1][1] = patch.ring[0].vtx; + const Vertex v12 = v[1][2] = patch.ring[1].vtx; + const Vertex v22 = v[2][2] = patch.ring[2].vtx; + const Vertex v21 = v[2][1] = patch.ring[3].vtx; + + /* fill border vertices */ + init_border(patch.ring[0],v[0][1],v[0][2],v11,v12,v21,v22); + init_border(patch.ring[1],v[1][3],v[2][3],v12,v22,v11,v21); + init_border(patch.ring[2],v[3][2],v[3][1],v22,v21,v12,v11); + init_border(patch.ring[3],v[2][0],v[1][0],v21,v11,v22,v12); + + /* fill corner vertices */ + init_corner(patch.ring[0],v[0][0],v[0][1],v[0][2],v[1][0],v11,v12,v[2][0],v21,v22); + init_corner(patch.ring[1],v[0][3],v[1][3],v[2][3],v[0][2],v12,v22,v[0][1],v11,v21); + init_corner(patch.ring[2],v[3][3],v[3][2],v[3][1],v[2][3],v22,v21,v[1][3],v12,v11); + init_corner(patch.ring[3],v[3][0],v[2][0],v[1][0],v[3][1],v21,v11,v[3][2],v22,v12); + } + + void init_border(const HalfEdge* edge0, const char* vertices, size_t stride, + Vertex& v01, Vertex& v02, + const Vertex& v11, const Vertex& v12, + const Vertex& v21, const Vertex& v22) + { + if (likely(edge0->hasOpposite())) + { + const HalfEdge* e = edge0->opposite()->next()->next(); + v01 = load(e,vertices,stride); + v02 = load(e->next(),vertices,stride); + } else { + v01 = 2.0f*v11-v21; + v02 = 2.0f*v12-v22; + } + } + + void init_corner(const HalfEdge* edge0, const char* vertices, size_t stride, + Vertex& v00, const Vertex& v01, const Vertex& v02, + const Vertex& v10, const Vertex& v11, const Vertex& v12, + const Vertex& v20, const Vertex& v21, const Vertex& v22) + { + const bool has_back0 = edge0->hasOpposite(); + const bool has_front1 = edge0->prev()->hasOpposite(); + + if (likely(has_back0)) + { + const HalfEdge* e = edge0->opposite()->next(); + if (likely(has_front1)) + { + assert(e->hasOpposite()); + assert(edge0->prev()->opposite()->prev()->hasOpposite()); + v00 = load(e->opposite()->prev(),vertices,stride); + } + else { + assert(!e->hasOpposite()); + v00 = 2.0f*v01-v02; + } + } + else + { + if (likely(has_front1)) { + assert(!edge0->prev()->opposite()->prev()->hasOpposite()); + v00 = 2.0f*v10-v20; + } + else { + assert(edge0->vertex_crease_weight == 0.0f || std::isinf(edge0->vertex_crease_weight)); + v00 = convex_corner(edge0->vertex_crease_weight,v01,v02,v10,v11,v12,v20,v21,v22); + } + } + } + + void init(const HalfEdge* edge0, const char* vertices, size_t stride) + { + assert( edge0->isRegularFace() ); + + /* fill inner vertices */ + const Vertex v11 = v[1][1] = load(edge0,vertices,stride); const HalfEdge* edge1 = edge0->next(); + const Vertex v12 = v[1][2] = load(edge1,vertices,stride); const HalfEdge* edge2 = edge1->next(); + const Vertex v22 = v[2][2] = load(edge2,vertices,stride); const HalfEdge* edge3 = edge2->next(); + const Vertex v21 = v[2][1] = load(edge3,vertices,stride); assert(edge0 == edge3->next()); + + /* fill border vertices */ + init_border(edge0,vertices,stride,v[0][1],v[0][2],v11,v12,v21,v22); + init_border(edge1,vertices,stride,v[1][3],v[2][3],v12,v22,v11,v21); + init_border(edge2,vertices,stride,v[3][2],v[3][1],v22,v21,v12,v11); + init_border(edge3,vertices,stride,v[2][0],v[1][0],v21,v11,v22,v12); + + /* fill corner vertices */ + init_corner(edge0,vertices,stride,v[0][0],v[0][1],v[0][2],v[1][0],v11,v12,v[2][0],v21,v22); + init_corner(edge1,vertices,stride,v[0][3],v[1][3],v[2][3],v[0][2],v12,v22,v[0][1],v11,v21); + init_corner(edge2,vertices,stride,v[3][3],v[3][2],v[3][1],v[2][3],v22,v21,v[1][3],v12,v11); + init_corner(edge3,vertices,stride,v[3][0],v[2][0],v[1][0],v[3][1],v21,v11,v[3][2],v22,v12); + } + + __forceinline BBox bounds() const + { + const Vertex* const cv = &v[0][0]; + BBox bounds (cv[0]); + for (size_t i=1; i<16 ; i++) + bounds.extend( cv[i] ); + return bounds; + } + + __forceinline Vertex eval(const float uu, const float vv) const + { + const Vec4f v_n = BSplineBasis::eval(vv); + const Vertex_t curve0 = madd(v_n[0],v[0][0],madd(v_n[1],v[1][0],madd(v_n[2],v[2][0],v_n[3] * v[3][0]))); + const Vertex_t curve1 = madd(v_n[0],v[0][1],madd(v_n[1],v[1][1],madd(v_n[2],v[2][1],v_n[3] * v[3][1]))); + const Vertex_t curve2 = madd(v_n[0],v[0][2],madd(v_n[1],v[1][2],madd(v_n[2],v[2][2],v_n[3] * v[3][2]))); + const Vertex_t curve3 = madd(v_n[0],v[0][3],madd(v_n[1],v[1][3],madd(v_n[2],v[2][3],v_n[3] * v[3][3]))); + + const Vec4f u_n = BSplineBasis::eval(uu); + return madd(u_n[0],curve0,madd(u_n[1],curve1,madd(u_n[2],curve2,u_n[3] * curve3))); + } + + __forceinline Vertex eval_du(const float uu, const float vv) const + { + const Vec4f v_n = BSplineBasis::eval(vv); + const Vertex_t curve0 = madd(v_n[0],v[0][0],madd(v_n[1],v[1][0],madd(v_n[2],v[2][0],v_n[3] * v[3][0]))); + const Vertex_t curve1 = madd(v_n[0],v[0][1],madd(v_n[1],v[1][1],madd(v_n[2],v[2][1],v_n[3] * v[3][1]))); + const Vertex_t curve2 = madd(v_n[0],v[0][2],madd(v_n[1],v[1][2],madd(v_n[2],v[2][2],v_n[3] * v[3][2]))); + const Vertex_t curve3 = madd(v_n[0],v[0][3],madd(v_n[1],v[1][3],madd(v_n[2],v[2][3],v_n[3] * v[3][3]))); + + const Vec4f u_n = BSplineBasis::derivative(uu); + return madd(u_n[0],curve0,madd(u_n[1],curve1,madd(u_n[2],curve2,u_n[3] * curve3))); + } + + __forceinline Vertex eval_dv(const float uu, const float vv) const + { + const Vec4f v_n = BSplineBasis::derivative(vv); + const Vertex_t curve0 = madd(v_n[0],v[0][0],madd(v_n[1],v[1][0],madd(v_n[2],v[2][0],v_n[3] * v[3][0]))); + const Vertex_t curve1 = madd(v_n[0],v[0][1],madd(v_n[1],v[1][1],madd(v_n[2],v[2][1],v_n[3] * v[3][1]))); + const Vertex_t curve2 = madd(v_n[0],v[0][2],madd(v_n[1],v[1][2],madd(v_n[2],v[2][2],v_n[3] * v[3][2]))); + const Vertex_t curve3 = madd(v_n[0],v[0][3],madd(v_n[1],v[1][3],madd(v_n[2],v[2][3],v_n[3] * v[3][3]))); + + const Vec4f u_n = BSplineBasis::eval(uu); + return madd(u_n[0],curve0,madd(u_n[1],curve1,madd(u_n[2],curve2,u_n[3] * curve3))); + } + + __forceinline Vertex eval_dudu(const float uu, const float vv) const + { + const Vec4f v_n = BSplineBasis::eval(vv); + const Vertex_t curve0 = madd(v_n[0],v[0][0],madd(v_n[1],v[1][0],madd(v_n[2],v[2][0],v_n[3] * v[3][0]))); + const Vertex_t curve1 = madd(v_n[0],v[0][1],madd(v_n[1],v[1][1],madd(v_n[2],v[2][1],v_n[3] * v[3][1]))); + const Vertex_t curve2 = madd(v_n[0],v[0][2],madd(v_n[1],v[1][2],madd(v_n[2],v[2][2],v_n[3] * v[3][2]))); + const Vertex_t curve3 = madd(v_n[0],v[0][3],madd(v_n[1],v[1][3],madd(v_n[2],v[2][3],v_n[3] * v[3][3]))); + + const Vec4f u_n = BSplineBasis::derivative2(uu); + return madd(u_n[0],curve0,madd(u_n[1],curve1,madd(u_n[2],curve2,u_n[3] * curve3))); + } + + __forceinline Vertex eval_dvdv(const float uu, const float vv) const + { + const Vec4f v_n = BSplineBasis::derivative2(vv); + const Vertex_t curve0 = madd(v_n[0],v[0][0],madd(v_n[1],v[1][0],madd(v_n[2],v[2][0],v_n[3] * v[3][0]))); + const Vertex_t curve1 = madd(v_n[0],v[0][1],madd(v_n[1],v[1][1],madd(v_n[2],v[2][1],v_n[3] * v[3][1]))); + const Vertex_t curve2 = madd(v_n[0],v[0][2],madd(v_n[1],v[1][2],madd(v_n[2],v[2][2],v_n[3] * v[3][2]))); + const Vertex_t curve3 = madd(v_n[0],v[0][3],madd(v_n[1],v[1][3],madd(v_n[2],v[2][3],v_n[3] * v[3][3]))); + + const Vec4f u_n = BSplineBasis::eval(uu); + return madd(u_n[0],curve0,madd(u_n[1],curve1,madd(u_n[2],curve2,u_n[3] * curve3))); + } + + __forceinline Vertex eval_dudv(const float uu, const float vv) const + { + const Vec4f v_n = BSplineBasis::derivative(vv); + const Vertex_t curve0 = madd(v_n[0],v[0][0],madd(v_n[1],v[1][0],madd(v_n[2],v[2][0],v_n[3] * v[3][0]))); + const Vertex_t curve1 = madd(v_n[0],v[0][1],madd(v_n[1],v[1][1],madd(v_n[2],v[2][1],v_n[3] * v[3][1]))); + const Vertex_t curve2 = madd(v_n[0],v[0][2],madd(v_n[1],v[1][2],madd(v_n[2],v[2][2],v_n[3] * v[3][2]))); + const Vertex_t curve3 = madd(v_n[0],v[0][3],madd(v_n[1],v[1][3],madd(v_n[2],v[2][3],v_n[3] * v[3][3]))); + + const Vec4f u_n = BSplineBasis::derivative(uu); + return madd(u_n[0],curve0,madd(u_n[1],curve1,madd(u_n[2],curve2,u_n[3] * curve3))); + } + + __forceinline Vertex normal(const float uu, const float vv) const + { + const Vertex tu = eval_du(uu,vv); + const Vertex tv = eval_dv(uu,vv); + return cross(tu,tv); + } + + template + __forceinline Vec3 eval(const T& uu, const T& vv, const Vec4& u_n, const Vec4& v_n) const + { + const T curve0_x = madd(v_n[0],T(v[0][0].x),madd(v_n[1],T(v[1][0].x),madd(v_n[2],T(v[2][0].x),v_n[3] * T(v[3][0].x)))); + const T curve1_x = madd(v_n[0],T(v[0][1].x),madd(v_n[1],T(v[1][1].x),madd(v_n[2],T(v[2][1].x),v_n[3] * T(v[3][1].x)))); + const T curve2_x = madd(v_n[0],T(v[0][2].x),madd(v_n[1],T(v[1][2].x),madd(v_n[2],T(v[2][2].x),v_n[3] * T(v[3][2].x)))); + const T curve3_x = madd(v_n[0],T(v[0][3].x),madd(v_n[1],T(v[1][3].x),madd(v_n[2],T(v[2][3].x),v_n[3] * T(v[3][3].x)))); + const T x = madd(u_n[0],curve0_x,madd(u_n[1],curve1_x,madd(u_n[2],curve2_x,u_n[3] * curve3_x))); + + const T curve0_y = madd(v_n[0],T(v[0][0].y),madd(v_n[1],T(v[1][0].y),madd(v_n[2],T(v[2][0].y),v_n[3] * T(v[3][0].y)))); + const T curve1_y = madd(v_n[0],T(v[0][1].y),madd(v_n[1],T(v[1][1].y),madd(v_n[2],T(v[2][1].y),v_n[3] * T(v[3][1].y)))); + const T curve2_y = madd(v_n[0],T(v[0][2].y),madd(v_n[1],T(v[1][2].y),madd(v_n[2],T(v[2][2].y),v_n[3] * T(v[3][2].y)))); + const T curve3_y = madd(v_n[0],T(v[0][3].y),madd(v_n[1],T(v[1][3].y),madd(v_n[2],T(v[2][3].y),v_n[3] * T(v[3][3].y)))); + const T y = madd(u_n[0],curve0_y,madd(u_n[1],curve1_y,madd(u_n[2],curve2_y,u_n[3] * curve3_y))); + + const T curve0_z = madd(v_n[0],T(v[0][0].z),madd(v_n[1],T(v[1][0].z),madd(v_n[2],T(v[2][0].z),v_n[3] * T(v[3][0].z)))); + const T curve1_z = madd(v_n[0],T(v[0][1].z),madd(v_n[1],T(v[1][1].z),madd(v_n[2],T(v[2][1].z),v_n[3] * T(v[3][1].z)))); + const T curve2_z = madd(v_n[0],T(v[0][2].z),madd(v_n[1],T(v[1][2].z),madd(v_n[2],T(v[2][2].z),v_n[3] * T(v[3][2].z)))); + const T curve3_z = madd(v_n[0],T(v[0][3].z),madd(v_n[1],T(v[1][3].z),madd(v_n[2],T(v[2][3].z),v_n[3] * T(v[3][3].z)))); + const T z = madd(u_n[0],curve0_z,madd(u_n[1],curve1_z,madd(u_n[2],curve2_z,u_n[3] * curve3_z))); + + return Vec3(x,y,z); + } + + template + __forceinline Vec3 eval(const T& uu, const T& vv) const + { + const Vec4 u_n = BSplineBasis::eval(uu); + const Vec4 v_n = BSplineBasis::eval(vv); + return eval(uu,vv,u_n,v_n); + } + + template + __forceinline Vec3 eval_du(const T& uu, const T& vv) const + { + const Vec4 u_n = BSplineBasis::derivative(uu); + const Vec4 v_n = BSplineBasis::eval(vv); + return eval(uu,vv,u_n,v_n); + } + + template + __forceinline Vec3 eval_dv(const T& uu, const T& vv) const + { + const Vec4 u_n = BSplineBasis::eval(uu); + const Vec4 v_n = BSplineBasis::derivative(vv); + return eval(uu,vv,u_n,v_n); + } + + template + __forceinline Vec3 eval_dudu(const T& uu, const T& vv) const + { + const Vec4 u_n = BSplineBasis::derivative2(uu); + const Vec4 v_n = BSplineBasis::eval(vv); + return eval(uu,vv,u_n,v_n); + } + + template + __forceinline Vec3 eval_dvdv(const T& uu, const T& vv) const + { + const Vec4 u_n = BSplineBasis::eval(uu); + const Vec4 v_n = BSplineBasis::derivative2(vv); + return eval(uu,vv,u_n,v_n); + } + + template + __forceinline Vec3 eval_dudv(const T& uu, const T& vv) const + { + const Vec4 u_n = BSplineBasis::derivative(uu); + const Vec4 v_n = BSplineBasis::derivative(vv); + return eval(uu,vv,u_n,v_n); + } + + template + __forceinline Vec3 normal(const T& uu, const T& vv) const { + return cross(eval_du(uu,vv),eval_dv(uu,vv)); + } + + void eval(const float u, const float v, + Vertex* P, Vertex* dPdu, Vertex* dPdv, Vertex* ddPdudu, Vertex* ddPdvdv, Vertex* ddPdudv, + const float dscale = 1.0f) const + { + if (P) { + *P = eval(u,v); + } + if (dPdu) { + assert(dPdu); *dPdu = eval_du(u,v)*dscale; + assert(dPdv); *dPdv = eval_dv(u,v)*dscale; + } + if (ddPdudu) { + assert(ddPdudu); *ddPdudu = eval_dudu(u,v)*sqr(dscale); + assert(ddPdvdv); *ddPdvdv = eval_dvdv(u,v)*sqr(dscale); + assert(ddPdudv); *ddPdudv = eval_dudv(u,v)*sqr(dscale); + } + } + + template + __forceinline vfloat eval(const size_t i, const vfloat& uu, const vfloat& vv, const Vec4& u_n, const Vec4& v_n) const + { + const vfloat curve0_x = madd(v_n[0],vfloat(v[0][0][i]),madd(v_n[1],vfloat(v[1][0][i]),madd(v_n[2],vfloat(v[2][0][i]),v_n[3] * vfloat(v[3][0][i])))); + const vfloat curve1_x = madd(v_n[0],vfloat(v[0][1][i]),madd(v_n[1],vfloat(v[1][1][i]),madd(v_n[2],vfloat(v[2][1][i]),v_n[3] * vfloat(v[3][1][i])))); + const vfloat curve2_x = madd(v_n[0],vfloat(v[0][2][i]),madd(v_n[1],vfloat(v[1][2][i]),madd(v_n[2],vfloat(v[2][2][i]),v_n[3] * vfloat(v[3][2][i])))); + const vfloat curve3_x = madd(v_n[0],vfloat(v[0][3][i]),madd(v_n[1],vfloat(v[1][3][i]),madd(v_n[2],vfloat(v[2][3][i]),v_n[3] * vfloat(v[3][3][i])))); + return madd(u_n[0],curve0_x,madd(u_n[1],curve1_x,madd(u_n[2],curve2_x,u_n[3] * curve3_x))); + } + + template + void eval(const vbool& valid, const vfloat& uu, const vfloat& vv, + float* P, float* dPdu, float* dPdv, float* ddPdudu, float* ddPdvdv, float* ddPdudv, + const float dscale, const size_t dstride, const size_t N) const + { + if (P) { + const Vec4 u_n = BSplineBasis::eval(uu); + const Vec4 v_n = BSplineBasis::eval(vv); + for (size_t i=0; i u_n = BSplineBasis::derivative(uu); + const Vec4 v_n = BSplineBasis::eval(vv); + for (size_t i=0; i u_n = BSplineBasis::eval(uu); + const Vec4 v_n = BSplineBasis::derivative(vv); + for (size_t i=0; i u_n = BSplineBasis::derivative2(uu); + const Vec4 v_n = BSplineBasis::eval(vv); + for (size_t i=0; i u_n = BSplineBasis::eval(uu); + const Vec4 v_n = BSplineBasis::derivative2(vv); + for (size_t i=0; i u_n = BSplineBasis::derivative(uu); + const Vec4 v_n = BSplineBasis::derivative(vv); + for (size_t i=0; i BSplinePatch3fa; +} diff --git a/thirdparty/embree/kernels/subdiv/catmullclark_coefficients.h b/thirdparty/embree/kernels/subdiv/catmullclark_coefficients.h new file mode 100644 index 000000000000..05031cf6b9ae --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/catmullclark_coefficients.h @@ -0,0 +1,85 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/geometry.h" + +namespace embree +{ + static const size_t MAX_PATCH_VALENCE = 16; //!< maximum number of vertices of a patch + static const size_t MAX_RING_FACE_VALENCE = 64; //!< maximum number of faces per ring + static const size_t MAX_RING_EDGE_VALENCE = 2*64; //!< maximum number of edges per ring + + class CatmullClarkPrecomputedCoefficients + { + private: + + float table_cos_2PI_div_n[MAX_RING_FACE_VALENCE+1]; + + float* table_limittangent_a[MAX_RING_FACE_VALENCE+1]; + float* table_limittangent_b[MAX_RING_FACE_VALENCE+1]; + float table_limittangent_c[MAX_RING_FACE_VALENCE+1]; + + __forceinline float set_cos_2PI_div_n(const size_t n) { + if (unlikely(n == 0)) return 1.0f; + return cosf(2.0f*float(pi)/(float)n); + } + + __forceinline float set_limittangent_a(const size_t i, const size_t n) + { + if (unlikely(n == 0)) return 1.0f; + const float c0 = 1.0f/(float)n * 1.0f / sqrtf(4.0f + cosf(float(pi)/(float)n)*cosf(float(pi)/(float)n)); + const float c1 = (1.0f/(float)n + cosf(float(pi)/(float)n) * c0); + return cosf(2.0f*float(pi)*(float)i/(float)n) * c1; + } + + __forceinline float set_limittangent_b(const size_t i, const size_t n) + { + if (unlikely(n == 0)) return 1.0f; + const float c0 = 1.0f/(float)n * 1.0f / sqrtf(4.0f + cosf(float(pi)/(float)n)*cosf(float(pi)/(float)n)); + return cosf((2.0f*float(pi)*i+float(pi))/(float)n) * c0; + } + + __forceinline float set_limittangent_c(const size_t n) + { + if (unlikely(n == 0)) return 1.0f; + return 2.0f/16.0f * (5.0f + cosf(2.0f*float(pi)/(float)n) + cosf(float(pi)/(float)n) * sqrtf(18.0f+2.0f*cosf(2.0f*float(pi)/(float)n))); + } + + public: + + __forceinline float cos_2PI_div_n(const size_t n) + { + if (likely(n <= MAX_RING_FACE_VALENCE)) + return table_cos_2PI_div_n[n]; + else + return set_cos_2PI_div_n(n); + } + + __forceinline float limittangent_a(const size_t i, const size_t n) + { + assert(n <= MAX_RING_FACE_VALENCE); + assert(i < n); + return table_limittangent_a[n][i]; + } + + __forceinline float limittangent_b(const size_t i, const size_t n) + { + assert(n <= MAX_RING_FACE_VALENCE); + assert(i < n); + return table_limittangent_b[n][i]; + } + + __forceinline float limittangent_c(const size_t n) + { + assert(n <= MAX_RING_FACE_VALENCE); + return table_limittangent_c[n]; + } + + static CatmullClarkPrecomputedCoefficients table; + + CatmullClarkPrecomputedCoefficients(); + ~CatmullClarkPrecomputedCoefficients(); + }; +} diff --git a/thirdparty/embree/kernels/subdiv/catmullclark_patch.h b/thirdparty/embree/kernels/subdiv/catmullclark_patch.h new file mode 100644 index 000000000000..ab1d63594a06 --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/catmullclark_patch.h @@ -0,0 +1,562 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "catmullclark_ring.h" +#include "bezier_curve.h" + +namespace embree +{ + template + class __aligned(64) CatmullClarkPatchT + { + public: + typedef CatmullClark1RingT CatmullClark1Ring; + typedef typename CatmullClark1Ring::Type Type; + + array_t,4> ring; + + public: + __forceinline CatmullClarkPatchT () {} + + __forceinline CatmullClarkPatchT (const HalfEdge* first_half_edge, const char* vertices, size_t stride) { + init(first_half_edge,vertices,stride); + } + + __forceinline CatmullClarkPatchT (const HalfEdge* first_half_edge, const BufferView& vertices) { + init(first_half_edge,vertices.getPtr(),vertices.getStride()); + } + + __forceinline void init (const HalfEdge* first_half_edge, const char* vertices, size_t stride) + { + for (unsigned i=0; i<4; i++) + ring[i].init(first_half_edge+i,vertices,stride); + + assert(verify()); + } + + __forceinline size_t bytes() const { + return ring[0].bytes()+ring[1].bytes()+ring[2].bytes()+ring[3].bytes(); + } + + __forceinline void serialize(void* ptr, size_t& ofs) const + { + for (size_t i=0; i<4; i++) + ring[i].serialize((char*)ptr,ofs); + } + + __forceinline void deserialize(void* ptr) + { + size_t ofs = 0; + for (size_t i=0; i<4; i++) + ring[i].deserialize((char*)ptr,ofs); + } + + __forceinline BBox3fa bounds() const + { + BBox3fa bounds (ring[0].bounds()); + for (size_t i=1; i<4; i++) + bounds.extend(ring[i].bounds()); + return bounds; + } + + __forceinline Type type() const + { + const int ty0 = ring[0].type() ^ CatmullClark1Ring::TYPE_CREASES; + const int ty1 = ring[1].type() ^ CatmullClark1Ring::TYPE_CREASES; + const int ty2 = ring[2].type() ^ CatmullClark1Ring::TYPE_CREASES; + const int ty3 = ring[3].type() ^ CatmullClark1Ring::TYPE_CREASES; + return (Type) ((ty0 & ty1 & ty2 & ty3) ^ CatmullClark1Ring::TYPE_CREASES); + } + + __forceinline bool isFinalResolution(float res) const { + return ring[0].isFinalResolution(res) && ring[1].isFinalResolution(res) && ring[2].isFinalResolution(res) && ring[3].isFinalResolution(res); + } + + static __forceinline void init_regular(const CatmullClark1RingT& p0, + const CatmullClark1RingT& p1, + CatmullClark1RingT& dest0, + CatmullClark1RingT& dest1) + { + assert(p1.face_valence > 2); + dest1.vertex_level = dest0.vertex_level = p0.edge_level; + dest1.face_valence = dest0.face_valence = 4; + dest1.edge_valence = dest0.edge_valence = 8; + dest1.border_index = dest0.border_index = -1; + dest1.vtx = dest0.vtx = (Vertex_t)p0.ring[0]; + dest1.vertex_crease_weight = dest0.vertex_crease_weight = 0.0f; + + dest1.ring[2] = dest0.ring[0] = (Vertex_t)p0.ring[1]; + dest1.ring[1] = dest0.ring[7] = (Vertex_t)p1.ring[0]; + dest1.ring[0] = dest0.ring[6] = (Vertex_t)p1.vtx; + dest1.ring[7] = dest0.ring[5] = (Vertex_t)p1.ring[4]; + dest1.ring[6] = dest0.ring[4] = (Vertex_t)p0.ring[p0.edge_valence-1]; + dest1.ring[5] = dest0.ring[3] = (Vertex_t)p0.ring[p0.edge_valence-2]; + dest1.ring[4] = dest0.ring[2] = (Vertex_t)p0.vtx; + dest1.ring[3] = dest0.ring[1] = (Vertex_t)p0.ring[2]; + + dest1.crease_weight[1] = dest0.crease_weight[0] = 0.0f; + dest1.crease_weight[0] = dest0.crease_weight[3] = p1.crease_weight[1]; + dest1.crease_weight[3] = dest0.crease_weight[2] = 0.0f; + dest1.crease_weight[2] = dest0.crease_weight[1] = p0.crease_weight[0]; + + if (p0.eval_unique_identifier <= p1.eval_unique_identifier) + { + dest0.eval_start_index = 3; + dest1.eval_start_index = 0; + dest0.eval_unique_identifier = p0.eval_unique_identifier; + dest1.eval_unique_identifier = p0.eval_unique_identifier; + } + else + { + dest0.eval_start_index = 1; + dest1.eval_start_index = 2; + dest0.eval_unique_identifier = p1.eval_unique_identifier; + dest1.eval_unique_identifier = p1.eval_unique_identifier; + } + } + + static __forceinline void init_border(const CatmullClark1RingT &p0, + const CatmullClark1RingT &p1, + CatmullClark1RingT &dest0, + CatmullClark1RingT &dest1) + { + dest1.vertex_level = dest0.vertex_level = p0.edge_level; + dest1.face_valence = dest0.face_valence = 3; + dest1.edge_valence = dest0.edge_valence = 6; + dest0.border_index = 2; + dest1.border_index = 4; + dest1.vtx = dest0.vtx = (Vertex_t)p0.ring[0]; + dest1.vertex_crease_weight = dest0.vertex_crease_weight = 0.0f; + + dest1.ring[2] = dest0.ring[0] = (Vertex_t)p0.ring[1]; + dest1.ring[1] = dest0.ring[5] = (Vertex_t)p1.ring[0]; + dest1.ring[0] = dest0.ring[4] = (Vertex_t)p1.vtx; + dest1.ring[5] = dest0.ring[3] = (Vertex_t)p0.ring[p0.border_index+1]; // dummy + dest1.ring[4] = dest0.ring[2] = (Vertex_t)p0.vtx; + dest1.ring[3] = dest0.ring[1] = (Vertex_t)p0.ring[2]; + + dest1.crease_weight[1] = dest0.crease_weight[0] = 0.0f; + dest1.crease_weight[0] = dest0.crease_weight[2] = p1.crease_weight[1]; + dest1.crease_weight[2] = dest0.crease_weight[1] = p0.crease_weight[0]; + + if (p0.eval_unique_identifier <= p1.eval_unique_identifier) + { + dest0.eval_start_index = 1; + dest1.eval_start_index = 2; + dest0.eval_unique_identifier = p0.eval_unique_identifier; + dest1.eval_unique_identifier = p0.eval_unique_identifier; + } + else + { + dest0.eval_start_index = 2; + dest1.eval_start_index = 0; + dest0.eval_unique_identifier = p1.eval_unique_identifier; + dest1.eval_unique_identifier = p1.eval_unique_identifier; + } + } + + static __forceinline void init_regular(const Vertex_t ¢er, const Vertex_t center_ring[8], const unsigned int offset, CatmullClark1RingT &dest) + { + dest.vertex_level = 0.0f; + dest.face_valence = 4; + dest.edge_valence = 8; + dest.border_index = -1; + dest.vtx = (Vertex_t)center; + dest.vertex_crease_weight = 0.0f; + for (size_t i=0; i<8; i++) + dest.ring[i] = (Vertex_t)center_ring[(offset+i)%8]; + for (size_t i=0; i<4; i++) + dest.crease_weight[i] = 0.0f; + + dest.eval_start_index = (8-offset)>>1; + if (dest.eval_start_index >= dest.face_valence) dest.eval_start_index -= dest.face_valence; + assert( dest.eval_start_index < dest.face_valence ); + dest.eval_unique_identifier = 0; + } + + __noinline void subdivide(array_t& patch) const + { + ring[0].subdivide(patch[0].ring[0]); + ring[1].subdivide(patch[1].ring[1]); + ring[2].subdivide(patch[2].ring[2]); + ring[3].subdivide(patch[3].ring[3]); + + patch[0].ring[0].edge_level = 0.5f*ring[0].edge_level; + patch[0].ring[1].edge_level = 0.25f*(ring[1].edge_level+ring[3].edge_level); + patch[0].ring[2].edge_level = 0.25f*(ring[0].edge_level+ring[2].edge_level); + patch[0].ring[3].edge_level = 0.5f*ring[3].edge_level; + + patch[1].ring[0].edge_level = 0.5f*ring[0].edge_level; + patch[1].ring[1].edge_level = 0.5f*ring[1].edge_level; + patch[1].ring[2].edge_level = 0.25f*(ring[0].edge_level+ring[2].edge_level); + patch[1].ring[3].edge_level = 0.25f*(ring[1].edge_level+ring[3].edge_level); + + patch[2].ring[0].edge_level = 0.25f*(ring[0].edge_level+ring[2].edge_level); + patch[2].ring[1].edge_level = 0.5f*ring[1].edge_level; + patch[2].ring[2].edge_level = 0.5f*ring[2].edge_level; + patch[2].ring[3].edge_level = 0.25f*(ring[1].edge_level+ring[3].edge_level); + + patch[3].ring[0].edge_level = 0.25f*(ring[0].edge_level+ring[2].edge_level); + patch[3].ring[1].edge_level = 0.25f*(ring[1].edge_level+ring[3].edge_level); + patch[3].ring[2].edge_level = 0.5f*ring[2].edge_level; + patch[3].ring[3].edge_level = 0.5f*ring[3].edge_level; + + const bool regular0 = ring[0].has_last_face() && ring[1].face_valence > 2; + if (likely(regular0)) + init_regular(patch[0].ring[0],patch[1].ring[1],patch[0].ring[1],patch[1].ring[0]); + else + init_border(patch[0].ring[0],patch[1].ring[1],patch[0].ring[1],patch[1].ring[0]); + + const bool regular1 = ring[1].has_last_face() && ring[2].face_valence > 2; + if (likely(regular1)) + init_regular(patch[1].ring[1],patch[2].ring[2],patch[1].ring[2],patch[2].ring[1]); + else + init_border(patch[1].ring[1],patch[2].ring[2],patch[1].ring[2],patch[2].ring[1]); + + const bool regular2 = ring[2].has_last_face() && ring[3].face_valence > 2; + if (likely(regular2)) + init_regular(patch[2].ring[2],patch[3].ring[3],patch[2].ring[3],patch[3].ring[2]); + else + init_border(patch[2].ring[2],patch[3].ring[3],patch[2].ring[3],patch[3].ring[2]); + + const bool regular3 = ring[3].has_last_face() && ring[0].face_valence > 2; + if (likely(regular3)) + init_regular(patch[3].ring[3],patch[0].ring[0],patch[3].ring[0],patch[0].ring[3]); + else + init_border(patch[3].ring[3],patch[0].ring[0],patch[3].ring[0],patch[0].ring[3]); + + Vertex_t center = (ring[0].vtx + ring[1].vtx + ring[2].vtx + ring[3].vtx) * 0.25f; + + Vertex_t center_ring[8]; + center_ring[0] = (Vertex_t)patch[3].ring[3].ring[0]; + center_ring[7] = (Vertex_t)patch[3].ring[3].vtx; + center_ring[6] = (Vertex_t)patch[2].ring[2].ring[0]; + center_ring[5] = (Vertex_t)patch[2].ring[2].vtx; + center_ring[4] = (Vertex_t)patch[1].ring[1].ring[0]; + center_ring[3] = (Vertex_t)patch[1].ring[1].vtx; + center_ring[2] = (Vertex_t)patch[0].ring[0].ring[0]; + center_ring[1] = (Vertex_t)patch[0].ring[0].vtx; + + init_regular(center,center_ring,0,patch[0].ring[2]); + init_regular(center,center_ring,2,patch[1].ring[3]); + init_regular(center,center_ring,4,patch[2].ring[0]); + init_regular(center,center_ring,6,patch[3].ring[1]); + + assert(patch[0].verify()); + assert(patch[1].verify()); + assert(patch[2].verify()); + assert(patch[3].verify()); + } + + bool verify() const { + return ring[0].hasValidPositions() && ring[1].hasValidPositions() && ring[2].hasValidPositions() && ring[3].hasValidPositions(); + } + + __forceinline void init( FinalQuad& quad ) const + { + quad.vtx[0] = (Vertex_t)ring[0].vtx; + quad.vtx[1] = (Vertex_t)ring[1].vtx; + quad.vtx[2] = (Vertex_t)ring[2].vtx; + quad.vtx[3] = (Vertex_t)ring[3].vtx; + }; + + friend __forceinline embree_ostream operator<<(embree_ostream o, const CatmullClarkPatchT &p) + { + o << "CatmullClarkPatch { " << embree_endl; + for (size_t i=0; i<4; i++) + o << "ring" << i << ": " << p.ring[i] << embree_endl; + o << "}" << embree_endl; + return o; + } + }; + + typedef CatmullClarkPatchT CatmullClarkPatch3fa; + + template + class __aligned(64) GeneralCatmullClarkPatchT + { + public: + typedef CatmullClarkPatchT CatmullClarkPatch; + typedef CatmullClark1RingT CatmullClark1Ring; + typedef BezierCurveT BezierCurve; + + static const unsigned SIZE = MAX_PATCH_VALENCE; + DynamicStackArray,8,SIZE> ring; + unsigned N; + + __forceinline GeneralCatmullClarkPatchT () + : N(0) {} + + GeneralCatmullClarkPatchT (const HalfEdge* h, const char* vertices, size_t stride) { + init(h,vertices,stride); + } + + __forceinline GeneralCatmullClarkPatchT (const HalfEdge* first_half_edge, const BufferView& vertices) { + init(first_half_edge,vertices.getPtr(),vertices.getStride()); + } + + __forceinline void init (const HalfEdge* h, const char* vertices, size_t stride) + { + unsigned int i = 0; + const HalfEdge* edge = h; + do { + ring[i].init(edge,vertices,stride); + edge = edge->next(); + i++; + } while ((edge != h) && (i < SIZE)); + N = i; + } + + __forceinline unsigned size() const { + return N; + } + + __forceinline bool isQuadPatch() const { + return (N == 4) && ring[0].only_quads && ring[1].only_quads && ring[2].only_quads && ring[3].only_quads; + } + + static __forceinline void init_regular(const CatmullClark1RingT& p0, + const CatmullClark1RingT& p1, + CatmullClark1RingT& dest0, + CatmullClark1RingT& dest1) + { + assert(p1.face_valence > 2); + dest1.vertex_level = dest0.vertex_level = p0.edge_level; + dest1.face_valence = dest0.face_valence = 4; + dest1.edge_valence = dest0.edge_valence = 8; + dest1.border_index = dest0.border_index = -1; + dest1.vtx = dest0.vtx = (Vertex_t)p0.ring[0]; + dest1.vertex_crease_weight = dest0.vertex_crease_weight = 0.0f; + + dest1.ring[2] = dest0.ring[0] = (Vertex_t)p0.ring[1]; + dest1.ring[1] = dest0.ring[7] = (Vertex_t)p1.ring[0]; + dest1.ring[0] = dest0.ring[6] = (Vertex_t)p1.vtx; + dest1.ring[7] = dest0.ring[5] = (Vertex_t)p1.ring[4]; + dest1.ring[6] = dest0.ring[4] = (Vertex_t)p0.ring[p0.edge_valence-1]; + dest1.ring[5] = dest0.ring[3] = (Vertex_t)p0.ring[p0.edge_valence-2]; + dest1.ring[4] = dest0.ring[2] = (Vertex_t)p0.vtx; + dest1.ring[3] = dest0.ring[1] = (Vertex_t)p0.ring[2]; + + dest1.crease_weight[1] = dest0.crease_weight[0] = 0.0f; + dest1.crease_weight[0] = dest0.crease_weight[3] = p1.crease_weight[1]; + dest1.crease_weight[3] = dest0.crease_weight[2] = 0.0f; + dest1.crease_weight[2] = dest0.crease_weight[1] = p0.crease_weight[0]; + + if (p0.eval_unique_identifier <= p1.eval_unique_identifier) + { + dest0.eval_start_index = 3; + dest1.eval_start_index = 0; + dest0.eval_unique_identifier = p0.eval_unique_identifier; + dest1.eval_unique_identifier = p0.eval_unique_identifier; + } + else + { + dest0.eval_start_index = 1; + dest1.eval_start_index = 2; + dest0.eval_unique_identifier = p1.eval_unique_identifier; + dest1.eval_unique_identifier = p1.eval_unique_identifier; + } + } + + + static __forceinline void init_border(const CatmullClark1RingT &p0, + const CatmullClark1RingT &p1, + CatmullClark1RingT &dest0, + CatmullClark1RingT &dest1) + { + dest1.vertex_level = dest0.vertex_level = p0.edge_level; + dest1.face_valence = dest0.face_valence = 3; + dest1.edge_valence = dest0.edge_valence = 6; + dest0.border_index = 2; + dest1.border_index = 4; + dest1.vtx = dest0.vtx = (Vertex_t)p0.ring[0]; + dest1.vertex_crease_weight = dest0.vertex_crease_weight = 0.0f; + + dest1.ring[2] = dest0.ring[0] = (Vertex_t)p0.ring[1]; + dest1.ring[1] = dest0.ring[5] = (Vertex_t)p1.ring[0]; + dest1.ring[0] = dest0.ring[4] = (Vertex_t)p1.vtx; + dest1.ring[5] = dest0.ring[3] = (Vertex_t)p0.ring[p0.border_index+1]; // dummy + dest1.ring[4] = dest0.ring[2] = (Vertex_t)p0.vtx; + dest1.ring[3] = dest0.ring[1] = (Vertex_t)p0.ring[2]; + + dest1.crease_weight[1] = dest0.crease_weight[0] = 0.0f; + dest1.crease_weight[0] = dest0.crease_weight[2] = p1.crease_weight[1]; + dest1.crease_weight[2] = dest0.crease_weight[1] = p0.crease_weight[0]; + + if (p0.eval_unique_identifier <= p1.eval_unique_identifier) + { + dest0.eval_start_index = 1; + dest1.eval_start_index = 2; + dest0.eval_unique_identifier = p0.eval_unique_identifier; + dest1.eval_unique_identifier = p0.eval_unique_identifier; + } + else + { + dest0.eval_start_index = 2; + dest1.eval_start_index = 0; + dest0.eval_unique_identifier = p1.eval_unique_identifier; + dest1.eval_unique_identifier = p1.eval_unique_identifier; + } + } + + static __forceinline void init_regular(const Vertex_t ¢er, const array_t& center_ring, const float vertex_level, const unsigned int N, const unsigned int offset, CatmullClark1RingT &dest) + { + assert(N<(MAX_RING_FACE_VALENCE)); + assert(2*N<(MAX_RING_EDGE_VALENCE)); + dest.vertex_level = vertex_level; + dest.face_valence = N; + dest.edge_valence = 2*N; + dest.border_index = -1; + dest.vtx = (Vertex_t)center; + dest.vertex_crease_weight = 0.0f; + for (unsigned i=0; i<2*N; i++) { + dest.ring[i] = (Vertex_t)center_ring[(2*N+offset+i-1)%(2*N)]; + assert(isvalid(dest.ring[i])); + } + for (unsigned i=0; i>1; + if (dest.eval_start_index >= dest.face_valence) dest.eval_start_index -= dest.face_valence; + + assert( dest.eval_start_index < dest.face_valence ); + dest.eval_unique_identifier = 0; + } + + __noinline void subdivide(array_t& patch, unsigned& N_o) const + { + N_o = N; + assert( N ); + for (unsigned i=0; i center_ring; + float center_vertex_level = 2.0f; // guarantees that irregular vertices get always isolated also for non-quads + + for (unsigned i=0; i 2; + if (likely(regular)) init_regular(patch[i].ring[0],patch[ip1].ring[0],patch[i].ring[1],patch[ip1].ring[3]); + else init_border (patch[i].ring[0],patch[ip1].ring[0],patch[i].ring[1],patch[ip1].ring[3]); + + assert( patch[i].ring[1].hasValidPositions() ); + assert( patch[ip1].ring[3].hasValidPositions() ); + + float level = 0.25f*(ring[im1].edge_level+ring[ip1].edge_level); + patch[i].ring[1].edge_level = patch[ip1].ring[2].edge_level = level; + center_vertex_level = max(center_vertex_level,level); + + center += ring[i].vtx; + center_ring[2*i+0] = (Vertex_t)patch[i].ring[0].vtx; + center_ring[2*i+1] = (Vertex_t)patch[i].ring[0].ring[0]; + } + center /= float(N); + + for (unsigned int i=0; i& patches) + { + CatmullClark1Ring patches1ring1 = patches[1].ring[1]; + patches[1].ring[1] = patches[1].ring[0]; // FIXME: optimize these assignments + patches[1].ring[0] = patches[1].ring[3]; + patches[1].ring[3] = patches[1].ring[2]; + patches[1].ring[2] = patches1ring1; + + CatmullClark1Ring patches2ring2 = patches[2].ring[2]; + patches[2].ring[2] = patches[2].ring[0]; + patches[2].ring[0] = patches2ring2; + CatmullClark1Ring patches2ring3 = patches[2].ring[3]; + patches[2].ring[3] = patches[2].ring[1]; + patches[2].ring[1] = patches2ring3; + + CatmullClark1Ring patches3ring3 = patches[3].ring[3]; + patches[3].ring[3] = patches[3].ring[0]; + patches[3].ring[0] = patches[3].ring[1]; + patches[3].ring[1] = patches[3].ring[2]; + patches[3].ring[2] = patches3ring3; + } + + __forceinline void getLimitBorder(BezierCurve curves[GeneralCatmullClarkPatchT::SIZE]) const + { + Vertex P0 = ring[0].getLimitVertex(); + for (unsigned i=0; i GeneralCatmullClarkPatch3fa; +} diff --git a/thirdparty/embree/kernels/subdiv/catmullclark_ring.h b/thirdparty/embree/kernels/subdiv/catmullclark_ring.h new file mode 100644 index 000000000000..73b41fd4ff46 --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/catmullclark_ring.h @@ -0,0 +1,826 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/geometry.h" +#include "../common/buffer.h" +#include "half_edge.h" +#include "catmullclark_coefficients.h" + +namespace embree +{ + struct __aligned(64) FinalQuad { + Vec3fa vtx[4]; + }; + + template + struct __aligned(64) CatmullClark1RingT + { + ALIGNED_STRUCT_(64); + + int border_index; //!< edge index where border starts + unsigned int face_valence; //!< number of adjacent quad faces + unsigned int edge_valence; //!< number of adjacent edges (2*face_valence) + float vertex_crease_weight; //!< weight of vertex crease (0 if no vertex crease) + DynamicStackArray crease_weight; //!< edge crease weights for each adjacent edge + float vertex_level; //!< maximum level of all adjacent edges + float edge_level; //!< level of first edge + unsigned int eval_start_index; //!< topology dependent index to start evaluation + unsigned int eval_unique_identifier; //!< topology dependent unique identifier for this ring + Vertex vtx; //!< center vertex + DynamicStackArray ring; //!< ring of neighboring vertices + + public: + CatmullClark1RingT () + : eval_start_index(0), eval_unique_identifier(0) {} // FIXME: default constructor should be empty + + /*! calculates number of bytes required to serialize this structure */ + __forceinline size_t bytes() const + { + size_t ofs = 0; + ofs += sizeof(border_index); + ofs += sizeof(face_valence); + assert(2*face_valence == edge_valence); + ofs += sizeof(vertex_crease_weight); + ofs += face_valence*sizeof(float); + ofs += sizeof(vertex_level); + ofs += sizeof(edge_level); + ofs += sizeof(eval_start_index); + ofs += sizeof(eval_unique_identifier); + ofs += sizeof(vtx); + ofs += edge_valence*sizeof(Vertex); + return ofs; + } + + template + static __forceinline void store(char* ptr, size_t& ofs, const Ty& v) { + *(Ty*)&ptr[ofs] = v; ofs += sizeof(Ty); + } + + template + static __forceinline void load(char* ptr, size_t& ofs, Ty& v) { + v = *(Ty*)&ptr[ofs]; ofs += sizeof(Ty); + } + + /*! serializes the ring to some memory location */ + __forceinline void serialize(char* ptr, size_t& ofs) const + { + store(ptr,ofs,border_index); + store(ptr,ofs,face_valence); + store(ptr,ofs,vertex_crease_weight); + for (size_t i=0; ii); + return ring[i]; + } + + __forceinline const Vertex& back(size_t i) const { + assert(edge_valence>=i); + return ring[edge_valence-i]; + } + + __forceinline bool has_last_face() const { + return (size_t)border_index != (size_t)edge_valence-2; + } + + __forceinline bool has_opposite_front(size_t i) const { + return (size_t)border_index != 2*i; + } + + __forceinline bool has_opposite_back(size_t i) const { + return (size_t)border_index != ((size_t)edge_valence-2-2*i); + } + + __forceinline BBox3fa bounds() const + { + BBox3fa bounds ( vtx ); + for (size_t i = 0; igetStartVertexIndex()*stride); + vertex_crease_weight = h->vertex_crease_weight; + + HalfEdge* p = (HalfEdge*) h; + + unsigned i=0; + unsigned min_vertex_index = (unsigned)-1; + unsigned min_vertex_index_face = (unsigned)-1; + edge_level = p->edge_level; + vertex_level = 0.0f; + + do + { + vertex_level = max(vertex_level,p->edge_level); + crease_weight[i/2] = p->edge_crease_weight; + assert(p->hasOpposite() || p->edge_crease_weight == float(inf)); + + /* store first two vertices of face */ + p = p->next(); + const unsigned index0 = p->getStartVertexIndex(); + ring[i++] = Vertex_t::loadu(vertices+index0*stride); + if (index0 < min_vertex_index) { min_vertex_index = index0; min_vertex_index_face = i>>1; } + p = p->next(); + + const unsigned index1 = p->getStartVertexIndex(); + ring[i++] = Vertex_t::loadu(vertices+index1*stride); + p = p->next(); + + /* continue with next face */ + if (likely(p->hasOpposite())) + p = p->opposite(); + + /* if there is no opposite go the long way to the other side of the border */ + else + { + /* find minimum start vertex */ + const unsigned index0 = p->getStartVertexIndex(); + if (index0 < min_vertex_index) { min_vertex_index = index0; min_vertex_index_face = i>>1; } + + /*! mark first border edge and store dummy vertex for face between the two border edges */ + border_index = i; + crease_weight[i/2] = inf; + ring[i++] = Vertex_t::loadu(vertices+index0*stride); + ring[i++] = vtx; // dummy vertex + + /*! goto other side of border */ + p = (HalfEdge*) h; + while (p->hasOpposite()) + p = p->opposite()->next(); + } + + } while (p != h); + + edge_valence = i; + face_valence = i >> 1; + eval_unique_identifier = min_vertex_index; + eval_start_index = min_vertex_index_face; + + assert( hasValidPositions() ); + } + + __forceinline void subdivide(CatmullClark1RingT& dest) const + { + dest.edge_level = 0.5f*edge_level; + dest.vertex_level = 0.5f*vertex_level; + dest.face_valence = face_valence; + dest.edge_valence = edge_valence; + dest.border_index = border_index; + dest.vertex_crease_weight = max(0.0f,vertex_crease_weight-1.0f); + dest.eval_start_index = eval_start_index; + dest.eval_unique_identifier = eval_unique_identifier; + + /* calculate face points */ + Vertex_t S = Vertex_t(0.0f); + for (size_t i=0; i= face_valence) face_index -= face_valence; assert(face_index < face_valence); + size_t index0 = 2*face_index+0; if (index0 >= edge_valence) index0 -= edge_valence; assert(index0 < edge_valence); + size_t index1 = 2*face_index+1; if (index1 >= edge_valence) index1 -= edge_valence; assert(index1 < edge_valence); + size_t index2 = 2*face_index+2; if (index2 >= edge_valence) index2 -= edge_valence; assert(index2 < edge_valence); + S += dest.ring[index1] = ((vtx + ring[index1]) + (ring[index0] + ring[index2])) * 0.25f; + } + + /* calculate new edge points */ + size_t num_creases = 0; + array_t crease_id; + + for (size_t i=0; i= face_valence) face_index -= face_valence; + const float edge_crease = crease_weight[face_index]; + dest.crease_weight[face_index] = max(edge_crease-1.0f,0.0f); + + size_t index = 2*face_index; + size_t prev_index = face_index == 0 ? edge_valence-1 : 2*face_index-1; + size_t next_index = 2*face_index+1; + + const Vertex_t v = vtx + ring[index]; + const Vertex_t f = dest.ring[prev_index] + dest.ring[next_index]; + S += ring[index]; + + /* fast path for regular edge points */ + if (likely(edge_crease <= 0.0f)) { + dest.ring[index] = (v+f) * 0.25f; + } + + /* slower path for hard edge rule */ + else { + crease_id[num_creases++] = face_index; + dest.ring[index] = v*0.5f; + + /* even slower path for blended edge rule */ + if (unlikely(edge_crease < 1.0f)) { + dest.ring[index] = lerp((v+f)*0.25f,v*0.5f,edge_crease); + } + } + } + + /* compute new vertex using smooth rule */ + const float inv_face_valence = 1.0f / (float)face_valence; + const Vertex_t v_smooth = (Vertex_t) madd(inv_face_valence,S,(float(face_valence)-2.0f)*vtx)*inv_face_valence; + dest.vtx = v_smooth; + + /* compute new vertex using vertex_crease_weight rule */ + if (unlikely(vertex_crease_weight > 0.0f)) + { + if (vertex_crease_weight >= 1.0f) { + dest.vtx = vtx; + } else { + dest.vtx = lerp(v_smooth,vtx,vertex_crease_weight); + } + return; + } + + /* no edge crease rule and dart rule */ + if (likely(num_creases <= 1)) + return; + + /* compute new vertex using crease rule */ + if (likely(num_creases == 2)) + { + /* update vertex using crease rule */ + const size_t crease0 = crease_id[0], crease1 = crease_id[1]; + const Vertex_t v_sharp = (Vertex_t)(ring[2*crease0] + 6.0f*vtx + ring[2*crease1]) * (1.0f / 8.0f); + dest.vtx = v_sharp; + + /* update crease_weights using chaikin rule */ + const float crease_weight0 = crease_weight[crease0], crease_weight1 = crease_weight[crease1]; + dest.crease_weight[crease0] = max(0.25f*(3.0f*crease_weight0 + crease_weight1)-1.0f,0.0f); + dest.crease_weight[crease1] = max(0.25f*(3.0f*crease_weight1 + crease_weight0)-1.0f,0.0f); + + /* interpolate between sharp and smooth rule */ + const float v_blend = 0.5f*(crease_weight0+crease_weight1); + if (unlikely(v_blend < 1.0f)) { + dest.vtx = lerp(v_smooth,v_sharp,v_blend); + } + } + + /* compute new vertex using corner rule */ + else { + dest.vtx = vtx; + } + } + + __forceinline bool isRegular1() const + { + if (border_index == -1) { + if (face_valence == 4) return true; + } else { + if (face_valence < 4) return true; + } + return false; + } + + __forceinline size_t numEdgeCreases() const + { + ssize_t numCreases = 0; + for (size_t i=0; i 0.0f; + } + return numCreases; + } + + enum Type { + TYPE_NONE = 0, //!< invalid type + TYPE_REGULAR = 1, //!< regular patch when ignoring creases + TYPE_REGULAR_CREASES = 2, //!< regular patch when considering creases + TYPE_GREGORY = 4, //!< gregory patch when ignoring creases + TYPE_GREGORY_CREASES = 8, //!< gregory patch when considering creases + TYPE_CREASES = 16 //!< patch has crease features + }; + + __forceinline Type type() const + { + /* check if there is an edge crease anywhere */ + const size_t numCreases = numEdgeCreases(); + const bool noInnerCreases = hasBorder() ? numCreases == 2 : numCreases == 0; + + Type crease_mask = (Type) (TYPE_REGULAR | TYPE_GREGORY); + if (noInnerCreases ) crease_mask = (Type) (crease_mask | TYPE_REGULAR_CREASES | TYPE_GREGORY_CREASES); + if (numCreases != 0) crease_mask = (Type) (crease_mask | TYPE_CREASES); + + /* calculate if this vertex is regular */ + bool hasBorder = border_index != -1; + if (face_valence == 2 && hasBorder) { + if (vertex_crease_weight == 0.0f ) return (Type) (crease_mask & (TYPE_REGULAR | TYPE_REGULAR_CREASES | TYPE_GREGORY | TYPE_GREGORY_CREASES | TYPE_CREASES)); + else if (vertex_crease_weight == float(inf)) return (Type) (crease_mask & (TYPE_REGULAR | TYPE_REGULAR_CREASES | TYPE_GREGORY | TYPE_GREGORY_CREASES | TYPE_CREASES)); + else return TYPE_CREASES; + } + else if (vertex_crease_weight != 0.0f) return TYPE_CREASES; + else if (face_valence == 3 && hasBorder) return (Type) (crease_mask & (TYPE_REGULAR | TYPE_REGULAR_CREASES | TYPE_GREGORY | TYPE_GREGORY_CREASES | TYPE_CREASES)); + else if (face_valence == 4 && !hasBorder) return (Type) (crease_mask & (TYPE_REGULAR | TYPE_REGULAR_CREASES | TYPE_GREGORY | TYPE_GREGORY_CREASES | TYPE_CREASES)); + else return (Type) (crease_mask & (TYPE_GREGORY | TYPE_GREGORY_CREASES | TYPE_CREASES)); + } + + __forceinline bool isFinalResolution(float res) const { + return vertex_level <= res; + } + + /* computes the limit vertex */ + __forceinline Vertex getLimitVertex() const + { + /* return hard corner */ + if (unlikely(std::isinf(vertex_crease_weight))) + return vtx; + + /* border vertex rule */ + if (unlikely(border_index != -1)) + { + const unsigned int second_border_index = border_index+2 >= int(edge_valence) ? 0 : border_index+2; + return (4.0f * vtx + (ring[border_index] + ring[second_border_index])) * 1.0f/6.0f; + } + + Vertex_t F( 0.0f ); + Vertex_t E( 0.0f ); + + assert(eval_start_index < face_valence); + + for (size_t i=0; i= face_valence) index -= face_valence; + F += ring[2*index+1]; + E += ring[2*index]; + } + + const float n = (float)face_valence; + return (Vertex_t)(n*n*vtx+4.0f*E+F) / ((n+5.0f)*n); + } + + /* gets limit tangent in the direction of egde vtx -> ring[0] */ + __forceinline Vertex getLimitTangent() const + { + if (unlikely(std::isinf(vertex_crease_weight))) + return ring[0] - vtx; + + /* border vertex rule */ + if (unlikely(border_index != -1)) + { + if (border_index != (int)edge_valence-2 ) { + return ring[0] - vtx; + } + else + { + const unsigned int second_border_index = border_index+2 >= int(edge_valence) ? 0 : border_index+2; + return (ring[second_border_index] - ring[border_index]) * 0.5f; + } + } + + Vertex_t alpha( 0.0f ); + Vertex_t beta ( 0.0f ); + + const size_t n = face_valence; + + assert(eval_start_index < face_valence); + + Vertex_t q( 0.0f ); + for (size_t i=0; i= face_valence) index -= face_valence; + const float a = CatmullClarkPrecomputedCoefficients::table.limittangent_a(index,n); + const float b = CatmullClarkPrecomputedCoefficients::table.limittangent_b(index,n); + alpha += a * ring[2*index]; + beta += b * ring[2*index+1]; + } + + const float sigma = CatmullClarkPrecomputedCoefficients::table.limittangent_c(n); + return sigma * (alpha + beta); + } + + /* gets limit tangent in the direction of egde vtx -> ring[edge_valence-2] */ + __forceinline Vertex getSecondLimitTangent() const + { + if (unlikely(std::isinf(vertex_crease_weight))) + return ring[2] - vtx; + + /* border vertex rule */ + if (unlikely(border_index != -1)) + { + if (border_index != 2) { + return ring[2] - vtx; + } + else { + const unsigned int second_border_index = border_index+2 >= int(edge_valence) ? 0 : border_index+2; + return (ring[border_index] - ring[second_border_index]) * 0.5f; + } + } + + Vertex_t alpha( 0.0f ); + Vertex_t beta ( 0.0f ); + + const size_t n = face_valence; + + assert(eval_start_index < face_valence); + + for (size_t i=0; i= face_valence) index -= face_valence; + + size_t prev_index = index == 0 ? face_valence-1 : index-1; // need to be bit-wise exact in cosf eval + const float a = CatmullClarkPrecomputedCoefficients::table.limittangent_a(prev_index,n); + const float b = CatmullClarkPrecomputedCoefficients::table.limittangent_b(prev_index,n); + alpha += a * ring[2*index]; + beta += b * ring[2*index+1]; + } + + const float sigma = CatmullClarkPrecomputedCoefficients::table.limittangent_c(n); + return sigma* (alpha + beta); + } + + /* gets surface normal */ + const Vertex getNormal() const { + return cross(getLimitTangent(),getSecondLimitTangent()); + } + + /* returns center of the n-th quad in the 1-ring */ + __forceinline Vertex getQuadCenter(const size_t index) const + { + const Vertex_t &p0 = vtx; + const Vertex_t &p1 = ring[2*index+0]; + const Vertex_t &p2 = ring[2*index+1]; + const Vertex_t &p3 = index == face_valence-1 ? ring[0] : ring[2*index+2]; + const Vertex p = (p0+p1+p2+p3) * 0.25f; + return p; + } + + /* returns center of the n-th edge in the 1-ring */ + __forceinline Vertex getEdgeCenter(const size_t index) const { + return (vtx + ring[index*2]) * 0.5f; + } + + bool hasValidPositions() const + { + for (size_t i=0; i " << c.ring[i]; + if (i % 2 == 0) o << " crease = " << c.crease_weight[i/2]; + o << embree_endl; + } + return o; + } + }; + + typedef CatmullClark1RingT CatmullClark1Ring3fa; + + template + struct __aligned(64) GeneralCatmullClark1RingT + { + ALIGNED_STRUCT_(64); + + typedef CatmullClark1RingT CatmullClark1Ring; + + struct Face + { + __forceinline Face() {} + __forceinline Face (int size, float crease_weight) + : size(size), crease_weight(crease_weight) {} + + // FIXME: add member that returns total number of vertices + + int size; // number of vertices-2 of nth face in ring + float crease_weight; + }; + + Vertex vtx; + DynamicStackArray ring; + DynamicStackArray faces; + unsigned int face_valence; + unsigned int edge_valence; + int border_face; + float vertex_crease_weight; + float vertex_level; //!< maximum level of adjacent edges + float edge_level; // level of first edge + bool only_quads; // true if all faces are quads + unsigned int eval_start_face_index; + unsigned int eval_start_vertex_index; + unsigned int eval_unique_identifier; + + public: + GeneralCatmullClark1RingT() + : eval_start_face_index(0), eval_start_vertex_index(0), eval_unique_identifier(0) {} + + __forceinline bool isRegular() const + { + if (border_face == -1 && face_valence == 4) return true; + return false; + } + + __forceinline bool has_last_face() const { + return border_face != (int)face_valence-1; + } + + __forceinline bool has_second_face() const { + return (border_face == -1) || (border_face >= 2); + } + + bool hasValidPositions() const + { + for (size_t i=0; igetStartVertexIndex()*stride); + vertex_crease_weight = h->vertex_crease_weight; + HalfEdge* p = (HalfEdge*) h; + + unsigned int e=0, f=0; + unsigned min_vertex_index = (unsigned)-1; + unsigned min_vertex_index_face = (unsigned)-1; + unsigned min_vertex_index_vertex = (unsigned)-1; + edge_level = p->edge_level; + vertex_level = 0.0f; + do + { + HalfEdge* p_prev = p->prev(); + HalfEdge* p_next = p->next(); + const float crease_weight = p->edge_crease_weight; + assert(p->hasOpposite() || p->edge_crease_weight == float(inf)); + vertex_level = max(vertex_level,p->edge_level); + + /* find minimum start vertex */ + unsigned vertex_index = p_next->getStartVertexIndex(); + if (vertex_index < min_vertex_index) { min_vertex_index = vertex_index; min_vertex_index_face = f; min_vertex_index_vertex = e; } + + /* store first N-2 vertices of face */ + unsigned int vn = 0; + for (p = p_next; p!=p_prev; p=p->next()) { + ring[e++] = Vertex_t::loadu(vertices+p->getStartVertexIndex()*stride); + vn++; + } + faces[f++] = Face(vn,crease_weight); + only_quads &= (vn == 2); + + /* continue with next face */ + if (likely(p->hasOpposite())) + p = p->opposite(); + + /* if there is no opposite go the long way to the other side of the border */ + else + { + /* find minimum start vertex */ + unsigned vertex_index = p->getStartVertexIndex(); + if (vertex_index < min_vertex_index) { min_vertex_index = vertex_index; min_vertex_index_face = f; min_vertex_index_vertex = e; } + + /*! mark first border edge and store dummy vertex for face between the two border edges */ + border_face = f; + faces[f++] = Face(2,inf); + ring[e++] = Vertex_t::loadu(vertices+p->getStartVertexIndex()*stride); + ring[e++] = vtx; // dummy vertex + + /*! goto other side of border */ + p = (HalfEdge*) h; + while (p->hasOpposite()) + p = p->opposite()->next(); + } + + } while (p != h); + + edge_valence = e; + face_valence = f; + eval_unique_identifier = min_vertex_index; + eval_start_face_index = min_vertex_index_face; + eval_start_vertex_index = min_vertex_index_vertex; + + assert( hasValidPositions() ); + } + + __forceinline void subdivide(CatmullClark1Ring& dest) const + { + dest.edge_level = 0.5f*edge_level; + dest.vertex_level = 0.5f*vertex_level; + dest.face_valence = face_valence; + dest.edge_valence = 2*face_valence; + dest.border_index = border_face == -1 ? -1 : 2*border_face; // FIXME: + dest.vertex_crease_weight = max(0.0f,vertex_crease_weight-1.0f); + dest.eval_start_index = eval_start_face_index; + dest.eval_unique_identifier = eval_unique_identifier; + assert(dest.face_valence <= MAX_RING_FACE_VALENCE); + + /* calculate face points */ + Vertex_t S = Vertex_t(0.0f); + for (size_t face=0, v=eval_start_vertex_index; face crease_id; + Vertex_t C = Vertex_t(0.0f); + for (size_t face=0, j=eval_start_vertex_index; face 0.0f)) + { + if (vertex_crease_weight >= 1.0f) { + dest.vtx = vtx; + } else { + dest.vtx = lerp(vtx,v_smooth,vertex_crease_weight); + } + return; + } + + if (likely(num_creases <= 1)) + return; + + /* compute new vertex using crease rule */ + if (likely(num_creases == 2)) { + const Vertex_t v_sharp = (Vertex_t)(C + 6.0f * vtx) * (1.0f / 8.0f); + const float crease_weight0 = faces[crease_id[0]].crease_weight; + const float crease_weight1 = faces[crease_id[1]].crease_weight; + dest.vtx = v_sharp; + dest.crease_weight[crease_id[0]] = max(0.25f*(3.0f*crease_weight0 + crease_weight1)-1.0f,0.0f); + dest.crease_weight[crease_id[1]] = max(0.25f*(3.0f*crease_weight1 + crease_weight0)-1.0f,0.0f); + const float v_blend = 0.5f*(crease_weight0+crease_weight1); + if (unlikely(v_blend < 1.0f)) { + dest.vtx = lerp(v_sharp,v_smooth,v_blend); + } + } + + /* compute new vertex using corner rule */ + else { + dest.vtx = vtx; + } + } + + void convert(CatmullClark1Ring& dst) const + { + dst.edge_level = edge_level; + dst.vertex_level = vertex_level; + dst.vtx = vtx; + dst.face_valence = face_valence; + dst.edge_valence = 2*face_valence; + dst.border_index = border_face == -1 ? -1 : 2*border_face; + for (size_t i=0; i ring[0] */ + __forceinline Vertex getLimitTangent() const + { + CatmullClark1Ring cc_vtx; + + /* fast path for quad only rings */ + if (only_quads) + { + convert(cc_vtx); + return cc_vtx.getLimitTangent(); + } + + subdivide(cc_vtx); + return 2.0f * cc_vtx.getLimitTangent(); + } + + /* gets limit tangent in the direction of egde vtx -> ring[edge_valence-2] */ + __forceinline Vertex getSecondLimitTangent() const + { + CatmullClark1Ring cc_vtx; + + /* fast path for quad only rings */ + if (only_quads) + { + convert(cc_vtx); + return cc_vtx.getSecondLimitTangent(); + } + + subdivide(cc_vtx); + return 2.0f * cc_vtx.getSecondLimitTangent(); + } + + + /* gets limit vertex */ + __forceinline Vertex getLimitVertex() const + { + CatmullClark1Ring cc_vtx; + + /* fast path for quad only rings */ + if (only_quads) + convert(cc_vtx); + else + subdivide(cc_vtx); + return cc_vtx.getLimitVertex(); + } + + friend __forceinline embree_ostream operator<<(embree_ostream o, const GeneralCatmullClark1RingT &c) + { + o << "vtx " << c.vtx << " size = " << c.edge_valence << ", border_face = " << c.border_face << ", " << " face_valence = " << c.face_valence << + ", edge_level = " << c.edge_level << ", vertex_level = " << c.vertex_level << ", ring: " << embree_endl; + for (size_t v=0, f=0; f " << c.ring[i]; + if (i == v) o << " crease = " << c.faces[f].crease_weight; + o << embree_endl; + } + } + return o; + } + }; +} diff --git a/thirdparty/embree/kernels/subdiv/catmullrom_curve.h b/thirdparty/embree/kernels/subdiv/catmullrom_curve.h new file mode 100644 index 000000000000..b244af481cb1 --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/catmullrom_curve.h @@ -0,0 +1,296 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/default.h" +#include "../common/scene_curves.h" + +/* + + Implements Catmul Rom curves with control points p0, p1, p2, p3. At + t=0 the curve goes through p1, with tangent (p2-p0)/3, and for t=1 + the curve goes through p2 with tangent (p3-p2)/2. + + */ + +namespace embree +{ + class CatmullRomBasis + { + public: + + template + static __forceinline Vec4 eval(const T& u) + { + const T t = u; + const T s = T(1.0f) - u; + const T n0 = - t * s * s; + const T n1 = 2.0f + t * t * (3.0f * t - 5.0f); + const T n2 = 2.0f + s * s * (3.0f * s - 5.0f); + const T n3 = - s * t * t; + return T(0.5f) * Vec4(n0, n1, n2, n3); + } + + template + static __forceinline Vec4 derivative(const T& u) + { + const T t = u; + const T s = 1.0f - u; + const T n0 = - s * s + 2.0f * s * t; + const T n1 = 2.0f * t * (3.0f * t - 5.0f) + 3.0f * t * t; + const T n2 = 2.0f * s * (3.0f * t + 2.0f) - 3.0f * s * s; + const T n3 = -2.0f * s * t + t * t; + return T(0.5f) * Vec4(n0, n1, n2, n3); + } + + template + static __forceinline Vec4 derivative2(const T& u) + { + const T t = u; + const T n0 = -3.0f * t + 2.0f; + const T n1 = 9.0f * t - 5.0f; + const T n2 = -9.0f * t + 4.0f; + const T n3 = 3.0f * t - 1.0f; + return Vec4(n0, n1, n2, n3); + } + }; + + struct PrecomputedCatmullRomBasis + { + enum { N = 16 }; + public: + PrecomputedCatmullRomBasis() {} + PrecomputedCatmullRomBasis(int shift); + + /* basis for bspline evaluation */ + public: + float c0[N+1][N+1]; + float c1[N+1][N+1]; + float c2[N+1][N+1]; + float c3[N+1][N+1]; + + /* basis for bspline derivative evaluation */ + public: + float d0[N+1][N+1]; + float d1[N+1][N+1]; + float d2[N+1][N+1]; + float d3[N+1][N+1]; + }; + extern PrecomputedCatmullRomBasis catmullrom_basis0; + extern PrecomputedCatmullRomBasis catmullrom_basis1; + + template + struct CatmullRomCurveT + { + Vertex v0,v1,v2,v3; + + __forceinline CatmullRomCurveT() {} + + __forceinline CatmullRomCurveT(const Vertex& v0, const Vertex& v1, const Vertex& v2, const Vertex& v3) + : v0(v0), v1(v1), v2(v2), v3(v3) {} + + __forceinline Vertex begin() const { + return madd(1.0f/6.0f,v0,madd(2.0f/3.0f,v1,1.0f/6.0f*v2)); + } + + __forceinline Vertex end() const { + return madd(1.0f/6.0f,v1,madd(2.0f/3.0f,v2,1.0f/6.0f*v3)); + } + + __forceinline Vertex center() const { + return 0.25f*(v0+v1+v2+v3); + } + + __forceinline BBox bounds() const { + return merge(BBox(v0),BBox(v1),BBox(v2),BBox(v3)); + } + + __forceinline friend CatmullRomCurveT operator -( const CatmullRomCurveT& a, const Vertex& b ) { + return CatmullRomCurveT(a.v0-b,a.v1-b,a.v2-b,a.v3-b); + } + + __forceinline CatmullRomCurveT xfm_pr(const LinearSpace3fa& space, const Vec3fa& p) const + { + const Vec3ff q0(xfmVector(space,v0-p), v0.w); + const Vec3ff q1(xfmVector(space,v1-p), v1.w); + const Vec3ff q2(xfmVector(space,v2-p), v2.w); + const Vec3ff q3(xfmVector(space,v3-p), v3.w); + return CatmullRomCurveT(q0,q1,q2,q3); + } + + __forceinline Vertex eval(const float t) const + { + const Vec4 b = CatmullRomBasis::eval(t); + return madd(b.x,v0,madd(b.y,v1,madd(b.z,v2,b.w*v3))); + } + + __forceinline Vertex eval_du(const float t) const + { + const Vec4 b = CatmullRomBasis::derivative(t); + return madd(b.x,v0,madd(b.y,v1,madd(b.z,v2,b.w*v3))); + } + + __forceinline Vertex eval_dudu(const float t) const + { + const Vec4 b = CatmullRomBasis::derivative2(t); + return madd(b.x,v0,madd(b.y,v1,madd(b.z,v2,b.w*v3))); + } + + __forceinline void eval(const float t, Vertex& p, Vertex& dp, Vertex& ddp) const + { + p = eval(t); + dp = eval_du(t); + ddp = eval_dudu(t); + } + + template + __forceinline Vec4vf veval(const vfloat& t) const + { + const Vec4vf b = CatmullRomBasis::eval(t); + return madd(b.x, Vec4vf(v0), madd(b.y, Vec4vf(v1), madd(b.z, Vec4vf(v2), b.w * Vec4vf(v3)))); + } + + template + __forceinline Vec4vf veval_du(const vfloat& t) const + { + const Vec4vf b = CatmullRomBasis::derivative(t); + return madd(b.x, Vec4vf(v0), madd(b.y, Vec4vf(v1), madd(b.z, Vec4vf(v2), b.w * Vec4vf(v3)))); + } + + template + __forceinline Vec4vf veval_dudu(const vfloat& t) const + { + const Vec4vf b = CatmullRomBasis::derivative2(t); + return madd(b.x, Vec4vf(v0), madd(b.y, Vec4vf(v1), madd(b.z, Vec4vf(v2), b.w * Vec4vf(v3)))); + } + + template + __forceinline void veval(const vfloat& t, Vec4vf& p, Vec4vf& dp) const + { + p = veval(t); + dp = veval_du(t); + } + + template + __forceinline Vec4vf eval0(const int ofs, const int size) const + { + assert(size <= PrecomputedCatmullRomBasis::N); + assert(ofs <= size); + return madd(vfloat::loadu(&catmullrom_basis0.c0[size][ofs]), Vec4vf(v0), + madd(vfloat::loadu(&catmullrom_basis0.c1[size][ofs]), Vec4vf(v1), + madd(vfloat::loadu(&catmullrom_basis0.c2[size][ofs]), Vec4vf(v2), + vfloat::loadu(&catmullrom_basis0.c3[size][ofs]) * Vec4vf(v3)))); + } + + template + __forceinline Vec4vf eval1(const int ofs, const int size) const + { + assert(size <= PrecomputedCatmullRomBasis::N); + assert(ofs <= size); + return madd(vfloat::loadu(&catmullrom_basis1.c0[size][ofs]), Vec4vf(v0), + madd(vfloat::loadu(&catmullrom_basis1.c1[size][ofs]), Vec4vf(v1), + madd(vfloat::loadu(&catmullrom_basis1.c2[size][ofs]), Vec4vf(v2), + vfloat::loadu(&catmullrom_basis1.c3[size][ofs]) * Vec4vf(v3)))); + } + + template + __forceinline Vec4vf derivative0(const int ofs, const int size) const + { + assert(size <= PrecomputedCatmullRomBasis::N); + assert(ofs <= size); + return madd(vfloat::loadu(&catmullrom_basis0.d0[size][ofs]), Vec4vf(v0), + madd(vfloat::loadu(&catmullrom_basis0.d1[size][ofs]), Vec4vf(v1), + madd(vfloat::loadu(&catmullrom_basis0.d2[size][ofs]), Vec4vf(v2), + vfloat::loadu(&catmullrom_basis0.d3[size][ofs]) * Vec4vf(v3)))); + } + + template + __forceinline Vec4vf derivative1(const int ofs, const int size) const + { + assert(size <= PrecomputedCatmullRomBasis::N); + assert(ofs <= size); + return madd(vfloat::loadu(&catmullrom_basis1.d0[size][ofs]), Vec4vf(v0), + madd(vfloat::loadu(&catmullrom_basis1.d1[size][ofs]), Vec4vf(v1), + madd(vfloat::loadu(&catmullrom_basis1.d2[size][ofs]), Vec4vf(v2), + vfloat::loadu(&catmullrom_basis1.d3[size][ofs]) * Vec4vf(v3)))); + } + + /* calculates bounds of catmull-rom curve geometry */ + __forceinline BBox3fa accurateRoundBounds() const + { + const int N = 7; + const float scale = 1.0f/(3.0f*(N-1)); + Vec4vfx pl(pos_inf), pu(neg_inf); + for (int i=0; i<=N; i+=VSIZEX) + { + vintx vi = vintx(i)+vintx(step); + vboolx valid = vi <= vintx(N); + const Vec4vfx p = eval0(i,N); + const Vec4vfx dp = derivative0(i,N); + const Vec4vfx pm = p-Vec4vfx(scale)*select(vi!=vintx(0),dp,Vec4vfx(zero)); + const Vec4vfx pp = p+Vec4vfx(scale)*select(vi!=vintx(N),dp,Vec4vfx(zero)); + pl = select(valid,min(pl,p,pm,pp),pl); // FIXME: use masked min + pu = select(valid,max(pu,p,pm,pp),pu); // FIXME: use masked min + } + const Vec3fa lower(reduce_min(pl.x),reduce_min(pl.y),reduce_min(pl.z)); + const Vec3fa upper(reduce_max(pu.x),reduce_max(pu.y),reduce_max(pu.z)); + const float r_min = reduce_min(pl.w); + const float r_max = reduce_max(pu.w); + const Vec3fa upper_r = Vec3fa(max(abs(r_min),abs(r_max))); + return enlarge(BBox3fa(lower,upper),upper_r); + } + + /* calculates bounds when tessellated into N line segments */ + __forceinline BBox3fa accurateFlatBounds(int N) const + { + if (likely(N == 4)) + { + const Vec4vf4 pi = eval0<4>(0,4); + const Vec3fa lower(reduce_min(pi.x),reduce_min(pi.y),reduce_min(pi.z)); + const Vec3fa upper(reduce_max(pi.x),reduce_max(pi.y),reduce_max(pi.z)); + const Vec3fa upper_r = Vec3fa(reduce_max(abs(pi.w))); + const Vec3ff pe = end(); + return enlarge(BBox3fa(min(lower,pe),max(upper,pe)),max(upper_r,Vec3fa(abs(pe.w)))); + } + else + { + Vec3vfx pl(pos_inf), pu(neg_inf); vfloatx ru(0.0f); + for (int i=0; i<=N; i+=VSIZEX) + { + vboolx valid = vintx(i)+vintx(step) <= vintx(N); + const Vec4vfx pi = eval0(i,N); + + pl.x = select(valid,min(pl.x,pi.x),pl.x); // FIXME: use masked min + pl.y = select(valid,min(pl.y,pi.y),pl.y); + pl.z = select(valid,min(pl.z,pi.z),pl.z); + + pu.x = select(valid,max(pu.x,pi.x),pu.x); // FIXME: use masked min + pu.y = select(valid,max(pu.y,pi.y),pu.y); + pu.z = select(valid,max(pu.z,pi.z),pu.z); + + ru = select(valid,max(ru,abs(pi.w)),ru); + } + const Vec3fa lower(reduce_min(pl.x),reduce_min(pl.y),reduce_min(pl.z)); + const Vec3fa upper(reduce_max(pu.x),reduce_max(pu.y),reduce_max(pu.z)); + const Vec3fa upper_r(reduce_max(ru)); + return enlarge(BBox3fa(lower,upper),upper_r); + } + } + + friend __forceinline embree_ostream operator<<(embree_ostream cout, const CatmullRomCurveT& curve) { + return cout << "CatmullRomCurve { v0 = " << curve.v0 << ", v1 = " << curve.v1 << ", v2 = " << curve.v2 << ", v3 = " << curve.v3 << " }"; + } + }; + + __forceinline CatmullRomCurveT enlargeRadiusToMinWidth(const IntersectContext* context, const CurveGeometry* geom, const Vec3fa& ray_org, const CatmullRomCurveT& curve) + { + return CatmullRomCurveT(enlargeRadiusToMinWidth(context,geom,ray_org,curve.v0), + enlargeRadiusToMinWidth(context,geom,ray_org,curve.v1), + enlargeRadiusToMinWidth(context,geom,ray_org,curve.v2), + enlargeRadiusToMinWidth(context,geom,ray_org,curve.v3)); + } + + typedef CatmullRomCurveT CatmullRomCurve3fa; +} + diff --git a/thirdparty/embree/kernels/subdiv/feature_adaptive_eval.h b/thirdparty/embree/kernels/subdiv/feature_adaptive_eval.h new file mode 100644 index 000000000000..23f24c360c8d --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/feature_adaptive_eval.h @@ -0,0 +1,226 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "patch.h" + +namespace embree +{ + namespace isa + { + template + struct FeatureAdaptiveEval + { + public: + + typedef PatchT Patch; + typedef typename Patch::Ref Ref; + typedef GeneralCatmullClarkPatchT GeneralCatmullClarkPatch; + typedef CatmullClark1RingT CatmullClarkRing; + typedef CatmullClarkPatchT CatmullClarkPatch; + typedef BSplinePatchT BSplinePatch; + typedef BezierPatchT BezierPatch; + typedef GregoryPatchT GregoryPatch; + typedef BilinearPatchT BilinearPatch; + typedef BezierCurveT BezierCurve; + + public: + + FeatureAdaptiveEval (const HalfEdge* edge, const char* vertices, size_t stride, const float u, const float v, + Vertex* P, Vertex* dPdu, Vertex* dPdv, Vertex* ddPdudu, Vertex* ddPdvdv, Vertex* ddPdudv) + : P(P), dPdu(dPdu), dPdv(dPdv), ddPdudu(ddPdudu), ddPdvdv(ddPdvdv), ddPdudv(ddPdudv) + { + switch (edge->patch_type) { + case HalfEdge::BILINEAR_PATCH: BilinearPatch(edge,vertices,stride).eval(u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,1.0f); break; + case HalfEdge::REGULAR_QUAD_PATCH: RegularPatchT(edge,vertices,stride).eval(u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,1.0f); break; +#if PATCH_USE_GREGORY == 2 + case HalfEdge::IRREGULAR_QUAD_PATCH: GregoryPatch(edge,vertices,stride).eval(u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,1.0f); break; +#endif + default: { + GeneralCatmullClarkPatch patch(edge,vertices,stride); + eval(patch,Vec2f(u,v),0); + break; + } + } + } + + FeatureAdaptiveEval (CatmullClarkPatch& patch, const float u, const float v, float dscale, size_t depth, + Vertex* P, Vertex* dPdu, Vertex* dPdv, Vertex* ddPdudu, Vertex* ddPdvdv, Vertex* ddPdudv) + : P(P), dPdu(dPdu), dPdv(dPdv), ddPdudu(ddPdudu), ddPdvdv(ddPdvdv), ddPdudv(ddPdudv) + { + eval(patch,Vec2f(u,v),dscale,depth); + } + + void eval_general_quad(const GeneralCatmullClarkPatch& patch, array_t& patches, const Vec2f& uv, size_t depth) + { + float u = uv.x, v = uv.y; + if (v < 0.5f) { + if (u < 0.5f) { +#if PATCH_USE_GREGORY == 2 + BezierCurve borders[2]; patch.getLimitBorder(borders,0); + BezierCurve border0l,border0r; borders[0].subdivide(border0l,border0r); + BezierCurve border2l,border2r; borders[1].subdivide(border2l,border2r); + eval(patches[0],Vec2f(2.0f*u,2.0f*v),2.0f,depth+1, &border0l, nullptr, nullptr, &border2r); +#else + eval(patches[0],Vec2f(2.0f*u,2.0f*v),2.0f,depth+1); +#endif + if (dPdu && dPdv) { + const Vertex dpdx = *dPdu, dpdy = *dPdv; + *dPdu = dpdx; *dPdv = dpdy; + } + } + else { +#if PATCH_USE_GREGORY == 2 + BezierCurve borders[2]; patch.getLimitBorder(borders,1); + BezierCurve border0l,border0r; borders[0].subdivide(border0l,border0r); + BezierCurve border2l,border2r; borders[1].subdivide(border2l,border2r); + eval(patches[1],Vec2f(2.0f*v,2.0f-2.0f*u),2.0f,depth+1, &border0l, nullptr, nullptr, &border2r); +#else + eval(patches[1],Vec2f(2.0f*v,2.0f-2.0f*u),2.0f,depth+1); +#endif + if (dPdu && dPdv) { + const Vertex dpdx = *dPdu, dpdy = *dPdv; + *dPdu = -dpdy; *dPdv = dpdx; + } + } + } else { + if (u > 0.5f) { +#if PATCH_USE_GREGORY == 2 + BezierCurve borders[2]; patch.getLimitBorder(borders,2); + BezierCurve border0l,border0r; borders[0].subdivide(border0l,border0r); + BezierCurve border2l,border2r; borders[1].subdivide(border2l,border2r); + eval(patches[2],Vec2f(2.0f-2.0f*u,2.0f-2.0f*v),2.0f,depth+1, &border0l, nullptr, nullptr, &border2r); +#else + eval(patches[2],Vec2f(2.0f-2.0f*u,2.0f-2.0f*v),2.0f,depth+1); +#endif + if (dPdu && dPdv) { + const Vertex dpdx = *dPdu, dpdy = *dPdv; + *dPdu = -dpdx; *dPdv = -dpdy; + } + } + else { +#if PATCH_USE_GREGORY == 2 + BezierCurve borders[2]; patch.getLimitBorder(borders,3); + BezierCurve border0l,border0r; borders[0].subdivide(border0l,border0r); + BezierCurve border2l,border2r; borders[1].subdivide(border2l,border2r); + eval(patches[3],Vec2f(2.0f-2.0f*v,2.0f*u),2.0f,depth+1, &border0l, nullptr, nullptr, &border2r); +#else + eval(patches[3],Vec2f(2.0f-2.0f*v,2.0f*u),2.0f,depth+1); +#endif + if (dPdu && dPdv) { + const Vertex dpdx = *dPdu, dpdy = *dPdv; + *dPdu = dpdy; *dPdv = -dpdx; + } + } + } + } + + __forceinline bool final(const CatmullClarkPatch& patch, const typename CatmullClarkRing::Type type, size_t depth) + { + const int max_eval_depth = (type & CatmullClarkRing::TYPE_CREASES) ? PATCH_MAX_EVAL_DEPTH_CREASE : PATCH_MAX_EVAL_DEPTH_IRREGULAR; +//#if PATCH_MIN_RESOLUTION +// return patch.isFinalResolution(PATCH_MIN_RESOLUTION) || depth>=(size_t)max_eval_depth; +//#else + return depth>=(size_t)max_eval_depth; +//#endif + } + + void eval(CatmullClarkPatch& patch, Vec2f uv, float dscale, size_t depth, + BezierCurve* border0 = nullptr, BezierCurve* border1 = nullptr, BezierCurve* border2 = nullptr, BezierCurve* border3 = nullptr) + { + while (true) + { + typename CatmullClarkPatch::Type ty = patch.type(); + + if (unlikely(final(patch,ty,depth))) + { + if (ty & CatmullClarkRing::TYPE_REGULAR) { + RegularPatch(patch,border0,border1,border2,border3).eval(uv.x,uv.y,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale); + PATCH_DEBUG_SUBDIVISION(234423,c,c,-1); + return; + } else { + IrregularFillPatch(patch,border0,border1,border2,border3).eval(uv.x,uv.y,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale); + PATCH_DEBUG_SUBDIVISION(34534,c,-1,c); + return; + } + } + else if (ty & CatmullClarkRing::TYPE_REGULAR_CREASES) { + assert(depth > 0); + RegularPatch(patch,border0,border1,border2,border3).eval(uv.x,uv.y,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale); + PATCH_DEBUG_SUBDIVISION(43524,c,c,-1); + return; + } +#if PATCH_USE_GREGORY == 2 + else if (ty & CatmullClarkRing::TYPE_GREGORY_CREASES) { + assert(depth > 0); + GregoryPatch(patch,border0,border1,border2,border3).eval(uv.x,uv.y,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale); + PATCH_DEBUG_SUBDIVISION(23498,c,-1,c); + return; + } +#endif + else + { + array_t patches; + patch.subdivide(patches); // FIXME: only have to generate one of the patches + + const float u = uv.x, v = uv.y; + if (v < 0.5f) { + if (u < 0.5f) { patch = patches[0]; uv = Vec2f(2.0f*u,2.0f*v); dscale *= 2.0f; } + else { patch = patches[1]; uv = Vec2f(2.0f*u-1.0f,2.0f*v); dscale *= 2.0f; } + } else { + if (u > 0.5f) { patch = patches[2]; uv = Vec2f(2.0f*u-1.0f,2.0f*v-1.0f); dscale *= 2.0f; } + else { patch = patches[3]; uv = Vec2f(2.0f*u,2.0f*v-1.0f); dscale *= 2.0f; } + } + depth++; + } + } + } + + void eval(const GeneralCatmullClarkPatch& patch, const Vec2f& uv, const size_t depth) + { + /* convert into standard quad patch if possible */ + if (likely(patch.isQuadPatch())) + { + CatmullClarkPatch qpatch; patch.init(qpatch); + return eval(qpatch,uv,1.0f,depth); + } + + /* subdivide patch */ + unsigned N; + array_t patches; + patch.subdivide(patches,N); // FIXME: only have to generate one of the patches + + /* parametrization for quads */ + if (N == 4) + eval_general_quad(patch,patches,uv,depth); + + /* parametrization for arbitrary polygons */ + else + { + const unsigned l = (unsigned) floor(0.5f*uv.x); const float u = 2.0f*frac(0.5f*uv.x)-0.5f; + const unsigned h = (unsigned) floor(0.5f*uv.y); const float v = 2.0f*frac(0.5f*uv.y)-0.5f; + const unsigned i = 4*h+l; assert(i= N) return; + +#if PATCH_USE_GREGORY == 2 + BezierCurve borders[2]; patch.getLimitBorder(borders,i); + BezierCurve border0l,border0r; borders[0].subdivide(border0l,border0r); + BezierCurve border2l,border2r; borders[1].subdivide(border2l,border2r); + eval(patches[i],Vec2f(u,v),1.0f,depth+1, &border0l, nullptr, nullptr, &border2r); +#else + eval(patches[i],Vec2f(u,v),1.0f,depth+1); +#endif + } + } + + private: + Vertex* const P; + Vertex* const dPdu; + Vertex* const dPdv; + Vertex* const ddPdudu; + Vertex* const ddPdvdv; + Vertex* const ddPdudv; + }; + } +} diff --git a/thirdparty/embree/kernels/subdiv/feature_adaptive_eval_grid.h b/thirdparty/embree/kernels/subdiv/feature_adaptive_eval_grid.h new file mode 100644 index 000000000000..76583b2e5dce --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/feature_adaptive_eval_grid.h @@ -0,0 +1,359 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "patch.h" +#include "catmullclark_patch.h" +#include "bspline_patch.h" +#include "gregory_patch.h" +#include "tessellation.h" + +namespace embree +{ + namespace isa + { + struct FeatureAdaptiveEvalGrid + { + typedef CatmullClark1Ring3fa CatmullClarkRing; + typedef CatmullClarkPatch3fa CatmullClarkPatch; + typedef BilinearPatch3fa BilinearPatch; + typedef BSplinePatch3fa BSplinePatch; + typedef BezierPatch3fa BezierPatch; + typedef GregoryPatch3fa GregoryPatch; + + private: + const unsigned x0,x1; + const unsigned y0,y1; + const unsigned swidth,sheight; + const float rcp_swidth, rcp_sheight; + float* const Px; + float* const Py; + float* const Pz; + float* const U; + float* const V; + float* const Nx; + float* const Ny; + float* const Nz; + const unsigned dwidth; + //const unsigned dheight; + unsigned count; + + + public: + FeatureAdaptiveEvalGrid (const GeneralCatmullClarkPatch3fa& patch, unsigned subPatch, + const unsigned x0, const unsigned x1, const unsigned y0, const unsigned y1, const unsigned swidth, const unsigned sheight, + float* Px, float* Py, float* Pz, float* U, float* V, + float* Nx, float* Ny, float* Nz, + const unsigned dwidth, const unsigned dheight) + : x0(x0), x1(x1), y0(y0), y1(y1), swidth(swidth), sheight(sheight), rcp_swidth(1.0f/(swidth-1.0f)), rcp_sheight(1.0f/(sheight-1.0f)), + Px(Px), Py(Py), Pz(Pz), U(U), V(V), Nx(Nx), Ny(Ny), Nz(Nz), dwidth(dwidth), /*dheight(dheight),*/ count(0) + { + assert(swidth < (2<<20) && sheight < (2<<20)); + const BBox2f srange(Vec2f(0.0f,0.0f),Vec2f(float(swidth-1),float(sheight-1))); + const BBox2f erange(Vec2f((float)x0,(float)y0),Vec2f((float)x1,(float)y1)); + + /* convert into standard quad patch if possible */ + if (likely(patch.isQuadPatch())) + { + CatmullClarkPatch3fa qpatch; patch.init(qpatch); + eval(qpatch, srange, erange, 0); + assert(count == (x1-x0+1)*(y1-y0+1)); + return; + } + + /* subdivide patch */ + unsigned N; + array_t patches; + patch.subdivide(patches,N); + + if (N == 4) + { + const Vec2f c = srange.center(); + const BBox2f srange0(srange.lower,c); + const BBox2f srange1(Vec2f(c.x,srange.lower.y),Vec2f(srange.upper.x,c.y)); + const BBox2f srange2(c,srange.upper); + const BBox2f srange3(Vec2f(srange.lower.x,c.y),Vec2f(c.x,srange.upper.y)); + +#if PATCH_USE_GREGORY == 2 + BezierCurve3fa borders[GeneralCatmullClarkPatch3fa::SIZE]; patch.getLimitBorder(borders); + BezierCurve3fa border0l,border0r; borders[0].subdivide(border0l,border0r); + BezierCurve3fa border1l,border1r; borders[1].subdivide(border1l,border1r); + BezierCurve3fa border2l,border2r; borders[2].subdivide(border2l,border2r); + BezierCurve3fa border3l,border3r; borders[3].subdivide(border3l,border3r); + GeneralCatmullClarkPatch3fa::fix_quad_ring_order(patches); + eval(patches[0],srange0,intersect(srange0,erange),1,&border0l,nullptr,nullptr,&border3r); + eval(patches[1],srange1,intersect(srange1,erange),1,&border0r,&border1l,nullptr,nullptr); + eval(patches[2],srange2,intersect(srange2,erange),1,nullptr,&border1r,&border2l,nullptr); + eval(patches[3],srange3,intersect(srange3,erange),1,nullptr,nullptr,&border2r,&border3l); +#else + GeneralCatmullClarkPatch3fa::fix_quad_ring_order(patches); + eval(patches[0],srange0,intersect(srange0,erange),1); + eval(patches[1],srange1,intersect(srange1,erange),1); + eval(patches[2],srange2,intersect(srange2,erange),1); + eval(patches[3],srange3,intersect(srange3,erange),1); +#endif + } + else + { + assert(subPatch < N); + +#if PATCH_USE_GREGORY == 2 + BezierCurve3fa borders[2]; patch.getLimitBorder(borders,subPatch); + BezierCurve3fa border0l,border0r; borders[0].subdivide(border0l,border0r); + BezierCurve3fa border2l,border2r; borders[1].subdivide(border2l,border2r); + eval(patches[subPatch], srange, erange, 1, &border0l, nullptr, nullptr, &border2r); +#else + eval(patches[subPatch], srange, erange, 1); +#endif + + } + assert(count == (x1-x0+1)*(y1-y0+1)); + } + + FeatureAdaptiveEvalGrid (const CatmullClarkPatch3fa& patch, + const BBox2f& srange, const BBox2f& erange, const unsigned depth, + const unsigned x0, const unsigned x1, const unsigned y0, const unsigned y1, const unsigned swidth, const unsigned sheight, + float* Px, float* Py, float* Pz, float* U, float* V, + float* Nx, float* Ny, float* Nz, + const unsigned dwidth, const unsigned dheight) + : x0(x0), x1(x1), y0(y0), y1(y1), swidth(swidth), sheight(sheight), rcp_swidth(1.0f/(swidth-1.0f)), rcp_sheight(1.0f/(sheight-1.0f)), + Px(Px), Py(Py), Pz(Pz), U(U), V(V), Nx(Nx), Ny(Ny), Nz(Nz), dwidth(dwidth), /*dheight(dheight),*/ count(0) + { + eval(patch,srange,erange,depth); + } + + template + void evalLocalGrid(const Patch& patch, const BBox2f& srange, const int lx0, const int lx1, const int ly0, const int ly1) + { + const float scale_x = rcp(srange.upper.x-srange.lower.x); + const float scale_y = rcp(srange.upper.y-srange.lower.y); + count += (lx1-lx0)*(ly1-ly0); + +#if 0 + for (unsigned iy=ly0; iy=max_eval_depth; +//#else + return depth>=max_eval_depth; +//#endif + } + + void eval(const CatmullClarkPatch3fa& patch, const BBox2f& srange, const BBox2f& erange, const unsigned depth, + const BezierCurve3fa* border0 = nullptr, const BezierCurve3fa* border1 = nullptr, const BezierCurve3fa* border2 = nullptr, const BezierCurve3fa* border3 = nullptr) + { + if (erange.empty()) + return; + + int lx0 = (int) ceilf(erange.lower.x); + int lx1 = (int) ceilf(erange.upper.x) + (erange.upper.x == x1 && (srange.lower.x < erange.upper.x || erange.upper.x == 0)); + int ly0 = (int) ceilf(erange.lower.y); + int ly1 = (int) ceilf(erange.upper.y) + (erange.upper.y == y1 && (srange.lower.y < erange.upper.y || erange.upper.y == 0)); + if (lx0 >= lx1 || ly0 >= ly1) return; + + CatmullClarkPatch::Type ty = patch.type(); + + if (unlikely(final(patch,ty,depth))) + { + if (ty & CatmullClarkRing::TYPE_REGULAR) { + RegularPatch rpatch(patch,border0,border1,border2,border3); + evalLocalGrid(rpatch,srange,lx0,lx1,ly0,ly1); + return; + } else { + IrregularFillPatch ipatch(patch,border0,border1,border2,border3); + evalLocalGrid(ipatch,srange,lx0,lx1,ly0,ly1); + return; + } + } + else if (ty & CatmullClarkRing::TYPE_REGULAR_CREASES) { + assert(depth > 0); + RegularPatch rpatch(patch,border0,border1,border2,border3); + evalLocalGrid(rpatch,srange,lx0,lx1,ly0,ly1); + return; + } +#if PATCH_USE_GREGORY == 2 + else if (ty & CatmullClarkRing::TYPE_GREGORY_CREASES) { + assert(depth > 0); + GregoryPatch gpatch(patch,border0,border1,border2,border3); + evalLocalGrid(gpatch,srange,lx0,lx1,ly0,ly1); + } +#endif + else + { + array_t patches; + patch.subdivide(patches); + + const Vec2f c = srange.center(); + const BBox2f srange0(srange.lower,c); + const BBox2f srange1(Vec2f(c.x,srange.lower.y),Vec2f(srange.upper.x,c.y)); + const BBox2f srange2(c,srange.upper); + const BBox2f srange3(Vec2f(srange.lower.x,c.y),Vec2f(c.x,srange.upper.y)); + + eval(patches[0],srange0,intersect(srange0,erange),depth+1); + eval(patches[1],srange1,intersect(srange1,erange),depth+1); + eval(patches[2],srange2,intersect(srange2,erange),depth+1); + eval(patches[3],srange3,intersect(srange3,erange),depth+1); + } + } + }; + + template + bool stitch_col(const Patch& patch, int subPatch, + const bool right, const unsigned y0, const unsigned y1, const int fine_y, const int coarse_y, + float* Px, float* Py, float* Pz, float* U, float* V, float* Nx, float* Ny, float* Nz, const unsigned dx0, const unsigned dwidth, const unsigned dheight) + { + assert(coarse_y <= fine_y); + if (likely(fine_y == coarse_y)) + return false; + + const unsigned y0s = stitch(y0,fine_y,coarse_y); + const unsigned y1s = stitch(y1,fine_y,coarse_y); + const unsigned M = y1s-y0s+1 + VSIZEX; + + dynamic_large_stack_array(float,px,M,64*sizeof(float)); + dynamic_large_stack_array(float,py,M,64*sizeof(float)); + dynamic_large_stack_array(float,pz,M,64*sizeof(float)); + dynamic_large_stack_array(float,u,M,64*sizeof(float)); + dynamic_large_stack_array(float,v,M,64*sizeof(float)); + dynamic_large_stack_array(float,nx,M,64*sizeof(float)); + dynamic_large_stack_array(float,ny,M,64*sizeof(float)); + dynamic_large_stack_array(float,nz,M,64*sizeof(float)); + const bool has_Nxyz = Nx; assert(!Nx || (Ny && Nz)); + Eval(patch,subPatch, right,right, y0s,y1s, 2,coarse_y+1, px,py,pz,u,v, + has_Nxyz ? (float*)nx : nullptr,has_Nxyz ? (float*)ny : nullptr ,has_Nxyz ? (float*)nz : nullptr, 1,4097); + + for (unsigned y=y0; y<=y1; y++) + { + const unsigned ys = stitch(y,fine_y,coarse_y)-y0s; + Px[(y-y0)*dwidth+dx0] = px[ys]; + Py[(y-y0)*dwidth+dx0] = py[ys]; + Pz[(y-y0)*dwidth+dx0] = pz[ys]; + U [(y-y0)*dwidth+dx0] = u[ys]; + V [(y-y0)*dwidth+dx0] = v[ys]; + if (unlikely(has_Nxyz)) { + Nx[(y-y0)*dwidth+dx0] = nx[ys]; + Ny[(y-y0)*dwidth+dx0] = ny[ys]; + Nz[(y-y0)*dwidth+dx0] = nz[ys]; + } + } + return true; + } + + template + bool stitch_row(const Patch& patch, int subPatch, + const bool bottom, const unsigned x0, const unsigned x1, const int fine_x, const int coarse_x, + float* Px, float* Py, float* Pz, float* U, float* V, float* Nx, float* Ny, float* Nz, const unsigned dy0, const unsigned dwidth, const unsigned dheight) + { + assert(coarse_x <= fine_x); + if (likely(fine_x == coarse_x)) + return false; + + const unsigned x0s = stitch(x0,fine_x,coarse_x); + const unsigned x1s = stitch(x1,fine_x,coarse_x); + const unsigned M = x1s-x0s+1 + VSIZEX; + + dynamic_large_stack_array(float,px,M,32*sizeof(float)); + dynamic_large_stack_array(float,py,M,32*sizeof(float)); + dynamic_large_stack_array(float,pz,M,32*sizeof(float)); + dynamic_large_stack_array(float,u,M,32*sizeof(float)); + dynamic_large_stack_array(float,v,M,32*sizeof(float)); + dynamic_large_stack_array(float,nx,M,32*sizeof(float)); + dynamic_large_stack_array(float,ny,M,32*sizeof(float)); + dynamic_large_stack_array(float,nz,M,32*sizeof(float)); + const bool has_Nxyz = Nx; assert(!Nx || (Ny && Nz)); + Eval(patch,subPatch, x0s,x1s, bottom,bottom, coarse_x+1,2, px,py,pz,u,v, + has_Nxyz ? (float*)nx :nullptr, has_Nxyz ? (float*)ny : nullptr , has_Nxyz ? (float*)nz : nullptr, 4097,1); + + for (unsigned x=x0; x<=x1; x++) + { + const unsigned xs = stitch(x,fine_x,coarse_x)-x0s; + Px[dy0*dwidth+x-x0] = px[xs]; + Py[dy0*dwidth+x-x0] = py[xs]; + Pz[dy0*dwidth+x-x0] = pz[xs]; + U [dy0*dwidth+x-x0] = u[xs]; + V [dy0*dwidth+x-x0] = v[xs]; + if (unlikely(has_Nxyz)) { + Nx[dy0*dwidth+x-x0] = nx[xs]; + Ny[dy0*dwidth+x-x0] = ny[xs]; + Nz[dy0*dwidth+x-x0] = nz[xs]; + } + } + return true; + } + + template + void feature_adaptive_eval_grid (const Patch& patch, unsigned subPatch, const float levels[4], + const unsigned x0, const unsigned x1, const unsigned y0, const unsigned y1, const unsigned swidth, const unsigned sheight, + float* Px, float* Py, float* Pz, float* U, float* V, float* Nx, float* Ny, float* Nz, const unsigned dwidth, const unsigned dheight) + { + bool sl = false, sr = false, st = false, sb = false; + if (levels) { + sl = x0 == 0 && stitch_col(patch,subPatch,0,y0,y1,sheight-1,int(levels[3]), Px,Py,Pz,U,V,Nx,Ny,Nz, 0 ,dwidth,dheight); + sr = x1 == swidth-1 && stitch_col(patch,subPatch,1,y0,y1,sheight-1,int(levels[1]), Px,Py,Pz,U,V,Nx,Ny,Nz, x1-x0,dwidth,dheight); + st = y0 == 0 && stitch_row(patch,subPatch,0,x0,x1,swidth-1,int(levels[0]), Px,Py,Pz,U,V,Nx,Ny,Nz, 0 ,dwidth,dheight); + sb = y1 == sheight-1 && stitch_row(patch,subPatch,1,x0,x1,swidth-1,int(levels[2]), Px,Py,Pz,U,V,Nx,Ny,Nz, y1-y0,dwidth,dheight); + } + const unsigned ofs = st*dwidth+sl; + Eval(patch,subPatch,x0+sl,x1-sr,y0+st,y1-sb, swidth,sheight, Px+ofs,Py+ofs,Pz+ofs,U+ofs,V+ofs,Nx?Nx+ofs:nullptr,Ny?Ny+ofs:nullptr,Nz?Nz+ofs:nullptr, dwidth,dheight); + } + } +} + diff --git a/thirdparty/embree/kernels/subdiv/feature_adaptive_eval_simd.h b/thirdparty/embree/kernels/subdiv/feature_adaptive_eval_simd.h new file mode 100644 index 000000000000..fa3216730f82 --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/feature_adaptive_eval_simd.h @@ -0,0 +1,186 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "patch.h" + +namespace embree +{ + namespace isa + { + template + struct FeatureAdaptiveEvalSimd + { + public: + + typedef PatchT Patch; + typedef typename Patch::Ref Ref; + typedef GeneralCatmullClarkPatchT GeneralCatmullClarkPatch; + typedef CatmullClark1RingT CatmullClarkRing; + typedef CatmullClarkPatchT CatmullClarkPatch; + typedef BSplinePatchT BSplinePatch; + typedef BezierPatchT BezierPatch; + typedef GregoryPatchT GregoryPatch; + typedef BilinearPatchT BilinearPatch; + typedef BezierCurveT BezierCurve; + + FeatureAdaptiveEvalSimd (const HalfEdge* edge, const char* vertices, size_t stride, const vbool& valid, const vfloat& u, const vfloat& v, + float* P, float* dPdu, float* dPdv, float* ddPdudu, float* ddPdvdv, float* ddPdudv, const size_t dstride, const size_t N) + : P(P), dPdu(dPdu), dPdv(dPdv), ddPdudu(ddPdudu), ddPdvdv(ddPdvdv), ddPdudv(ddPdudv), dstride(dstride), N(N) + { + switch (edge->patch_type) { + case HalfEdge::BILINEAR_PATCH: BilinearPatch(edge,vertices,stride).eval(valid,u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,1.0f,dstride,N); break; + case HalfEdge::REGULAR_QUAD_PATCH: RegularPatchT(edge,vertices,stride).eval(valid,u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,1.0f,dstride,N); break; +#if PATCH_USE_GREGORY == 2 + case HalfEdge::IRREGULAR_QUAD_PATCH: GregoryPatchT(edge,vertices,stride).eval(valid,u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,1.0f,dstride,N); break; +#endif + default: { + GeneralCatmullClarkPatch patch(edge,vertices,stride); + eval_direct(valid,patch,Vec2(u,v),0); + break; + } + } + } + + FeatureAdaptiveEvalSimd (const CatmullClarkPatch& patch, const vbool& valid, const vfloat& u, const vfloat& v, float dscale, size_t depth, + float* P, float* dPdu, float* dPdv, float* ddPdudu, float* ddPdvdv, float* ddPdudv, const size_t dstride, const size_t N) + : P(P), dPdu(dPdu), dPdv(dPdv), ddPdudu(ddPdudu), ddPdvdv(ddPdvdv), ddPdudv(ddPdudv), dstride(dstride), N(N) + { + eval_direct(valid,patch,Vec2(u,v),dscale,depth); + } + + template + __forceinline void eval_quad_direct(const vbool& valid, array_t& patches, const Vec2& uv, float dscale, size_t depth) + { + const vfloat u = uv.x, v = uv.y; + const vbool u0_mask = u < 0.5f, u1_mask = u >= 0.5f; + const vbool v0_mask = v < 0.5f, v1_mask = v >= 0.5f; + const vbool u0v0_mask = valid & u0_mask & v0_mask; + const vbool u0v1_mask = valid & u0_mask & v1_mask; + const vbool u1v0_mask = valid & u1_mask & v0_mask; + const vbool u1v1_mask = valid & u1_mask & v1_mask; + if (any(u0v0_mask)) eval_direct(u0v0_mask,patches[0],Vec2(2.0f*u,2.0f*v),2.0f*dscale,depth+1); + if (any(u1v0_mask)) eval_direct(u1v0_mask,patches[1],Vec2(2.0f*u-1.0f,2.0f*v),2.0f*dscale,depth+1); + if (any(u1v1_mask)) eval_direct(u1v1_mask,patches[2],Vec2(2.0f*u-1.0f,2.0f*v-1.0f),2.0f*dscale,depth+1); + if (any(u0v1_mask)) eval_direct(u0v1_mask,patches[3],Vec2(2.0f*u,2.0f*v-1.0f),2.0f*dscale,depth+1); + } + + template + __forceinline void eval_general_quad_direct(const vbool& valid, const GeneralCatmullClarkPatch& patch, array_t& patches, const Vec2& uv, float dscale, size_t depth) + { +#if PATCH_USE_GREGORY == 2 + BezierCurve borders[GeneralCatmullClarkPatch::SIZE]; patch.getLimitBorder(borders); + BezierCurve border0l,border0r; borders[0].subdivide(border0l,border0r); + BezierCurve border1l,border1r; borders[1].subdivide(border1l,border1r); + BezierCurve border2l,border2r; borders[2].subdivide(border2l,border2r); + BezierCurve border3l,border3r; borders[3].subdivide(border3l,border3r); +#endif + GeneralCatmullClarkPatch::fix_quad_ring_order(patches); + const vfloat u = uv.x, v = uv.y; + const vbool u0_mask = u < 0.5f, u1_mask = u >= 0.5f; + const vbool v0_mask = v < 0.5f, v1_mask = v >= 0.5f; + const vbool u0v0_mask = valid & u0_mask & v0_mask; + const vbool u0v1_mask = valid & u0_mask & v1_mask; + const vbool u1v0_mask = valid & u1_mask & v0_mask; + const vbool u1v1_mask = valid & u1_mask & v1_mask; +#if PATCH_USE_GREGORY == 2 + if (any(u0v0_mask)) eval_direct(u0v0_mask,patches[0],Vec2(2.0f*u,2.0f*v),2.0f*dscale,depth+1,&border0l,nullptr,nullptr,&border3r); + if (any(u1v0_mask)) eval_direct(u1v0_mask,patches[1],Vec2(2.0f*u-1.0f,2.0f*v),2.0f*dscale,depth+1,&border0r,&border1l,nullptr,nullptr); + if (any(u1v1_mask)) eval_direct(u1v1_mask,patches[2],Vec2(2.0f*u-1.0f,2.0f*v-1.0f),2.0f*dscale,depth+1,nullptr,&border1r,&border2l,nullptr); + if (any(u0v1_mask)) eval_direct(u0v1_mask,patches[3],Vec2(2.0f*u,2.0f*v-1.0f),2.0f*dscale,depth+1,nullptr,nullptr,&border2r,&border3l); +#else + if (any(u0v0_mask)) eval_direct(u0v0_mask,patches[0],Vec2(2.0f*u,2.0f*v),2.0f*dscale,depth+1); + if (any(u1v0_mask)) eval_direct(u1v0_mask,patches[1],Vec2(2.0f*u-1.0f,2.0f*v),2.0f*dscale,depth+1); + if (any(u1v1_mask)) eval_direct(u1v1_mask,patches[2],Vec2(2.0f*u-1.0f,2.0f*v-1.0f),2.0f*dscale,depth+1); + if (any(u0v1_mask)) eval_direct(u0v1_mask,patches[3],Vec2(2.0f*u,2.0f*v-1.0f),2.0f*dscale,depth+1); +#endif + } + + __forceinline bool final(const CatmullClarkPatch& patch, const typename CatmullClarkRing::Type type, size_t depth) + { + const size_t max_eval_depth = (type & CatmullClarkRing::TYPE_CREASES) ? PATCH_MAX_EVAL_DEPTH_CREASE : PATCH_MAX_EVAL_DEPTH_IRREGULAR; +//#if PATCH_MIN_RESOLUTION +// return patch.isFinalResolution(PATCH_MIN_RESOLUTION) || depth>=max_eval_depth; +//#else + return depth>=max_eval_depth; +//#endif + } + + void eval_direct(const vbool& valid, const CatmullClarkPatch& patch, const Vec2& uv, float dscale, size_t depth, + BezierCurve* border0 = nullptr, BezierCurve* border1 = nullptr, BezierCurve* border2 = nullptr, BezierCurve* border3 = nullptr) + { + typename CatmullClarkPatch::Type ty = patch.type(); + + if (unlikely(final(patch,ty,depth))) + { + if (ty & CatmullClarkRing::TYPE_REGULAR) { + RegularPatch(patch,border0,border1,border2,border3).eval(valid,uv.x,uv.y,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale,dstride,N); + } else { + IrregularFillPatch(patch,border0,border1,border2,border3).eval(valid,uv.x,uv.y,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale,dstride,N); + } + } + else if (ty & CatmullClarkRing::TYPE_REGULAR_CREASES) { + assert(depth > 0); RegularPatch(patch,border0,border1,border2,border3).eval(valid,uv.x,uv.y,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale,dstride,N); + } +#if PATCH_USE_GREGORY == 2 + else if (ty & CatmullClarkRing::TYPE_GREGORY_CREASES) { + assert(depth > 0); GregoryPatch(patch,border0,border1,border2,border3).eval(valid,uv.x,uv.y,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale,dstride,N); + } +#endif + else + { + array_t patches; + patch.subdivide(patches); // FIXME: only have to generate one of the patches + eval_quad_direct(valid,patches,uv,dscale,depth); + } + } + + void eval_direct(const vbool& valid, const GeneralCatmullClarkPatch& patch, const Vec2& uv, const size_t depth) + { + /* convert into standard quad patch if possible */ + if (likely(patch.isQuadPatch())) { + CatmullClarkPatch qpatch; patch.init(qpatch); + return eval_direct(valid,qpatch,uv,1.0f,depth); + } + + /* subdivide patch */ + unsigned Nc; + array_t patches; + patch.subdivide(patches,Nc); // FIXME: only have to generate one of the patches + + /* parametrization for quads */ + if (Nc == 4) + eval_general_quad_direct(valid,patch,patches,uv,1.0f,depth); + + /* parametrization for arbitrary polygons */ + else + { + const vint l = (vint)floor(0.5f*uv.x); const vfloat u = 2.0f*frac(0.5f*uv.x)-0.5f; + const vint h = (vint)floor(0.5f*uv.y); const vfloat v = 2.0f*frac(0.5f*uv.y)-0.5f; + const vint i = (h<<2)+l; assert(all(valid,i(u,v),1.0f,depth+1, &border0l, nullptr, nullptr, &border2r); +#else + eval_direct(valid,patches[i],Vec2(u,v),1.0f,depth+1); +#endif + }); + } + } + + private: + float* const P; + float* const dPdu; + float* const dPdv; + float* const ddPdudu; + float* const ddPdvdv; + float* const ddPdudv; + const size_t dstride; + const size_t N; + }; + } +} diff --git a/thirdparty/embree/kernels/subdiv/gregory_patch.h b/thirdparty/embree/kernels/subdiv/gregory_patch.h new file mode 100644 index 000000000000..2a7c4b1f2c23 --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/gregory_patch.h @@ -0,0 +1,893 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "catmullclark_patch.h" +#include "bezier_patch.h" +#include "bezier_curve.h" +#include "catmullclark_coefficients.h" + +namespace embree +{ + template + class __aligned(64) GregoryPatchT + { + typedef CatmullClarkPatchT CatmullClarkPatch; + typedef GeneralCatmullClarkPatchT GeneralCatmullClarkPatch; + typedef CatmullClark1RingT CatmullClark1Ring; + typedef BezierCurveT BezierCurve; + + public: + Vertex v[4][4]; + Vertex f[2][2]; + + __forceinline GregoryPatchT() {} + + __forceinline GregoryPatchT(const CatmullClarkPatch& patch) { + init(patch); + } + + __forceinline GregoryPatchT(const CatmullClarkPatch& patch, + const BezierCurve* border0, const BezierCurve* border1, const BezierCurve* border2, const BezierCurve* border3) + { + init_crackfix(patch,border0,border1,border2,border3); + } + + __forceinline GregoryPatchT (const HalfEdge* edge, const char* vertices, size_t stride) { + init(CatmullClarkPatch(edge,vertices,stride)); + } + + __forceinline Vertex& p0() { return v[0][0]; } + __forceinline Vertex& p1() { return v[0][3]; } + __forceinline Vertex& p2() { return v[3][3]; } + __forceinline Vertex& p3() { return v[3][0]; } + + __forceinline Vertex& e0_p() { return v[0][1]; } + __forceinline Vertex& e0_m() { return v[1][0]; } + __forceinline Vertex& e1_p() { return v[1][3]; } + __forceinline Vertex& e1_m() { return v[0][2]; } + __forceinline Vertex& e2_p() { return v[3][2]; } + __forceinline Vertex& e2_m() { return v[2][3]; } + __forceinline Vertex& e3_p() { return v[2][0]; } + __forceinline Vertex& e3_m() { return v[3][1]; } + + __forceinline Vertex& f0_p() { return v[1][1]; } + __forceinline Vertex& f1_p() { return v[1][2]; } + __forceinline Vertex& f2_p() { return v[2][2]; } + __forceinline Vertex& f3_p() { return v[2][1]; } + __forceinline Vertex& f0_m() { return f[0][0]; } + __forceinline Vertex& f1_m() { return f[0][1]; } + __forceinline Vertex& f2_m() { return f[1][1]; } + __forceinline Vertex& f3_m() { return f[1][0]; } + + __forceinline const Vertex& p0() const { return v[0][0]; } + __forceinline const Vertex& p1() const { return v[0][3]; } + __forceinline const Vertex& p2() const { return v[3][3]; } + __forceinline const Vertex& p3() const { return v[3][0]; } + + __forceinline const Vertex& e0_p() const { return v[0][1]; } + __forceinline const Vertex& e0_m() const { return v[1][0]; } + __forceinline const Vertex& e1_p() const { return v[1][3]; } + __forceinline const Vertex& e1_m() const { return v[0][2]; } + __forceinline const Vertex& e2_p() const { return v[3][2]; } + __forceinline const Vertex& e2_m() const { return v[2][3]; } + __forceinline const Vertex& e3_p() const { return v[2][0]; } + __forceinline const Vertex& e3_m() const { return v[3][1]; } + + __forceinline const Vertex& f0_p() const { return v[1][1]; } + __forceinline const Vertex& f1_p() const { return v[1][2]; } + __forceinline const Vertex& f2_p() const { return v[2][2]; } + __forceinline const Vertex& f3_p() const { return v[2][1]; } + __forceinline const Vertex& f0_m() const { return f[0][0]; } + __forceinline const Vertex& f1_m() const { return f[0][1]; } + __forceinline const Vertex& f2_m() const { return f[1][1]; } + __forceinline const Vertex& f3_m() const { return f[1][0]; } + + __forceinline Vertex initCornerVertex(const CatmullClarkPatch& irreg_patch, const size_t index) { + return irreg_patch.ring[index].getLimitVertex(); + } + + __forceinline Vertex initPositiveEdgeVertex(const CatmullClarkPatch& irreg_patch, const size_t index, const Vertex& p_vtx) { + return madd(1.0f/3.0f,irreg_patch.ring[index].getLimitTangent(),p_vtx); + } + + __forceinline Vertex initNegativeEdgeVertex(const CatmullClarkPatch& irreg_patch, const size_t index, const Vertex& p_vtx) { + return madd(1.0f/3.0f,irreg_patch.ring[index].getSecondLimitTangent(),p_vtx); + } + + __forceinline Vertex initPositiveEdgeVertex2(const CatmullClarkPatch& irreg_patch, const size_t index, const Vertex& p_vtx) + { + CatmullClark1Ring3fa r0,r1,r2; + irreg_patch.ring[index].subdivide(r0); + r0.subdivide(r1); + r1.subdivide(r2); + return madd(8.0f/3.0f,r2.getLimitTangent(),p_vtx); + } + + __forceinline Vertex initNegativeEdgeVertex2(const CatmullClarkPatch& irreg_patch, const size_t index, const Vertex& p_vtx) + { + CatmullClark1Ring3fa r0,r1,r2; + irreg_patch.ring[index].subdivide(r0); + r0.subdivide(r1); + r1.subdivide(r2); + return madd(8.0f/3.0f,r2.getSecondLimitTangent(),p_vtx); + } + + void initFaceVertex(const CatmullClarkPatch& irreg_patch, + const size_t index, + const Vertex& p_vtx, + const Vertex& e0_p_vtx, + const Vertex& e1_m_vtx, + const unsigned int face_valence_p1, + const Vertex& e0_m_vtx, + const Vertex& e3_p_vtx, + const unsigned int face_valence_p3, + Vertex& f_p_vtx, + Vertex& f_m_vtx) + { + const unsigned int face_valence = irreg_patch.ring[index].face_valence; + const unsigned int edge_valence = irreg_patch.ring[index].edge_valence; + const unsigned int border_index = irreg_patch.ring[index].border_index; + + const Vertex& vtx = irreg_patch.ring[index].vtx; + const Vertex e_i = irreg_patch.ring[index].getEdgeCenter(0); + const Vertex c_i_m_1 = irreg_patch.ring[index].getQuadCenter(0); + const Vertex e_i_m_1 = irreg_patch.ring[index].getEdgeCenter(1); + + Vertex c_i, e_i_p_1; + const bool hasHardEdge0 = + std::isinf(irreg_patch.ring[index].vertex_crease_weight) && + std::isinf(irreg_patch.ring[index].crease_weight[0]); + + if (unlikely((border_index == edge_valence-2) || hasHardEdge0)) + { + /* mirror quad center and edge mid-point */ + c_i = madd(2.0f, e_i - c_i_m_1, c_i_m_1); + e_i_p_1 = madd(2.0f, vtx - e_i_m_1, e_i_m_1); + } + else + { + c_i = irreg_patch.ring[index].getQuadCenter( face_valence-1 ); + e_i_p_1 = irreg_patch.ring[index].getEdgeCenter( face_valence-1 ); + } + + Vertex c_i_m_2, e_i_m_2; + const bool hasHardEdge1 = + std::isinf(irreg_patch.ring[index].vertex_crease_weight) && + std::isinf(irreg_patch.ring[index].crease_weight[1]); + + if (unlikely(border_index == 2 || hasHardEdge1)) + { + /* mirror quad center and edge mid-point */ + c_i_m_2 = madd(2.0f, e_i_m_1 - c_i_m_1, c_i_m_1); + e_i_m_2 = madd(2.0f, vtx - e_i, + e_i); + } + else + { + c_i_m_2 = irreg_patch.ring[index].getQuadCenter( 1 ); + e_i_m_2 = irreg_patch.ring[index].getEdgeCenter( 2 ); + } + + const float d = 3.0f; + //const float c = cosf(2.0f*M_PI/(float)face_valence); + //const float c_e_p = cosf(2.0f*M_PI/(float)face_valence_p1); + //const float c_e_m = cosf(2.0f*M_PI/(float)face_valence_p3); + + const float c = CatmullClarkPrecomputedCoefficients::table.cos_2PI_div_n(face_valence); + const float c_e_p = CatmullClarkPrecomputedCoefficients::table.cos_2PI_div_n(face_valence_p1); + const float c_e_m = CatmullClarkPrecomputedCoefficients::table.cos_2PI_div_n(face_valence_p3); + + const Vertex r_e_p = 1.0f/3.0f * (e_i_m_1 - e_i_p_1) + 2.0f/3.0f * (c_i_m_1 - c_i); + const Vertex r_e_m = 1.0f/3.0f * (e_i - e_i_m_2) + 2.0f/3.0f * (c_i_m_1 - c_i_m_2); + + f_p_vtx = 1.0f / d * (c_e_p * p_vtx + (d - 2.0f*c - c_e_p) * e0_p_vtx + 2.0f*c* e1_m_vtx + r_e_p); + f_m_vtx = 1.0f / d * (c_e_m * p_vtx + (d - 2.0f*c - c_e_m) * e0_m_vtx + 2.0f*c* e3_p_vtx + r_e_m); + } + + __noinline void init(const CatmullClarkPatch& patch) + { + assert( patch.ring[0].hasValidPositions() ); + assert( patch.ring[1].hasValidPositions() ); + assert( patch.ring[2].hasValidPositions() ); + assert( patch.ring[3].hasValidPositions() ); + + p0() = initCornerVertex(patch,0); + p1() = initCornerVertex(patch,1); + p2() = initCornerVertex(patch,2); + p3() = initCornerVertex(patch,3); + + e0_p() = initPositiveEdgeVertex(patch,0, p0()); + e1_p() = initPositiveEdgeVertex(patch,1, p1()); + e2_p() = initPositiveEdgeVertex(patch,2, p2()); + e3_p() = initPositiveEdgeVertex(patch,3, p3()); + + e0_m() = initNegativeEdgeVertex(patch,0, p0()); + e1_m() = initNegativeEdgeVertex(patch,1, p1()); + e2_m() = initNegativeEdgeVertex(patch,2, p2()); + e3_m() = initNegativeEdgeVertex(patch,3, p3()); + + const unsigned int face_valence_p0 = patch.ring[0].face_valence; + const unsigned int face_valence_p1 = patch.ring[1].face_valence; + const unsigned int face_valence_p2 = patch.ring[2].face_valence; + const unsigned int face_valence_p3 = patch.ring[3].face_valence; + + initFaceVertex(patch,0,p0(),e0_p(),e1_m(),face_valence_p1,e0_m(),e3_p(),face_valence_p3,f0_p(),f0_m() ); + initFaceVertex(patch,1,p1(),e1_p(),e2_m(),face_valence_p2,e1_m(),e0_p(),face_valence_p0,f1_p(),f1_m() ); + initFaceVertex(patch,2,p2(),e2_p(),e3_m(),face_valence_p3,e2_m(),e1_p(),face_valence_p1,f2_p(),f2_m() ); + initFaceVertex(patch,3,p3(),e3_p(),e0_m(),face_valence_p0,e3_m(),e2_p(),face_valence_p3,f3_p(),f3_m() ); + + } + + __noinline void init_crackfix(const CatmullClarkPatch& patch, + const BezierCurve* border0, + const BezierCurve* border1, + const BezierCurve* border2, + const BezierCurve* border3) + { + assert( patch.ring[0].hasValidPositions() ); + assert( patch.ring[1].hasValidPositions() ); + assert( patch.ring[2].hasValidPositions() ); + assert( patch.ring[3].hasValidPositions() ); + + p0() = initCornerVertex(patch,0); + p1() = initCornerVertex(patch,1); + p2() = initCornerVertex(patch,2); + p3() = initCornerVertex(patch,3); + + e0_p() = initPositiveEdgeVertex(patch,0, p0()); + e1_p() = initPositiveEdgeVertex(patch,1, p1()); + e2_p() = initPositiveEdgeVertex(patch,2, p2()); + e3_p() = initPositiveEdgeVertex(patch,3, p3()); + + e0_m() = initNegativeEdgeVertex(patch,0, p0()); + e1_m() = initNegativeEdgeVertex(patch,1, p1()); + e2_m() = initNegativeEdgeVertex(patch,2, p2()); + e3_m() = initNegativeEdgeVertex(patch,3, p3()); + + if (unlikely(border0 != nullptr)) + { + p0() = border0->v0; + e0_p() = border0->v1; + e1_m() = border0->v2; + p1() = border0->v3; + } + + if (unlikely(border1 != nullptr)) + { + p1() = border1->v0; + e1_p() = border1->v1; + e2_m() = border1->v2; + p2() = border1->v3; + } + + if (unlikely(border2 != nullptr)) + { + p2() = border2->v0; + e2_p() = border2->v1; + e3_m() = border2->v2; + p3() = border2->v3; + } + + if (unlikely(border3 != nullptr)) + { + p3() = border3->v0; + e3_p() = border3->v1; + e0_m() = border3->v2; + p0() = border3->v3; + } + + const unsigned int face_valence_p0 = patch.ring[0].face_valence; + const unsigned int face_valence_p1 = patch.ring[1].face_valence; + const unsigned int face_valence_p2 = patch.ring[2].face_valence; + const unsigned int face_valence_p3 = patch.ring[3].face_valence; + + initFaceVertex(patch,0,p0(),e0_p(),e1_m(),face_valence_p1,e0_m(),e3_p(),face_valence_p3,f0_p(),f0_m() ); + initFaceVertex(patch,1,p1(),e1_p(),e2_m(),face_valence_p2,e1_m(),e0_p(),face_valence_p0,f1_p(),f1_m() ); + initFaceVertex(patch,2,p2(),e2_p(),e3_m(),face_valence_p3,e2_m(),e1_p(),face_valence_p1,f2_p(),f2_m() ); + initFaceVertex(patch,3,p3(),e3_p(),e0_m(),face_valence_p0,e3_m(),e2_p(),face_valence_p3,f3_p(),f3_m() ); + } + + + void computeGregoryPatchFacePoints(const unsigned int face_valence, + const Vertex& r_e_p, + const Vertex& r_e_m, + const Vertex& p_vtx, + const Vertex& e0_p_vtx, + const Vertex& e1_m_vtx, + const unsigned int face_valence_p1, + const Vertex& e0_m_vtx, + const Vertex& e3_p_vtx, + const unsigned int face_valence_p3, + Vertex& f_p_vtx, + Vertex& f_m_vtx, + const float d = 3.0f) + { + //const float c = cosf(2.0*M_PI/(float)face_valence); + //const float c_e_p = cosf(2.0*M_PI/(float)face_valence_p1); + //const float c_e_m = cosf(2.0*M_PI/(float)face_valence_p3); + + const float c = CatmullClarkPrecomputedCoefficients::table.cos_2PI_div_n(face_valence); + const float c_e_p = CatmullClarkPrecomputedCoefficients::table.cos_2PI_div_n(face_valence_p1); + const float c_e_m = CatmullClarkPrecomputedCoefficients::table.cos_2PI_div_n(face_valence_p3); + + + f_p_vtx = 1.0f / d * (c_e_p * p_vtx + (d - 2.0f*c - c_e_p) * e0_p_vtx + 2.0f*c* e1_m_vtx + r_e_p); + f_m_vtx = 1.0f / d * (c_e_m * p_vtx + (d - 2.0f*c - c_e_m) * e0_m_vtx + 2.0f*c* e3_p_vtx + r_e_m); + f_p_vtx = 1.0f / d * (c_e_p * p_vtx + (d - 2.0f*c - c_e_p) * e0_p_vtx + 2.0f*c* e1_m_vtx + r_e_p); + f_m_vtx = 1.0f / d * (c_e_m * p_vtx + (d - 2.0f*c - c_e_m) * e0_m_vtx + 2.0f*c* e3_p_vtx + r_e_m); + } + + __noinline void init(const GeneralCatmullClarkPatch& patch) + { + assert(patch.size() == 4); +#if 0 + CatmullClarkPatch qpatch; patch.init(qpatch); + init(qpatch); +#else + const float face_valence_p0 = patch.ring[0].face_valence; + const float face_valence_p1 = patch.ring[1].face_valence; + const float face_valence_p2 = patch.ring[2].face_valence; + const float face_valence_p3 = patch.ring[3].face_valence; + + Vertex p0_r_p, p0_r_m; + patch.ring[0].computeGregoryPatchEdgePoints( p0(), e0_p(), e0_m(), p0_r_p, p0_r_m ); + + Vertex p1_r_p, p1_r_m; + patch.ring[1].computeGregoryPatchEdgePoints( p1(), e1_p(), e1_m(), p1_r_p, p1_r_m ); + + Vertex p2_r_p, p2_r_m; + patch.ring[2].computeGregoryPatchEdgePoints( p2(), e2_p(), e2_m(), p2_r_p, p2_r_m ); + + Vertex p3_r_p, p3_r_m; + patch.ring[3].computeGregoryPatchEdgePoints( p3(), e3_p(), e3_m(), p3_r_p, p3_r_m ); + + computeGregoryPatchFacePoints(face_valence_p0, p0_r_p, p0_r_m, p0(), e0_p(), e1_m(), face_valence_p1, e0_m(), e3_p(), face_valence_p3, f0_p(), f0_m() ); + computeGregoryPatchFacePoints(face_valence_p1, p1_r_p, p1_r_m, p1(), e1_p(), e2_m(), face_valence_p2, e1_m(), e0_p(), face_valence_p0, f1_p(), f1_m() ); + computeGregoryPatchFacePoints(face_valence_p2, p2_r_p, p2_r_m, p2(), e2_p(), e3_m(), face_valence_p3, e2_m(), e1_p(), face_valence_p1, f2_p(), f2_m() ); + computeGregoryPatchFacePoints(face_valence_p3, p3_r_p, p3_r_m, p3(), e3_p(), e0_m(), face_valence_p0, e3_m(), e2_p(), face_valence_p3, f3_p(), f3_m() ); + +#endif + } + + + __forceinline void convert_to_bezier() + { + f0_p() = (f0_p() + f0_m()) * 0.5f; + f1_p() = (f1_p() + f1_m()) * 0.5f; + f2_p() = (f2_p() + f2_m()) * 0.5f; + f3_p() = (f3_p() + f3_m()) * 0.5f; + f0_m() = Vertex( zero ); + f1_m() = Vertex( zero ); + f2_m() = Vertex( zero ); + f3_m() = Vertex( zero ); + } + + static __forceinline void computeInnerVertices(const Vertex matrix[4][4], const Vertex f_m[2][2], const float uu, const float vv, + Vertex_t& matrix_11, Vertex_t& matrix_12, Vertex_t& matrix_22, Vertex_t& matrix_21) + { + if (unlikely(uu == 0.0f || uu == 1.0f || vv == 0.0f || vv == 1.0f)) + { + matrix_11 = matrix[1][1]; + matrix_12 = matrix[1][2]; + matrix_22 = matrix[2][2]; + matrix_21 = matrix[2][1]; + } + else + { + const Vertex_t f0_p = matrix[1][1]; + const Vertex_t f1_p = matrix[1][2]; + const Vertex_t f2_p = matrix[2][2]; + const Vertex_t f3_p = matrix[2][1]; + + const Vertex_t f0_m = f_m[0][0]; + const Vertex_t f1_m = f_m[0][1]; + const Vertex_t f2_m = f_m[1][1]; + const Vertex_t f3_m = f_m[1][0]; + + matrix_11 = ( uu * f0_p + vv * f0_m)*rcp(uu+vv); + matrix_12 = ((1.0f-uu) * f1_m + vv * f1_p)*rcp(1.0f-uu+vv); + matrix_22 = ((1.0f-uu) * f2_p + (1.0f-vv) * f2_m)*rcp(2.0f-uu-vv); + matrix_21 = ( uu * f3_m + (1.0f-vv) * f3_p)*rcp(1.0f+uu-vv); + } + } + + template + static __forceinline void computeInnerVertices(const Vertex v[4][4], const Vertex f[2][2], + size_t i, const vfloat& uu, const vfloat& vv, vfloat& matrix_11, vfloat& matrix_12, vfloat& matrix_22, vfloat& matrix_21) + { + const auto m_border = (uu == 0.0f) | (uu == 1.0f) | (vv == 0.0f) | (vv == 1.0f); + + const vfloat f0_p = v[1][1][i]; + const vfloat f1_p = v[1][2][i]; + const vfloat f2_p = v[2][2][i]; + const vfloat f3_p = v[2][1][i]; + + const vfloat f0_m = f[0][0][i]; + const vfloat f1_m = f[0][1][i]; + const vfloat f2_m = f[1][1][i]; + const vfloat f3_m = f[1][0][i]; + + const vfloat one_minus_uu = vfloat(1.0f) - uu; + const vfloat one_minus_vv = vfloat(1.0f) - vv; + + const vfloat f0_i = ( uu * f0_p + vv * f0_m) * rcp(uu+vv); + const vfloat f1_i = (one_minus_uu * f1_m + vv * f1_p) * rcp(one_minus_uu+vv); + const vfloat f2_i = (one_minus_uu * f2_p + one_minus_vv * f2_m) * rcp(one_minus_uu+one_minus_vv); + const vfloat f3_i = ( uu * f3_m + one_minus_vv * f3_p) * rcp(uu+one_minus_vv); + + matrix_11 = select(m_border,f0_p,f0_i); + matrix_12 = select(m_border,f1_p,f1_i); + matrix_22 = select(m_border,f2_p,f2_i); + matrix_21 = select(m_border,f3_p,f3_i); + } + + static __forceinline Vertex eval(const Vertex matrix[4][4], const Vertex f[2][2], const float& uu, const float& vv) + { + Vertex_t v_11, v_12, v_22, v_21; + computeInnerVertices(matrix,f,uu,vv,v_11, v_12, v_22, v_21); + + const Vec4 Bu = BezierBasis::eval(uu); + const Vec4 Bv = BezierBasis::eval(vv); + + return madd(Bv.x,madd(Bu.x,matrix[0][0],madd(Bu.y,matrix[0][1],madd(Bu.z,matrix[0][2],Bu.w * matrix[0][3]))), + madd(Bv.y,madd(Bu.x,matrix[1][0],madd(Bu.y,v_11 ,madd(Bu.z,v_12 ,Bu.w * matrix[1][3]))), + madd(Bv.z,madd(Bu.x,matrix[2][0],madd(Bu.y,v_21 ,madd(Bu.z,v_22 ,Bu.w * matrix[2][3]))), + Bv.w*madd(Bu.x,matrix[3][0],madd(Bu.y,matrix[3][1],madd(Bu.z,matrix[3][2],Bu.w * matrix[3][3])))))); + } + + static __forceinline Vertex eval_du(const Vertex matrix[4][4], const Vertex f[2][2], const float uu, const float vv) // approximative derivative + { + Vertex_t v_11, v_12, v_22, v_21; + computeInnerVertices(matrix,f,uu,vv,v_11, v_12, v_22, v_21); + + const Vec4 Bu = BezierBasis::derivative(uu); + const Vec4 Bv = BezierBasis::eval(vv); + + return madd(Bv.x,madd(Bu.x,matrix[0][0],madd(Bu.y,matrix[0][1],madd(Bu.z,matrix[0][2],Bu.w * matrix[0][3]))), + madd(Bv.y,madd(Bu.x,matrix[1][0],madd(Bu.y,v_11 ,madd(Bu.z,v_12 ,Bu.w * matrix[1][3]))), + madd(Bv.z,madd(Bu.x,matrix[2][0],madd(Bu.y,v_21 ,madd(Bu.z,v_22 ,Bu.w * matrix[2][3]))), + Bv.w*madd(Bu.x,matrix[3][0],madd(Bu.y,matrix[3][1],madd(Bu.z,matrix[3][2],Bu.w * matrix[3][3])))))); + } + + static __forceinline Vertex eval_dv(const Vertex matrix[4][4], const Vertex f[2][2], const float uu, const float vv) // approximative derivative + { + Vertex_t v_11, v_12, v_22, v_21; + computeInnerVertices(matrix,f,uu,vv,v_11, v_12, v_22, v_21); + + const Vec4 Bu = BezierBasis::eval(uu); + const Vec4 Bv = BezierBasis::derivative(vv); + + return madd(Bv.x,madd(Bu.x,matrix[0][0],madd(Bu.y,matrix[0][1],madd(Bu.z,matrix[0][2],Bu.w * matrix[0][3]))), + madd(Bv.y,madd(Bu.x,matrix[1][0],madd(Bu.y,v_11 ,madd(Bu.z,v_12 ,Bu.w * matrix[1][3]))), + madd(Bv.z,madd(Bu.x,matrix[2][0],madd(Bu.y,v_21 ,madd(Bu.z,v_22 ,Bu.w * matrix[2][3]))), + Bv.w*madd(Bu.x,matrix[3][0],madd(Bu.y,matrix[3][1],madd(Bu.z,matrix[3][2],Bu.w * matrix[3][3])))))); + } + + static __forceinline Vertex eval_dudu(const Vertex matrix[4][4], const Vertex f[2][2], const float uu, const float vv) // approximative derivative + { + Vertex_t v_11, v_12, v_22, v_21; + computeInnerVertices(matrix,f,uu,vv,v_11, v_12, v_22, v_21); + + const Vec4 Bu = BezierBasis::derivative2(uu); + const Vec4 Bv = BezierBasis::eval(vv); + + return madd(Bv.x,madd(Bu.x,matrix[0][0],madd(Bu.y,matrix[0][1],madd(Bu.z,matrix[0][2],Bu.w * matrix[0][3]))), + madd(Bv.y,madd(Bu.x,matrix[1][0],madd(Bu.y,v_11 ,madd(Bu.z,v_12 ,Bu.w * matrix[1][3]))), + madd(Bv.z,madd(Bu.x,matrix[2][0],madd(Bu.y,v_21 ,madd(Bu.z,v_22 ,Bu.w * matrix[2][3]))), + Bv.w*madd(Bu.x,matrix[3][0],madd(Bu.y,matrix[3][1],madd(Bu.z,matrix[3][2],Bu.w * matrix[3][3])))))); + } + + static __forceinline Vertex eval_dvdv(const Vertex matrix[4][4], const Vertex f[2][2], const float uu, const float vv) // approximative derivative + { + Vertex_t v_11, v_12, v_22, v_21; + computeInnerVertices(matrix,f,uu,vv,v_11, v_12, v_22, v_21); + + const Vec4 Bu = BezierBasis::eval(uu); + const Vec4 Bv = BezierBasis::derivative2(vv); + + return madd(Bv.x,madd(Bu.x,matrix[0][0],madd(Bu.y,matrix[0][1],madd(Bu.z,matrix[0][2],Bu.w * matrix[0][3]))), + madd(Bv.y,madd(Bu.x,matrix[1][0],madd(Bu.y,v_11 ,madd(Bu.z,v_12 ,Bu.w * matrix[1][3]))), + madd(Bv.z,madd(Bu.x,matrix[2][0],madd(Bu.y,v_21 ,madd(Bu.z,v_22 ,Bu.w * matrix[2][3]))), + Bv.w*madd(Bu.x,matrix[3][0],madd(Bu.y,matrix[3][1],madd(Bu.z,matrix[3][2],Bu.w * matrix[3][3])))))); + } + + static __forceinline Vertex eval_dudv(const Vertex matrix[4][4], const Vertex f[2][2], const float uu, const float vv) // approximative derivative + { + Vertex_t v_11, v_12, v_22, v_21; + computeInnerVertices(matrix,f,uu,vv,v_11, v_12, v_22, v_21); + + const Vec4 Bu = BezierBasis::derivative(uu); + const Vec4 Bv = BezierBasis::derivative(vv); + + return madd(Bv.x,madd(Bu.x,matrix[0][0],madd(Bu.y,matrix[0][1],madd(Bu.z,matrix[0][2],Bu.w * matrix[0][3]))), + madd(Bv.y,madd(Bu.x,matrix[1][0],madd(Bu.y,v_11 ,madd(Bu.z,v_12 ,Bu.w * matrix[1][3]))), + madd(Bv.z,madd(Bu.x,matrix[2][0],madd(Bu.y,v_21 ,madd(Bu.z,v_22 ,Bu.w * matrix[2][3]))), + Bv.w*madd(Bu.x,matrix[3][0],madd(Bu.y,matrix[3][1],madd(Bu.z,matrix[3][2],Bu.w * matrix[3][3])))))); + } + + __forceinline Vertex eval(const float uu, const float vv) const { + return eval(v,f,uu,vv); + } + + __forceinline Vertex eval_du( const float uu, const float vv) const { + return eval_du(v,f,uu,vv); + } + + __forceinline Vertex eval_dv( const float uu, const float vv) const { + return eval_dv(v,f,uu,vv); + } + + __forceinline Vertex eval_dudu( const float uu, const float vv) const { + return eval_dudu(v,f,uu,vv); + } + + __forceinline Vertex eval_dvdv( const float uu, const float vv) const { + return eval_dvdv(v,f,uu,vv); + } + + __forceinline Vertex eval_dudv( const float uu, const float vv) const { + return eval_dudv(v,f,uu,vv); + } + + static __forceinline Vertex normal(const Vertex matrix[4][4], const Vertex f_m[2][2], const float uu, const float vv) // FIXME: why not using basis functions + { + /* interpolate inner vertices */ + Vertex_t matrix_11, matrix_12, matrix_22, matrix_21; + computeInnerVertices(matrix,f_m,uu,vv,matrix_11, matrix_12, matrix_22, matrix_21); + + /* tangentU */ + const Vertex_t col0 = deCasteljau(vv, (Vertex_t)matrix[0][0], (Vertex_t)matrix[1][0], (Vertex_t)matrix[2][0], (Vertex_t)matrix[3][0]); + const Vertex_t col1 = deCasteljau(vv, (Vertex_t)matrix[0][1], (Vertex_t)matrix_11 , (Vertex_t)matrix_21 , (Vertex_t)matrix[3][1]); + const Vertex_t col2 = deCasteljau(vv, (Vertex_t)matrix[0][2], (Vertex_t)matrix_12 , (Vertex_t)matrix_22 , (Vertex_t)matrix[3][2]); + const Vertex_t col3 = deCasteljau(vv, (Vertex_t)matrix[0][3], (Vertex_t)matrix[1][3], (Vertex_t)matrix[2][3], (Vertex_t)matrix[3][3]); + + const Vertex_t tangentU = deCasteljau_tangent(uu, col0, col1, col2, col3); + + /* tangentV */ + const Vertex_t row0 = deCasteljau(uu, (Vertex_t)matrix[0][0], (Vertex_t)matrix[0][1], (Vertex_t)matrix[0][2], (Vertex_t)matrix[0][3]); + const Vertex_t row1 = deCasteljau(uu, (Vertex_t)matrix[1][0], (Vertex_t)matrix_11 , (Vertex_t)matrix_12 , (Vertex_t)matrix[1][3]); + const Vertex_t row2 = deCasteljau(uu, (Vertex_t)matrix[2][0], (Vertex_t)matrix_21 , (Vertex_t)matrix_22 , (Vertex_t)matrix[2][3]); + const Vertex_t row3 = deCasteljau(uu, (Vertex_t)matrix[3][0], (Vertex_t)matrix[3][1], (Vertex_t)matrix[3][2], (Vertex_t)matrix[3][3]); + + const Vertex_t tangentV = deCasteljau_tangent(vv, row0, row1, row2, row3); + + /* normal = tangentU x tangentV */ + const Vertex_t n = cross(tangentU,tangentV); + + return n; + } + + __forceinline Vertex normal( const float uu, const float vv) const { + return normal(v,f,uu,vv); + } + + __forceinline void eval(const float u, const float v, + Vertex* P, Vertex* dPdu, Vertex* dPdv, + Vertex* ddPdudu, Vertex* ddPdvdv, Vertex* ddPdudv, + const float dscale = 1.0f) const + { + if (P) { + *P = eval(u,v); + } + if (dPdu) { + assert(dPdu); *dPdu = eval_du(u,v)*dscale; + assert(dPdv); *dPdv = eval_dv(u,v)*dscale; + } + if (ddPdudu) { + assert(ddPdudu); *ddPdudu = eval_dudu(u,v)*sqr(dscale); + assert(ddPdvdv); *ddPdvdv = eval_dvdv(u,v)*sqr(dscale); + assert(ddPdudv); *ddPdudv = eval_dudv(u,v)*sqr(dscale); + } + } + + template + static __forceinline vfloat eval(const Vertex v[4][4], const Vertex f[2][2], + const size_t i, const vfloat& uu, const vfloat& vv, const Vec4& u_n, const Vec4& v_n, + vfloat& matrix_11, vfloat& matrix_12, vfloat& matrix_22, vfloat& matrix_21) + { + const vfloat curve0_x = madd(v_n[0],vfloat(v[0][0][i]),madd(v_n[1],vfloat(v[1][0][i]),madd(v_n[2],vfloat(v[2][0][i]),v_n[3] * vfloat(v[3][0][i])))); + const vfloat curve1_x = madd(v_n[0],vfloat(v[0][1][i]),madd(v_n[1],vfloat(matrix_11 ),madd(v_n[2],vfloat(matrix_21 ),v_n[3] * vfloat(v[3][1][i])))); + const vfloat curve2_x = madd(v_n[0],vfloat(v[0][2][i]),madd(v_n[1],vfloat(matrix_12 ),madd(v_n[2],vfloat(matrix_22 ),v_n[3] * vfloat(v[3][2][i])))); + const vfloat curve3_x = madd(v_n[0],vfloat(v[0][3][i]),madd(v_n[1],vfloat(v[1][3][i]),madd(v_n[2],vfloat(v[2][3][i]),v_n[3] * vfloat(v[3][3][i])))); + return madd(u_n[0],curve0_x,madd(u_n[1],curve1_x,madd(u_n[2],curve2_x,u_n[3] * curve3_x))); + } + + template + static __forceinline void eval(const Vertex v[4][4], const Vertex f[2][2], + const vbool& valid, const vfloat& uu, const vfloat& vv, + float* P, float* dPdu, float* dPdv, float* ddPdudu, float* ddPdvdv, float* ddPdudv, + const float dscale, const size_t dstride, const size_t N) + { + if (P) { + const Vec4 u_n = BezierBasis::eval(uu); + const Vec4 v_n = BezierBasis::eval(vv); + for (size_t i=0; i u_n = BezierBasis::derivative(uu); + const Vec4 v_n = BezierBasis::eval(vv); + for (size_t i=0; i u_n = BezierBasis::eval(uu); + const Vec4 v_n = BezierBasis::derivative(vv); + for (size_t i=0; i u_n = BezierBasis::derivative2(uu); + const Vec4 v_n = BezierBasis::eval(vv); + for (size_t i=0; i u_n = BezierBasis::eval(uu); + const Vec4 v_n = BezierBasis::derivative2(vv); + for (size_t i=0; i u_n = BezierBasis::derivative(uu); + const Vec4 v_n = BezierBasis::derivative(vv); + for (size_t i=0; i + __forceinline void eval(const vbool& valid, const vfloat& uu, const vfloat& vv, + float* P, float* dPdu, float* dPdv, float* ddPdudu, float* ddPdvdv, float* ddPdudv, + const float dscale, const size_t dstride, const size_t N) const { + eval(v,f,valid,uu,vv,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale,dstride,N); + } + + template + static __forceinline Vec3 eval_t(const Vertex matrix[4][4], const Vec3 f[2][2], const T& uu, const T& vv) + { + typedef typename T::Bool M; + const M m_border = (uu == 0.0f) | (uu == 1.0f) | (vv == 0.0f) | (vv == 1.0f); + + const Vec3 f0_p = Vec3(matrix[1][1].x,matrix[1][1].y,matrix[1][1].z); + const Vec3 f1_p = Vec3(matrix[1][2].x,matrix[1][2].y,matrix[1][2].z); + const Vec3 f2_p = Vec3(matrix[2][2].x,matrix[2][2].y,matrix[2][2].z); + const Vec3 f3_p = Vec3(matrix[2][1].x,matrix[2][1].y,matrix[2][1].z); + + const Vec3 f0_m = f[0][0]; + const Vec3 f1_m = f[0][1]; + const Vec3 f2_m = f[1][1]; + const Vec3 f3_m = f[1][0]; + + const T one_minus_uu = T(1.0f) - uu; + const T one_minus_vv = T(1.0f) - vv; + + const Vec3 f0_i = ( uu * f0_p + vv * f0_m) * rcp(uu+vv); + const Vec3 f1_i = (one_minus_uu * f1_m + vv * f1_p) * rcp(one_minus_uu+vv); + const Vec3 f2_i = (one_minus_uu * f2_p + one_minus_vv * f2_m) * rcp(one_minus_uu+one_minus_vv); + const Vec3 f3_i = ( uu * f3_m + one_minus_vv * f3_p) * rcp(uu+one_minus_vv); + + const Vec3 F0( select(m_border,f0_p.x,f0_i.x), select(m_border,f0_p.y,f0_i.y), select(m_border,f0_p.z,f0_i.z) ); + const Vec3 F1( select(m_border,f1_p.x,f1_i.x), select(m_border,f1_p.y,f1_i.y), select(m_border,f1_p.z,f1_i.z) ); + const Vec3 F2( select(m_border,f2_p.x,f2_i.x), select(m_border,f2_p.y,f2_i.y), select(m_border,f2_p.z,f2_i.z) ); + const Vec3 F3( select(m_border,f3_p.x,f3_i.x), select(m_border,f3_p.y,f3_i.y), select(m_border,f3_p.z,f3_i.z) ); + + const T B0_u = one_minus_uu * one_minus_uu * one_minus_uu; + const T B0_v = one_minus_vv * one_minus_vv * one_minus_vv; + const T B1_u = 3.0f * (one_minus_uu * uu * one_minus_uu); + const T B1_v = 3.0f * (one_minus_vv * vv * one_minus_vv); + const T B2_u = 3.0f * (uu * one_minus_uu * uu); + const T B2_v = 3.0f * (vv * one_minus_vv * vv); + const T B3_u = uu * uu * uu; + const T B3_v = vv * vv * vv; + + const T x = madd(B0_v,madd(B0_u,matrix[0][0].x,madd(B1_u,matrix[0][1].x,madd(B2_u,matrix[0][2].x,B3_u * matrix[0][3].x))), + madd(B1_v,madd(B0_u,matrix[1][0].x,madd(B1_u,F0.x ,madd(B2_u,F1.x ,B3_u * matrix[1][3].x))), + madd(B2_v,madd(B0_u,matrix[2][0].x,madd(B1_u,F3.x ,madd(B2_u,F2.x ,B3_u * matrix[2][3].x))), + B3_v*madd(B0_u,matrix[3][0].x,madd(B1_u,matrix[3][1].x,madd(B2_u,matrix[3][2].x,B3_u * matrix[3][3].x)))))); + + const T y = madd(B0_v,madd(B0_u,matrix[0][0].y,madd(B1_u,matrix[0][1].y,madd(B2_u,matrix[0][2].y,B3_u * matrix[0][3].y))), + madd(B1_v,madd(B0_u,matrix[1][0].y,madd(B1_u,F0.y ,madd(B2_u,F1.y ,B3_u * matrix[1][3].y))), + madd(B2_v,madd(B0_u,matrix[2][0].y,madd(B1_u,F3.y ,madd(B2_u,F2.y ,B3_u * matrix[2][3].y))), + B3_v*madd(B0_u,matrix[3][0].y,madd(B1_u,matrix[3][1].y,madd(B2_u,matrix[3][2].y,B3_u * matrix[3][3].y)))))); + + const T z = madd(B0_v,madd(B0_u,matrix[0][0].z,madd(B1_u,matrix[0][1].z,madd(B2_u,matrix[0][2].z,B3_u * matrix[0][3].z))), + madd(B1_v,madd(B0_u,matrix[1][0].z,madd(B1_u,F0.z ,madd(B2_u,F1.z ,B3_u * matrix[1][3].z))), + madd(B2_v,madd(B0_u,matrix[2][0].z,madd(B1_u,F3.z ,madd(B2_u,F2.z ,B3_u * matrix[2][3].z))), + B3_v*madd(B0_u,matrix[3][0].z,madd(B1_u,matrix[3][1].z,madd(B2_u,matrix[3][2].z,B3_u * matrix[3][3].z)))))); + + return Vec3(x,y,z); + } + + template + __forceinline Vec3 eval(const T& uu, const T& vv) const + { + Vec3 ff[2][2]; + ff[0][0] = Vec3(f[0][0]); + ff[0][1] = Vec3(f[0][1]); + ff[1][1] = Vec3(f[1][1]); + ff[1][0] = Vec3(f[1][0]); + return eval_t(v,ff,uu,vv); + } + + template + static __forceinline Vec3 normal_t(const Vertex matrix[4][4], const Vec3 f[2][2], const T& uu, const T& vv) + { + typedef typename T::Bool M; + + const Vec3 f0_p = Vec3(matrix[1][1].x,matrix[1][1].y,matrix[1][1].z); + const Vec3 f1_p = Vec3(matrix[1][2].x,matrix[1][2].y,matrix[1][2].z); + const Vec3 f2_p = Vec3(matrix[2][2].x,matrix[2][2].y,matrix[2][2].z); + const Vec3 f3_p = Vec3(matrix[2][1].x,matrix[2][1].y,matrix[2][1].z); + + const Vec3 f0_m = f[0][0]; + const Vec3 f1_m = f[0][1]; + const Vec3 f2_m = f[1][1]; + const Vec3 f3_m = f[1][0]; + + const T one_minus_uu = T(1.0f) - uu; + const T one_minus_vv = T(1.0f) - vv; + + const Vec3 f0_i = ( uu * f0_p + vv * f0_m) * rcp(uu+vv); + const Vec3 f1_i = (one_minus_uu * f1_m + vv * f1_p) * rcp(one_minus_uu+vv); + const Vec3 f2_i = (one_minus_uu * f2_p + one_minus_vv * f2_m) * rcp(one_minus_uu+one_minus_vv); + const Vec3 f3_i = ( uu * f3_m + one_minus_vv * f3_p) * rcp(uu+one_minus_vv); + +#if 1 + const M m_corner0 = (uu == 0.0f) & (vv == 0.0f); + const M m_corner1 = (uu == 1.0f) & (vv == 0.0f); + const M m_corner2 = (uu == 1.0f) & (vv == 1.0f); + const M m_corner3 = (uu == 0.0f) & (vv == 1.0f); + const Vec3 matrix_11( select(m_corner0,f0_p.x,f0_i.x), select(m_corner0,f0_p.y,f0_i.y), select(m_corner0,f0_p.z,f0_i.z) ); + const Vec3 matrix_12( select(m_corner1,f1_p.x,f1_i.x), select(m_corner1,f1_p.y,f1_i.y), select(m_corner1,f1_p.z,f1_i.z) ); + const Vec3 matrix_22( select(m_corner2,f2_p.x,f2_i.x), select(m_corner2,f2_p.y,f2_i.y), select(m_corner2,f2_p.z,f2_i.z) ); + const Vec3 matrix_21( select(m_corner3,f3_p.x,f3_i.x), select(m_corner3,f3_p.y,f3_i.y), select(m_corner3,f3_p.z,f3_i.z) ); +#else + const M m_border = (uu == 0.0f) | (uu == 1.0f) | (vv == 0.0f) | (vv == 1.0f); + const Vec3 matrix_11( select(m_border,f0_p.x,f0_i.x), select(m_border,f0_p.y,f0_i.y), select(m_border,f0_p.z,f0_i.z) ); + const Vec3 matrix_12( select(m_border,f1_p.x,f1_i.x), select(m_border,f1_p.y,f1_i.y), select(m_border,f1_p.z,f1_i.z) ); + const Vec3 matrix_22( select(m_border,f2_p.x,f2_i.x), select(m_border,f2_p.y,f2_i.y), select(m_border,f2_p.z,f2_i.z) ); + const Vec3 matrix_21( select(m_border,f3_p.x,f3_i.x), select(m_border,f3_p.y,f3_i.y), select(m_border,f3_p.z,f3_i.z) ); +#endif + + const Vec3 matrix_00 = Vec3(matrix[0][0].x,matrix[0][0].y,matrix[0][0].z); + const Vec3 matrix_10 = Vec3(matrix[1][0].x,matrix[1][0].y,matrix[1][0].z); + const Vec3 matrix_20 = Vec3(matrix[2][0].x,matrix[2][0].y,matrix[2][0].z); + const Vec3 matrix_30 = Vec3(matrix[3][0].x,matrix[3][0].y,matrix[3][0].z); + + const Vec3 matrix_01 = Vec3(matrix[0][1].x,matrix[0][1].y,matrix[0][1].z); + const Vec3 matrix_02 = Vec3(matrix[0][2].x,matrix[0][2].y,matrix[0][2].z); + const Vec3 matrix_03 = Vec3(matrix[0][3].x,matrix[0][3].y,matrix[0][3].z); + + const Vec3 matrix_31 = Vec3(matrix[3][1].x,matrix[3][1].y,matrix[3][1].z); + const Vec3 matrix_32 = Vec3(matrix[3][2].x,matrix[3][2].y,matrix[3][2].z); + const Vec3 matrix_33 = Vec3(matrix[3][3].x,matrix[3][3].y,matrix[3][3].z); + + const Vec3 matrix_13 = Vec3(matrix[1][3].x,matrix[1][3].y,matrix[1][3].z); + const Vec3 matrix_23 = Vec3(matrix[2][3].x,matrix[2][3].y,matrix[2][3].z); + + /* tangentU */ + const Vec3 col0 = deCasteljau(vv, matrix_00, matrix_10, matrix_20, matrix_30); + const Vec3 col1 = deCasteljau(vv, matrix_01, matrix_11, matrix_21, matrix_31); + const Vec3 col2 = deCasteljau(vv, matrix_02, matrix_12, matrix_22, matrix_32); + const Vec3 col3 = deCasteljau(vv, matrix_03, matrix_13, matrix_23, matrix_33); + + const Vec3 tangentU = deCasteljau_tangent(uu, col0, col1, col2, col3); + + /* tangentV */ + const Vec3 row0 = deCasteljau(uu, matrix_00, matrix_01, matrix_02, matrix_03); + const Vec3 row1 = deCasteljau(uu, matrix_10, matrix_11, matrix_12, matrix_13); + const Vec3 row2 = deCasteljau(uu, matrix_20, matrix_21, matrix_22, matrix_23); + const Vec3 row3 = deCasteljau(uu, matrix_30, matrix_31, matrix_32, matrix_33); + + const Vec3 tangentV = deCasteljau_tangent(vv, row0, row1, row2, row3); + + /* normal = tangentU x tangentV */ + const Vec3 n = cross(tangentU,tangentV); + return n; + } + + template + __forceinline Vec3 normal(const T& uu, const T& vv) const + { + Vec3 ff[2][2]; + ff[0][0] = Vec3(f[0][0]); + ff[0][1] = Vec3(f[0][1]); + ff[1][1] = Vec3(f[1][1]); + ff[1][0] = Vec3(f[1][0]); + return normal_t(v,ff,uu,vv); + } + + __forceinline BBox bounds() const + { + const Vertex *const cv = &v[0][0]; + BBox bounds (cv[0]); + for (size_t i=1; i<16; i++) + bounds.extend( cv[i] ); + bounds.extend(f[0][0]); + bounds.extend(f[1][0]); + bounds.extend(f[1][1]); + bounds.extend(f[1][1]); + return bounds; + } + + friend embree_ostream operator<<(embree_ostream o, const GregoryPatchT& p) + { + for (size_t y=0; y<4; y++) + for (size_t x=0; x<4; x++) + o << "v[" << y << "][" << x << "] " << p.v[y][x] << embree_endl; + + for (size_t y=0; y<2; y++) + for (size_t x=0; x<2; x++) + o << "f[" << y << "][" << x << "] " << p.f[y][x] << embree_endl; + return o; + } + }; + + typedef GregoryPatchT GregoryPatch3fa; + + template + __forceinline BezierPatchT::BezierPatchT (const HalfEdge* edge, const char* vertices, size_t stride) + { + CatmullClarkPatchT patch(edge,vertices,stride); + GregoryPatchT gpatch(patch); + gpatch.convert_to_bezier(); + for (size_t y=0; y<4; y++) + for (size_t x=0; x<4; x++) + matrix[y][x] = (Vertex_t)gpatch.v[y][x]; + } + + template + __forceinline BezierPatchT::BezierPatchT(const CatmullClarkPatchT& patch) + { + GregoryPatchT gpatch(patch); + gpatch.convert_to_bezier(); + for (size_t y=0; y<4; y++) + for (size_t x=0; x<4; x++) + matrix[y][x] = (Vertex_t)gpatch.v[y][x]; + } + + template + __forceinline BezierPatchT::BezierPatchT(const CatmullClarkPatchT& patch, + const BezierCurveT* border0, + const BezierCurveT* border1, + const BezierCurveT* border2, + const BezierCurveT* border3) + { + GregoryPatchT gpatch(patch,border0,border1,border2,border3); + gpatch.convert_to_bezier(); + for (size_t y=0; y<4; y++) + for (size_t x=0; x<4; x++) + matrix[y][x] = (Vertex_t)gpatch.v[y][x]; + } +} diff --git a/thirdparty/embree/kernels/subdiv/gregory_patch_dense.h b/thirdparty/embree/kernels/subdiv/gregory_patch_dense.h new file mode 100644 index 000000000000..85effd02cfb3 --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/gregory_patch_dense.h @@ -0,0 +1,113 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "gregory_patch.h" + +namespace embree +{ + class __aligned(64) DenseGregoryPatch3fa + { + typedef Vec3fa Vec3fa_4x4[4][4]; + public: + + __forceinline DenseGregoryPatch3fa (const GregoryPatch3fa& patch) + { + for (size_t y=0; y<4; y++) + for (size_t x=0; x<4; x++) + matrix[y][x] = Vec3ff(patch.v[y][x], 0.0f); + + matrix[0][0].w = patch.f[0][0].x; + matrix[0][1].w = patch.f[0][0].y; + matrix[0][2].w = patch.f[0][0].z; + matrix[0][3].w = 0.0f; + + matrix[1][0].w = patch.f[0][1].x; + matrix[1][1].w = patch.f[0][1].y; + matrix[1][2].w = patch.f[0][1].z; + matrix[1][3].w = 0.0f; + + matrix[2][0].w = patch.f[1][1].x; + matrix[2][1].w = patch.f[1][1].y; + matrix[2][2].w = patch.f[1][1].z; + matrix[2][3].w = 0.0f; + + matrix[3][0].w = patch.f[1][0].x; + matrix[3][1].w = patch.f[1][0].y; + matrix[3][2].w = patch.f[1][0].z; + matrix[3][3].w = 0.0f; + } + + __forceinline void extract_f_m(Vec3fa f_m[2][2]) const + { + f_m[0][0] = Vec3fa( matrix[0][0].w, matrix[0][1].w, matrix[0][2].w ); + f_m[0][1] = Vec3fa( matrix[1][0].w, matrix[1][1].w, matrix[1][2].w ); + f_m[1][1] = Vec3fa( matrix[2][0].w, matrix[2][1].w, matrix[2][2].w ); + f_m[1][0] = Vec3fa( matrix[3][0].w, matrix[3][1].w, matrix[3][2].w ); + } + + __forceinline Vec3fa eval(const float uu, const float vv) const + { + __aligned(64) Vec3fa f_m[2][2]; extract_f_m(f_m); + return GregoryPatch3fa::eval(*(Vec3fa_4x4*)&matrix,f_m,uu,vv); + } + + __forceinline Vec3fa normal(const float uu, const float vv) const + { + __aligned(64) Vec3fa f_m[2][2]; extract_f_m(f_m); + return GregoryPatch3fa::normal(*(Vec3fa_4x4*)&matrix,f_m,uu,vv); + } + + template + __forceinline Vec3 eval(const T &uu, const T &vv) const + { + Vec3 f_m[2][2]; + f_m[0][0] = Vec3( matrix[0][0].w, matrix[0][1].w, matrix[0][2].w ); + f_m[0][1] = Vec3( matrix[1][0].w, matrix[1][1].w, matrix[1][2].w ); + f_m[1][1] = Vec3( matrix[2][0].w, matrix[2][1].w, matrix[2][2].w ); + f_m[1][0] = Vec3( matrix[3][0].w, matrix[3][1].w, matrix[3][2].w ); + return GregoryPatch3fa::eval_t(*(Vec3fa_4x4*)&matrix,f_m,uu,vv); + } + + template + __forceinline Vec3 normal(const T &uu, const T &vv) const + { + Vec3 f_m[2][2]; + f_m[0][0] = Vec3( matrix[0][0].w, matrix[0][1].w, matrix[0][2].w ); + f_m[0][1] = Vec3( matrix[1][0].w, matrix[1][1].w, matrix[1][2].w ); + f_m[1][1] = Vec3( matrix[2][0].w, matrix[2][1].w, matrix[2][2].w ); + f_m[1][0] = Vec3( matrix[3][0].w, matrix[3][1].w, matrix[3][2].w ); + return GregoryPatch3fa::normal_t(*(Vec3fa_4x4*)&matrix,f_m,uu,vv); + } + + __forceinline void eval(const float u, const float v, + Vec3fa* P, Vec3fa* dPdu, Vec3fa* dPdv, Vec3fa* ddPdudu, Vec3fa* ddPdvdv, Vec3fa* ddPdudv, + const float dscale = 1.0f) const + { + __aligned(64) Vec3fa f_m[2][2]; extract_f_m(f_m); + if (P) { + *P = GregoryPatch3fa::eval(*(Vec3fa_4x4*)&matrix,f_m,u,v); + } + if (dPdu) { + assert(dPdu); *dPdu = GregoryPatch3fa::eval_du(*(Vec3fa_4x4*)&matrix,f_m,u,v)*dscale; + assert(dPdv); *dPdv = GregoryPatch3fa::eval_dv(*(Vec3fa_4x4*)&matrix,f_m,u,v)*dscale; + } + if (ddPdudu) { + assert(ddPdudu); *ddPdudu = GregoryPatch3fa::eval_dudu(*(Vec3fa_4x4*)&matrix,f_m,u,v)*sqr(dscale); + assert(ddPdvdv); *ddPdvdv = GregoryPatch3fa::eval_dvdv(*(Vec3fa_4x4*)&matrix,f_m,u,v)*sqr(dscale); + assert(ddPdudv); *ddPdudv = GregoryPatch3fa::eval_dudv(*(Vec3fa_4x4*)&matrix,f_m,u,v)*sqr(dscale); + } + } + + template + __forceinline void eval(const vbool& valid, const vfloat& uu, const vfloat& vv, float* P, float* dPdu, float* dPdv, const float dscale, const size_t dstride, const size_t N) const + { + __aligned(64) Vec3fa f_m[2][2]; extract_f_m(f_m); + GregoryPatch3fa::eval(matrix,f_m,valid,uu,vv,P,dPdu,dPdv,dscale,dstride,N); + } + + private: + Vec3ff matrix[4][4]; // f_p/m points are stored in 4th component + }; +} diff --git a/thirdparty/embree/kernels/subdiv/gridrange.h b/thirdparty/embree/kernels/subdiv/gridrange.h new file mode 100644 index 000000000000..4fd741c8790a --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/gridrange.h @@ -0,0 +1,96 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/default.h" + +namespace embree +{ + struct __aligned(16) GridRange + { + unsigned int u_start; + unsigned int u_end; + unsigned int v_start; + unsigned int v_end; + + __forceinline GridRange() {} + + __forceinline GridRange(unsigned int u_start, unsigned int u_end, unsigned int v_start, unsigned int v_end) + : u_start(u_start), u_end(u_end), v_start(v_start), v_end(v_end) {} + + __forceinline unsigned int width() const { + return u_end-u_start+1; + } + + __forceinline unsigned int height() const { + return v_end-v_start+1; + } + + __forceinline bool hasLeafSize() const + { + const unsigned int u_size = u_end-u_start+1; + const unsigned int v_size = v_end-v_start+1; + assert(u_size >= 1); + assert(v_size >= 1); + return u_size <= 3 && v_size <= 3; + } + + static __forceinline unsigned int split(unsigned int start,unsigned int end) + { + const unsigned int center = (start+end)/2; + assert (center > start); + assert (center < end); + return center; + } + + __forceinline void split(GridRange& r0, GridRange& r1) const + { + assert( hasLeafSize() == false ); + const unsigned int u_size = u_end-u_start+1; + const unsigned int v_size = v_end-v_start+1; + r0 = *this; + r1 = *this; + + if (u_size >= v_size) + { + const unsigned int u_mid = split(u_start,u_end); + r0.u_end = u_mid; + r1.u_start = u_mid; + } + else + { + const unsigned int v_mid = split(v_start,v_end); + r0.v_end = v_mid; + r1.v_start = v_mid; + } + } + + __forceinline unsigned int splitIntoSubRanges(GridRange r[4]) const + { + assert( !hasLeafSize() ); + unsigned int children = 0; + GridRange first,second; + split(first,second); + + if (first.hasLeafSize()) { + r[0] = first; + children++; + } + else { + first.split(r[0],r[1]); + children += 2; + } + + if (second.hasLeafSize()) { + r[children] = second; + children++; + } + else { + second.split(r[children+0],r[children+1]); + children += 2; + } + return children; + } + }; +} diff --git a/thirdparty/embree/kernels/subdiv/half_edge.h b/thirdparty/embree/kernels/subdiv/half_edge.h new file mode 100644 index 000000000000..fb350ca71fe3 --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/half_edge.h @@ -0,0 +1,371 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "catmullclark_coefficients.h" + +namespace embree +{ + class __aligned(32) HalfEdge + { + friend class SubdivMesh; + public: + + enum PatchType : char { + BILINEAR_PATCH = 0, //!< a bilinear patch + REGULAR_QUAD_PATCH = 1, //!< a regular quad patch can be represented as a B-Spline + IRREGULAR_QUAD_PATCH = 2, //!< an irregular quad patch can be represented as a Gregory patch + COMPLEX_PATCH = 3 //!< these patches need subdivision and cannot be processed by the above fast code paths + }; + + enum VertexType : char { + REGULAR_VERTEX = 0, //!< regular vertex + NON_MANIFOLD_EDGE_VERTEX = 1, //!< vertex of a non-manifold edge + }; + + __forceinline friend PatchType max( const PatchType& ty0, const PatchType& ty1) { + return (PatchType) max((int)ty0,(int)ty1); + } + + struct Edge + { + /*! edge constructor */ + __forceinline Edge(const uint32_t v0, const uint32_t v1) + : v0(v0), v1(v1) {} + + /*! create an 64 bit identifier that is unique for the not oriented edge */ + __forceinline operator uint64_t() const + { + uint32_t p0 = v0, p1 = v1; + if (p0next(); } + __forceinline const HalfEdge* rotate() const { return opposite()->next(); } + + __forceinline unsigned int getStartVertexIndex() const { return vtx_index; } + __forceinline unsigned int getEndVertexIndex () const { return next()->vtx_index; } + __forceinline Edge getEdge () const { return Edge(getStartVertexIndex(),getEndVertexIndex()); } + + + /*! tests if the start vertex of the edge is regular */ + __forceinline PatchType vertexType() const + { + const HalfEdge* p = this; + size_t face_valence = 0; + bool hasBorder = false; + + do + { + /* we need subdivision to handle edge creases */ + if (p->hasOpposite() && p->edge_crease_weight > 0.0f) + return COMPLEX_PATCH; + + face_valence++; + + /* test for quad */ + const HalfEdge* pp = p; + pp = pp->next(); if (pp == p) return COMPLEX_PATCH; + pp = pp->next(); if (pp == p) return COMPLEX_PATCH; + pp = pp->next(); if (pp == p) return COMPLEX_PATCH; + pp = pp->next(); if (pp != p) return COMPLEX_PATCH; + + /* continue with next face */ + p = p->prev(); + if (likely(p->hasOpposite())) + p = p->opposite(); + + /* if there is no opposite go the long way to the other side of the border */ + else + { + face_valence++; + hasBorder = true; + p = this; + while (p->hasOpposite()) + p = p->rotate(); + } + } while (p != this); + + /* calculate vertex type */ + if (face_valence == 2 && hasBorder) { + if (vertex_crease_weight == 0.0f ) return REGULAR_QUAD_PATCH; + else if (vertex_crease_weight == float(inf)) return REGULAR_QUAD_PATCH; + else return COMPLEX_PATCH; + } + else if (vertex_crease_weight != 0.0f) return COMPLEX_PATCH; + else if (face_valence == 3 && hasBorder) return REGULAR_QUAD_PATCH; + else if (face_valence == 4 && !hasBorder) return REGULAR_QUAD_PATCH; + else return IRREGULAR_QUAD_PATCH; + } + + /*! tests if this edge is part of a bilinear patch */ + __forceinline bool bilinearVertex() const { + return vertex_crease_weight == float(inf) && edge_crease_weight == float(inf); + } + + /*! calculates the type of the patch */ + __forceinline PatchType patchType() const + { + const HalfEdge* p = this; + PatchType ret = REGULAR_QUAD_PATCH; + bool bilinear = true; + + ret = max(ret,p->vertexType()); + bilinear &= p->bilinearVertex(); + if ((p = p->next()) == this) return COMPLEX_PATCH; + + ret = max(ret,p->vertexType()); + bilinear &= p->bilinearVertex(); + if ((p = p->next()) == this) return COMPLEX_PATCH; + + ret = max(ret,p->vertexType()); + bilinear &= p->bilinearVertex(); + if ((p = p->next()) == this) return COMPLEX_PATCH; + + ret = max(ret,p->vertexType()); + bilinear &= p->bilinearVertex(); + if ((p = p->next()) != this) return COMPLEX_PATCH; + + if (bilinear) return BILINEAR_PATCH; + return ret; + } + + /*! tests if the face is a regular b-spline face */ + __forceinline bool isRegularFace() const { + return patch_type == REGULAR_QUAD_PATCH; + } + + /*! tests if the face can be diced (using bspline or gregory patch) */ + __forceinline bool isGregoryFace() const { + return patch_type == IRREGULAR_QUAD_PATCH || patch_type == REGULAR_QUAD_PATCH; + } + + /*! tests if the base vertex of this half edge is a corner vertex */ + __forceinline bool isCorner() const { + return !hasOpposite() && !prev()->hasOpposite(); + } + + /*! tests if the vertex is attached to any border */ + __forceinline bool vertexHasBorder() const + { + const HalfEdge* p = this; + do { + if (!p->hasOpposite()) return true; + p = p->rotate(); + } while (p != this); + return false; + } + + /*! tests if the face this half edge belongs to has some border */ + __forceinline bool faceHasBorder() const + { + const HalfEdge* p = this; + do { + if (p->vertexHasBorder()) return true; + p = p->next(); + } while (p != this); + return false; + } + + /*! calculates conservative bounds of a catmull clark subdivision face */ + __forceinline BBox3fa bounds(const BufferView& vertices) const + { + BBox3fa bounds = this->get1RingBounds(vertices); + for (const HalfEdge* p=this->next(); p!=this; p=p->next()) + bounds.extend(p->get1RingBounds(vertices)); + return bounds; + } + + /*! tests if this is a valid patch */ + __forceinline bool valid(const BufferView& vertices) const + { + size_t N = 1; + if (!this->validRing(vertices)) return false; + for (const HalfEdge* p=this->next(); p!=this; p=p->next(), N++) { + if (!p->validRing(vertices)) return false; + } + return N >= 3 && N <= MAX_PATCH_VALENCE; + } + + /*! counts number of polygon edges */ + __forceinline unsigned int numEdges() const + { + unsigned int N = 1; + for (const HalfEdge* p=this->next(); p!=this; p=p->next(), N++); + return N; + } + + /*! calculates face and edge valence */ + __forceinline void calculateFaceValenceAndEdgeValence(size_t& faceValence, size_t& edgeValence) const + { + faceValence = 0; + edgeValence = 0; + + const HalfEdge* p = this; + do + { + /* calculate bounds of current face */ + unsigned int numEdges = p->numEdges(); + assert(numEdges >= 3); + edgeValence += numEdges-2; + + faceValence++; + p = p->prev(); + + /* continue with next face */ + if (likely(p->hasOpposite())) + p = p->opposite(); + + /* if there is no opposite go the long way to the other side of the border */ + else { + faceValence++; + edgeValence++; + p = this; + while (p->hasOpposite()) + p = p->opposite()->next(); + } + + } while (p != this); + } + + /*! stream output */ + friend __forceinline std::ostream &operator<<(std::ostream &o, const HalfEdge &h) + { + return o << "{ " << + "vertex = " << h.vtx_index << ", " << //" -> " << h.next()->vtx_index << ", " << + "prev = " << h.prev_half_edge_ofs << ", " << + "next = " << h.next_half_edge_ofs << ", " << + "opposite = " << h.opposite_half_edge_ofs << ", " << + "edge_crease = " << h.edge_crease_weight << ", " << + "vertex_crease = " << h.vertex_crease_weight << ", " << + //"edge_level = " << h.edge_level << + " }"; + } + + private: + + /*! calculates the bounds of the face associated with the half-edge */ + __forceinline BBox3fa getFaceBounds(const BufferView& vertices) const + { + BBox3fa b = vertices[getStartVertexIndex()]; + for (const HalfEdge* p = next(); p!=this; p=p->next()) { + b.extend(vertices[p->getStartVertexIndex()]); + } + return b; + } + + /*! calculates the bounds of the 1-ring associated with the vertex of the half-edge */ + __forceinline BBox3fa get1RingBounds(const BufferView& vertices) const + { + BBox3fa bounds = empty; + const HalfEdge* p = this; + do + { + /* calculate bounds of current face */ + bounds.extend(p->getFaceBounds(vertices)); + p = p->prev(); + + /* continue with next face */ + if (likely(p->hasOpposite())) + p = p->opposite(); + + /* if there is no opposite go the long way to the other side of the border */ + else { + p = this; + while (p->hasOpposite()) + p = p->opposite()->next(); + } + + } while (p != this); + + return bounds; + } + + /*! tests if this is a valid face */ + __forceinline bool validFace(const BufferView& vertices, size_t& N) const + { + const Vec3fa v = vertices[getStartVertexIndex()]; + if (!isvalid(v)) return false; + size_t n = 1; + for (const HalfEdge* p = next(); p!=this; p=p->next(), n++) { + const Vec3fa v = vertices[p->getStartVertexIndex()]; + if (!isvalid(v)) return false; + } + N += n-2; + return n >= 3 && n <= MAX_PATCH_VALENCE; + } + + /*! tests if this is a valid ring */ + __forceinline bool validRing(const BufferView& vertices) const + { + size_t faceValence = 0; + size_t edgeValence = 0; + + const HalfEdge* p = this; + do + { + /* calculate bounds of current face */ + if (!p->validFace(vertices,edgeValence)) + return false; + + faceValence++; + p = p->prev(); + + /* continue with next face */ + if (likely(p->hasOpposite())) + p = p->opposite(); + + /* if there is no opposite go the long way to the other side of the border */ + else { + faceValence++; + edgeValence++; + p = this; + while (p->hasOpposite()) + p = p->opposite()->next(); + } + + } while (p != this); + + return faceValence <= MAX_RING_FACE_VALENCE && edgeValence <= MAX_RING_EDGE_VALENCE; + } + + private: + unsigned int vtx_index; //!< index of edge start vertex + int next_half_edge_ofs; //!< relative offset to next half edge of face + int prev_half_edge_ofs; //!< relative offset to previous half edge of face + int opposite_half_edge_ofs; //!< relative offset to opposite half edge + + public: + float edge_crease_weight; //!< crease weight attached to edge + float vertex_crease_weight; //!< crease weight attached to start vertex + float edge_level; //!< subdivision factor for edge + PatchType patch_type; //!< stores type of subdiv patch + VertexType vertex_type; //!< stores type of the start vertex + char align[2]; + }; +} diff --git a/thirdparty/embree/kernels/subdiv/hermite_curve.h b/thirdparty/embree/kernels/subdiv/hermite_curve.h new file mode 100644 index 000000000000..9fab79cf0c7f --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/hermite_curve.h @@ -0,0 +1,38 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../common/default.h" +#include "bezier_curve.h" + +namespace embree +{ + template + struct HermiteCurveT : BezierCurveT + { + __forceinline HermiteCurveT() {} + + __forceinline HermiteCurveT(const BezierCurveT& curve) + : BezierCurveT(curve) {} + + __forceinline HermiteCurveT(const Vertex& v0, const Vertex& t0, const Vertex& v1, const Vertex& t1) + : BezierCurveT(v0,madd(1.0f/3.0f,t0,v0),nmadd(1.0f/3.0f,t1,v1),v1) {} + + __forceinline HermiteCurveT xfm_pr(const LinearSpace3fa& space, const Vec3fa& p) const + { + const Vec3ff q0(xfmVector(space,this->v0-p), this->v0.w); + const Vec3ff q1(xfmVector(space,this->v1-p), this->v1.w); + const Vec3ff q2(xfmVector(space,this->v2-p), this->v2.w); + const Vec3ff q3(xfmVector(space,this->v3-p), this->v3.w); + return BezierCurveT(q0,q1,q2,q3); + } + }; + + __forceinline HermiteCurveT enlargeRadiusToMinWidth(const IntersectContext* context, const CurveGeometry* geom, const Vec3fa& ray_org, const HermiteCurveT& curve) { + return HermiteCurveT(enlargeRadiusToMinWidth(context,geom,ray_org,BezierCurveT(curve))); + } + + typedef HermiteCurveT HermiteCurve3fa; +} + diff --git a/thirdparty/embree/kernels/subdiv/linear_bezier_patch.h b/thirdparty/embree/kernels/subdiv/linear_bezier_patch.h new file mode 100644 index 000000000000..f4a854af7f32 --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/linear_bezier_patch.h @@ -0,0 +1,403 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "bezier_curve.h" + +namespace embree +{ + namespace isa + { + template + struct TensorLinearQuadraticBezierSurface + { + QuadraticBezierCurve L; + QuadraticBezierCurve R; + + __forceinline TensorLinearQuadraticBezierSurface() {} + + __forceinline TensorLinearQuadraticBezierSurface(const TensorLinearQuadraticBezierSurface& curve) + : L(curve.L), R(curve.R) {} + + __forceinline TensorLinearQuadraticBezierSurface& operator= (const TensorLinearQuadraticBezierSurface& other) { + L = other.L; R = other.R; return *this; + } + + __forceinline TensorLinearQuadraticBezierSurface(const QuadraticBezierCurve& L, const QuadraticBezierCurve& R) + : L(L), R(R) {} + + __forceinline BBox bounds() const { + return merge(L.bounds(),R.bounds()); + } + }; + + template<> + struct TensorLinearQuadraticBezierSurface + { + QuadraticBezierCurve LR; + + __forceinline TensorLinearQuadraticBezierSurface() {} + + __forceinline TensorLinearQuadraticBezierSurface(const TensorLinearQuadraticBezierSurface& curve) + : LR(curve.LR) {} + + __forceinline TensorLinearQuadraticBezierSurface& operator= (const TensorLinearQuadraticBezierSurface& other) { + LR = other.LR; return *this; + } + + __forceinline TensorLinearQuadraticBezierSurface(const QuadraticBezierCurve& LR) + : LR(LR) {} + + __forceinline BBox bounds() const + { + const BBox b = LR.bounds(); + const BBox bl(Vec2fa(b.lower),Vec2fa(b.upper)); + const BBox br(Vec2fa(shuffle<2,3,2,3>(b.lower)),Vec2fa(shuffle<2,3,2,3>(b.upper))); + return merge(bl,br); + } + }; + + template + struct TensorLinearCubicBezierSurface + { + CubicBezierCurve L; + CubicBezierCurve R; + + __forceinline TensorLinearCubicBezierSurface() {} + + __forceinline TensorLinearCubicBezierSurface(const TensorLinearCubicBezierSurface& curve) + : L(curve.L), R(curve.R) {} + + __forceinline TensorLinearCubicBezierSurface& operator= (const TensorLinearCubicBezierSurface& other) { + L = other.L; R = other.R; return *this; + } + + __forceinline TensorLinearCubicBezierSurface(const CubicBezierCurve& L, const CubicBezierCurve& R) + : L(L), R(R) {} + + template class SourceCurve> + __forceinline static TensorLinearCubicBezierSurface fromCenterAndNormalCurve(const SourceCurve& center, const SourceCurve& normal) + { + SourceCurve vcurve = center; + SourceCurve ncurve = normal; + + /* here we construct a patch which follows the curve l(t) = + * p(t) +/- r(t)*normalize(cross(n(t),dp(t))) */ + + const Vec3ff p0 = vcurve.eval(0.0f); + const Vec3ff dp0 = vcurve.eval_du(0.0f); + const Vec3ff ddp0 = vcurve.eval_dudu(0.0f); + + const Vec3fa n0 = ncurve.eval(0.0f); + const Vec3fa dn0 = ncurve.eval_du(0.0f); + + const Vec3ff p1 = vcurve.eval(1.0f); + const Vec3ff dp1 = vcurve.eval_du(1.0f); + const Vec3ff ddp1 = vcurve.eval_dudu(1.0f); + + const Vec3fa n1 = ncurve.eval(1.0f); + const Vec3fa dn1 = ncurve.eval_du(1.0f); + + const Vec3fa bt0 = cross(n0,dp0); + const Vec3fa dbt0 = cross(dn0,dp0) + cross(n0,ddp0); + + const Vec3fa bt1 = cross(n1,dp1); + const Vec3fa dbt1 = cross(dn1,dp1) + cross(n1,ddp1); + + const Vec3fa k0 = normalize(bt0); + const Vec3fa dk0 = dnormalize(bt0,dbt0); + + const Vec3fa k1 = normalize(bt1); + const Vec3fa dk1 = dnormalize(bt1,dbt1); + + const Vec3fa l0 = p0 - p0.w*k0; + const Vec3fa dl0 = dp0 - (dp0.w*k0 + p0.w*dk0); + + const Vec3fa r0 = p0 + p0.w*k0; + const Vec3fa dr0 = dp0 + (dp0.w*k0 + p0.w*dk0); + + const Vec3fa l1 = p1 - p1.w*k1; + const Vec3fa dl1 = dp1 - (dp1.w*k1 + p1.w*dk1); + + const Vec3fa r1 = p1 + p1.w*k1; + const Vec3fa dr1 = dp1 + (dp1.w*k1 + p1.w*dk1); + + const float scale = 1.0f/3.0f; + CubicBezierCurve L(l0,l0+scale*dl0,l1-scale*dl1,l1); + CubicBezierCurve R(r0,r0+scale*dr0,r1-scale*dr1,r1); + return TensorLinearCubicBezierSurface(L,R); + } + + __forceinline BBox bounds() const { + return merge(L.bounds(),R.bounds()); + } + + __forceinline BBox3fa accurateBounds() const { + return merge(L.accurateBounds(),R.accurateBounds()); + } + + __forceinline CubicBezierCurve reduce_v() const { + return merge(CubicBezierCurve>(L),CubicBezierCurve>(R)); + } + + __forceinline LinearBezierCurve reduce_u() const { + return LinearBezierCurve(L.bounds(),R.bounds()); + } + + __forceinline TensorLinearCubicBezierSurface xfm(const V& dx) const { + return TensorLinearCubicBezierSurface(L.xfm(dx),R.xfm(dx)); + } + + __forceinline TensorLinearCubicBezierSurface vxfm(const V& dx) const { + return TensorLinearCubicBezierSurface(L.vxfm(dx),R.vxfm(dx)); + } + + __forceinline TensorLinearCubicBezierSurface xfm(const V& dx, const V& p) const { + return TensorLinearCubicBezierSurface(L.xfm(dx,p),R.xfm(dx,p)); + } + + __forceinline TensorLinearCubicBezierSurface xfm(const LinearSpace3fa& space) const { + return TensorLinearCubicBezierSurface(L.xfm(space),R.xfm(space)); + } + + __forceinline TensorLinearCubicBezierSurface xfm(const LinearSpace3fa& space, const Vec3fa& p) const { + return TensorLinearCubicBezierSurface(L.xfm(space,p),R.xfm(space,p)); + } + + __forceinline TensorLinearCubicBezierSurface xfm(const LinearSpace3fa& space, const Vec3fa& p, const float s) const { + return TensorLinearCubicBezierSurface(L.xfm(space,p,s),R.xfm(space,p,s)); + } + + __forceinline TensorLinearCubicBezierSurface clip_u(const Interval1f& u) const { + return TensorLinearCubicBezierSurface(L.clip(u),R.clip(u)); + } + + __forceinline TensorLinearCubicBezierSurface clip_v(const Interval1f& v) const { + return TensorLinearCubicBezierSurface(clerp(L,R,V(v.lower)),clerp(L,R,V(v.upper))); + } + + __forceinline TensorLinearCubicBezierSurface clip(const Interval1f& u, const Interval1f& v) const { + return clip_v(v).clip_u(u); + } + + __forceinline void split_u(TensorLinearCubicBezierSurface& left, TensorLinearCubicBezierSurface& right, const float u = 0.5f) const + { + CubicBezierCurve L0,L1; L.split(L0,L1,u); + CubicBezierCurve R0,R1; R.split(R0,R1,u); + new (&left ) TensorLinearCubicBezierSurface(L0,R0); + new (&right) TensorLinearCubicBezierSurface(L1,R1); + } + + __forceinline TensorLinearCubicBezierSurface vsplit_u(vboolx& valid, const BBox1f& u) const { + valid = true; clear(valid,VSIZEX-1); + return TensorLinearCubicBezierSurface(L.split(u),R.split(u)); + } + + __forceinline V eval(const float u, const float v) const { + return clerp(L,R,V(v)).eval(u); + } + + __forceinline V eval_du(const float u, const float v) const { + return clerp(L,R,V(v)).eval_dt(u); + } + + __forceinline V eval_dv(const float u, const float v) const { + return (R-L).eval(u); + } + + __forceinline void eval(const float u, const float v, V& p, V& dpdu, V& dpdv) const + { + V p0, dp0du; L.eval(u,p0,dp0du); + V p1, dp1du; R.eval(u,p1,dp1du); + p = lerp(p0,p1,v); + dpdu = lerp(dp0du,dp1du,v); + dpdv = p1-p0; + } + + __forceinline TensorLinearQuadraticBezierSurface derivative_u() const { + return TensorLinearQuadraticBezierSurface(L.derivative(),R.derivative()); + } + + __forceinline CubicBezierCurve derivative_v() const { + return R-L; + } + + __forceinline V axis_u() const { + return (L.end()-L.begin())+(R.end()-R.begin()); + } + + __forceinline V axis_v() const { + return (R.begin()-L.begin())+(R.end()-L.end()); + } + + friend embree_ostream operator<<(embree_ostream cout, const TensorLinearCubicBezierSurface& a) + { + return cout << "TensorLinearCubicBezierSurface" << embree_endl + << "{" << embree_endl + << " L = " << a.L << ", " << embree_endl + << " R = " << a.R << embree_endl + << "}"; + } + + friend __forceinline TensorLinearCubicBezierSurface clerp(const TensorLinearCubicBezierSurface& a, const TensorLinearCubicBezierSurface& b, const float t) { + return TensorLinearCubicBezierSurface(clerp(a.L,b.L,V(t)), clerp(a.R,b.R,V(t))); + } + }; + + template<> + struct TensorLinearCubicBezierSurface + { + CubicBezierCurve LR; + + __forceinline TensorLinearCubicBezierSurface() {} + + __forceinline TensorLinearCubicBezierSurface(const TensorLinearCubicBezierSurface& curve) + : LR(curve.LR) {} + + __forceinline TensorLinearCubicBezierSurface& operator= (const TensorLinearCubicBezierSurface& other) { + LR = other.LR; return *this; + } + + __forceinline TensorLinearCubicBezierSurface(const CubicBezierCurve& LR) + : LR(LR) {} + + __forceinline TensorLinearCubicBezierSurface(const CubicBezierCurve& L, const CubicBezierCurve& R) + : LR(shuffle<0,1,0,1>(vfloat4(L.v0),vfloat4(R.v0)),shuffle<0,1,0,1>(vfloat4(L.v1),vfloat4(R.v1)),shuffle<0,1,0,1>(vfloat4(L.v2),vfloat4(R.v2)),shuffle<0,1,0,1>(vfloat4(L.v3),vfloat4(R.v3))) {} + + __forceinline CubicBezierCurve getL() const { + return CubicBezierCurve(Vec2fa(LR.v0),Vec2fa(LR.v1),Vec2fa(LR.v2),Vec2fa(LR.v3)); + } + + __forceinline CubicBezierCurve getR() const { + return CubicBezierCurve(Vec2fa(shuffle<2,3,2,3>(LR.v0)),Vec2fa(shuffle<2,3,2,3>(LR.v1)),Vec2fa(shuffle<2,3,2,3>(LR.v2)),Vec2fa(shuffle<2,3,2,3>(LR.v3))); + } + + __forceinline BBox bounds() const + { + const BBox b = LR.bounds(); + const BBox bl(Vec2fa(b.lower),Vec2fa(b.upper)); + const BBox br(Vec2fa(shuffle<2,3,2,3>(b.lower)),Vec2fa(shuffle<2,3,2,3>(b.upper))); + return merge(bl,br); + } + + __forceinline BBox1f bounds(const Vec2fa& axis) const + { + const CubicBezierCurve LRx = LR; + const CubicBezierCurve LRy(shuffle<1,0,3,2>(LR.v0),shuffle<1,0,3,2>(LR.v1),shuffle<1,0,3,2>(LR.v2),shuffle<1,0,3,2>(LR.v3)); + const CubicBezierCurve LRa = cmadd(shuffle<0>(vfloat4(axis)),LRx,shuffle<1>(vfloat4(axis))*LRy); + const BBox Lb = LRa.bounds(); + const BBox Rb(shuffle<3>(Lb.lower),shuffle<3>(Lb.upper)); + const BBox b = merge(Lb,Rb); + return BBox1f(b.lower[0],b.upper[0]); + } + + __forceinline TensorLinearCubicBezierSurface xfm(const Vec2fa& dx) const + { + const CubicBezierCurve LRx = LR; + const CubicBezierCurve LRy(shuffle<1,0,3,2>(LR.v0),shuffle<1,0,3,2>(LR.v1),shuffle<1,0,3,2>(LR.v2),shuffle<1,0,3,2>(LR.v3)); + const CubicBezierCurve LRa = cmadd(shuffle<0>(vfloat4(dx)),LRx,shuffle<1>(vfloat4(dx))*LRy); + return TensorLinearCubicBezierSurface(CubicBezierCurve(LRa.v0[0],LRa.v1[0],LRa.v2[0],LRa.v3[0]), + CubicBezierCurve(LRa.v0[2],LRa.v1[2],LRa.v2[2],LRa.v3[2])); + } + + __forceinline TensorLinearCubicBezierSurface xfm(const Vec2fa& dx, const Vec2fa& p) const + { + const vfloat4 pxyxy = shuffle<0,1,0,1>(vfloat4(p)); + const CubicBezierCurve LRx = LR-pxyxy; + const CubicBezierCurve LRy(shuffle<1,0,3,2>(LR.v0),shuffle<1,0,3,2>(LR.v1),shuffle<1,0,3,2>(LR.v2),shuffle<1,0,3,2>(LR.v3)); + const CubicBezierCurve LRa = cmadd(shuffle<0>(vfloat4(dx)),LRx,shuffle<1>(vfloat4(dx))*LRy); + return TensorLinearCubicBezierSurface(CubicBezierCurve(LRa.v0[0],LRa.v1[0],LRa.v2[0],LRa.v3[0]), + CubicBezierCurve(LRa.v0[2],LRa.v1[2],LRa.v2[2],LRa.v3[2])); + } + + __forceinline TensorLinearCubicBezierSurface clip_u(const Interval1f& u) const { + return TensorLinearCubicBezierSurface(LR.clip(u)); + } + + __forceinline TensorLinearCubicBezierSurface clip_v(const Interval1f& v) const + { + const CubicBezierCurve LL(shuffle<0,1,0,1>(LR.v0),shuffle<0,1,0,1>(LR.v1),shuffle<0,1,0,1>(LR.v2),shuffle<0,1,0,1>(LR.v3)); + const CubicBezierCurve RR(shuffle<2,3,2,3>(LR.v0),shuffle<2,3,2,3>(LR.v1),shuffle<2,3,2,3>(LR.v2),shuffle<2,3,2,3>(LR.v3)); + return TensorLinearCubicBezierSurface(clerp(LL,RR,vfloat4(v.lower,v.lower,v.upper,v.upper))); + } + + __forceinline TensorLinearCubicBezierSurface clip(const Interval1f& u, const Interval1f& v) const { + return clip_v(v).clip_u(u); + } + + __forceinline void split_u(TensorLinearCubicBezierSurface& left, TensorLinearCubicBezierSurface& right, const float u = 0.5f) const + { + CubicBezierCurve LR0,LR1; LR.split(LR0,LR1,u); + new (&left ) TensorLinearCubicBezierSurface(LR0); + new (&right) TensorLinearCubicBezierSurface(LR1); + } + + __forceinline TensorLinearCubicBezierSurface vsplit_u(vboolx& valid, const BBox1f& u) const { + valid = true; clear(valid,VSIZEX-1); + return TensorLinearCubicBezierSurface(getL().split(u),getR().split(u)); + } + + __forceinline Vec2fa eval(const float u, const float v) const + { + const vfloat4 p = LR.eval(u); + return Vec2fa(lerp(shuffle<0,1,0,1>(p),shuffle<2,3,2,3>(p),v)); + } + + __forceinline Vec2fa eval_du(const float u, const float v) const + { + const vfloat4 dpdu = LR.eval_dt(u); + return Vec2fa(lerp(shuffle<0,1,0,1>(dpdu),shuffle<2,3,2,3>(dpdu),v)); + } + + __forceinline Vec2fa eval_dv(const float u, const float v) const + { + const vfloat4 p = LR.eval(u); + return Vec2fa(shuffle<2,3,2,3>(p)-shuffle<0,1,0,1>(p)); + } + + __forceinline void eval(const float u, const float v, Vec2fa& p, Vec2fa& dpdu, Vec2fa& dpdv) const + { + vfloat4 p0, dp0du; LR.eval(u,p0,dp0du); + p = Vec2fa(lerp(shuffle<0,1,0,1>(p0),shuffle<2,3,2,3>(p0),v)); + dpdu = Vec2fa(lerp(shuffle<0,1,0,1>(dp0du),shuffle<2,3,2,3>(dp0du),v)); + dpdv = Vec2fa(shuffle<2,3,2,3>(p0)-shuffle<0,1,0,1>(p0)); + } + + __forceinline TensorLinearQuadraticBezierSurface derivative_u() const { + return TensorLinearQuadraticBezierSurface(LR.derivative()); + } + + __forceinline CubicBezierCurve derivative_v() const { + return getR()-getL(); + } + + __forceinline Vec2fa axis_u() const + { + const CubicBezierCurve L = getL(); + const CubicBezierCurve R = getR(); + return (L.end()-L.begin())+(R.end()-R.begin()); + } + + __forceinline Vec2fa axis_v() const + { + const CubicBezierCurve L = getL(); + const CubicBezierCurve R = getR(); + return (R.begin()-L.begin())+(R.end()-L.end()); + } + + friend embree_ostream operator<<(embree_ostream cout, const TensorLinearCubicBezierSurface& a) + { + return cout << "TensorLinearCubicBezierSurface" << embree_endl + << "{" << embree_endl + << " L = " << a.getL() << ", " << embree_endl + << " R = " << a.getR() << embree_endl + << "}"; + } + }; + + typedef TensorLinearCubicBezierSurface TensorLinearCubicBezierSurface1f; + typedef TensorLinearCubicBezierSurface TensorLinearCubicBezierSurface2fa; + typedef TensorLinearCubicBezierSurface TensorLinearCubicBezierSurface3fa; + } +} diff --git a/thirdparty/embree/kernels/subdiv/patch.h b/thirdparty/embree/kernels/subdiv/patch.h new file mode 100644 index 000000000000..d58241b96d6e --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/patch.h @@ -0,0 +1,371 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "catmullclark_patch.h" +#include "bilinear_patch.h" +#include "bspline_patch.h" +#include "bezier_patch.h" +#include "gregory_patch.h" +#include "tessellation_cache.h" + +#if 1 +#define PATCH_DEBUG_SUBDIVISION(ptr,x,y,z) +#else +#define PATCH_DEBUG_SUBDIVISION(ptr,x,y,z) \ + { \ + size_t hex = (size_t)ptr; \ + for (size_t i=0; i<4; i++) hex = hex ^ (hex >> 8); \ + const float c = (float)(((hex >> 0) ^ (hex >> 4) ^ (hex >> 8) ^ (hex >> 12) ^ (hex >> 16))&0xf)/15.0f; \ + if (P) *P = Vertex(0.5f+0.5f*x,0.5f+0.5f*y,0.5f+0.5f*z,0.0f); \ + } +#endif + +#define PATCH_MAX_CACHE_DEPTH 2 +//#define PATCH_MIN_RESOLUTION 1 // FIXME: not yet completely implemented +#define PATCH_MAX_EVAL_DEPTH_IRREGULAR 10 // maximum evaluation depth at irregular vertices (has to be larger or equal than PATCH_MAX_CACHE_DEPTH) +#define PATCH_MAX_EVAL_DEPTH_CREASE 10 // maximum evaluation depth at crease features (has to be larger or equal than PATCH_MAX_CACHE_DEPTH) +#define PATCH_USE_GREGORY 1 // 0 = no gregory, 1 = fill, 2 = as early as possible + +#if PATCH_USE_GREGORY==2 +#define PATCH_USE_BEZIER_PATCH 1 // enable use of bezier instead of b-spline patches +#else +#define PATCH_USE_BEZIER_PATCH 0 // enable use of bezier instead of b-spline patches +#endif + +#if PATCH_USE_BEZIER_PATCH +# define RegularPatch BezierPatch +# define RegularPatchT BezierPatchT +#else +# define RegularPatch BSplinePatch +# define RegularPatchT BSplinePatchT +#endif + +#if PATCH_USE_GREGORY +#define IrregularFillPatch GregoryPatch +#define IrregularFillPatchT GregoryPatchT +#else +#define IrregularFillPatch BilinearPatch +#define IrregularFillPatchT BilinearPatchT +#endif + +namespace embree +{ + template + struct __aligned(64) PatchT + { + public: + + typedef GeneralCatmullClarkPatchT GeneralCatmullClarkPatch; + typedef CatmullClarkPatchT CatmullClarkPatch; + typedef CatmullClark1RingT CatmullClarkRing; + typedef BezierCurveT BezierCurve; + + enum Type { + INVALID_PATCH = 0, + BILINEAR_PATCH = 1, + BSPLINE_PATCH = 2, + BEZIER_PATCH = 3, + GREGORY_PATCH = 4, + SUBDIVIDED_GENERAL_PATCH = 7, + SUBDIVIDED_QUAD_PATCH = 8, + EVAL_PATCH = 9, + }; + + struct Ref + { + __forceinline Ref(void* p = nullptr) + : ptr((size_t)p) {} + + __forceinline operator bool() const { return ptr != 0; } + __forceinline operator size_t() const { return ptr; } + + __forceinline Ref (Type ty, void* in) + : ptr(((size_t)in)+ty) { assert((((size_t)in) & 0xF) == 0); } + + __forceinline Type type () const { return (Type)(ptr & 0xF); } + __forceinline void* object() const { return (void*) (ptr & ~0xF); } + + size_t ptr; + }; + + struct EvalPatch + { + /* creates EvalPatch from a CatmullClarkPatch */ + template + __noinline static Ref create(const Allocator& alloc, const CatmullClarkPatch& patch) + { + size_t ofs = 0, bytes = patch.bytes(); + void* ptr = alloc(bytes); + patch.serialize(ptr,ofs); + assert(ofs == bytes); + return Ref(EVAL_PATCH, ptr); + } + }; + + struct BilinearPatch + { + /* creates BilinearPatch from a CatmullClarkPatch */ + template + __noinline static Ref create(const Allocator& alloc, const CatmullClarkPatch& patch, + const BezierCurve* border0, const BezierCurve* border1, const BezierCurve* border2, const BezierCurve* border3) { + return Ref(BILINEAR_PATCH, new (alloc(sizeof(BilinearPatch))) BilinearPatch(patch)); + } + + __forceinline BilinearPatch (const CatmullClarkPatch& patch) + : patch(patch) {} + + /* creates BilinearPatch from 4 vertices */ + template + __noinline static Ref create(const Allocator& alloc, const HalfEdge* edge, const char* vertices, size_t stride) { + return Ref(BILINEAR_PATCH, new (alloc(sizeof(BilinearPatch))) BilinearPatch(edge,vertices,stride)); + } + + __forceinline BilinearPatch (const HalfEdge* edge, const char* vertices, size_t stride) + : patch(edge,vertices,stride) {} + + public: + BilinearPatchT patch; + }; + + struct BSplinePatch + { + /* creates BSplinePatch from a half edge */ + template + __noinline static Ref create(const Allocator& alloc, const HalfEdge* edge, const char* vertices, size_t stride) { + return Ref(BSPLINE_PATCH, new (alloc(sizeof(BSplinePatch))) BSplinePatch(edge,vertices,stride)); + } + + __forceinline BSplinePatch (const HalfEdge* edge, const char* vertices, size_t stride) + : patch(edge,vertices,stride) {} + + /* creates BSplinePatch from a CatmullClarkPatch */ + template + __noinline static Ref create(const Allocator& alloc, const CatmullClarkPatch& patch, + const BezierCurve* border0, const BezierCurve* border1, const BezierCurve* border2, const BezierCurve* border3) { + return Ref(BSPLINE_PATCH, new (alloc(sizeof(BSplinePatch))) BSplinePatch(patch,border0,border1,border2,border3)); + } + + __forceinline BSplinePatch (const CatmullClarkPatch& patch, const BezierCurve* border0, const BezierCurve* border1, const BezierCurve* border2, const BezierCurve* border3) + : patch(patch,border0,border1,border2,border3) {} + + public: + BSplinePatchT patch; + }; + + struct BezierPatch + { + /* creates BezierPatch from a half edge */ + template + __noinline static Ref create(const Allocator& alloc, const HalfEdge* edge, const char* vertices, size_t stride) { + return Ref(BEZIER_PATCH, new (alloc(sizeof(BezierPatch))) BezierPatch(edge,vertices,stride)); + } + + __forceinline BezierPatch (const HalfEdge* edge, const char* vertices, size_t stride) + : patch(edge,vertices,stride) {} + + /* creates Bezier from a CatmullClarkPatch */ + template + __noinline static Ref create(const Allocator& alloc, const CatmullClarkPatch& patch, + const BezierCurve* border0, const BezierCurve* border1, const BezierCurve* border2, const BezierCurve* border3) { + return Ref(BEZIER_PATCH, new (alloc(sizeof(BezierPatch))) BezierPatch(patch,border0,border1,border2,border3)); + } + + __forceinline BezierPatch (const CatmullClarkPatch& patch, const BezierCurve* border0, const BezierCurve* border1, const BezierCurve* border2, const BezierCurve* border3) + : patch(patch,border0,border1,border2,border3) {} + + public: + BezierPatchT patch; + }; + + struct GregoryPatch + { + /* creates GregoryPatch from half edge */ + template + __noinline static Ref create(const Allocator& alloc, const HalfEdge* edge, const char* vertices, size_t stride) { + return Ref(GREGORY_PATCH, new (alloc(sizeof(GregoryPatch))) GregoryPatch(edge,vertices,stride)); + } + + __forceinline GregoryPatch (const HalfEdge* edge, const char* vertices, size_t stride) + : patch(CatmullClarkPatch(edge,vertices,stride)) {} + + /* creates GregoryPatch from CatmullClarkPatch */ + template + __noinline static Ref create(const Allocator& alloc, const CatmullClarkPatch& patch, + const BezierCurve* border0, const BezierCurve* border1, const BezierCurve* border2, const BezierCurve* border3) { + return Ref(GREGORY_PATCH, new (alloc(sizeof(GregoryPatch))) GregoryPatch(patch,border0,border1,border2,border3)); + } + + __forceinline GregoryPatch (const CatmullClarkPatch& patch, const BezierCurve* border0, const BezierCurve* border1, const BezierCurve* border2, const BezierCurve* border3) + : patch(patch,border0,border1,border2,border3) {} + + public: + GregoryPatchT patch; + }; + + struct SubdividedQuadPatch + { + template + __noinline static Ref create(const Allocator& alloc, Ref children[4]) { + return Ref(SUBDIVIDED_QUAD_PATCH, new (alloc(sizeof(SubdividedQuadPatch))) SubdividedQuadPatch(children)); + } + + __forceinline SubdividedQuadPatch(Ref children[4]) { + for (size_t i=0; i<4; i++) child[i] = children[i]; + } + + public: + Ref child[4]; + }; + + struct SubdividedGeneralPatch + { + template + __noinline static Ref create(const Allocator& alloc, Ref* children, const unsigned N) { + return Ref(SUBDIVIDED_GENERAL_PATCH, new (alloc(sizeof(SubdividedGeneralPatch))) SubdividedGeneralPatch(children,N)); + } + + __forceinline SubdividedGeneralPatch(Ref* children, const unsigned N) : N(N) { + for (unsigned i=0; i + __noinline static Ref create(const Allocator& alloc, const HalfEdge* edge, const char* vertices, size_t stride) + { + if (PATCH_MAX_CACHE_DEPTH == 0) + return nullptr; + + Ref child(0); + switch (edge->patch_type) { + case HalfEdge::BILINEAR_PATCH: child = BilinearPatch::create(alloc,edge,vertices,stride); break; + case HalfEdge::REGULAR_QUAD_PATCH: child = RegularPatch::create(alloc,edge,vertices,stride); break; +#if PATCH_USE_GREGORY == 2 + case HalfEdge::IRREGULAR_QUAD_PATCH: child = GregoryPatch::create(alloc,edge,vertices,stride); break; +#endif + default: { + GeneralCatmullClarkPatch patch(edge,vertices,stride); + child = PatchT::create(alloc,patch,edge,vertices,stride,0); + } + } + return child; + } + + template + __noinline static Ref create(const Allocator& alloc, GeneralCatmullClarkPatch& patch, const HalfEdge* edge, const char* vertices, size_t stride, size_t depth) + { + /* convert into standard quad patch if possible */ + if (likely(patch.isQuadPatch())) + { + CatmullClarkPatch qpatch; patch.init(qpatch); + return PatchT::create(alloc,qpatch,edge,vertices,stride,depth); + } + + /* do only cache up to some depth */ + if (depth >= PATCH_MAX_CACHE_DEPTH) + return nullptr; + + /* subdivide patch */ + unsigned N; + array_t patches; + patch.subdivide(patches,N); + + if (N == 4) + { + Ref child[4]; +#if PATCH_USE_GREGORY == 2 + BezierCurve borders[GeneralCatmullClarkPatch::SIZE]; patch.getLimitBorder(borders); + BezierCurve border0l,border0r; borders[0].subdivide(border0l,border0r); + BezierCurve border1l,border1r; borders[1].subdivide(border1l,border1r); + BezierCurve border2l,border2r; borders[2].subdivide(border2l,border2r); + BezierCurve border3l,border3r; borders[3].subdivide(border3l,border3r); + GeneralCatmullClarkPatch::fix_quad_ring_order(patches); + child[0] = PatchT::create(alloc,patches[0],edge,vertices,stride,depth+1,&border0l,nullptr,nullptr,&border3r); + child[1] = PatchT::create(alloc,patches[1],edge,vertices,stride,depth+1,&border0r,&border1l,nullptr,nullptr); + child[2] = PatchT::create(alloc,patches[2],edge,vertices,stride,depth+1,nullptr,&border1r,&border2l,nullptr); + child[3] = PatchT::create(alloc,patches[3],edge,vertices,stride,depth+1,nullptr,nullptr,&border2r,&border3l); +#else + GeneralCatmullClarkPatch::fix_quad_ring_order(patches); + for (size_t i=0; i<4; i++) + child[i] = PatchT::create(alloc,patches[i],edge,vertices,stride,depth+1); +#endif + return SubdividedQuadPatch::create(alloc,child); + } + else + { + assert(N=max_eval_depth; +//#else + return depth>=max_eval_depth; +//#endif + } + + template + __noinline static Ref create(const Allocator& alloc, CatmullClarkPatch& patch, const HalfEdge* edge, const char* vertices, size_t stride, size_t depth, + const BezierCurve* border0 = nullptr, const BezierCurve* border1 = nullptr, const BezierCurve* border2 = nullptr, const BezierCurve* border3 = nullptr) + { + const typename CatmullClarkPatch::Type ty = patch.type(); + if (unlikely(final(patch,ty,depth))) { + if (ty & CatmullClarkRing::TYPE_REGULAR) return RegularPatch::create(alloc,patch,border0,border1,border2,border3); + else return IrregularFillPatch::create(alloc,patch,border0,border1,border2,border3); + } + else if (ty & CatmullClarkRing::TYPE_REGULAR_CREASES) { + assert(depth > 0); return RegularPatch::create(alloc,patch,border0,border1,border2,border3); + } +#if PATCH_USE_GREGORY == 2 + else if (ty & CatmullClarkRing::TYPE_GREGORY_CREASES) { + assert(depth > 0); return GregoryPatch::create(alloc,patch,border0,border1,border2,border3); + } +#endif + else if (depth >= PATCH_MAX_CACHE_DEPTH) { + return EvalPatch::create(alloc,patch); + } + + else + { + Ref child[4]; + array_t patches; + patch.subdivide(patches); + + for (size_t i=0; i<4; i++) + child[i] = PatchT::create(alloc,patches[i],edge,vertices,stride,depth+1); + return SubdividedQuadPatch::create(alloc,child); + } + } + }; + + typedef PatchT Patch3fa; +} diff --git a/thirdparty/embree/kernels/subdiv/patch_eval.h b/thirdparty/embree/kernels/subdiv/patch_eval.h new file mode 100644 index 000000000000..482d015fa312 --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/patch_eval.h @@ -0,0 +1,129 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "patch.h" +#include "feature_adaptive_eval.h" + +namespace embree +{ + namespace isa + { + template + struct PatchEval + { + public: + + typedef PatchT Patch; + typedef typename Patch::Ref Ref; + typedef CatmullClarkPatchT CatmullClarkPatch; + + PatchEval (SharedLazyTessellationCache::CacheEntry& entry, size_t commitCounter, + const HalfEdge* edge, const char* vertices, size_t stride, const float u, const float v, + Vertex* P, Vertex* dPdu, Vertex* dPdv, Vertex* ddPdudu, Vertex* ddPdvdv, Vertex* ddPdudv) + : P(P), dPdu(dPdu), dPdv(dPdv), ddPdudu(ddPdudu), ddPdvdv(ddPdvdv), ddPdudv(ddPdudv) + { + /* conservative time for the very first allocation */ + auto time = SharedLazyTessellationCache::sharedLazyTessellationCache.getTime(commitCounter); + + Ref patch = SharedLazyTessellationCache::lookup(entry,commitCounter,[&] () { + auto alloc = [&](size_t bytes) { return SharedLazyTessellationCache::malloc(bytes); }; + return Patch::create(alloc,edge,vertices,stride); + },true); + + auto curTime = SharedLazyTessellationCache::sharedLazyTessellationCache.getTime(commitCounter); + const bool allAllocationsValid = SharedLazyTessellationCache::validTime(time,curTime); + + if (patch && allAllocationsValid && eval(patch,u,v,1.0f,0)) { + SharedLazyTessellationCache::unlock(); + return; + } + SharedLazyTessellationCache::unlock(); + FeatureAdaptiveEval(edge,vertices,stride,u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv); + PATCH_DEBUG_SUBDIVISION(edge,c,-1,-1); + } + + __forceinline bool eval_quad(const typename Patch::SubdividedQuadPatch* This, const float u, const float v, const float dscale, const size_t depth) + { + if (v < 0.5f) { + if (u < 0.5f) return eval(This->child[0],2.0f*u,2.0f*v,2.0f*dscale,depth+1); + else return eval(This->child[1],2.0f*u-1.0f,2.0f*v,2.0f*dscale,depth+1); + } else { + if (u > 0.5f) return eval(This->child[2],2.0f*u-1.0f,2.0f*v-1.0f,2.0f*dscale,depth+1); + else return eval(This->child[3],2.0f*u,2.0f*v-1.0f,2.0f*dscale,depth+1); + } + } + + bool eval_general(const typename Patch::SubdividedGeneralPatch* This, const float U, const float V, const size_t depth) + { + const unsigned l = (unsigned) floor(0.5f*U); const float u = 2.0f*frac(0.5f*U)-0.5f; + const unsigned h = (unsigned) floor(0.5f*V); const float v = 2.0f*frac(0.5f*V)-0.5f; + const unsigned i = 4*h+l; assert(iN); + return eval(This->child[i],u,v,1.0f,depth+1); + } + + bool eval(Ref This, const float& u, const float& v, const float dscale, const size_t depth) + { + if (!This) return false; + //PRINT(depth); + //PRINT2(u,v); + + switch (This.type()) + { + case Patch::BILINEAR_PATCH: { + //PRINT("bilinear"); + ((typename Patch::BilinearPatch*)This.object())->patch.eval(u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale); + PATCH_DEBUG_SUBDIVISION(This,-1,c,c); + return true; + } + case Patch::BSPLINE_PATCH: { + //PRINT("bspline"); + ((typename Patch::BSplinePatch*)This.object())->patch.eval(u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale); + PATCH_DEBUG_SUBDIVISION(This,-1,c,-1); + return true; + } + case Patch::BEZIER_PATCH: { + //PRINT("bezier"); + ((typename Patch::BezierPatch*)This.object())->patch.eval(u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale); + PATCH_DEBUG_SUBDIVISION(This,-1,c,-1); + return true; + } + case Patch::GREGORY_PATCH: { + //PRINT("gregory"); + ((typename Patch::GregoryPatch*)This.object())->patch.eval(u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale); + PATCH_DEBUG_SUBDIVISION(This,-1,-1,c); + return true; + } + case Patch::SUBDIVIDED_QUAD_PATCH: { + //PRINT("subdivided quad"); + return eval_quad(((typename Patch::SubdividedQuadPatch*)This.object()),u,v,dscale,depth); + } + case Patch::SUBDIVIDED_GENERAL_PATCH: { + //PRINT("general_patch"); + assert(dscale == 1.0f); + return eval_general(((typename Patch::SubdividedGeneralPatch*)This.object()),u,v,depth); + } + case Patch::EVAL_PATCH: { + //PRINT("eval_patch"); + CatmullClarkPatch patch; patch.deserialize(This.object()); + FeatureAdaptiveEval(patch,u,v,dscale,depth,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv); + return true; + } + default: + assert(false); + return false; + } + } + + private: + Vertex* const P; + Vertex* const dPdu; + Vertex* const dPdv; + Vertex* const ddPdudu; + Vertex* const ddPdvdv; + Vertex* const ddPdudv; + }; + } +} + diff --git a/thirdparty/embree/kernels/subdiv/patch_eval_grid.h b/thirdparty/embree/kernels/subdiv/patch_eval_grid.h new file mode 100644 index 000000000000..c05db55f4cde --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/patch_eval_grid.h @@ -0,0 +1,245 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "patch.h" +#include "feature_adaptive_eval_grid.h" + +namespace embree +{ + namespace isa + { + struct PatchEvalGrid + { + typedef Patch3fa Patch; + typedef Patch::Ref Ref; + typedef GeneralCatmullClarkPatch3fa GeneralCatmullClarkPatch; + typedef CatmullClarkPatch3fa CatmullClarkPatch; + typedef BSplinePatch3fa BSplinePatch; + typedef BezierPatch3fa BezierPatch; + typedef GregoryPatch3fa GregoryPatch; + typedef BilinearPatch3fa BilinearPatch; + + private: + const unsigned x0,x1; + const unsigned y0,y1; + const unsigned swidth,sheight; + const float rcp_swidth, rcp_sheight; + float* const Px; + float* const Py; + float* const Pz; + float* const U; + float* const V; + float* const Nx; + float* const Ny; + float* const Nz; + const unsigned dwidth,dheight; + unsigned count; + + public: + + PatchEvalGrid (Ref patch, unsigned subPatch, + const unsigned x0, const unsigned x1, const unsigned y0, const unsigned y1, const unsigned swidth, const unsigned sheight, + float* Px, float* Py, float* Pz, float* U, float* V, + float* Nx, float* Ny, float* Nz, + const unsigned dwidth, const unsigned dheight) + : x0(x0), x1(x1), y0(y0), y1(y1), swidth(swidth), sheight(sheight), rcp_swidth(1.0f/(swidth-1.0f)), rcp_sheight(1.0f/(sheight-1.0f)), + Px(Px), Py(Py), Pz(Pz), U(U), V(V), Nx(Nx), Ny(Ny), Nz(Nz), dwidth(dwidth), dheight(dheight), count(0) + { + assert(swidth < (2<<20) && sheight < (2<<20)); + const BBox2f srange(Vec2f(0.0f,0.0f),Vec2f(float(swidth-1),float(sheight-1))); + const BBox2f erange(Vec2f(float(x0),float(y0)),Vec2f((float)x1,(float)y1)); + bool done MAYBE_UNUSED = eval(patch,subPatch,srange,erange); + assert(done); + assert(count == (x1-x0+1)*(y1-y0+1)); + } + + template + __forceinline void evalLocalGrid(const Patch* patch, const BBox2f& srange, const int lx0, const int lx1, const int ly0, const int ly1) + { + const float scale_x = rcp(srange.upper.x-srange.lower.x); + const float scale_y = rcp(srange.upper.y-srange.lower.y); + count += (lx1-lx0)*(ly1-ly0); + +#if 0 + for (unsigned iy=ly0; iypatch.eval(lu,lv); + const float u = float(ix)*rcp_swidth; + const float v = float(iy)*rcp_sheight; + const int ofs = (iy-y0)*dwidth+(ix-x0); + Px[ofs] = p.x; + Py[ofs] = p.y; + Pz[ofs] = p.z; + U[ofs] = u; + V[ofs] = v; + } + } +#else + foreach2(lx0,lx1,ly0,ly1,[&](const vboolx& valid, const vintx& ix, const vintx& iy) { + const vfloatx lu = select(ix == swidth -1, vfloatx(1.0f), (vfloatx(ix)-srange.lower.x)*scale_x); + const vfloatx lv = select(iy == sheight-1, vfloatx(1.0f), (vfloatx(iy)-srange.lower.y)*scale_y); + const Vec3vfx p = patch->patch.eval(lu,lv); + Vec3vfx n = zero; + if (unlikely(Nx != nullptr)) n = normalize_safe(patch->patch.normal(lu,lv)); + const vfloatx u = vfloatx(ix)*rcp_swidth; + const vfloatx v = vfloatx(iy)*rcp_sheight; + const vintx ofs = (iy-y0)*dwidth+(ix-x0); + if (likely(all(valid)) && all(iy==iy[0])) { + const unsigned ofs2 = ofs[0]; + vfloatx::storeu(Px+ofs2,p.x); + vfloatx::storeu(Py+ofs2,p.y); + vfloatx::storeu(Pz+ofs2,p.z); + vfloatx::storeu(U+ofs2,u); + vfloatx::storeu(V+ofs2,v); + if (unlikely(Nx != nullptr)) { + vfloatx::storeu(Nx+ofs2,n.x); + vfloatx::storeu(Ny+ofs2,n.y); + vfloatx::storeu(Nz+ofs2,n.z); + } + } else { + foreach_unique_index(valid,iy,[&](const vboolx& valid, const int iy0, const int j) { + const unsigned ofs2 = ofs[j]-j; + vfloatx::storeu(valid,Px+ofs2,p.x); + vfloatx::storeu(valid,Py+ofs2,p.y); + vfloatx::storeu(valid,Pz+ofs2,p.z); + vfloatx::storeu(valid,U+ofs2,u); + vfloatx::storeu(valid,V+ofs2,v); + if (unlikely(Nx != nullptr)) { + vfloatx::storeu(valid,Nx+ofs2,n.x); + vfloatx::storeu(valid,Ny+ofs2,n.y); + vfloatx::storeu(valid,Nz+ofs2,n.z); + } + }); + } + }); +#endif + } + + bool eval(Ref This, const BBox2f& srange, const BBox2f& erange, const unsigned depth) + { + if (erange.empty()) + return true; + + const int lx0 = (int) ceilf(erange.lower.x); + const int lx1 = (int) ceilf(erange.upper.x) + (erange.upper.x == x1 && (srange.lower.x < erange.upper.x || erange.upper.x == 0)); + const int ly0 = (int) ceilf(erange.lower.y); + const int ly1 = (int) ceilf(erange.upper.y) + (erange.upper.y == y1 && (srange.lower.y < erange.upper.y || erange.upper.y == 0)); + if (lx0 >= lx1 || ly0 >= ly1) + return true; + + if (!This) + return false; + + switch (This.type()) + { + case Patch::BILINEAR_PATCH: { + evalLocalGrid((Patch::BilinearPatch*)This.object(),srange,lx0,lx1,ly0,ly1); + return true; + } + case Patch::BSPLINE_PATCH: { + evalLocalGrid((Patch::BSplinePatch*)This.object(),srange,lx0,lx1,ly0,ly1); + return true; + } + case Patch::BEZIER_PATCH: { + evalLocalGrid((Patch::BezierPatch*)This.object(),srange,lx0,lx1,ly0,ly1); + return true; + } + case Patch::GREGORY_PATCH: { + evalLocalGrid((Patch::GregoryPatch*)This.object(),srange,lx0,lx1,ly0,ly1); + return true; + } + case Patch::SUBDIVIDED_QUAD_PATCH: + { + const Vec2f c = srange.center(); + const BBox2f srange0(srange.lower,c); + const BBox2f srange1(Vec2f(c.x,srange.lower.y),Vec2f(srange.upper.x,c.y)); + const BBox2f srange2(c,srange.upper); + const BBox2f srange3(Vec2f(srange.lower.x,c.y),Vec2f(c.x,srange.upper.y)); + + Patch::SubdividedQuadPatch* patch = (Patch::SubdividedQuadPatch*)This.object(); + eval(patch->child[0],srange0,intersect(srange0,erange),depth+1); + eval(patch->child[1],srange1,intersect(srange1,erange),depth+1); + eval(patch->child[2],srange2,intersect(srange2,erange),depth+1); + eval(patch->child[3],srange3,intersect(srange3,erange),depth+1); + return true; + } + case Patch::EVAL_PATCH: { + CatmullClarkPatch patch; patch.deserialize(This.object()); + FeatureAdaptiveEvalGrid(patch,srange,erange,depth,x0,x1,y0,y1,swidth,sheight,Px,Py,Pz,U,V,Nx,Ny,Nz,dwidth,dheight); + count += (lx1-lx0)*(ly1-ly0); + return true; + } + default: + assert(false); + return false; + } + } + + bool eval(Ref This, unsigned subPatch, const BBox2f& srange, const BBox2f& erange) + { + if (!This) + return false; + + switch (This.type()) + { + case Patch::SUBDIVIDED_GENERAL_PATCH: { + Patch::SubdividedGeneralPatch* patch = (Patch::SubdividedGeneralPatch*)This.object(); + assert(subPatch < patch->N); + return eval(patch->child[subPatch],srange,erange,1); + } + default: + assert(subPatch == 0); + return eval(This,srange,erange,0); + } + } + }; + + __forceinline unsigned patch_eval_subdivision_count (const HalfEdge* h) + { + const unsigned N = h->numEdges(); + if (N == 4) return 1; + else return N; + } + + template + inline void patch_eval_subdivision (const HalfEdge* h, Tessellator tessellator) + { + const unsigned N = h->numEdges(); + int neighborSubdiv[GeneralCatmullClarkPatch3fa::SIZE]; // FIXME: use array_t + float levels[GeneralCatmullClarkPatch3fa::SIZE]; + for (unsigned i=0; ihasOpposite() ? h->opposite()->numEdges() != 4 : 0; + levels[i] = h->edge_level; + h = h->next(); + } + if (N == 4) + { + const Vec2f uv[4] = { Vec2f(0.0f,0.0f), Vec2f(1.0f,0.0f), Vec2f(1.0f,1.0f), Vec2f(0.0f,1.0f) }; + tessellator(uv,neighborSubdiv,levels,0); + } + else + { + for (unsigned i=0; i 16"); + const int h = (i >> 2) & 3, l = i & 3; + const Vec2f subPatchID((float)l,(float)h); + const Vec2f uv[4] = { 2.0f*subPatchID + (0.5f+Vec2f(0.0f,0.0f)), + 2.0f*subPatchID + (0.5f+Vec2f(1.0f,0.0f)), + 2.0f*subPatchID + (0.5f+Vec2f(1.0f,1.0f)), + 2.0f*subPatchID + (0.5f+Vec2f(0.0f,1.0f)) }; + const int neighborSubdiv1[4] = { 0,0,0,0 }; + const float levels1[4] = { 0.5f*levels[(i+0)%N], 0.5f*levels[(i+0)%N], 0.5f*levels[(i+N-1)%N], 0.5f*levels[(i+N-1)%N] }; + tessellator(uv,neighborSubdiv1,levels1,i); + } + } + } + } +} + diff --git a/thirdparty/embree/kernels/subdiv/patch_eval_simd.h b/thirdparty/embree/kernels/subdiv/patch_eval_simd.h new file mode 100644 index 000000000000..28016d9e20d9 --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/patch_eval_simd.h @@ -0,0 +1,127 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "patch.h" +#include "feature_adaptive_eval_simd.h" + +namespace embree +{ + namespace isa + { + template + struct PatchEvalSimd + { + public: + + typedef PatchT Patch; + typedef typename Patch::Ref Ref; + typedef CatmullClarkPatchT CatmullClarkPatch; + + PatchEvalSimd (SharedLazyTessellationCache::CacheEntry& entry, size_t commitCounter, + const HalfEdge* edge, const char* vertices, size_t stride, const vbool& valid0, const vfloat& u, const vfloat& v, + float* P, float* dPdu, float* dPdv, float* ddPdudu, float* ddPdvdv, float* ddPdudv, const size_t dstride, const size_t N) + : P(P), dPdu(dPdu), dPdv(dPdv), ddPdudu(ddPdudu), ddPdvdv(ddPdvdv), ddPdudv(ddPdudv), dstride(dstride), N(N) + { + /* conservative time for the very first allocation */ + auto time = SharedLazyTessellationCache::sharedLazyTessellationCache.getTime(commitCounter); + + Ref patch = SharedLazyTessellationCache::lookup(entry,commitCounter,[&] () { + auto alloc = [](size_t bytes) { return SharedLazyTessellationCache::malloc(bytes); }; + return Patch::create(alloc,edge,vertices,stride); + }, true); + + auto curTime = SharedLazyTessellationCache::sharedLazyTessellationCache.getTime(commitCounter); + const bool allAllocationsValid = SharedLazyTessellationCache::validTime(time,curTime); + + patch = allAllocationsValid ? patch : nullptr; + + /* use cached data structure for calculations */ + const vbool valid1 = patch ? eval(valid0,patch,u,v,1.0f,0) : vbool(false); + SharedLazyTessellationCache::unlock(); + const vbool valid2 = valid0 & !valid1; + if (any(valid2)) { + FeatureAdaptiveEvalSimd(edge,vertices,stride,valid2,u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dstride,N); + } + } + + vbool eval_quad(const vbool& valid, const typename Patch::SubdividedQuadPatch* This, const vfloat& u, const vfloat& v, const float dscale, const size_t depth) + { + vbool ret = false; + const vbool u0_mask = u < 0.5f, u1_mask = u >= 0.5f; + const vbool v0_mask = v < 0.5f, v1_mask = v >= 0.5f; + const vbool u0v0_mask = valid & u0_mask & v0_mask; + const vbool u0v1_mask = valid & u0_mask & v1_mask; + const vbool u1v0_mask = valid & u1_mask & v0_mask; + const vbool u1v1_mask = valid & u1_mask & v1_mask; + if (any(u0v0_mask)) ret |= eval(u0v0_mask,This->child[0],2.0f*u,2.0f*v,2.0f*dscale,depth+1); + if (any(u1v0_mask)) ret |= eval(u1v0_mask,This->child[1],2.0f*u-1.0f,2.0f*v,2.0f*dscale,depth+1); + if (any(u1v1_mask)) ret |= eval(u1v1_mask,This->child[2],2.0f*u-1.0f,2.0f*v-1.0f,2.0f*dscale,depth+1); + if (any(u0v1_mask)) ret |= eval(u0v1_mask,This->child[3],2.0f*u,2.0f*v-1.0f,2.0f*dscale,depth+1); + return ret; + } + + vbool eval_general(const vbool& valid, const typename Patch::SubdividedGeneralPatch* patch, const vfloat& U, const vfloat& V, const size_t depth) + { + vbool ret = false; + const vint l = (vint)floor(0.5f*U); const vfloat u = 2.0f*frac(0.5f*U)-0.5f; + const vint h = (vint)floor(0.5f*V); const vfloat v = 2.0f*frac(0.5f*V)-0.5f; + const vint i = (h<<2)+l; assert(all(valid,iN)); + foreach_unique(valid,i,[&](const vbool& valid, const int i) { + ret |= eval(valid,patch->child[i],u,v,1.0f,depth+1); + }); + return ret; + } + + vbool eval(const vbool& valid, Ref This, const vfloat& u, const vfloat& v, const float dscale, const size_t depth) + { + if (!This) return false; + switch (This.type()) + { + case Patch::BILINEAR_PATCH: { + ((typename Patch::BilinearPatch*)This.object())->patch.eval(valid,u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale,dstride,N); + return valid; + } + case Patch::BSPLINE_PATCH: { + ((typename Patch::BSplinePatch*)This.object())->patch.eval(valid,u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale,dstride,N); + return valid; + } + case Patch::BEZIER_PATCH: { + ((typename Patch::BezierPatch*)This.object())->patch.eval(valid,u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale,dstride,N); + return valid; + } + case Patch::GREGORY_PATCH: { + ((typename Patch::GregoryPatch*)This.object())->patch.eval(valid,u,v,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dscale,dstride,N); + return valid; + } + case Patch::SUBDIVIDED_QUAD_PATCH: { + return eval_quad(valid,((typename Patch::SubdividedQuadPatch*)This.object()),u,v,dscale,depth); + } + case Patch::SUBDIVIDED_GENERAL_PATCH: { + assert(dscale == 1.0f); + return eval_general(valid,((typename Patch::SubdividedGeneralPatch*)This.object()),u,v,depth); + } + case Patch::EVAL_PATCH: { + CatmullClarkPatch patch; patch.deserialize(This.object()); + FeatureAdaptiveEvalSimd(patch,valid,u,v,dscale,depth,P,dPdu,dPdv,ddPdudu,ddPdvdv,ddPdudv,dstride,N); + return valid; + } + default: + assert(false); + return false; + } + } + + private: + float* const P; + float* const dPdu; + float* const dPdv; + float* const ddPdudu; + float* const ddPdvdv; + float* const ddPdudv; + const size_t dstride; + const size_t N; + }; + } +} diff --git a/thirdparty/embree/kernels/subdiv/subdivpatch1base.h b/thirdparty/embree/kernels/subdiv/subdivpatch1base.h new file mode 100644 index 000000000000..d5bc403cca8c --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/subdivpatch1base.h @@ -0,0 +1,156 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../geometry/primitive.h" +#include "bspline_patch.h" +#include "bezier_patch.h" +#include "gregory_patch.h" +#include "gregory_patch_dense.h" +#include "tessellation.h" +#include "tessellation_cache.h" +#include "gridrange.h" +#include "patch_eval_grid.h" +#include "feature_adaptive_eval_grid.h" +#include "../common/scene_subdiv_mesh.h" + +namespace embree +{ + struct __aligned(64) SubdivPatch1Base + { + public: + + enum Type { + INVALID_PATCH = 0, + BSPLINE_PATCH = 1, + BEZIER_PATCH = 2, + GREGORY_PATCH = 3, + EVAL_PATCH = 5, + BILINEAR_PATCH = 6, + }; + + enum Flags { + TRANSITION_PATCH = 16, + }; + + /*! Default constructor. */ + __forceinline SubdivPatch1Base () {} + + SubdivPatch1Base (const unsigned int gID, + const unsigned int pID, + const unsigned int subPatch, + const SubdivMesh *const mesh, + const size_t time, + const Vec2f uv[4], + const float edge_level[4], + const int subdiv[4], + const int simd_width); + + __forceinline bool needsStitching() const { + return flags & TRANSITION_PATCH; + } + + __forceinline Vec2f getUV(const size_t i) const { + return Vec2f((float)u[i],(float)v[i]) * (8.0f/0x10000); + } + + static void computeEdgeLevels(const float edge_level[4], const int subdiv[4], float level[4]); + static Vec2i computeGridSize(const float level[4]); + bool updateEdgeLevels(const float edge_level[4], const int subdiv[4], const SubdivMesh *const mesh, const int simd_width); + + public: + + __forceinline size_t getGridBytes() const { + const size_t grid_size_xyzuv = (grid_size_simd_blocks * VSIZEX) * 4; + return 64*((grid_size_xyzuv+15) / 16); + } + + __forceinline void write_lock() { mtx.lock(); } + __forceinline void write_unlock() { mtx.unlock(); } + __forceinline bool try_write_lock() { return mtx.try_lock(); } + //__forceinline bool try_read_lock() { return mtx.try_read_lock(); } + + __forceinline void resetRootRef() { + //assert( mtx.hasInitialState() ); + root_ref = SharedLazyTessellationCache::Tag(); + } + + __forceinline SharedLazyTessellationCache::CacheEntry& entry() { + return (SharedLazyTessellationCache::CacheEntry&) root_ref; + } + + public: + __forceinline unsigned int geomID() const { + return geom; + } + + __forceinline unsigned int primID() const { + return prim; + } + + public: + SharedLazyTessellationCache::Tag root_ref; + SpinLock mtx; + + unsigned short u[4]; //!< 16bit discretized u,v coordinates + unsigned short v[4]; + float level[4]; + + unsigned char flags; + unsigned char type; + unsigned short grid_u_res; + unsigned int geom; //!< geometry ID of the subdivision mesh this patch belongs to + unsigned int prim; //!< primitive ID of this subdivision patch + unsigned short grid_v_res; + + unsigned short grid_size_simd_blocks; + unsigned int time_; + + struct PatchHalfEdge { + const HalfEdge* edge; + unsigned subPatch; + }; + + Vec3fa patch_v[4][4]; + + const HalfEdge *edge() const { return ((PatchHalfEdge*)patch_v)->edge; } + unsigned time() const { return time_; } + unsigned subPatch() const { return ((PatchHalfEdge*)patch_v)->subPatch; } + + void set_edge(const HalfEdge *h) const { ((PatchHalfEdge*)patch_v)->edge = h; } + void set_subPatch(const unsigned s) const { ((PatchHalfEdge*)patch_v)->subPatch = s; } + }; + + namespace isa + { + Vec3fa patchEval(const SubdivPatch1Base& patch, const float uu, const float vv); + Vec3fa patchNormal(const SubdivPatch1Base& patch, const float uu, const float vv); + + template + Vec3 patchEval(const SubdivPatch1Base& patch, const simdf& uu, const simdf& vv); + + template + Vec3 patchNormal(const SubdivPatch1Base& patch, const simdf& uu, const simdf& vv); + + + /* eval grid over patch and stich edges when required */ + void evalGrid(const SubdivPatch1Base& patch, + const unsigned x0, const unsigned x1, + const unsigned y0, const unsigned y1, + const unsigned swidth, const unsigned sheight, + float *__restrict__ const grid_x, + float *__restrict__ const grid_y, + float *__restrict__ const grid_z, + float *__restrict__ const grid_u, + float *__restrict__ const grid_v, + const SubdivMesh* const geom); + + /* eval grid over patch and stich edges when required */ + BBox3fa evalGridBounds(const SubdivPatch1Base& patch, + const unsigned x0, const unsigned x1, + const unsigned y0, const unsigned y1, + const unsigned swidth, const unsigned sheight, + const SubdivMesh* const geom); + } +} diff --git a/thirdparty/embree/kernels/subdiv/tessellation.h b/thirdparty/embree/kernels/subdiv/tessellation.h new file mode 100644 index 000000000000..bda1e2d5591e --- /dev/null +++ b/thirdparty/embree/kernels/subdiv/tessellation.h @@ -0,0 +1,161 @@ +// Copyright 2009-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace embree +{ + /* adjust discret tessellation level for feature-adaptive pre-subdivision */ + __forceinline float adjustTessellationLevel(float l, const size_t sublevel) + { + for (size_t i=0; i= 2); + + const float inv_low_rate = rcp((float)(low_rate-1)); + const unsigned int dy = low_rate - 1; + const unsigned int dx = high_rate - 1; + + int p = 2*dy-dx; + + unsigned int offset = 0; + unsigned int y = 0; + float value = 0.0f; + for(unsigned int x=0;x data; + }; + + static __forceinline size_t extractCommitIndex(const int64_t v) { return v >> SharedLazyTessellationCache::COMMIT_INDEX_SHIFT; } + + struct CacheEntry + { + Tag tag; + SpinLock mutex; + }; + + private: + + float *data; + bool hugepages; + size_t size; + size_t maxBlocks; + ThreadWorkState *threadWorkState; + + __aligned(64) std::atomic localTime; + __aligned(64) std::atomic next_block; + __aligned(64) SpinLock reset_state; + __aligned(64) SpinLock linkedlist_mtx; + __aligned(64) std::atomic switch_block_threshold; + __aligned(64) std::atomic numRenderThreads; + + + public: + + + SharedLazyTessellationCache(); + ~SharedLazyTessellationCache(); + + void getNextRenderThreadWorkState(); + + __forceinline size_t maxAllocSize() const { + return switch_block_threshold; + } + + __forceinline size_t getCurrentIndex() { return localTime.load(); } + __forceinline void addCurrentIndex(const size_t i=1) { localTime.fetch_add(i); } + + __forceinline size_t getTime(const size_t globalTime) { + return localTime.load()+NUM_CACHE_SEGMENTS*globalTime; + } + + + __forceinline size_t lockThread (ThreadWorkState *const t_state, const ssize_t plus=1) { return t_state->counter.fetch_add(plus); } + __forceinline size_t unlockThread(ThreadWorkState *const t_state, const ssize_t plus=-1) { assert(isLocked(t_state)); return t_state->counter.fetch_add(plus); } + + __forceinline bool isLocked(ThreadWorkState *const t_state) { return t_state->counter.load() != 0; } + + static __forceinline void lock () { sharedLazyTessellationCache.lockThread(threadState()); } + static __forceinline void unlock() { sharedLazyTessellationCache.unlockThread(threadState()); } + static __forceinline bool isLocked() { return sharedLazyTessellationCache.isLocked(threadState()); } + static __forceinline size_t getState() { return threadState()->counter.load(); } + static __forceinline void lockThreadLoop() { sharedLazyTessellationCache.lockThreadLoop(threadState()); } + + static __forceinline size_t getTCacheTime(const size_t globalTime) { + return sharedLazyTessellationCache.getTime(globalTime); + } + + /* per thread lock */ + __forceinline void lockThreadLoop (ThreadWorkState *const t_state) + { + while(1) + { + size_t lock = SharedLazyTessellationCache::sharedLazyTessellationCache.lockThread(t_state,1); + if (unlikely(lock >= THREAD_BLOCK_ATOMIC_ADD)) + { + /* lock failed wait until sync phase is over */ + sharedLazyTessellationCache.unlockThread(t_state,-1); + sharedLazyTessellationCache.waitForUsersLessEqual(t_state,0); + } + else + break; + } + } + + static __forceinline void* lookup(CacheEntry& entry, size_t globalTime) + { + const int64_t subdiv_patch_root_ref = entry.tag.get(); + CACHE_STATS(SharedTessellationCacheStats::cache_accesses++); + + if (likely(subdiv_patch_root_ref != 0)) + { + const size_t subdiv_patch_root = (subdiv_patch_root_ref & REF_TAG_MASK) + (size_t)sharedLazyTessellationCache.getDataPtr(); + const size_t subdiv_patch_cache_index = extractCommitIndex(subdiv_patch_root_ref); + + if (likely( sharedLazyTessellationCache.validCacheIndex(subdiv_patch_cache_index,globalTime) )) + { + CACHE_STATS(SharedTessellationCacheStats::cache_hits++); + return (void*) subdiv_patch_root; + } + } + CACHE_STATS(SharedTessellationCacheStats::cache_misses++); + return nullptr; + } + + template + static __forceinline auto lookup (CacheEntry& entry, size_t globalTime, const Constructor constructor, const bool before=false) -> decltype(constructor()) + { + ThreadWorkState *t_state = SharedLazyTessellationCache::threadState(); + + while (true) + { + sharedLazyTessellationCache.lockThreadLoop(t_state); + void* patch = SharedLazyTessellationCache::lookup(entry,globalTime); + if (patch) return (decltype(constructor())) patch; + + if (entry.mutex.try_lock()) + { + if (!validTag(entry.tag,globalTime)) + { + auto timeBefore = sharedLazyTessellationCache.getTime(globalTime); + auto ret = constructor(); // thread is locked here! + assert(ret); + /* this should never return nullptr */ + auto timeAfter = sharedLazyTessellationCache.getTime(globalTime); + auto time = before ? timeBefore : timeAfter; + __memory_barrier(); + entry.tag = SharedLazyTessellationCache::Tag(ret,time); + __memory_barrier(); + entry.mutex.unlock(); + return ret; + } + entry.mutex.unlock(); + } + SharedLazyTessellationCache::sharedLazyTessellationCache.unlockThread(t_state); + } + } + + __forceinline bool validCacheIndex(const size_t i, const size_t globalTime) + { +#if FORCE_SIMPLE_FLUSH == 1 + return i == getTime(globalTime); +#else + return i+(NUM_CACHE_SEGMENTS-1) >= getTime(globalTime); +#endif + } + + static __forceinline bool validTime(const size_t oldtime, const size_t newTime) + { + return oldtime+(NUM_CACHE_SEGMENTS-1) >= newTime; + } + + + static __forceinline bool validTag(const Tag& tag, size_t globalTime) + { + const int64_t subdiv_patch_root_ref = tag.get(); + if (subdiv_patch_root_ref == 0) return false; + const size_t subdiv_patch_cache_index = extractCommitIndex(subdiv_patch_root_ref); + return sharedLazyTessellationCache.validCacheIndex(subdiv_patch_cache_index,globalTime); + } + + void waitForUsersLessEqual(ThreadWorkState *const t_state, + const unsigned int users); + + __forceinline size_t alloc(const size_t blocks) + { + if (unlikely(blocks >= switch_block_threshold)) + throw_RTCError(RTC_ERROR_INVALID_OPERATION,"allocation exceeds size of tessellation cache segment"); + + assert(blocks < switch_block_threshold); + size_t index = next_block.fetch_add(blocks); + if (unlikely(index + blocks >= switch_block_threshold)) return (size_t)-1; + return index; + } + + static __forceinline void* malloc(const size_t bytes) + { + size_t block_index = -1; + ThreadWorkState *const t_state = threadState(); + while (true) + { + block_index = sharedLazyTessellationCache.alloc((bytes+BLOCK_SIZE-1)/BLOCK_SIZE); + if (block_index == (size_t)-1) + { + sharedLazyTessellationCache.unlockThread(t_state); + sharedLazyTessellationCache.allocNextSegment(); + sharedLazyTessellationCache.lockThread(t_state); + continue; + } + break; + } + return sharedLazyTessellationCache.getBlockPtr(block_index); + } + + __forceinline void *getBlockPtr(const size_t block_index) + { + assert(block_index < maxBlocks); + assert(data); + assert(block_index*16 <= size); + return (void*)&data[block_index*16]; + } + + __forceinline void* getDataPtr() { return data; } + __forceinline size_t getNumUsedBytes() { return next_block * BLOCK_SIZE; } + __forceinline size_t getMaxBlocks() { return maxBlocks; } + __forceinline size_t getSize() { return size; } + + void allocNextSegment(); + void realloc(const size_t newSize); + + void reset(); + + static SharedLazyTessellationCache sharedLazyTessellationCache; + }; +} diff --git a/thirdparty/embree/pathces/godot-changes.patch b/thirdparty/embree/pathces/godot-changes.patch new file mode 100644 index 000000000000..82b3da44ac6d --- /dev/null +++ b/thirdparty/embree/pathces/godot-changes.patch @@ -0,0 +1,192 @@ +diff --git a/common/math/math.h b/common/math/math.h +index 5af0691a2..1982c27c1 100644 +--- a/common/math/math.h ++++ b/common/math/math.h +@@ -12,7 +12,7 @@ + #include + #include + +-#if defined(__WIN32__) ++#if defined(__WIN32__) && !defined(__MINGW32__) + #if (__MSV_VER <= 1700) + namespace std + { +diff --git a/common/sys/intrinsics.h b/common/sys/intrinsics.h +index 3f0619cac..58f5c3bb4 100644 +--- a/common/sys/intrinsics.h ++++ b/common/sys/intrinsics.h +@@ -11,6 +11,12 @@ + + #include + ++// -- GODOT start -- ++#if defined(__WIN32__) && defined(__MINGW32__) ++#include ++#endif ++// -- GODOT end -- ++ + #if defined(__BMI__) && defined(__GNUC__) && !defined(__INTEL_COMPILER) + #if !defined(_tzcnt_u32) + #define _tzcnt_u32 __tzcnt_u32 +@@ -30,8 +36,14 @@ + #endif + + #if defined(__WIN32__) +-# define NOMINMAX +-# include ++// -- GODOT start -- ++#if !defined(NOMINMAX) ++// -- GODOT end -- ++#define NOMINMAX ++// -- GODOT start -- ++#endif ++#include "windows.h" ++// -- GODOT end -- + #endif + + /* normally defined in pmmintrin.h, but we always need this */ +@@ -413,8 +425,16 @@ namespace embree + + __forceinline void pause_cpu(const size_t N = 8) + { ++// -- GODOT start -- + for (size_t i=0; i +#include + +namespace oidn { + + class Barrier + { + private: + std::mutex m; + std::condition_variable cv; + volatile int count; + + public: + Barrier(int count) : count(count) {} + + void wait() + { + std::unique_lock lk(m); + count--; + + if (count == 0) + { + lk.unlock(); + cv.notify_all(); + } + else + { + cv.wait(lk, [&]{ return count == 0; }); + } + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/common/exception.h b/thirdparty/oidn/common/exception.h new file mode 100644 index 000000000000..18069c6a7dcf --- /dev/null +++ b/thirdparty/oidn/common/exception.h @@ -0,0 +1,45 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include +#include "platform.h" + +namespace oidn { + + class Exception : public std::exception + { + private: + Error error; + const char* message; + + public: + Exception(Error error, const char* message) + : error(error), message(message) {} + + Error code() const noexcept + { + return error; + } + + const char* what() const noexcept override + { + return message; + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/common/platform.cpp b/thirdparty/oidn/common/platform.cpp new file mode 100644 index 000000000000..59a14ff47c1f --- /dev/null +++ b/thirdparty/oidn/common/platform.cpp @@ -0,0 +1,114 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "platform.h" + +namespace oidn { + + // ---------------------------------------------------------------------------- + // Common functions + // ---------------------------------------------------------------------------- + + void* alignedMalloc(size_t size, size_t alignment) + { + if (size == 0) + return nullptr; + + assert((alignment & (alignment-1)) == 0); + void* ptr = _mm_malloc(size, alignment); + + if (ptr == nullptr) + throw std::bad_alloc(); + + return ptr; + } + + void alignedFree(void* ptr) + { + if (ptr) + _mm_free(ptr); + } + + // ---------------------------------------------------------------------------- + // System information + // ---------------------------------------------------------------------------- + + std::string getPlatformName() + { + std::string name; + + #if defined(__linux__) + name = "Linux"; + #elif defined(__FreeBSD__) + name = "FreeBSD"; + #elif defined(__CYGWIN__) + name = "Cygwin"; + #elif defined(_WIN32) + name = "Windows"; + #elif defined(__APPLE__) + name = "macOS"; + #elif defined(__unix__) + name = "Unix"; + #else + return "Unknown"; + #endif + + #if defined(__x86_64__) || defined(_M_X64) || defined(__ia64__) || defined(__aarch64__) + name += " (64-bit)"; + #else + name += " (32-bit)"; + #endif + + return name; + } + + std::string getCompilerName() + { + #if defined(__INTEL_COMPILER) + int mayor = __INTEL_COMPILER / 100 % 100; + int minor = __INTEL_COMPILER % 100; + std::string version = "Intel Compiler "; + version += toString(mayor); + version += "." + toString(minor); + #if defined(__INTEL_COMPILER_UPDATE) + version += "." + toString(__INTEL_COMPILER_UPDATE); + #endif + return version; + #elif defined(__clang__) + return "Clang " __clang_version__; + #elif defined(__GNUC__) + return "GCC " __VERSION__; + #elif defined(_MSC_VER) + std::string version = toString(_MSC_FULL_VER); + version.insert(4, "."); + version.insert(9, "."); + version.insert(2, "."); + return "Visual C++ Compiler " + version; + #else + return "Unknown"; + #endif + } + + std::string getBuildName() + { + #if defined(NDEBUG) + return "Release"; + #else + return "Debug"; + #endif + } + +} // namespace oidn diff --git a/thirdparty/oidn/common/platform.h b/thirdparty/oidn/common/platform.h new file mode 100644 index 000000000000..9373b617b57c --- /dev/null +++ b/thirdparty/oidn/common/platform.h @@ -0,0 +1,131 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #define NOMINMAX + #include +#elif defined(__APPLE__) + #include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "include/OpenImageDenoise/oidn.hpp" + +namespace oidn { + + // ---------------------------------------------------------------------------- + // Macros + // ---------------------------------------------------------------------------- + + #if defined(_WIN32) + // Windows + #if !defined(__noinline) + #define __noinline __declspec(noinline) + #endif + #else + // Unix + #if !defined(__forceinline) + #define __forceinline inline __attribute__((always_inline)) + #endif + #if !defined(__noinline) + #define __noinline __attribute__((noinline)) + #endif + #endif + + #ifndef UNUSED + #define UNUSED(x) ((void)x) + #endif + #ifndef MAYBE_UNUSED + #define MAYBE_UNUSED(x) UNUSED(x) + #endif + + // ---------------------------------------------------------------------------- + // Error handling and debugging + // ---------------------------------------------------------------------------- + + struct Verbose + { + int verbose; + + Verbose(int v = 0) : verbose(v) {} + __forceinline bool isVerbose(int v = 1) const { return v <= verbose; } + }; + + #define OIDN_WARNING(message) { if (isVerbose()) std::cerr << "Warning: " << message << std::endl; } + #define OIDN_FATAL(message) throw std::runtime_error(message); + + // ---------------------------------------------------------------------------- + // Common functions + // ---------------------------------------------------------------------------- + + using std::min; + using std::max; + + template + __forceinline T clamp(const T& value, const T& minValue, const T& maxValue) + { + return min(max(value, minValue), maxValue); + } + + void* alignedMalloc(size_t size, size_t alignment); + void alignedFree(void* ptr); + + template + inline std::string toString(const T& a) + { + std::stringstream sm; + sm << a; + return sm.str(); + } + +#if defined(__APPLE__) + template + bool getSysctl(const char* name, T& value) + { + int64_t result = 0; + size_t size = sizeof(result); + + if (sysctlbyname(name, &result, &size, nullptr, 0) != 0) + return false; + + value = T(result); + return true; + } +#endif + + // ---------------------------------------------------------------------------- + // System information + // ---------------------------------------------------------------------------- + + std::string getPlatformName(); + std::string getCompilerName(); + std::string getBuildName(); + +} // namespace oidn diff --git a/thirdparty/oidn/common/ref.h b/thirdparty/oidn/common/ref.h new file mode 100644 index 000000000000..de44603af205 --- /dev/null +++ b/thirdparty/oidn/common/ref.h @@ -0,0 +1,163 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" + +namespace oidn { + + class RefCount + { + private: + std::atomic count; + + public: + __forceinline RefCount(int count = 0) noexcept : count(count) {} + + __forceinline size_t incRef() noexcept + { + return count.fetch_add(1) + 1; + } + + __forceinline size_t decRef() + { + const size_t newCount = decRefKeep(); + if (newCount == 0) + destroy(); + return newCount; + } + + __forceinline size_t decRefKeep() noexcept + { + return count.fetch_add(-1) - 1; + } + + __forceinline void destroy() + { + delete this; + } + + protected: + // Disable copying + RefCount(const RefCount&) = delete; + RefCount& operator =(const RefCount&) = delete; + + virtual ~RefCount() noexcept = default; + }; + + template + class Ref + { + private: + T* ptr; + + public: + __forceinline Ref() noexcept : ptr(nullptr) {} + __forceinline Ref(std::nullptr_t) noexcept : ptr(nullptr) {} + __forceinline Ref(const Ref& other) noexcept : ptr(other.ptr) { if (ptr) ptr->incRef(); } + __forceinline Ref(Ref&& other) noexcept : ptr(other.ptr) { other.ptr = nullptr; } + __forceinline Ref(T* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); } + + template + __forceinline Ref(const Ref& other) noexcept : ptr(other.get()) { if (ptr) ptr->incRef(); } + + template + __forceinline explicit Ref(Y* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); } + + __forceinline ~Ref() { if (ptr) ptr->decRef(); } + + __forceinline Ref& operator =(const Ref& other) + { + if (other.ptr) + other.ptr->incRef(); + if (ptr) + ptr->decRef(); + ptr = other.ptr; + return *this; + } + + __forceinline Ref& operator =(Ref&& other) + { + if (ptr) + ptr->decRef(); + ptr = other.ptr; + other.ptr = nullptr; + return *this; + } + + __forceinline Ref& operator =(T* other) + { + if (other) + other->incRef(); + if (ptr) + ptr->decRef(); + ptr = other; + return *this; + } + + __forceinline Ref& operator =(std::nullptr_t) + { + if (ptr) + ptr->decRef(); + ptr = nullptr; + return *this; + } + + __forceinline operator bool() const noexcept { return ptr != nullptr; } + + __forceinline T& operator *() const noexcept { return *ptr; } + __forceinline T* operator ->() const noexcept { return ptr; } + + __forceinline T* get() const noexcept { return ptr; } + + __forceinline T* detach() noexcept + { + T* res = ptr; + ptr = nullptr; + return res; + } + }; + + template __forceinline bool operator < (const Ref& a, const Ref& b) noexcept { return a.ptr < b.ptr; } + + template __forceinline bool operator ==(const Ref& a, std::nullptr_t) noexcept { return a.ptr == nullptr; } + template __forceinline bool operator ==(std::nullptr_t, const Ref& b) noexcept { return nullptr == b.ptr; } + template __forceinline bool operator ==(const Ref& a, const Ref& b) noexcept { return a.ptr == b.ptr; } + + template __forceinline bool operator !=(const Ref& a, std::nullptr_t) noexcept { return a.ptr != nullptr; } + template __forceinline bool operator !=(std::nullptr_t, const Ref& b) noexcept { return nullptr != b.ptr; } + template __forceinline bool operator !=(const Ref& a, const Ref& b) noexcept { return a.ptr != b.ptr; } + + template + __forceinline Ref makeRef(Args&&... args) + { + return Ref(new T(std::forward(args)...)); + } + + template + __forceinline Ref staticRefCast(const Ref& a) + { + return Ref(static_cast(a.get())); + } + + template + __forceinline Ref dynamicRefCast(const Ref& a) + { + return Ref(dynamic_cast(a.get())); + } + +} // namespace oidn diff --git a/thirdparty/oidn/common/tensor.cpp b/thirdparty/oidn/common/tensor.cpp new file mode 100644 index 000000000000..0249f2e14193 --- /dev/null +++ b/thirdparty/oidn/common/tensor.cpp @@ -0,0 +1,83 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "exception.h" +#include "tensor.h" + +namespace oidn { + + std::map parseTensors(void* buffer) + { + char* input = (char*)buffer; + + // Parse the magic value + const int magic = *(unsigned short*)input; + if (magic != 0x41D7) + throw Exception(Error::InvalidOperation, "invalid tensor archive"); + input += sizeof(unsigned short); + + // Parse the version + const int majorVersion = *(unsigned char*)input++; + const int minorVersion = *(unsigned char*)input++; + UNUSED(minorVersion); + if (majorVersion > 1) + throw Exception(Error::InvalidOperation, "unsupported tensor archive version"); + + // Parse the number of tensors + const int numTensors = *(int*)input; + input += sizeof(int); + + // Parse the tensors + std::map tensorMap; + for (int i = 0; i < numTensors; ++i) + { + Tensor tensor; + + // Parse the name + const int nameLen = *(unsigned char*)input++; + std::string name(input, nameLen); + input += nameLen; + + // Parse the number of dimensions + const int ndims = *(unsigned char*)input++; + + // Parse the shape of the tensor + tensor.dims.resize(ndims); + for (int i = 0; i < ndims; ++i) + tensor.dims[i] = ((int*)input)[i]; + input += ndims * sizeof(int); + + // Parse the format of the tensor + tensor.format = std::string(input, input + ndims); + input += ndims; + + // Parse the data type of the tensor + const char type = *(unsigned char*)input++; + if (type != 'f') // only float32 is supported + throw Exception(Error::InvalidOperation, "unsupported tensor data type"); + + // Skip the data + tensor.data = (float*)input; + input += tensor.size() * sizeof(float); + + // Add the tensor to the map + tensorMap.emplace(name, std::move(tensor)); + } + + return tensorMap; + } + +} // namespace oidn diff --git a/thirdparty/oidn/common/tensor.h b/thirdparty/oidn/common/tensor.h new file mode 100644 index 000000000000..48e7d1123d10 --- /dev/null +++ b/thirdparty/oidn/common/tensor.h @@ -0,0 +1,66 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" +#include +#include + +namespace oidn { + + template + using shared_vector = std::shared_ptr>; + + // Generic tensor + struct Tensor + { + float* data; + std::vector dims; + std::string format; + shared_vector buffer; // optional, only for reference counting + + __forceinline Tensor() : data(nullptr) {} + + __forceinline Tensor(const std::vector& dims, const std::string& format) + : dims(dims), + format(format) + { + buffer = std::make_shared>(size() * sizeof(float)); + data = (float*)buffer->data(); + } + + __forceinline operator bool() const { return data != nullptr; } + + __forceinline int ndims() const { return (int)dims.size(); } + + // Returns the number of values + __forceinline size_t size() const + { + size_t size = 1; + for (int i = 0; i < ndims(); ++i) + size *= dims[i]; + return size; + } + + __forceinline float& operator [](size_t i) { return data[i]; } + __forceinline const float& operator [](size_t i) const { return data[i]; } + }; + + // Parses tensors from a buffer + std::map parseTensors(void* buffer); + +} // namespace oidn diff --git a/thirdparty/oidn/common/thread.cpp b/thirdparty/oidn/common/thread.cpp new file mode 100644 index 000000000000..48c489c57b61 --- /dev/null +++ b/thirdparty/oidn/common/thread.cpp @@ -0,0 +1,297 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#if defined(_MSC_VER) + #pragma warning (disable : 4146) // unary minus operator applied to unsigned type, result still unsigned +#endif + +#if defined(__APPLE__) + #include + #include +#endif + +#include "thread.h" +#include + +namespace oidn { + +#if defined(_WIN32) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Windows + // -------------------------------------------------------------------------- + + ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) + : Verbose(verbose) + { + HMODULE hLib = GetModuleHandle(TEXT("kernel32")); + pGetLogicalProcessorInformationEx = (GetLogicalProcessorInformationExFunc)GetProcAddress(hLib, "GetLogicalProcessorInformationEx"); + pSetThreadGroupAffinity = (SetThreadGroupAffinityFunc)GetProcAddress(hLib, "SetThreadGroupAffinity"); + + if (pGetLogicalProcessorInformationEx && pSetThreadGroupAffinity) + { + // Get logical processor information + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX buffer = nullptr; + DWORD bufferSize = 0; + + // First call the function with an empty buffer to get the required buffer size + BOOL result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize); + if (result || GetLastError() != ERROR_INSUFFICIENT_BUFFER) + { + OIDN_WARNING("GetLogicalProcessorInformationEx failed"); + return; + } + + // Allocate the buffer + buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)malloc(bufferSize); + if (!buffer) + { + OIDN_WARNING("SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX allocation failed"); + return; + } + + // Call again the function but now with the properly sized buffer + result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize); + if (!result) + { + OIDN_WARNING("GetLogicalProcessorInformationEx failed"); + free(buffer); + return; + } + + // Iterate over the logical processor information structures + // There should be one structure for each physical core + char* ptr = (char*)buffer; + while (ptr < (char*)buffer + bufferSize) + { + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX item = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)ptr; + if (item->Relationship == RelationProcessorCore && item->Processor.GroupCount > 0) + { + // Iterate over the groups + int numThreads = 0; + for (int group = 0; (group < item->Processor.GroupCount) && (numThreads < numThreadsPerCore); ++group) + { + GROUP_AFFINITY coreAffinity = item->Processor.GroupMask[group]; + while ((coreAffinity.Mask != 0) && (numThreads < numThreadsPerCore)) + { + // Extract the next set bit/thread from the mask + GROUP_AFFINITY threadAffinity = coreAffinity; + threadAffinity.Mask = threadAffinity.Mask & -threadAffinity.Mask; + + // Push the affinity for this thread + affinities.push_back(threadAffinity); + oldAffinities.push_back(threadAffinity); + numThreads++; + + // Remove this bit/thread from the mask + coreAffinity.Mask ^= threadAffinity.Mask; + } + } + } + + // Next structure + ptr += item->Size; + } + + // Free the buffer + free(buffer); + } + } + + void ThreadAffinity::set(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + // Save the current affinity and set the new one + const HANDLE thread = GetCurrentThread(); + if (!pSetThreadGroupAffinity(thread, &affinities[threadIndex], &oldAffinities[threadIndex])) + OIDN_WARNING("SetThreadGroupAffinity failed"); + } + + void ThreadAffinity::restore(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + // Restore the original affinity + const HANDLE thread = GetCurrentThread(); + if (!pSetThreadGroupAffinity(thread, &oldAffinities[threadIndex], nullptr)) + OIDN_WARNING("SetThreadGroupAffinity failed"); + } + +#elif defined(__linux__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Linux + // -------------------------------------------------------------------------- + + ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) + : Verbose(verbose) + { + std::vector threadIds; + + // Parse the thread/CPU topology + for (int cpuId = 0; ; cpuId++) + { + std::fstream fs; + std::string cpu = std::string("/sys/devices/system/cpu/cpu") + std::to_string(cpuId) + std::string("/topology/thread_siblings_list"); + fs.open(cpu.c_str(), std::fstream::in); + if (fs.fail()) break; + + int i; + int j = 0; + while ((j < numThreadsPerCore) && (fs >> i)) + { + if (std::none_of(threadIds.begin(), threadIds.end(), [&](int id) { return id == i; })) + threadIds.push_back(i); + + if (fs.peek() == ',') + fs.ignore(); + j++; + } + + fs.close(); + } + + #if 0 + for (size_t i = 0; i < thread_ids.size(); ++i) + std::cout << "thread " << i << " -> " << thread_ids[i] << std::endl; + #endif + + // Create the affinity structures + affinities.resize(threadIds.size()); + oldAffinities.resize(threadIds.size()); + + for (size_t i = 0; i < threadIds.size(); ++i) + { + cpu_set_t affinity; + CPU_ZERO(&affinity); + CPU_SET(threadIds[i], &affinity); + + affinities[i] = affinity; + oldAffinities[i] = affinity; + } + } + + void ThreadAffinity::set(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const pthread_t thread = pthread_self(); + + // Save the current affinity + if (pthread_getaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0) + { + OIDN_WARNING("pthread_getaffinity_np failed"); + oldAffinities[threadIndex] = affinities[threadIndex]; + return; + } + + // Set the new affinity + if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &affinities[threadIndex]) != 0) + OIDN_WARNING("pthread_setaffinity_np failed"); + } + + void ThreadAffinity::restore(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const pthread_t thread = pthread_self(); + + // Restore the original affinity + if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0) + OIDN_WARNING("pthread_setaffinity_np failed"); + } + +#elif defined(__APPLE__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - macOS + // -------------------------------------------------------------------------- + + ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) + : Verbose(verbose) + { + // Query the thread/CPU topology + int numPhysicalCpus; + int numLogicalCpus; + + if (!getSysctl("hw.physicalcpu", numPhysicalCpus) || !getSysctl("hw.logicalcpu", numLogicalCpus)) + { + OIDN_WARNING("sysctlbyname failed"); + return; + } + + if ((numLogicalCpus % numPhysicalCpus != 0) && (numThreadsPerCore > 1)) + return; // this shouldn't happen + const int maxThreadsPerCore = numLogicalCpus / numPhysicalCpus; + + // Create the affinity structures + // macOS doesn't support binding a thread to a specific core, but we can at least group threads which + // should be on the same core together + for (int core = 1; core <= numPhysicalCpus; ++core) // tags start from 1! + { + thread_affinity_policy affinity; + affinity.affinity_tag = core; + + for (int thread = 0; thread < min(numThreadsPerCore, maxThreadsPerCore); ++thread) + { + affinities.push_back(affinity); + oldAffinities.push_back(affinity); + } + } + } + + void ThreadAffinity::set(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const auto thread = mach_thread_self(); + + // Save the current affinity + mach_msg_type_number_t policyCount = THREAD_AFFINITY_POLICY_COUNT; + boolean_t getDefault = FALSE; + if (thread_policy_get(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], &policyCount, &getDefault) != KERN_SUCCESS) + { + OIDN_WARNING("thread_policy_get failed"); + oldAffinities[threadIndex] = affinities[threadIndex]; + return; + } + + // Set the new affinity + if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&affinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS) + OIDN_WARNING("thread_policy_set failed"); + } + + void ThreadAffinity::restore(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const auto thread = mach_thread_self(); + + // Restore the original affinity + if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS) + OIDN_WARNING("thread_policy_set failed"); + } + +#endif + +} // namespace oidn diff --git a/thirdparty/oidn/common/thread.h b/thirdparty/oidn/common/thread.h new file mode 100644 index 000000000000..2c731367da3b --- /dev/null +++ b/thirdparty/oidn/common/thread.h @@ -0,0 +1,202 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" + +#if !defined(_WIN32) + #include + #include + #if defined(__APPLE__) + #include + #endif +#endif + +#include +#include + +namespace oidn { + + // -------------------------------------------------------------------------- + // ThreadLocal + // -------------------------------------------------------------------------- + + // Wrapper which makes any variable thread-local + template + class ThreadLocal : public Verbose + { + private: + #if defined(_WIN32) + DWORD key; + #else + pthread_key_t key; + #endif + + std::vector instances; + std::mutex mutex; + + public: + ThreadLocal(int verbose = 0) + : Verbose(verbose) + { + #if defined(_WIN32) + key = TlsAlloc(); + if (key == TLS_OUT_OF_INDEXES) + OIDN_FATAL("TlsAlloc failed"); + #else + if (pthread_key_create(&key, nullptr) != 0) + OIDN_FATAL("pthread_key_create failed"); + #endif + } + + ~ThreadLocal() + { + std::lock_guard lock(mutex); + for (T* ptr : instances) + delete ptr; + + #if defined(_WIN32) + if (!TlsFree(key)) + OIDN_WARNING("TlsFree failed"); + #else + if (pthread_key_delete(key) != 0) + OIDN_WARNING("pthread_key_delete failed"); + #endif + } + + T& get() + { + #if defined(_WIN32) + T* ptr = (T*)TlsGetValue(key); + #else + T* ptr = (T*)pthread_getspecific(key); + #endif + + if (ptr) + return *ptr; + + ptr = new T; + std::lock_guard lock(mutex); + instances.push_back(ptr); + + #if defined(_WIN32) + if (!TlsSetValue(key, ptr)) + OIDN_FATAL("TlsSetValue failed"); + #else + if (pthread_setspecific(key, ptr) != 0) + OIDN_FATAL("pthread_setspecific failed"); + #endif + + return *ptr; + } + }; + +#if defined(_WIN32) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Windows + // -------------------------------------------------------------------------- + + class ThreadAffinity : public Verbose + { + private: + typedef BOOL (WINAPI *GetLogicalProcessorInformationExFunc)(LOGICAL_PROCESSOR_RELATIONSHIP, + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, + PDWORD); + + typedef BOOL (WINAPI *SetThreadGroupAffinityFunc)(HANDLE, + CONST GROUP_AFFINITY*, + PGROUP_AFFINITY); + + GetLogicalProcessorInformationExFunc pGetLogicalProcessorInformationEx = nullptr; + SetThreadGroupAffinityFunc pSetThreadGroupAffinity = nullptr; + + std::vector affinities; // thread affinities + std::vector oldAffinities; // original thread affinities + + public: + ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); + + int getNumThreads() const + { + return (int)affinities.size(); + } + + // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity + void set(int threadIndex); + + // Restores the affinity of the thread + void restore(int threadIndex); + }; + +#elif defined(__linux__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Linux + // -------------------------------------------------------------------------- + + class ThreadAffinity : public Verbose + { + private: + std::vector affinities; // thread affinities + std::vector oldAffinities; // original thread affinities + + public: + ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); + + int getNumThreads() const + { + return (int)affinities.size(); + } + + // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity + void set(int threadIndex); + + // Restores the affinity of the thread + void restore(int threadIndex); + }; + +#elif defined(__APPLE__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - macOS + // -------------------------------------------------------------------------- + + class ThreadAffinity : public Verbose + { + private: + std::vector affinities; // thread affinities + std::vector oldAffinities; // original thread affinities + + public: + ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); + + int getNumThreads() const + { + return (int)affinities.size(); + } + + // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity + void set(int threadIndex); + + // Restores the affinity of the thread + void restore(int threadIndex); + }; + +#endif + +} // namespace oidn diff --git a/thirdparty/oidn/common/timer.h b/thirdparty/oidn/common/timer.h new file mode 100644 index 000000000000..62aaaa1c33fc --- /dev/null +++ b/thirdparty/oidn/common/timer.h @@ -0,0 +1,49 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" +#include + +namespace oidn { + + class Timer + { + private: + using clock = std::chrono::high_resolution_clock; + + std::chrono::time_point start; + + public: + Timer() + { + reset(); + } + + void reset() + { + start = clock::now(); + } + + double query() const + { + auto end = clock::now(); + return std::chrono::duration_cast>(end - start).count(); + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/api.cpp b/thirdparty/oidn/core/api.cpp new file mode 100644 index 000000000000..7353fe4e2528 --- /dev/null +++ b/thirdparty/oidn/core/api.cpp @@ -0,0 +1,408 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#ifdef _WIN32 +# define OIDN_API extern "C" __declspec(dllexport) +#else +# define OIDN_API extern "C" __attribute__ ((visibility ("default"))) +#endif + +// Locks the device that owns the specified object +// Use *only* inside OIDN_TRY/CATCH! +#define OIDN_LOCK(obj) \ + std::lock_guard lock(obj->getDevice()->getMutex()); + +// Try/catch for converting exceptions to errors +#define OIDN_TRY \ + try { + +#define OIDN_CATCH(obj) \ + } catch (Exception& e) { \ + Device::setError(obj ? obj->getDevice() : nullptr, e.code(), e.what()); \ + } catch (std::bad_alloc&) { \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \ + } catch (mkldnn::error& e) { \ + if (e.status == mkldnn_out_of_memory) \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \ + else \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.message); \ + } catch (std::exception& e) { \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.what()); \ + } catch (...) { \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, "unknown exception caught"); \ + } + +#include "device.h" +#include "filter.h" +#include + +namespace oidn { + + namespace + { + __forceinline void checkHandle(void* handle) + { + if (handle == nullptr) + throw Exception(Error::InvalidArgument, "invalid handle"); + } + + template + __forceinline void retainObject(T* obj) + { + if (obj) + { + obj->incRef(); + } + else + { + OIDN_TRY + checkHandle(obj); + OIDN_CATCH(obj) + } + } + + template + __forceinline void releaseObject(T* obj) + { + if (obj == nullptr || obj->decRefKeep() == 0) + { + OIDN_TRY + checkHandle(obj); + OIDN_LOCK(obj); + obj->destroy(); + OIDN_CATCH(obj) + } + } + + template<> + __forceinline void releaseObject(Device* obj) + { + if (obj == nullptr || obj->decRefKeep() == 0) + { + OIDN_TRY + checkHandle(obj); + // Do NOT lock the device because it owns the mutex + obj->destroy(); + OIDN_CATCH(obj) + } + } + } + + OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type) + { + Ref device = nullptr; + OIDN_TRY + if (type == OIDN_DEVICE_TYPE_CPU || type == OIDN_DEVICE_TYPE_DEFAULT) + device = makeRef(); + else + throw Exception(Error::InvalidArgument, "invalid device type"); + OIDN_CATCH(device) + return (OIDNDevice)device.detach(); + } + + OIDN_API void oidnRetainDevice(OIDNDevice hDevice) + { + Device* device = (Device*)hDevice; + retainObject(device); + } + + OIDN_API void oidnReleaseDevice(OIDNDevice hDevice) + { + Device* device = (Device*)hDevice; + releaseObject(device); + } + + OIDN_API void oidnSetDevice1b(OIDNDevice hDevice, const char* name, bool value) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->set1i(name, value); + OIDN_CATCH(device) + } + + OIDN_API void oidnSetDevice1i(OIDNDevice hDevice, const char* name, int value) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->set1i(name, value); + OIDN_CATCH(device) + } + + OIDN_API bool oidnGetDevice1b(OIDNDevice hDevice, const char* name) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + return device->get1i(name); + OIDN_CATCH(device) + return false; + } + + OIDN_API int oidnGetDevice1i(OIDNDevice hDevice, const char* name) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + return device->get1i(name); + OIDN_CATCH(device) + return 0; + } + + OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice hDevice, OIDNErrorFunction func, void* userPtr) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->setErrorFunction((ErrorFunction)func, userPtr); + OIDN_CATCH(device) + } + + OIDN_API OIDNError oidnGetDeviceError(OIDNDevice hDevice, const char** outMessage) + { + Device* device = (Device*)hDevice; + OIDN_TRY + return (OIDNError)Device::getError(device, outMessage); + OIDN_CATCH(device) + if (outMessage) *outMessage = ""; + return OIDN_ERROR_UNKNOWN; + } + + OIDN_API void oidnCommitDevice(OIDNDevice hDevice) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->commit(); + OIDN_CATCH(device) + } + + OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice hDevice, size_t byteSize) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + Ref buffer = device->newBuffer(byteSize); + return (OIDNBuffer)buffer.detach(); + OIDN_CATCH(device) + return nullptr; + } + + OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice hDevice, void* ptr, size_t byteSize) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + Ref buffer = device->newBuffer(ptr, byteSize); + return (OIDNBuffer)buffer.detach(); + OIDN_CATCH(device) + return nullptr; + } + + OIDN_API void oidnRetainBuffer(OIDNBuffer hBuffer) + { + Buffer* buffer = (Buffer*)hBuffer; + retainObject(buffer); + } + + OIDN_API void oidnReleaseBuffer(OIDNBuffer hBuffer) + { + Buffer* buffer = (Buffer*)hBuffer; + releaseObject(buffer); + } + + OIDN_API void* oidnMapBuffer(OIDNBuffer hBuffer, OIDNAccess access, size_t byteOffset, size_t byteSize) + { + Buffer* buffer = (Buffer*)hBuffer; + OIDN_TRY + checkHandle(hBuffer); + OIDN_LOCK(buffer); + return buffer->map(byteOffset, byteSize); + OIDN_CATCH(buffer) + return nullptr; + } + + OIDN_API void oidnUnmapBuffer(OIDNBuffer hBuffer, void* mappedPtr) + { + Buffer* buffer = (Buffer*)hBuffer; + OIDN_TRY + checkHandle(hBuffer); + OIDN_LOCK(buffer); + return buffer->unmap(mappedPtr); + OIDN_CATCH(buffer) + } + + OIDN_API OIDNFilter oidnNewFilter(OIDNDevice hDevice, const char* type) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + Ref filter = device->newFilter(type); + return (OIDNFilter)filter.detach(); + OIDN_CATCH(device) + return nullptr; + } + + OIDN_API void oidnRetainFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + retainObject(filter); + } + + OIDN_API void oidnReleaseFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + releaseObject(filter); + } + + OIDN_API void oidnSetFilterImage(OIDNFilter hFilter, const char* name, + OIDNBuffer hBuffer, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + checkHandle(hBuffer); + OIDN_LOCK(filter); + Ref buffer = (Buffer*)hBuffer; + if (buffer->getDevice() != filter->getDevice()) + throw Exception(Error::InvalidArgument, "the specified objects are bound to different devices"); + Image data(buffer, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride); + filter->setImage(name, data); + OIDN_CATCH(filter) + } + + OIDN_API void oidnSetSharedFilterImage(OIDNFilter hFilter, const char* name, + void* ptr, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + Image data(ptr, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride); + filter->setImage(name, data); + OIDN_CATCH(filter) + } + + OIDN_API void oidnSetFilter1b(OIDNFilter hFilter, const char* name, bool value) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->set1i(name, int(value)); + OIDN_CATCH(filter) + } + + OIDN_API bool oidnGetFilter1b(OIDNFilter hFilter, const char* name) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + return filter->get1i(name); + OIDN_CATCH(filter) + return false; + } + + OIDN_API void oidnSetFilter1i(OIDNFilter hFilter, const char* name, int value) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->set1i(name, value); + OIDN_CATCH(filter) + } + + OIDN_API int oidnGetFilter1i(OIDNFilter hFilter, const char* name) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + return filter->get1i(name); + OIDN_CATCH(filter) + return 0; + } + + OIDN_API void oidnSetFilter1f(OIDNFilter hFilter, const char* name, float value) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->set1f(name, value); + OIDN_CATCH(filter) + } + + OIDN_API float oidnGetFilter1f(OIDNFilter hFilter, const char* name) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + return filter->get1f(name); + OIDN_CATCH(filter) + return 0; + } + + OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter hFilter, OIDNProgressMonitorFunction func, void* userPtr) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->setProgressMonitorFunction(func, userPtr); + OIDN_CATCH(filter) + } + + OIDN_API void oidnCommitFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->commit(); + OIDN_CATCH(filter) + } + + OIDN_API void oidnExecuteFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->execute(); + OIDN_CATCH(filter) + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/autoencoder.cpp b/thirdparty/oidn/core/autoencoder.cpp new file mode 100644 index 000000000000..d8da684cb894 --- /dev/null +++ b/thirdparty/oidn/core/autoencoder.cpp @@ -0,0 +1,535 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "autoencoder.h" + +namespace oidn { + + // -------------------------------------------------------------------------- + // AutoencoderFilter + // -------------------------------------------------------------------------- + + AutoencoderFilter::AutoencoderFilter(const Ref& device) + : Filter(device) + { + } + + void AutoencoderFilter::setImage(const std::string& name, const Image& data) + { + if (name == "color") + color = data; + else if (name == "albedo") + albedo = data; + else if (name == "normal") + normal = data; + else if (name == "output") + output = data; + + dirty = true; + } + + void AutoencoderFilter::set1i(const std::string& name, int value) + { + if (name == "hdr") + hdr = value; + else if (name == "srgb") + srgb = value; + else if (name == "maxMemoryMB") + maxMemoryMB = value; + + dirty = true; + } + + int AutoencoderFilter::get1i(const std::string& name) + { + if (name == "hdr") + return hdr; + else if (name == "srgb") + return srgb; + else if (name == "maxMemoryMB") + return maxMemoryMB; + else if (name == "alignment") + return alignment; + else if (name == "overlap") + return overlap; + else + throw Exception(Error::InvalidArgument, "invalid parameter"); + } + + void AutoencoderFilter::set1f(const std::string& name, float value) + { + if (name == "hdrScale") + hdrScale = value; + + dirty = true; + } + + float AutoencoderFilter::get1f(const std::string& name) + { + if (name == "hdrScale") + return hdrScale; + else + throw Exception(Error::InvalidArgument, "invalid parameter"); + } + + void AutoencoderFilter::commit() + { + if (!dirty) + return; + + // -- GODOT start -- + //device->executeTask([&]() + //{ + // GODOT end -- + + if (mayiuse(avx512_common)) + net = buildNet<16>(); + else + net = buildNet<8>(); + + // GODOT start -- + //}); + // GODOT end -- + + dirty = false; + } + + void AutoencoderFilter::execute() + { + if (dirty) + throw Exception(Error::InvalidOperation, "changes to the filter are not committed"); + + if (!net) + return; + // -- GODOT start -- + //device->executeTask([&]() + //{ + // -- GODOT end -- + Progress progress; + progress.func = progressFunc; + progress.userPtr = progressUserPtr; + progress.taskCount = tileCountH * tileCountW; + + // Iterate over the tiles + int tileIndex = 0; + + for (int i = 0; i < tileCountH; ++i) + { + const int h = i * (tileH - 2*overlap); // input tile position (including overlap) + const int overlapBeginH = i > 0 ? overlap : 0; // overlap on the top + const int overlapEndH = i < tileCountH-1 ? overlap : 0; // overlap on the bottom + const int tileH1 = min(H - h, tileH); // input tile size (including overlap) + const int tileH2 = tileH1 - overlapBeginH - overlapEndH; // output tile size + const int alignOffsetH = tileH - roundUp(tileH1, alignment); // align to the bottom in the tile buffer + + for (int j = 0; j < tileCountW; ++j) + { + const int w = j * (tileW - 2*overlap); // input tile position (including overlap) + const int overlapBeginW = j > 0 ? overlap : 0; // overlap on the left + const int overlapEndW = j < tileCountW-1 ? overlap : 0; // overlap on the right + const int tileW1 = min(W - w, tileW); // input tile size (including overlap) + const int tileW2 = tileW1 - overlapBeginW - overlapEndW; // output tile size + const int alignOffsetW = tileW - roundUp(tileW1, alignment); // align to the right in the tile buffer + + // Set the input tile + inputReorder->setTile(h, w, + alignOffsetH, alignOffsetW, + tileH1, tileW1); + + // Set the output tile + outputReorder->setTile(alignOffsetH + overlapBeginH, alignOffsetW + overlapBeginW, + h + overlapBeginH, w + overlapBeginW, + tileH2, tileW2); + + //printf("Tile: %d %d -> %d %d\n", w+overlapBeginW, h+overlapBeginH, w+overlapBeginW+tileW2, h+overlapBeginH+tileH2); + + // Denoise the tile + net->execute(progress, tileIndex); + + // Next tile + tileIndex++; + } + } + // -- GODOT start -- + //}); + // -- GODOT end -- + } + + void AutoencoderFilter::computeTileSize() + { + const int minTileSize = 3*overlap; + const int estimatedBytesPerPixel = mayiuse(avx512_common) ? estimatedBytesPerPixel16 : estimatedBytesPerPixel8; + const int64_t maxTilePixels = (int64_t(maxMemoryMB)*1024*1024 - estimatedBytesBase) / estimatedBytesPerPixel; + + tileCountH = 1; + tileCountW = 1; + tileH = roundUp(H, alignment); + tileW = roundUp(W, alignment); + + // Divide the image into tiles until the tile size gets below the threshold + while (int64_t(tileH) * tileW > maxTilePixels) + { + if (tileH > minTileSize && tileH > tileW) + { + tileCountH++; + tileH = max(roundUp(ceilDiv(H - 2*overlap, tileCountH), alignment) + 2*overlap, minTileSize); + } + else if (tileW > minTileSize) + { + tileCountW++; + tileW = max(roundUp(ceilDiv(W - 2*overlap, tileCountW), alignment) + 2*overlap, minTileSize); + } + else + break; + } + + // Compute the final number of tiles + tileCountH = (H > tileH) ? ceilDiv(H - 2*overlap, tileH - 2*overlap) : 1; + tileCountW = (W > tileW) ? ceilDiv(W - 2*overlap, tileW - 2*overlap) : 1; + + if (device->isVerbose(2)) + { + std::cout << "Tile size : " << tileW << "x" << tileH << std::endl; + std::cout << "Tile count: " << tileCountW << "x" << tileCountH << std::endl; + } + } + + template + std::shared_ptr AutoencoderFilter::buildNet() + { + H = color.height; + W = color.width; + + // Configure the network + int inputC; + void* weightPtr; + + if (srgb && hdr) + throw Exception(Error::InvalidOperation, "srgb and hdr modes cannot be enabled at the same time"); + + if (color && !albedo && !normal && weightData.hdr) + { + inputC = 3; + weightPtr = hdr ? weightData.hdr : weightData.ldr; + } + else if (color && albedo && !normal && weightData.hdr_alb) + { + inputC = 6; + weightPtr = hdr ? weightData.hdr_alb : weightData.ldr_alb; + } + else if (color && albedo && normal && weightData.hdr_alb_nrm) + { + inputC = 9; + weightPtr = hdr ? weightData.hdr_alb_nrm : weightData.ldr_alb_nrm; + } + else + { + throw Exception(Error::InvalidOperation, "unsupported combination of input features"); + } + + if (!output) + throw Exception(Error::InvalidOperation, "output image not specified"); + + if ((color.format != Format::Float3) + || (albedo && albedo.format != Format::Float3) + || (normal && normal.format != Format::Float3) + || (output.format != Format::Float3)) + throw Exception(Error::InvalidOperation, "unsupported image format"); + + if ((albedo && (albedo.width != W || albedo.height != H)) + || (normal && (normal.width != W || normal.height != H)) + || (output.width != W || output.height != H)) + throw Exception(Error::InvalidOperation, "image size mismatch"); + + // Compute the tile size + computeTileSize(); + + // If the image size is zero, there is nothing else to do + if (H <= 0 || W <= 0) + return nullptr; + + // Parse the weights + const auto weightMap = parseTensors(weightPtr); + + // Create the network + std::shared_ptr> net = std::make_shared>(device, weightMap); + + // Compute the tensor sizes + const auto inputDims = memory::dims({1, inputC, tileH, tileW}); + const auto inputReorderDims = net->getInputReorderDims(inputDims, alignment); //-> concat0 + + const auto conv1Dims = net->getConvDims("conv1", inputReorderDims); //-> temp0 + const auto conv1bDims = net->getConvDims("conv1b", conv1Dims); //-> temp1 + const auto pool1Dims = net->getPoolDims(conv1bDims); //-> concat1 + const auto conv2Dims = net->getConvDims("conv2", pool1Dims); //-> temp0 + const auto pool2Dims = net->getPoolDims(conv2Dims); //-> concat2 + const auto conv3Dims = net->getConvDims("conv3", pool2Dims); //-> temp0 + const auto pool3Dims = net->getPoolDims(conv3Dims); //-> concat3 + const auto conv4Dims = net->getConvDims("conv4", pool3Dims); //-> temp0 + const auto pool4Dims = net->getPoolDims(conv4Dims); //-> concat4 + const auto conv5Dims = net->getConvDims("conv5", pool4Dims); //-> temp0 + const auto pool5Dims = net->getPoolDims(conv5Dims); //-> temp1 + const auto upsample4Dims = net->getUpsampleDims(pool5Dims); //-> concat4 + const auto concat4Dims = net->getConcatDims(upsample4Dims, pool4Dims); + const auto conv6Dims = net->getConvDims("conv6", concat4Dims); //-> temp0 + const auto conv6bDims = net->getConvDims("conv6b", conv6Dims); //-> temp1 + const auto upsample3Dims = net->getUpsampleDims(conv6bDims); //-> concat3 + const auto concat3Dims = net->getConcatDims(upsample3Dims, pool3Dims); + const auto conv7Dims = net->getConvDims("conv7", concat3Dims); //-> temp0 + const auto conv7bDims = net->getConvDims("conv7b", conv7Dims); //-> temp1 + const auto upsample2Dims = net->getUpsampleDims(conv7bDims); //-> concat2 + const auto concat2Dims = net->getConcatDims(upsample2Dims, pool2Dims); + const auto conv8Dims = net->getConvDims("conv8", concat2Dims); //-> temp0 + const auto conv8bDims = net->getConvDims("conv8b", conv8Dims); //-> temp1 + const auto upsample1Dims = net->getUpsampleDims(conv8bDims); //-> concat1 + const auto concat1Dims = net->getConcatDims(upsample1Dims, pool1Dims); + const auto conv9Dims = net->getConvDims("conv9", concat1Dims); //-> temp0 + const auto conv9bDims = net->getConvDims("conv9b", conv9Dims); //-> temp1 + const auto upsample0Dims = net->getUpsampleDims(conv9bDims); //-> concat0 + const auto concat0Dims = net->getConcatDims(upsample0Dims, inputReorderDims); + const auto conv10Dims = net->getConvDims("conv10", concat0Dims); //-> temp0 + const auto conv10bDims = net->getConvDims("conv10b", conv10Dims); //-> temp1 + const auto conv11Dims = net->getConvDims("conv11", conv10bDims); //-> temp0 + + const auto outputDims = memory::dims({1, 3, tileH, tileW}); + + // Allocate two temporary ping-pong buffers to decrease memory usage + const auto temp0Dims = getMaxTensorDims({ + conv1Dims, + conv2Dims, + conv3Dims, + conv4Dims, + conv5Dims, + conv6Dims, + conv7Dims, + conv8Dims, + conv9Dims, + conv10Dims, + conv11Dims + }); + + const auto temp1Dims = getMaxTensorDims({ + conv1bDims, + pool5Dims, + conv6bDims, + conv7bDims, + conv8bDims, + conv9bDims, + conv10bDims, + }); + + auto temp0 = net->allocTensor(temp0Dims); + auto temp1 = net->allocTensor(temp1Dims); + + // Allocate enough memory to hold the concat outputs. Then use the first + // half to hold the previous conv output and the second half to hold the + // pool/orig image output. This works because everything is C dimension + // outermost, padded to K floats, and all the concats are on the C dimension. + auto concat0Dst = net->allocTensor(concat0Dims); + auto concat1Dst = net->allocTensor(concat1Dims); + auto concat2Dst = net->allocTensor(concat2Dims); + auto concat3Dst = net->allocTensor(concat3Dims); + auto concat4Dst = net->allocTensor(concat4Dims); + + // Transfer function + std::shared_ptr transferFunc = makeTransferFunc(); + + // Autoexposure + if (auto tf = std::dynamic_pointer_cast(transferFunc)) + { + if (isnan(hdrScale)) + net->addAutoexposure(color, tf); + else + tf->setExposure(hdrScale); + } + + // Input reorder + auto inputReorderDst = net->castTensor(inputReorderDims, concat0Dst, upsample0Dims); + inputReorder = net->addInputReorder(color, albedo, normal, + transferFunc, + alignment, inputReorderDst); + + // conv1 + auto conv1 = net->addConv("conv1", inputReorder->getDst(), temp0); + + // conv1b + auto conv1b = net->addConv("conv1b", conv1->getDst(), temp1); + + // pool1 + // Adjust pointer for pool1 to eliminate concat1 + auto pool1Dst = net->castTensor(pool1Dims, concat1Dst, upsample1Dims); + auto pool1 = net->addPool(conv1b->getDst(), pool1Dst); + + // conv2 + auto conv2 = net->addConv("conv2", pool1->getDst(), temp0); + + // pool2 + // Adjust pointer for pool2 to eliminate concat2 + auto pool2Dst = net->castTensor(pool2Dims, concat2Dst, upsample2Dims); + auto pool2 = net->addPool(conv2->getDst(), pool2Dst); + + // conv3 + auto conv3 = net->addConv("conv3", pool2->getDst(), temp0); + + // pool3 + // Adjust pointer for pool3 to eliminate concat3 + auto pool3Dst = net->castTensor(pool3Dims, concat3Dst, upsample3Dims); + auto pool3 = net->addPool(conv3->getDst(), pool3Dst); + + // conv4 + auto conv4 = net->addConv("conv4", pool3->getDst(), temp0); + + // pool4 + // Adjust pointer for pool4 to eliminate concat4 + auto pool4Dst = net->castTensor(pool4Dims, concat4Dst, upsample4Dims); + auto pool4 = net->addPool(conv4->getDst(), pool4Dst); + + // conv5 + auto conv5 = net->addConv("conv5", pool4->getDst(), temp0); + + // pool5 + auto pool5 = net->addPool(conv5->getDst(), temp1); + + // upsample4 + auto upsample4Dst = net->castTensor(upsample4Dims, concat4Dst); + auto upsample4 = net->addUpsample(pool5->getDst(), upsample4Dst); + + // conv6 + auto conv6 = net->addConv("conv6", concat4Dst, temp0); + + // conv6b + auto conv6b = net->addConv("conv6b", conv6->getDst(), temp1); + + // upsample3 + auto upsample3Dst = net->castTensor(upsample3Dims, concat3Dst); + auto upsample3 = net->addUpsample(conv6b->getDst(), upsample3Dst); + + // conv7 + auto conv7 = net->addConv("conv7", concat3Dst, temp0); + + // conv7b + auto conv7b = net->addConv("conv7b", conv7->getDst(), temp1); + + // upsample2 + auto upsample2Dst = net->castTensor(upsample2Dims, concat2Dst); + auto upsample2 = net->addUpsample(conv7b->getDst(), upsample2Dst); + + // conv8 + auto conv8 = net->addConv("conv8", concat2Dst, temp0); + + // conv8b + auto conv8b = net->addConv("conv8b", conv8->getDst(), temp1); + + // upsample1 + auto upsample1Dst = net->castTensor(upsample1Dims, concat1Dst); + auto upsample1 = net->addUpsample(conv8b->getDst(), upsample1Dst); + + // conv9 + auto conv9 = net->addConv("conv9", concat1Dst, temp0); + + // conv9b + auto conv9b = net->addConv("conv9b", conv9->getDst(), temp1); + + // upsample0 + auto upsample0Dst = net->castTensor(upsample0Dims, concat0Dst); + auto upsample0 = net->addUpsample(conv9b->getDst(), upsample0Dst); + + // conv10 + auto conv10 = net->addConv("conv10", concat0Dst, temp0); + + // conv10b + auto conv10b = net->addConv("conv10b", conv10->getDst(), temp1); + + // conv11 + auto conv11 = net->addConv("conv11", conv10b->getDst(), temp0, false /* no relu */); + + // Output reorder + outputReorder = net->addOutputReorder(conv11->getDst(), transferFunc, output); + + net->finalize(); + return net; + } + + std::shared_ptr AutoencoderFilter::makeTransferFunc() + { + if (hdr) + return std::make_shared(); + else if (srgb) + return std::make_shared(); + else + return std::make_shared(); + } + +// -- GODOT start -- +// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. +#if 0 +// -- GODOT end -- + + // -------------------------------------------------------------------------- + // RTFilter + // -------------------------------------------------------------------------- + + namespace weights + { + // LDR + extern unsigned char rt_ldr[]; // color + extern unsigned char rt_ldr_alb[]; // color, albedo + extern unsigned char rt_ldr_alb_nrm[]; // color, albedo, normal + + // HDR + extern unsigned char rt_hdr[]; // color + extern unsigned char rt_hdr_alb[]; // color, albedo + extern unsigned char rt_hdr_alb_nrm[]; // color, albedo, normal + } + + RTFilter::RTFilter(const Ref& device) + : AutoencoderFilter(device) + { + weightData.ldr = weights::rt_ldr; + weightData.ldr_alb = weights::rt_ldr_alb; + weightData.ldr_alb_nrm = weights::rt_ldr_alb_nrm; + weightData.hdr = weights::rt_hdr; + weightData.hdr_alb = weights::rt_hdr_alb; + weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm; + } +// -- GODOT start -- +#endif +// -- GODOT end -- + + // -------------------------------------------------------------------------- + // RTLightmapFilter + // -------------------------------------------------------------------------- + + namespace weights + { + // HDR + extern unsigned char rtlightmap_hdr[]; // color + } + + RTLightmapFilter::RTLightmapFilter(const Ref& device) + : AutoencoderFilter(device) + { + weightData.hdr = weights::rtlightmap_hdr; + + hdr = true; + } + + std::shared_ptr RTLightmapFilter::makeTransferFunc() + { + return std::make_shared(); + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/autoencoder.h b/thirdparty/oidn/core/autoencoder.h new file mode 100644 index 000000000000..98b610844ed7 --- /dev/null +++ b/thirdparty/oidn/core/autoencoder.h @@ -0,0 +1,120 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "filter.h" +#include "network.h" +#include "transfer_function.h" + +namespace oidn { + + // -------------------------------------------------------------------------- + // AutoencoderFilter - Direct-predicting autoencoder + // -------------------------------------------------------------------------- + + class AutoencoderFilter : public Filter + { + protected: + static constexpr int alignment = 32; // required spatial alignment in pixels (padding may be necessary) + static constexpr int receptiveField = 222; // receptive field in pixels + static constexpr int overlap = roundUp(receptiveField / 2, alignment); // required spatial overlap between tiles in pixels + + static constexpr int estimatedBytesBase = 16*1024*1024; // estimated base memory usage + static constexpr int estimatedBytesPerPixel8 = 889; // estimated memory usage per pixel for K=8 + static constexpr int estimatedBytesPerPixel16 = 2185; // estimated memory usage per pixel for K=16 + + Image color; + Image albedo; + Image normal; + Image output; + bool hdr = false; + float hdrScale = std::numeric_limits::quiet_NaN(); + bool srgb = false; + int maxMemoryMB = 6000; // approximate maximum memory usage in MBs + + int H = 0; // image height + int W = 0; // image width + int tileH = 0; // tile height + int tileW = 0; // tile width + int tileCountH = 1; // number of tiles in H dimension + int tileCountW = 1; // number of tiles in W dimension + + std::shared_ptr net; + std::shared_ptr inputReorder; + std::shared_ptr outputReorder; + + struct + { + void* ldr = nullptr; + void* ldr_alb = nullptr; + void* ldr_alb_nrm = nullptr; + void* hdr = nullptr; + void* hdr_alb = nullptr; + void* hdr_alb_nrm = nullptr; + } weightData; + + explicit AutoencoderFilter(const Ref& device); + virtual std::shared_ptr makeTransferFunc(); + + public: + void setImage(const std::string& name, const Image& data) override; + void set1i(const std::string& name, int value) override; + int get1i(const std::string& name) override; + void set1f(const std::string& name, float value) override; + float get1f(const std::string& name) override; + + void commit() override; + void execute() override; + + private: + void computeTileSize(); + + template + std::shared_ptr buildNet(); + + bool isCommitted() const { return bool(net); } + }; + + // -------------------------------------------------------------------------- + // RTFilter - Generic ray tracing denoiser + // -------------------------------------------------------------------------- + +// -- GODOT start -- +// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. +#if 0 +// -- GODOT end -- + class RTFilter : public AutoencoderFilter + { + public: + explicit RTFilter(const Ref& device); + }; +// -- GODOT start -- +#endif +// -- GODOT end -- + + // -------------------------------------------------------------------------- + // RTLightmapFilter - Ray traced lightmap denoiser + // -------------------------------------------------------------------------- + + class RTLightmapFilter : public AutoencoderFilter + { + public: + explicit RTLightmapFilter(const Ref& device); + std::shared_ptr makeTransferFunc() override; + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/buffer.h b/thirdparty/oidn/core/buffer.h new file mode 100644 index 000000000000..b95109152e72 --- /dev/null +++ b/thirdparty/oidn/core/buffer.h @@ -0,0 +1,75 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include "device.h" + +namespace oidn { + + class Device; + + // Buffer which may or may not own its data + class Buffer : public RefCount + { + private: + char* ptr; + size_t byteSize; + bool shared; + Ref device; + + public: + __forceinline Buffer(const Ref& device, size_t size) + : ptr((char*)alignedMalloc(size, 64)), + byteSize(size), + shared(false), + device(device) {} + + __forceinline Buffer(const Ref& device, void* data, size_t size) + : ptr((char*)data), + byteSize(size), + shared(true), + device(device) + { + if (data == nullptr) + throw Exception(Error::InvalidArgument, "buffer pointer null"); + } + + __forceinline ~Buffer() + { + if (!shared) + alignedFree(ptr); + } + + __forceinline char* data() { return ptr; } + __forceinline const char* data() const { return ptr; } + __forceinline size_t size() const { return byteSize; } + + void* map(size_t offset, size_t size) + { + if (offset + size > byteSize) + throw Exception(Error::InvalidArgument, "buffer region out of range"); + + return ptr + offset; + } + + void unmap(void* mappedPtr) {} + + Device* getDevice() { return device.get(); } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/common.h b/thirdparty/oidn/core/common.h new file mode 100644 index 000000000000..a35dd908b46c --- /dev/null +++ b/thirdparty/oidn/core/common.h @@ -0,0 +1,136 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common/platform.h" + +#include "mkl-dnn/include/mkldnn.hpp" +#include "mkl-dnn/include/mkldnn_debug.h" +#include "mkl-dnn/src/common/mkldnn_thread.hpp" +#include "mkl-dnn/src/common/type_helpers.hpp" +#include "mkl-dnn/src/cpu/jit_generator.hpp" + +#include "common/ref.h" +#include "common/exception.h" +#include "common/thread.h" +// -- GODOT start -- +//#include "common/tasking.h" +// -- GODOT end -- +#include "math.h" + +namespace oidn { + + using namespace mkldnn; + using namespace mkldnn::impl::cpu; + using mkldnn::impl::parallel_nd; + using mkldnn::impl::memory_desc_matches_tag; + + + inline size_t getFormatBytes(Format format) + { + switch (format) + { + case Format::Undefined: return 1; + case Format::Float: return sizeof(float); + case Format::Float2: return sizeof(float)*2; + case Format::Float3: return sizeof(float)*3; + case Format::Float4: return sizeof(float)*4; + } + assert(0); + return 0; + } + + + inline memory::dims getTensorDims(const std::shared_ptr& mem) + { + const mkldnn_memory_desc_t& desc = mem->get_desc().data; + return memory::dims(&desc.dims[0], &desc.dims[desc.ndims]); + } + + inline memory::data_type getTensorType(const std::shared_ptr& mem) + { + const mkldnn_memory_desc_t& desc = mem->get_desc().data; + return memory::data_type(desc.data_type); + } + + // Returns the number of values in a tensor + inline size_t getTensorSize(const memory::dims& dims) + { + size_t res = 1; + for (int i = 0; i < (int)dims.size(); ++i) + res *= dims[i]; + return res; + } + + inline memory::dims getMaxTensorDims(const std::vector& dims) + { + memory::dims result; + size_t maxSize = 0; + + for (const auto& d : dims) + { + const size_t size = getTensorSize(d); + if (size > maxSize) + { + result = d; + maxSize = size; + } + } + + return result; + } + + inline size_t getTensorSize(const std::shared_ptr& mem) + { + return getTensorSize(getTensorDims(mem)); + } + + + template + inline int getPadded(int dim) + { + return (dim + (K-1)) & ~(K-1); + } + + template + inline memory::dims getPadded_nchw(const memory::dims& dims) + { + assert(dims.size() == 4); + memory::dims padDims = dims; + padDims[1] = getPadded(dims[1]); // pad C + return padDims; + } + + + template + struct BlockedFormat; + + template<> + struct BlockedFormat<8> + { + static constexpr memory::format_tag nChwKc = memory::format_tag::nChw8c; + static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw8i8o; + }; + + template<> + struct BlockedFormat<16> + { + static constexpr memory::format_tag nChwKc = memory::format_tag::nChw16c; + static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw16i16o; + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/device.cpp b/thirdparty/oidn/core/device.cpp new file mode 100644 index 000000000000..3cd658b9c808 --- /dev/null +++ b/thirdparty/oidn/core/device.cpp @@ -0,0 +1,238 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "device.h" +#include "autoencoder.h" + +namespace oidn { + + thread_local Device::ErrorState Device::globalError; + + Device::Device() + { + if (!mayiuse(sse41)) + throw Exception(Error::UnsupportedHardware, "SSE4.1 support is required at minimum"); + } + + Device::~Device() + { + // -- GODOT start -- + //observer.reset(); + // -- GODOT end -- + } + + void Device::setError(Device* device, Error code, const std::string& message) + { + // Update the stored error only if the previous error was queried + if (device) + { + ErrorState& curError = device->error.get(); + + if (curError.code == Error::None) + { + curError.code = code; + curError.message = message; + } + + // Print the error message in verbose mode + if (device->isVerbose()) + std::cerr << "Error: " << message << std::endl; + + // Call the error callback function + ErrorFunction errorFunc; + void* errorUserPtr; + + { + std::lock_guard lock(device->mutex); + errorFunc = device->errorFunc; + errorUserPtr = device->errorUserPtr; + } + + if (errorFunc) + errorFunc(errorUserPtr, code, (code == Error::None) ? nullptr : message.c_str()); + } + else + { + if (globalError.code == Error::None) + { + globalError.code = code; + globalError.message = message; + } + } + } + + Error Device::getError(Device* device, const char** outMessage) + { + // Return and clear the stored error code, but keep the error message so pointers to it will + // remain valid until the next getError call + if (device) + { + ErrorState& curError = device->error.get(); + const Error code = curError.code; + if (outMessage) + *outMessage = (code == Error::None) ? nullptr : curError.message.c_str(); + curError.code = Error::None; + return code; + } + else + { + const Error code = globalError.code; + if (outMessage) + *outMessage = (code == Error::None) ? nullptr : globalError.message.c_str(); + globalError.code = Error::None; + return code; + } + } + + void Device::setErrorFunction(ErrorFunction func, void* userPtr) + { + errorFunc = func; + errorUserPtr = userPtr; + } + + int Device::get1i(const std::string& name) + { + if (name == "numThreads") + return numThreads; + else if (name == "setAffinity") + return setAffinity; + else if (name == "verbose") + return verbose; + else if (name == "version") + return OIDN_VERSION; + else if (name == "versionMajor") + return OIDN_VERSION_MAJOR; + else if (name == "versionMinor") + return OIDN_VERSION_MINOR; + else if (name == "versionPatch") + return OIDN_VERSION_PATCH; + else + throw Exception(Error::InvalidArgument, "invalid parameter"); + } + + void Device::set1i(const std::string& name, int value) + { + if (name == "numThreads") + numThreads = value; + else if (name == "setAffinity") + setAffinity = value; + else if (name == "verbose") + { + verbose = value; + error.verbose = value; + } + + dirty = true; + } + + void Device::commit() + { + if (isCommitted()) + throw Exception(Error::InvalidOperation, "device can be committed only once"); + + // -- GODOT start -- + #if 0 + // -- GODOT end -- + // Get the optimal thread affinities + if (setAffinity) + { + affinity = std::make_shared(1, verbose); // one thread per core + if (affinity->getNumThreads() == 0) + affinity.reset(); + } + + // Create the task arena + const int maxNumThreads = affinity ? affinity->getNumThreads() : tbb::this_task_arena::max_concurrency(); + numThreads = (numThreads > 0) ? min(numThreads, maxNumThreads) : maxNumThreads; + arena = std::make_shared(numThreads); + + // Automatically set the thread affinities + if (affinity) + observer = std::make_shared(affinity, *arena); + // -- GODOT start -- + #endif + numThreads = 1; + // -- GODOT end -- + dirty = false; + + if (isVerbose()) + print(); + } + + void Device::checkCommitted() + { + if (dirty) + throw Exception(Error::InvalidOperation, "changes to the device are not committed"); + } + + Ref Device::newBuffer(size_t byteSize) + { + checkCommitted(); + return makeRef(Ref(this), byteSize); + } + + Ref Device::newBuffer(void* ptr, size_t byteSize) + { + checkCommitted(); + return makeRef(Ref(this), ptr, byteSize); + } + + Ref Device::newFilter(const std::string& type) + { + checkCommitted(); + + if (isVerbose()) + std::cout << "Filter: " << type << std::endl; + + Ref filter; + +// -- GODOT start -- +// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. +#if 0 +// -- GODOT end -- + if (type == "RT") + filter = makeRef(Ref(this)); +// -- GODOT start -- +// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. +#endif + if (type == "RTLightmap") +// -- GODOT end -- + filter = makeRef(Ref(this)); + else + throw Exception(Error::InvalidArgument, "unknown filter type"); + + return filter; + } + + void Device::print() + { + std::cout << std::endl; + + std::cout << "Intel(R) Open Image Denoise " << OIDN_VERSION_STRING << std::endl; + std::cout << " Compiler: " << getCompilerName() << std::endl; + std::cout << " Build : " << getBuildName() << std::endl; + std::cout << " Platform: " << getPlatformName() << std::endl; + +// -- GODOT start -- +// std::cout << " Tasking :"; +// std::cout << " TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR; +// std::cout << " TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version(); +// std::cout << std::endl; +// -- GODOT end -- + std::cout << std::endl; + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/device.h b/thirdparty/oidn/core/device.h new file mode 100644 index 000000000000..d9cfd8541abf --- /dev/null +++ b/thirdparty/oidn/core/device.h @@ -0,0 +1,102 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" + +namespace oidn { + + class Buffer; + class Filter; + + class Device : public RefCount, public Verbose + { + private: + // Thread-safety + std::mutex mutex; + + // Error handling + struct ErrorState + { + Error code = Error::None; + std::string message; + }; + + static thread_local ErrorState globalError; + ThreadLocal error; + ErrorFunction errorFunc = nullptr; + void* errorUserPtr = nullptr; + +// -- GODOT start -- +// // Tasking +// std::shared_ptr arena; +// std::shared_ptr observer; +// std::shared_ptr affinity; +// -- GODOT end -- + + // Parameters + int numThreads = 0; // autodetect by default + bool setAffinity = true; + + bool dirty = true; + + public: + Device(); + ~Device(); + + static void setError(Device* device, Error code, const std::string& message); + static Error getError(Device* device, const char** outMessage); + + void setErrorFunction(ErrorFunction func, void* userPtr); + + int get1i(const std::string& name); + void set1i(const std::string& name, int value); + + void commit(); + +// -- GODOT start -- +// template +// void executeTask(F& f) +// { +// arena->execute(f); +// } + +// template +// void executeTask(const F& f) +// { +// arena->execute(f); +// } +// -- GODOT end -- + + Ref newBuffer(size_t byteSize); + Ref newBuffer(void* ptr, size_t byteSize); + Ref newFilter(const std::string& type); + + __forceinline Device* getDevice() { return this; } + __forceinline std::mutex& getMutex() { return mutex; } + + private: +// -- GODOT start -- + //bool isCommitted() const { return bool(arena); } + bool isCommitted() const { return false; } +// -- GODOT end -- + void checkCommitted(); + + void print(); + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/filter.cpp b/thirdparty/oidn/core/filter.cpp new file mode 100644 index 000000000000..ec1f10af871f --- /dev/null +++ b/thirdparty/oidn/core/filter.cpp @@ -0,0 +1,27 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "filter.h" + +namespace oidn { + + void Filter::setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr) + { + progressFunc = func; + progressUserPtr = userPtr; + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/filter.h b/thirdparty/oidn/core/filter.h new file mode 100644 index 000000000000..935fa202f45c --- /dev/null +++ b/thirdparty/oidn/core/filter.h @@ -0,0 +1,52 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include "device.h" +#include "image.h" + +namespace oidn { + + class Filter : public RefCount + { + protected: + Ref device; + + ProgressMonitorFunction progressFunc = nullptr; + void* progressUserPtr = nullptr; + + bool dirty = true; + + public: + explicit Filter(const Ref& device) : device(device) {} + + virtual void setImage(const std::string& name, const Image& data) = 0; + virtual void set1i(const std::string& name, int value) = 0; + virtual int get1i(const std::string& name) = 0; + virtual void set1f(const std::string& name, float value) = 0; + virtual float get1f(const std::string& name) = 0; + + void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr); + + virtual void commit() = 0; + virtual void execute() = 0; + + Device* getDevice() { return device.get(); } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/image.h b/thirdparty/oidn/core/image.h new file mode 100644 index 000000000000..748f49c4e5c4 --- /dev/null +++ b/thirdparty/oidn/core/image.h @@ -0,0 +1,111 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include "buffer.h" + +namespace oidn { + + struct Image + { + static constexpr int maxSize = 65536; + + char* ptr; // pointer to the first pixel + int width; // width in number of pixels + int height; // height in number of pixels + size_t bytePixelStride; // pixel stride in number of *bytes* + size_t rowStride; // row stride in number of *pixel strides* + Format format; // pixel format + Ref buffer; // buffer containing the image data + + Image() : ptr(nullptr), width(0), height(0), bytePixelStride(0), rowStride(0), format(Format::Undefined) {} + + Image(void* ptr, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride) + { + if (ptr == nullptr) + throw Exception(Error::InvalidArgument, "buffer pointer null"); + + init((char*)ptr + byteOffset, format, width, height, inBytePixelStride, inByteRowStride); + } + + Image(const Ref& buffer, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride) + { + init(buffer->data() + byteOffset, format, width, height, inBytePixelStride, inByteRowStride); + + if (byteOffset + height * rowStride * bytePixelStride > buffer->size()) + throw Exception(Error::InvalidArgument, "buffer region out of range"); + } + + void init(char* ptr, Format format, int width, int height, size_t inBytePixelStride, size_t inByteRowStride) + { + assert(width >= 0); + assert(height >= 0); + if (width > maxSize || height > maxSize) + throw Exception(Error::InvalidArgument, "image size too large"); + + this->ptr = ptr; + this->width = width; + this->height = height; + + const size_t pixelSize = getFormatBytes(format); + if (inBytePixelStride != 0) + { + if (inBytePixelStride < pixelSize) + throw Exception(Error::InvalidArgument, "pixel stride smaller than pixel size"); + + this->bytePixelStride = inBytePixelStride; + } + else + { + this->bytePixelStride = pixelSize; + } + + if (inByteRowStride != 0) + { + if (inByteRowStride < width * this->bytePixelStride) + throw Exception(Error::InvalidArgument, "row stride smaller than width * pixel stride"); + if (inByteRowStride % this->bytePixelStride != 0) + throw Exception(Error::InvalidArgument, "row stride not integer multiple of pixel stride"); + + this->rowStride = inByteRowStride / this->bytePixelStride; + } + else + { + this->rowStride = width; + } + + this->format = format; + } + + __forceinline char* get(int y, int x) + { + return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride); + } + + __forceinline const char* get(int y, int x) const + { + return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride); + } + + operator bool() const + { + return ptr != nullptr; + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/input_reorder.h b/thirdparty/oidn/core/input_reorder.h new file mode 100644 index 000000000000..966856afe9e5 --- /dev/null +++ b/thirdparty/oidn/core/input_reorder.h @@ -0,0 +1,232 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" +#include "image.h" + +namespace oidn { + + // Input reorder node + template + class InputReorderNode : public Node + { + private: + // Source + Image color; + Image albedo; + Image normal; + + // Destination + std::shared_ptr dst; + float* dstPtr; + int C2; + int H2; + int W2; + + // Tile + int h1Begin; + int w1Begin; + int h2Begin; + int w2Begin; + int H; + int W; + + std::shared_ptr transferFunc; + + public: + InputReorderNode(const Image& color, + const Image& albedo, + const Image& normal, + const std::shared_ptr& dst, + const std::shared_ptr& transferFunc) + : color(color), albedo(albedo), normal(normal), + dst(dst), + h1Begin(0), w1Begin(0), + H(color.height), W(color.width), + transferFunc(transferFunc) + { + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); + assert(dstDesc.ndims == 4); + assert(dstDesc.data_type == memory::data_type::f32); + assert(dstDesc.dims[0] == 1); + //assert(dstDesc.dims[1] >= getPadded(C1)); + + dstPtr = (float*)dst->get_data_handle(); + C2 = dstDesc.dims[1]; + H2 = dstDesc.dims[2]; + W2 = dstDesc.dims[3]; + } + + void setTile(int h1, int w1, int h2, int w2, int H, int W) override + { + h1Begin = h1; + w1Begin = w1; + h2Begin = h2; + w2Begin = w2; + this->H = H; + this->W = W; + } + + void execute(stream& sm) override + { + assert(H + h1Begin <= color.height); + assert(W + w1Begin <= color.width); + assert(H + h2Begin <= H2); + assert(W + w2Begin <= W2); + + parallel_nd(H2, [&](int h2) + { + const int h = h2 - h2Begin; + + if (h >= 0 && h < H) + { + const int h1 = h + h1Begin; + + // Zero pad + for (int w2 = 0; w2 < w2Begin; ++w2) + { + int c = 0; + while (c < C2) + store(h2, w2, c, 0.f); + } + + // Reorder + for (int w = 0; w < W; ++w) + { + const int w1 = w + w1Begin; + const int w2 = w + w2Begin; + + int c = 0; + storeColor(h2, w2, c, (float*)color.get(h1, w1)); + if (albedo) + storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1)); + if (normal) + storeNormal(h2, w2, c, (float*)normal.get(h1, w1)); + while (c < C2) + store(h2, w2, c, 0.f); + } + + // Zero pad + for (int w2 = W + w2Begin; w2 < W2; ++w2) + { + int c = 0; + while (c < C2) + store(h2, w2, c, 0.f); + } + } + else + { + // Zero pad + for (int w2 = 0; w2 < W2; ++w2) + { + int c = 0; + while (c < C2) + store(h2, w2, c, 0.f); + } + } + }); + } + + std::shared_ptr getDst() const override { return dst; } + + private: + // Stores a single value + __forceinline void store(int h, int w, int& c, float value) + { + // Destination is in nChwKc format + float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K); + *dst_c = value; + c++; + } + + // Stores a color + __forceinline void storeColor(int h, int w, int& c, const float* values) + { + #pragma unroll + for (int i = 0; i < 3; ++i) + { + // Load the value + float x = values[i]; + + // Sanitize the value + x = maxSafe(x, 0.f); + + // Apply the transfer function + x = transferFunc->forward(x); + + // Store the value + store(h, w, c, x); + } + } + + // Stores an albedo + __forceinline void storeAlbedo(int h, int w, int& c, const float* values) + { + #pragma unroll + for (int i = 0; i < 3; ++i) + { + // Load the value + float x = values[i]; + + // Sanitize the value + x = clampSafe(x, 0.f, 1.f); + + // Store the value + store(h, w, c, x); + } + } + + // Stores a normal + __forceinline void storeNormal(int h, int w, int& c, const float* values) + { + // Load the normal + float x = values[0]; + float y = values[1]; + float z = values[2]; + + // Compute the length of the normal + const float lengthSqr = sqr(x) + sqr(y) + sqr(z); + + // Normalize the normal and transform it to [0..1] + if (isfinite(lengthSqr)) + { + const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f; + + const float scale = invLength * 0.5f; + const float offset = 0.5f; + + x = x * scale + offset; + y = y * scale + offset; + z = z * scale + offset; + } + else + { + x = 0.f; + y = 0.f; + z = 0.f; + } + + // Store the normal + store(h, w, c, x); + store(h, w, c, y); + store(h, w, c, z); + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/math.h b/thirdparty/oidn/core/math.h new file mode 100644 index 000000000000..a844ef0d1dc0 --- /dev/null +++ b/thirdparty/oidn/core/math.h @@ -0,0 +1,78 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common/platform.h" + +namespace oidn { + + constexpr float minVectorLength = 1e-10f; + constexpr float minVectorLengthSqr = minVectorLength * minVectorLength; + + using std::log; + using std::log2; + using std::exp; + using std::exp2; + using std::pow; + using std::isfinite; + using std::isnan; + + __forceinline float sqr(float x) + { + return x * x; + } + + __forceinline float rcp(float x) + { + __m128 r = _mm_rcp_ss(_mm_set_ss(x)); + return _mm_cvtss_f32(_mm_sub_ss(_mm_add_ss(r, r), _mm_mul_ss(_mm_mul_ss(r, r), _mm_set_ss(x)))); + } + + __forceinline float rsqrt(float x) + { + __m128 r = _mm_rsqrt_ss(_mm_set_ss(x)); + return _mm_cvtss_f32(_mm_add_ss(_mm_mul_ss(_mm_set_ss(1.5f), r), + _mm_mul_ss(_mm_mul_ss(_mm_mul_ss(_mm_set_ss(x), _mm_set_ss(-0.5f)), r), _mm_mul_ss(r, r)))); + } + + __forceinline float maxSafe(float value, float minValue) + { + return isfinite(value) ? max(value, minValue) : minValue; + } + + __forceinline float clampSafe(float value, float minValue, float maxValue) + { + return isfinite(value) ? clamp(value, minValue, maxValue) : minValue; + } + + // Returns ceil(a / b) for non-negative integers + template + __forceinline constexpr Int ceilDiv(Int a, Int b) + { + //assert(a >= 0); + //assert(b > 0); + return (a + b - 1) / b; + } + + // Returns a rounded up to multiple of b + template + __forceinline constexpr Int roundUp(Int a, Int b) + { + return ceilDiv(a, b) * b; + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/network.cpp b/thirdparty/oidn/core/network.cpp new file mode 100644 index 000000000000..ed8328c954dc --- /dev/null +++ b/thirdparty/oidn/core/network.cpp @@ -0,0 +1,436 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "upsample.h" +#include "weights_reorder.h" +#include "network.h" +// -- GODOT start -- +#include +// -- GODOT end -- + +namespace oidn { + + template + Network::Network(const Ref& device, const std::map& weightMap) + : device(device), + eng(engine::cpu, 0), + sm(eng), + weightMap(weightMap) + { + } + + template + void Network::execute(const Progress& progress, int taskIndex) + { + if (progress.func) + { + const double value = double(taskIndex) / double(progress.taskCount); + if (!progress.func(progress.userPtr, value)) + throw Exception(Error::Cancelled, "execution was cancelled"); + } + + for (size_t i = 0; i < nodes.size(); ++i) + { + nodes[i]->execute(sm); + + if (progress.func) + { + const double value = (double(taskIndex) + double(i+1) / double(nodes.size())) / double(progress.taskCount); + if (!progress.func(progress.userPtr, value)) + throw Exception(Error::Cancelled, "execution was cancelled"); + } + } + } + + template + std::shared_ptr Network::allocTensor(const memory::dims& dims, + memory::format_tag format, + void* data) + { + if (format == memory::format_tag::any) + { + if (dims.size() == 4) + format = BlockedFormat::nChwKc; + else if (dims.size() == 1) + format = memory::format_tag::x; + else + assert(0); + } + memory::desc desc(dims, memory::data_type::f32, format); + if (data == nullptr) + { + const size_t bytes = getTensorSize(dims) * sizeof(float); + if (format == BlockedFormat::nChwKc) + activationAllocBytes += bytes; + totalAllocBytes += bytes; + + return std::make_shared(desc, eng); + } + else + { + return std::make_shared(desc, eng, data); + } + } + + template + std::shared_ptr Network::castTensor(const memory::dims& dims, + const std::shared_ptr& src, + size_t srcOffset, + memory::format_tag format) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + MAYBE_UNUSED(srcDesc); + assert(srcDesc.data_type == memory::data_type::f32); + assert(getTensorSize(src) >= srcOffset + getTensorSize(dims)); + + if (format == memory::format_tag::any) + { + if (dims.size() == 4) + format = BlockedFormat::nChwKc; + else if (dims.size() == 1) + format = memory::format_tag::x; + else + assert(0); + } + memory::desc desc(dims, memory::data_type::f32, format); + float* srcPtr = (float*)src->get_data_handle() + srcOffset; + return std::make_shared(desc, eng, srcPtr); + } + + template + std::shared_ptr Network::castTensor(const memory::dims& dims, + const std::shared_ptr& src, + const memory::dims& srcOffset) + { + return castTensor(dims, src, getTensorSize(srcOffset)); + } + + template + void Network::zeroTensor(const std::shared_ptr& dst) + { + assert(getTensorType(dst) == memory::data_type::f32); + memset(dst->get_data_handle(), 0, getTensorSize(dst)*sizeof(float)); + } + + template + memory::dims Network::getInputReorderDims(const memory::dims& srcDims, int alignment) + { + memory::dims dstDims = srcDims; + dstDims[1] = getPadded(srcDims[1]); // round up C + dstDims[2] = roundUp(srcDims[2], memory::dim(alignment)); // round up H + dstDims[3] = roundUp(srcDims[3], memory::dim(alignment)); // round up W + return dstDims; + } + + template + std::shared_ptr Network::addInputReorder(const Image& color, + const Image& albedo, + const Image& normal, + const std::shared_ptr& transferFunc, + int alignment, + const std::shared_ptr& userDst) + { + assert(color); + int inputC = 3; + if (albedo) inputC += 3; + if (normal) inputC += 3; + + memory::dims srcDims = {1, inputC, color.height, color.width}; + memory::dims dstDims = getInputReorderDims(srcDims, alignment); + + // Allocate padded memory + auto dst = userDst; + if (!dst) + dst = allocTensor(dstDims); + + // Push node + std::shared_ptr node; + + if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(color, albedo, normal, dst, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(color, albedo, normal, dst, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(color, albedo, normal, dst, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(color, albedo, normal, dst, tf); + else + assert(0); + + nodes.push_back(node); + return node; + } + + template + std::shared_ptr Network::addOutputReorder(const std::shared_ptr& src, + const std::shared_ptr& transferFunc, + const Image& output) + { + memory::dims srcDims = getTensorDims(src); + assert(srcDims[1] == K); + + // Push node + std::shared_ptr node; + + if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(src, output, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(src, output, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(src, output, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(src, output, tf); + else + assert(0); + + nodes.push_back(node); + return node; + } + + template + memory::dims Network::getConvDims(const std::string& name, const memory::dims& srcDims) + { + auto b = weightMap[name + "/b"]; + memory::dims dstDims = srcDims; + dstDims[1] = getPadded(b.dims[0]); // dstDims[C] = getPadded(OC) + return dstDims; + } + + template + std::shared_ptr Network::addConv(const std::string& name, + const std::shared_ptr& src, + const std::shared_ptr& userDst, + bool relu) + { + const memory::dims strides = {1, 1}; + const memory::dims padding = {1, 1}; + + memory::dims srcDims = getTensorDims(src); + + // Get the weights + const auto& W = weightMap[name + "/W"]; + if (W.ndims() != 4 || W.format != "oihw") + throw Exception(Error::InvalidOperation, "invalid convolution weights"); + memory::dims weightsDims = W.dims; + auto userWeights = allocTensor(weightsDims, memory::format_tag::oihw, W.data); + + // Pad the weights + memory::dims weightsPadDims = weightsDims; + weightsPadDims[1] = getPadded(weightsDims[1]); // IC + weightsPadDims[0] = getPadded(weightsDims[0]); // OC + assert(srcDims[1] == weightsPadDims[1]); // srcDims[C] == weightsPadDims[IC] + auto weightsPad = allocTensor(weightsPadDims, memory::format_tag::oihw); + WeightsReorderNode(userWeights, weightsPad).execute(sm); + + // Get the biases + const auto& b = weightMap[name + "/b"]; + if (b.ndims() != 1) + throw Exception(Error::InvalidOperation, "invalid convolution biases"); + memory::dims biasDims = b.dims; + + // Copy/pad the biases + memory::dims biasPadDims = {getPadded(biasDims[0])}; + auto bias = allocTensor(biasPadDims); + if (biasDims[0] != biasPadDims[0]) + memset(bias->get_data_handle(), 0, biasPadDims[0]*sizeof(float)); + memcpy(bias->get_data_handle(), b.data, biasDims[0]*sizeof(float)); + + // Allocate memory for destination + memory::dims dstDims = srcDims; + dstDims[1] = weightsPadDims[0]; // dstDims[C] = weightsPadDims[OC] + + std::shared_ptr dst; + if (!userDst) + dst = allocTensor(dstDims); + else if (getTensorDims(userDst) == dstDims) + dst = userDst; + else + dst = castTensor(dstDims, userDst); + + // Create a convolution + // Let the convolution primitive choose the weights format + auto weightsDesc = memory::desc({ weightsPadDims }, memory::data_type::f32, memory::format_tag::any); + + auto convAlgo = (K == 16) ? convolution_winograd : convolution_direct; + auto convDesc = convolution_forward::desc( + prop_kind::forward_inference, convAlgo, + src->get_desc(), + weightsDesc, + bias->get_desc(), + dst->get_desc(), + strides, padding, padding, padding_kind::zero); + + // Incorporate relu + mkldnn::primitive_attr convAttr; + if (relu) + { + mkldnn::post_ops ops; + ops.append_eltwise( + 1.f, // scale factor, not used + algorithm::eltwise_relu, + 0.f, // max with + 0.f // unused + ); + convAttr.set_post_ops(ops); + } + convAttr.set_scratchpad_mode(scratchpad_mode_user); + + auto convPrimDesc = convolution_forward::primitive_desc(convDesc, convAttr, eng); + + // Reorder the weights to the final format, if necessary + auto weights = weightsPad; + if (convPrimDesc.weights_desc() != weightsPad->get_desc()) + { + weights = std::make_shared(convPrimDesc.weights_desc(), eng); + ReorderNode(weightsPad, weights).execute(sm); + } + + // Create convolution node and add it to the net + auto node = std::make_shared(convPrimDesc, src, weights, bias, dst); + nodes.push_back(node); + return node; + } + + template + memory::dims Network::getPoolDims(const memory::dims& srcDims) + { + memory::dims dstDims = srcDims; + dstDims[2] /= 2; // H/2 + dstDims[3] /= 2; // W/2 + return dstDims; + } + + template + std::shared_ptr Network::addPool(const std::shared_ptr& src, + const std::shared_ptr& userDst) + { + const memory::dims kernel = {2, 2}; + const memory::dims strides = {2, 2}; + const memory::dims padding = {0, 0}; + + memory::dims srcDims = getTensorDims(src); + memory::dims dstDims = getPoolDims(srcDims); + + std::shared_ptr dst; + if (!userDst) + dst = allocTensor(dstDims); + else if (getTensorDims(userDst) == dstDims) + dst = userDst; + else + dst = castTensor(dstDims, userDst); + + auto poolDesc = pooling_forward::desc( + prop_kind::forward_inference, pooling_max, + src->get_desc(), + dst->get_desc(), + strides, kernel, padding, padding, padding_kind::zero); + + mkldnn::primitive_attr poolAttr; + poolAttr.set_scratchpad_mode(scratchpad_mode_user); + + auto poolPrimDesc = pooling_forward::primitive_desc(poolDesc, poolAttr, eng); + + auto node = std::make_shared(poolPrimDesc, src, dst); + nodes.push_back(node); + return node; + } + + template + memory::dims Network::getUpsampleDims(const memory::dims& srcDims) + { + memory::dims dstDims = srcDims; + dstDims[2] *= 2; // H*2 + dstDims[3] *= 2; // W*2 + return dstDims; + } + + template + std::shared_ptr Network::addUpsample(const std::shared_ptr& src, + const std::shared_ptr& userDst) + { + memory::dims srcDims = getTensorDims(src); + memory::dims dstDims = getUpsampleDims(srcDims); + + std::shared_ptr dst; + if (!userDst) + dst = allocTensor(dstDims); + else if (getTensorDims(userDst) == dstDims) + dst = userDst; + else + dst = castTensor(dstDims, userDst); + + // Create upsampling node and add it to net + auto node = std::make_shared>(src, dst); + nodes.push_back(node); + return node; + } + + template + memory::dims Network::getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims) + { + assert(src1Dims[0] == src2Dims[0]); // N + assert(src1Dims[2] == src2Dims[2]); // H + assert(src1Dims[3] == src2Dims[3]); // W + + memory::dims dstDims = src1Dims; + dstDims[1] += src2Dims[1]; // C + return dstDims; + } + + template + std::shared_ptr Network::addAutoexposure(const Image& color, + const std::shared_ptr& transferFunc) + { + auto node = std::make_shared(color, transferFunc); + nodes.push_back(node); + return node; + } + + template + void Network::finalize() + { + // Compute the size of the scratchpad + size_t scratchpadSize = 0; + for (const auto& node : nodes) + scratchpadSize = max(scratchpadSize, node->getScratchpadSize()); + + // Allocate the scratchpad + memory::dims scratchpadDims = { memory::dim(scratchpadSize) }; + memory::desc scratchpadDesc(scratchpadDims, memory::data_type::u8, memory::format_tag::x); + auto scratchpad = std::make_shared(scratchpadDesc, eng); + activationAllocBytes += scratchpadSize; + totalAllocBytes += scratchpadSize; + + // Set the scratchpad for the nodes + for (auto& node : nodes) + node->setScratchpad(scratchpad); + + // Free the weights + weightMap.clear(); + + // Print statistics + if (device->isVerbose(2)) + { + std::cout << "Activation bytes: " << activationAllocBytes << std::endl; + std::cout << "Scratchpad bytes: " << scratchpadSize << std::endl; + std::cout << "Total bytes : " << totalAllocBytes << std::endl; + } + } + + template class Network<8>; + template class Network<16>; + +} // namespace oidn diff --git a/thirdparty/oidn/core/network.h b/thirdparty/oidn/core/network.h new file mode 100644 index 000000000000..7a696fd3550d --- /dev/null +++ b/thirdparty/oidn/core/network.h @@ -0,0 +1,112 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "common/tensor.h" +#include "image.h" +#include "node.h" +#include "input_reorder.h" +#include "output_reorder.h" +#include "transfer_function.h" + +#pragma once + +namespace oidn { + + // Progress state + struct Progress + { + ProgressMonitorFunction func; + void* userPtr; + int taskCount; + }; + + class Executable + { + public: + virtual ~Executable() {} + virtual void execute(const Progress& progress, int taskIndex) = 0; + }; + + template + class Network : public Executable + { + public: + Network(const Ref& device, const std::map& weightMap); + + void execute(const Progress& progress, int taskIndex) override; + + std::shared_ptr allocTensor(const memory::dims& dims, + memory::format_tag format = memory::format_tag::any, + void* data = nullptr); + + std::shared_ptr castTensor(const memory::dims& dims, + const std::shared_ptr& src, + size_t srcOffset = 0, + memory::format_tag format = memory::format_tag::any); + + std::shared_ptr castTensor(const memory::dims& dims, + const std::shared_ptr& src, + const memory::dims& srcOffset); + + void zeroTensor(const std::shared_ptr& dst); + + memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment); + + std::shared_ptr addInputReorder(const Image& color, + const Image& albedo, + const Image& normal, + const std::shared_ptr& transferFunc, + int alignment, + const std::shared_ptr& userDst = nullptr); + + std::shared_ptr addOutputReorder(const std::shared_ptr& src, + const std::shared_ptr& transferFunc, + const Image& output); + + memory::dims getConvDims(const std::string& name, const memory::dims& srcDims); + std::shared_ptr addConv(const std::string& name, + const std::shared_ptr& src, + const std::shared_ptr& userDst = nullptr, + bool relu = true); + + memory::dims getPoolDims(const memory::dims& srcDims); + std::shared_ptr addPool(const std::shared_ptr& src, + const std::shared_ptr& userDst = nullptr); + + memory::dims getUpsampleDims(const memory::dims& srcDims); + std::shared_ptr addUpsample(const std::shared_ptr& src, + const std::shared_ptr& userDst = nullptr); + + memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims); + + std::shared_ptr addAutoexposure(const Image& color, + const std::shared_ptr& transferFunc); + + void finalize(); + + private: + Ref device; + engine eng; + stream sm; + std::vector> nodes; + std::map weightMap; + + // Memory allocation statistics + size_t activationAllocBytes = 0; // number of allocated activation bytes + size_t totalAllocBytes = 0; // total number of allocated bytes + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/node.h b/thirdparty/oidn/core/node.h new file mode 100644 index 000000000000..b9ffe906dff5 --- /dev/null +++ b/thirdparty/oidn/core/node.h @@ -0,0 +1,142 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include + +namespace oidn { + + class Node + { + public: + virtual ~Node() = default; + + virtual void execute(stream& sm) = 0; + + virtual std::shared_ptr getDst() const { return nullptr; } + + virtual size_t getScratchpadSize() const { return 0; } + virtual void setScratchpad(const std::shared_ptr& mem) {} + + virtual void setTile(int h1, int w1, int h2, int w2, int H, int W) + { + assert(0); // not supported + } + }; + + // Node wrapping an MKL-DNN primitive + class MklNode : public Node + { + private: + primitive prim; + std::unordered_map args; + std::shared_ptr scratchpad; + + public: + MklNode(const primitive& prim, const std::unordered_map& args) + : prim(prim), + args(args) + {} + + size_t getScratchpadSize() const override + { + const auto primDesc = prim.get_primitive_desc(); + const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0); + if (scratchpadDesc == nullptr) + return 0; + return mkldnn_memory_desc_get_size(scratchpadDesc); + } + + void setScratchpad(const std::shared_ptr& mem) override + { + scratchpad = mem; + args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad)); + } + + void execute(stream& sm) override + { + prim.execute(sm, args); + } + }; + + // Convolution node + class ConvNode : public MklNode + { + private: + std::shared_ptr src; + std::shared_ptr weights; + std::shared_ptr bias; + std::shared_ptr dst; + + public: + ConvNode(const convolution_forward::primitive_desc& desc, + const std::shared_ptr& src, + const std::shared_ptr& weights, + const std::shared_ptr& bias, + const std::shared_ptr& dst) + : MklNode(convolution_forward(desc), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_WEIGHTS, *weights }, + { MKLDNN_ARG_BIAS, *bias }, + { MKLDNN_ARG_DST, *dst } }), + src(src), weights(weights), bias(bias), dst(dst) + {} + + std::shared_ptr getDst() const override { return dst; } + }; + + // Pooling node + class PoolNode : public MklNode + { + private: + std::shared_ptr src; + std::shared_ptr dst; + + public: + PoolNode(const pooling_forward::primitive_desc& desc, + const std::shared_ptr& src, + const std::shared_ptr& dst) + : MklNode(pooling_forward(desc), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_DST, *dst } }), + src(src), dst(dst) + {} + + std::shared_ptr getDst() const override { return dst; } + }; + + // Reorder node + class ReorderNode : public MklNode + { + private: + std::shared_ptr src; + std::shared_ptr dst; + + public: + ReorderNode(const std::shared_ptr& src, + const std::shared_ptr& dst) + : MklNode(reorder(reorder::primitive_desc(*src, *dst)), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_DST, *dst } }), + src(src), dst(dst) + {} + + std::shared_ptr getDst() const override { return dst; } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/output_reorder.h b/thirdparty/oidn/core/output_reorder.h new file mode 100644 index 000000000000..7918d48e1508 --- /dev/null +++ b/thirdparty/oidn/core/output_reorder.h @@ -0,0 +1,126 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" +#include "image.h" + +namespace oidn { + + // Output reorder node + template + class OutputReorderNode : public Node + { + private: + // Source + std::shared_ptr src; + const float* srcPtr; + int H1; + int W1; + + // Destination + Image output; + + // Tile + int h1Begin; + int w1Begin; + int h2Begin; + int w2Begin; + int H; + int W; + + std::shared_ptr transferFunc; + + public: + OutputReorderNode(const std::shared_ptr& src, + const Image& output, + const std::shared_ptr& transferFunc) + : src(src), + output(output), + h1Begin(0), w1Begin(0), + h2Begin(0), w2Begin(0), + H(output.height), W(output.width), + transferFunc(transferFunc) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + MAYBE_UNUSED(srcDesc); + assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); + assert(srcDesc.ndims == 4); + assert(srcDesc.data_type == memory::data_type::f32); + assert(srcDesc.dims[0] == 1); + // We assume output data is <= K OC + assert(srcDesc.dims[1] == K); + + srcPtr = (float*)src->get_data_handle(); + H1 = srcDesc.dims[2]; + W1 = srcDesc.dims[3]; + } + + void setTile(int h1, int w1, int h2, int w2, int H, int W) override + { + h1Begin = h1; + w1Begin = w1; + h2Begin = h2; + w2Begin = w2; + this->H = H; + this->W = W; + } + + void execute(stream& sm) override + { + assert(h1Begin + H <= H1); + assert(w1Begin + W <= W1); + assert(h2Begin + H <= output.height); + assert(w2Begin + W <= output.width); + + const int C1 = K; + + parallel_nd(H, [&](int h) + { + const int h1 = h + h1Begin; + const int h2 = h + h2Begin; + + for (int w = 0; w < W; ++w) + { + const int w1 = w + w1Begin; + const int w2 = w + w2Begin; + float* dstPtr_C = (float*)output.get(h2, w2); + + // Source is in nChwKc format. In this case C is 1 so this is really nhwc + const float* srcPtr_C = srcPtr + h1*W1*C1 + w1*C1; + + #pragma unroll + for (int i = 0; i < 3; ++i) + { + // Load the value + float x = srcPtr_C[i]; + + // The CNN output may contain negative values or even NaNs, so it must be sanitized + x = maxSafe(x, 0.f); + + // Apply the inverse transfer function + x = transferFunc->inverse(x); + + // Sanitize and store the final value + dstPtr_C[i] = max(x, 0.f); + } + } + }); + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/transfer_function.cpp b/thirdparty/oidn/core/transfer_function.cpp new file mode 100644 index 000000000000..ce5deca56bc7 --- /dev/null +++ b/thirdparty/oidn/core/transfer_function.cpp @@ -0,0 +1,103 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "transfer_function.h" + +namespace oidn { + + const float LogTransferFunction::xScale = 1.f / log(LogTransferFunction::yMax + 1.f); + const float PQXTransferFunction::xScale = 1.f / PQXTransferFunction::pqxForward(PQXTransferFunction::yMax * PQXTransferFunction::yScale); + + float AutoexposureNode::autoexposure(const Image& color) + { + assert(color.format == Format::Float3); + + constexpr float key = 0.18f; + constexpr float eps = 1e-8f; + constexpr int K = 16; // downsampling amount + + // Downsample the image to minimize sensitivity to noise + const int H = color.height; // original height + const int W = color.width; // original width + const int HK = (H + K/2) / K; // downsampled height + const int WK = (W + K/2) / K; // downsampled width + + // Compute the average log luminance of the downsampled image + using Sum = std::pair; + + // -- GODOT start -- + // Sum sum = + // tbb::parallel_reduce( + // tbb::blocked_range2d(0, HK, 0, WK), + // Sum(0.f, 0), + // [&](const tbb::blocked_range2d& r, Sum sum) -> Sum + // { + // // Iterate over blocks + // for (int i = r.rows().begin(); i != r.rows().end(); ++i) + // { + // for (int j = r.cols().begin(); j != r.cols().end(); ++j) + // { + + Sum sum = Sum(0.0f, 0); + + for (int i = 0; i != HK; ++i) + { + for (int j = 0; j != WK; ++j) + { + // Compute the average luminance in the current block + const int beginH = int(ptrdiff_t(i) * H / HK); + const int beginW = int(ptrdiff_t(j) * W / WK); + const int endH = int(ptrdiff_t(i+1) * H / HK); + const int endW = int(ptrdiff_t(j+1) * W / WK); + + float L = 0.f; + + for (int h = beginH; h < endH; ++h) + { + for (int w = beginW; w < endW; ++w) + { + const float* rgb = (const float*)color.get(h, w); + + const float r = maxSafe(rgb[0], 0.f); + const float g = maxSafe(rgb[1], 0.f); + const float b = maxSafe(rgb[2], 0.f); + + L += luminance(r, g, b); + } + } + + L /= (endH - beginH) * (endW - beginW); + + // Accumulate the log luminance + if (L > eps) + { + sum.first += log2(L); + sum.second++; + } + } + } + + // return sum; + // }, + // [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); }, + // tbb::static_partitioner() + // ); + // -- GODOT end -- + + return (sum.second > 0) ? (key / exp2(sum.first / float(sum.second))) : 1.f; + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/transfer_function.h b/thirdparty/oidn/core/transfer_function.h new file mode 100644 index 000000000000..35f28330925c --- /dev/null +++ b/thirdparty/oidn/core/transfer_function.h @@ -0,0 +1,201 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "image.h" +#include "node.h" + +namespace oidn { + + __forceinline float luminance(float r, float g, float b) + { + return 0.212671f * r + 0.715160f * g + 0.072169f * b; + } + + // Color transfer function base class + class TransferFunction + { + public: + virtual ~TransferFunction() = default; + + virtual float forward(float y) const = 0; + virtual float inverse(float x) const = 0; + }; + + // HDR transfer function base class + class HDRTransferFunction : public TransferFunction + { + protected: + static constexpr float yMax = 65504.f; + + float exposure; + float rcpExposure; + + public: + HDRTransferFunction(float exposure = 1.f) + { + setExposure(exposure); + } + + void setExposure(float exposure) + { + this->exposure = exposure; + this->rcpExposure = (exposure != 0.f) ? (1.f / exposure) : 0.f; + } + }; + + // Linear transfer function (LDR) + class LinearTransferFunction : public TransferFunction + { + public: + __forceinline float forward(float y) const override + { + return min(y, 1.f); + } + + __forceinline float inverse(float x) const override + { + return min(x, 1.f); + } + }; + + // 2.2 gamma transfer function (LDR) + class GammaTransferFunction : public TransferFunction + { + public: + __forceinline float forward(float y) const override + { + return min(pow(y, 1.f/2.2f), 1.f); + } + + __forceinline float inverse(float x) const override + { + return min(pow(x, 2.2f), 1.f); + } + }; + + // Logarithmic transfer function (HDR) + // Compresses [0..65504] to [0..1] + class LogTransferFunction : public HDRTransferFunction + { + private: + static const float xScale; + + public: + LogTransferFunction(float exposure = 1.f) + : HDRTransferFunction(exposure) + { + } + + __forceinline float forward(float y) const override + { + return log(y * exposure + 1.f) * xScale; + } + + __forceinline float inverse(float x) const override + { + return (exp(x * (1.f/xScale)) - 1.f) * rcpExposure; + } + }; + + // PQX transfer function (HDR) + // Compresses [0..65504] to [0..1] + class PQXTransferFunction : public HDRTransferFunction + { + private: + static constexpr float m1 = 2610.f / 4096.f / 4.f; + static constexpr float m2 = 2523.f / 4096.f * 128.f; + static constexpr float c1 = 3424.f / 4096.f; + static constexpr float c2 = 2413.f / 4096.f * 32.f; + static constexpr float c3 = 2392.f / 4096.f * 32.f; + static constexpr float a = 3711.f / 4096.f / 8.f; + + static constexpr float yScale = 100.f / 10000.f; + static const float xScale; + + public: + PQXTransferFunction(float exposure = 1.f) + : HDRTransferFunction(exposure) + { + } + + __forceinline float forward(float y) const override + { + return pqxForward(y * exposure * yScale) * xScale; + } + + __forceinline float inverse(float x) const override + { + return pqxInverse(x * (1.f/xScale)) * (1.f/yScale) * rcpExposure; + } + + private: + static __forceinline float pqForward(float y) + { + const float yp = pow(y, m1); + return pow((c1 + c2 * yp) * rcp(1.f + c3 * yp), m2); + } + + static __forceinline float pqxForward(float y) + { + if (y <= 1.f) + return pqForward(y); + else + return a * log(y) + 1.f; + } + + static __forceinline float pqInverse(float x) + { + const float xp = pow(x, 1.f/m2); + return pow(max((xp - c1) * rcp(c2 - c3 * xp), 0.f), 1.f/m1); + } + + static __forceinline float pqxInverse(float x) + { + if (x <= 1.f) + return pqInverse(x); + else + return exp((x - 1.f) * (1.f/a)); + } + }; + + // Autoexposure node + class AutoexposureNode : public Node + { + private: + Image color; + std::shared_ptr transferFunc; + + public: + AutoexposureNode(const Image& color, + const std::shared_ptr& transferFunc) + : color(color), + transferFunc(transferFunc) + {} + + void execute(stream& sm) override + { + const float exposure = autoexposure(color); + //printf("exposure = %f\n", exposure); + transferFunc->setExposure(exposure); + } + + private: + static float autoexposure(const Image& color); + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/upsample.h b/thirdparty/oidn/core/upsample.h new file mode 100644 index 000000000000..f6cace44cd2d --- /dev/null +++ b/thirdparty/oidn/core/upsample.h @@ -0,0 +1,92 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" + +namespace oidn { + + // 2x2 nearest-neighbor upsampling node + template + class UpsampleNode : public Node + { + private: + std::shared_ptr src; + std::shared_ptr dst; + + public: + UpsampleNode(const std::shared_ptr& src, + const std::shared_ptr& dst) + : src(src), + dst(dst) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + MAYBE_UNUSED(srcDesc); + MAYBE_UNUSED(dstDesc); + assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); + assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); + assert(srcDesc.ndims == 4); + assert(dstDesc.ndims == 4); + assert(srcDesc.data_type == memory::data_type::f32); + assert(dstDesc.data_type == memory::data_type::f32); + assert(srcDesc.dims[0] == 1); + assert(dstDesc.dims[0] == 1); + // 2x2 upsampling + assert(dstDesc.dims[2] == srcDesc.dims[2] * 2); + assert(dstDesc.dims[3] == srcDesc.dims[3] * 2); + } + + void execute(stream& sm) override + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + + const float* srcPtr = (float*)src->get_data_handle(); + float* dstPtr = (float*)dst->get_data_handle(); + + const int C = srcDesc.dims[1]; + const int H = srcDesc.dims[2]; + const int W = srcDesc.dims[3]; + const int CK = C / K; + + parallel_nd(CK, H, [&](int ck, int h) + { + const size_t offset = ck*H*W*K + h*W*K; + const float* srcPtr_line = srcPtr + offset; + float* dstPtr_line0 = dstPtr + offset * 4; + float* dstPtr_line1 = dstPtr_line0 + W*2*K; // next line + + for (int w = 0; w < W; ++w) + { + #pragma unroll + for (int k = 0; k < K; k += 4) + { + const __m128 m = _mm_load_ps(&srcPtr_line[w*K + k]); + + _mm_stream_ps(&dstPtr_line0[w*2*K + k], m); + _mm_stream_ps(&dstPtr_line0[w*2*K+K + k], m); + _mm_stream_ps(&dstPtr_line1[w*2*K + k], m); + _mm_stream_ps(&dstPtr_line1[w*2*K+K + k], m); + } + } + }); + } + + std::shared_ptr getDst() const override { return dst; } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/weights_reorder.h b/thirdparty/oidn/core/weights_reorder.h new file mode 100644 index 000000000000..6c5dacb8aaa2 --- /dev/null +++ b/thirdparty/oidn/core/weights_reorder.h @@ -0,0 +1,99 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" + +namespace oidn { + + // Reorders weights from oihw to padded oihw format + template + class WeightsReorderNode : public Node + { + private: + std::shared_ptr src; + std::shared_ptr dst; + + public: + WeightsReorderNode(const std::shared_ptr& src, + const std::shared_ptr& dst) + : src(src), + dst(dst) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + MAYBE_UNUSED(srcDesc); + MAYBE_UNUSED(dstDesc); + assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(memory::format_tag::oihw))); + assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(memory::format_tag::oihw))); + assert(srcDesc.ndims == 4); + assert(dstDesc.ndims == 4); + assert(srcDesc.data_type == memory::data_type::f32); + assert(dstDesc.data_type == memory::data_type::f32); + assert(getPadded(srcDesc.dims[0]) == dstDesc.dims[0]); // OC + assert(getPadded(srcDesc.dims[1]) == dstDesc.dims[1]); // IC + assert(srcDesc.dims[2] == dstDesc.dims[2]); + assert(srcDesc.dims[3] == dstDesc.dims[3]); + } + + void execute(stream& sm) override + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + + const float* srcPtr = (float*)src->get_data_handle(); + float* dstPtr = (float*)dst->get_data_handle(); + + const int OC1 = srcDesc.dims[0]; + const int OC2 = dstDesc.dims[0]; + const int IC1 = srcDesc.dims[1]; + const int IC2 = dstDesc.dims[1]; + const int H = dstDesc.dims[2]; + const int W = dstDesc.dims[3]; + + for (int oc = 0; oc < OC2; ++oc) + { + for (int ic = 0; ic < IC2; ++ic) + { + for (int h = 0; h < H; ++h) + { + for (int w = 0; w < W; ++w) + { + // Output is in oihw format + float* dstPtr_c = dstPtr + oc*IC2*H*W + ic*H*W + h*W + w; + + if (oc < OC1 && ic < IC1) + { + // Input is in oihw format + const float* srcPtr_c = srcPtr + oc*IC1*H*W + ic*H*W + h*W + w; + *dstPtr_c = *srcPtr_c; + } + else + { + // padding + *dstPtr_c = 0; + } + } + } + } + } + } + + std::shared_ptr getDst() const override { return dst; } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/include/OpenImageDenoise/oidn.h b/thirdparty/oidn/include/OpenImageDenoise/oidn.h new file mode 100644 index 000000000000..57ba6baa213e --- /dev/null +++ b/thirdparty/oidn/include/OpenImageDenoise/oidn.h @@ -0,0 +1,214 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include +#include +#include + +#include "version.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +#ifndef OIDN_API +#if defined(_WIN32) && !defined(OIDN_STATIC_LIB) +# define OIDN_API __declspec(dllimport) +#else +# define OIDN_API +#endif +#endif + +// ---------------------------------------------------------------------------- +// Device +// ---------------------------------------------------------------------------- + +// Device types +typedef enum +{ + OIDN_DEVICE_TYPE_DEFAULT = 0, // select device automatically + + OIDN_DEVICE_TYPE_CPU = 1, // CPU device +} OIDNDeviceType; + +// Error codes +typedef enum +{ + OIDN_ERROR_NONE = 0, // no error occurred + OIDN_ERROR_UNKNOWN = 1, // an unknown error occurred + OIDN_ERROR_INVALID_ARGUMENT = 2, // an invalid argument was specified + OIDN_ERROR_INVALID_OPERATION = 3, // the operation is not allowed + OIDN_ERROR_OUT_OF_MEMORY = 4, // not enough memory to execute the operation + OIDN_ERROR_UNSUPPORTED_HARDWARE = 5, // the hardware (e.g. CPU) is not supported + OIDN_ERROR_CANCELLED = 6, // the operation was cancelled by the user +} OIDNError; + +// Error callback function +typedef void (*OIDNErrorFunction)(void* userPtr, OIDNError code, const char* message); + +// Device handle +typedef struct OIDNDeviceImpl* OIDNDevice; + +// Creates a new device. +OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type); + +// Retains the device (increments the reference count). +OIDN_API void oidnRetainDevice(OIDNDevice device); + +// Releases the device (decrements the reference count). +OIDN_API void oidnReleaseDevice(OIDNDevice device); + +// Sets a boolean parameter of the device. +OIDN_API void oidnSetDevice1b(OIDNDevice device, const char* name, bool value); + +// Sets an integer parameter of the device. +OIDN_API void oidnSetDevice1i(OIDNDevice device, const char* name, int value); + +// Gets a boolean parameter of the device. +OIDN_API bool oidnGetDevice1b(OIDNDevice device, const char* name); + +// Gets an integer parameter of the device (e.g. "version"). +OIDN_API int oidnGetDevice1i(OIDNDevice device, const char* name); + +// Sets the error callback function of the device. +OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice device, OIDNErrorFunction func, void* userPtr); + +// Returns the first unqueried error code stored in the device for the current +// thread, optionally also returning a string message (if not NULL), and clears +// the stored error. Can be called with a NULL device as well to check why a +// device creation failed. +OIDN_API OIDNError oidnGetDeviceError(OIDNDevice device, const char** outMessage); + +// Commits all previous changes to the device. +// Must be called before first using the device (e.g. creating filters). +OIDN_API void oidnCommitDevice(OIDNDevice device); + +// ---------------------------------------------------------------------------- +// Buffer +// ---------------------------------------------------------------------------- + +// Formats for images and other data stored in buffers +typedef enum +{ + OIDN_FORMAT_UNDEFINED = 0, + + // 32-bit single-precision floating point scalar and vector formats + OIDN_FORMAT_FLOAT = 1, + OIDN_FORMAT_FLOAT2 = 2, + OIDN_FORMAT_FLOAT3 = 3, + OIDN_FORMAT_FLOAT4 = 4, +} OIDNFormat; + +// Access modes for mapping buffers +typedef enum +{ + OIDN_ACCESS_READ = 0, // read-only access + OIDN_ACCESS_WRITE = 1, // write-only access + OIDN_ACCESS_READ_WRITE = 2, // read and write access + OIDN_ACCESS_WRITE_DISCARD = 3, // write-only access, previous contents discarded +} OIDNAccess; + +// Buffer handle +typedef struct OIDNBufferImpl* OIDNBuffer; + +// Creates a new buffer (data allocated and owned by the device). +OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice device, size_t byteSize); + +// Creates a new shared buffer (data allocated and owned by the user). +OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice device, void* ptr, size_t byteSize); + +// Maps a region of the buffer to host memory. +// If byteSize is 0, the maximum available amount of memory will be mapped. +OIDN_API void* oidnMapBuffer(OIDNBuffer buffer, OIDNAccess access, size_t byteOffset, size_t byteSize); + +// Unmaps a region of the buffer. +// mappedPtr must be a pointer returned by a previous call to oidnMapBuffer. +OIDN_API void oidnUnmapBuffer(OIDNBuffer buffer, void* mappedPtr); + +// Retains the buffer (increments the reference count). +OIDN_API void oidnRetainBuffer(OIDNBuffer buffer); + +// Releases the buffer (decrements the reference count). +OIDN_API void oidnReleaseBuffer(OIDNBuffer buffer); + +// ---------------------------------------------------------------------------- +// Filter +// ---------------------------------------------------------------------------- + +// Progress monitor callback function +typedef bool (*OIDNProgressMonitorFunction)(void* userPtr, double n); + +// Filter handle +typedef struct OIDNFilterImpl* OIDNFilter; + +// Creates a new filter of the specified type (e.g. "RT"). +OIDN_API OIDNFilter oidnNewFilter(OIDNDevice device, const char* type); + +// Retains the filter (increments the reference count). +OIDN_API void oidnRetainFilter(OIDNFilter filter); + +// Releases the filter (decrements the reference count). +OIDN_API void oidnReleaseFilter(OIDNFilter filter); + +// Sets an image parameter of the filter (stored in a buffer). +// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically. +OIDN_API void oidnSetFilterImage(OIDNFilter filter, const char* name, + OIDNBuffer buffer, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride); + +// Sets an image parameter of the filter (owned by the user). +// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically. +OIDN_API void oidnSetSharedFilterImage(OIDNFilter filter, const char* name, + void* ptr, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride); + +// Sets a boolean parameter of the filter. +OIDN_API void oidnSetFilter1b(OIDNFilter filter, const char* name, bool value); + +// Gets a boolean parameter of the filter. +OIDN_API bool oidnGetFilter1b(OIDNFilter filter, const char* name); + +// Sets an integer parameter of the filter. +OIDN_API void oidnSetFilter1i(OIDNFilter filter, const char* name, int value); + +// Gets an integer parameter of the filter. +OIDN_API int oidnGetFilter1i(OIDNFilter filter, const char* name); + +// Sets a float parameter of the filter. +OIDN_API void oidnSetFilter1f(OIDNFilter filter, const char* name, float value); + +// Gets a float parameter of the filter. +OIDN_API float oidnGetFilter1f(OIDNFilter filter, const char* name); + +// Sets the progress monitor callback function of the filter. +OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter filter, OIDNProgressMonitorFunction func, void* userPtr); + +// Commits all previous changes to the filter. +// Must be called before first executing the filter. +OIDN_API void oidnCommitFilter(OIDNFilter filter); + +// Executes the filter. +OIDN_API void oidnExecuteFilter(OIDNFilter filter); + +#if defined(__cplusplus) +} +#endif diff --git a/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp b/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp new file mode 100644 index 000000000000..9f95a56fe17f --- /dev/null +++ b/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp @@ -0,0 +1,468 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include +#include "oidn.h" + +namespace oidn { + + // -------------------------------------------------------------------------- + // Buffer + // -------------------------------------------------------------------------- + + // Formats for images and other data stored in buffers + enum class Format + { + Undefined = OIDN_FORMAT_UNDEFINED, + + // 32-bit single-precision floating point scalar and vector formats + Float = OIDN_FORMAT_FLOAT, + Float2 = OIDN_FORMAT_FLOAT2, + Float3 = OIDN_FORMAT_FLOAT3, + Float4 = OIDN_FORMAT_FLOAT4, + }; + + // Access modes for mapping buffers + enum class Access + { + Read = OIDN_ACCESS_READ, // read-only access + Write = OIDN_ACCESS_WRITE, // write-only access + ReadWrite = OIDN_ACCESS_READ_WRITE, // read and write access + WriteDiscard = OIDN_ACCESS_WRITE_DISCARD, // write-only access, previous contents discarded + }; + + // Buffer object with automatic reference counting + class BufferRef + { + private: + OIDNBuffer handle; + + public: + BufferRef() : handle(nullptr) {} + BufferRef(OIDNBuffer handle) : handle(handle) {} + + BufferRef(const BufferRef& other) : handle(other.handle) + { + if (handle) + oidnRetainBuffer(handle); + } + + BufferRef(BufferRef&& other) : handle(other.handle) + { + other.handle = nullptr; + } + + BufferRef& operator =(const BufferRef& other) + { + if (&other != this) + { + if (other.handle) + oidnRetainBuffer(other.handle); + if (handle) + oidnReleaseBuffer(handle); + handle = other.handle; + } + return *this; + } + + BufferRef& operator =(BufferRef&& other) + { + std::swap(handle, other.handle); + return *this; + } + + BufferRef& operator =(OIDNBuffer other) + { + if (other) + oidnRetainBuffer(other); + if (handle) + oidnReleaseBuffer(handle); + handle = other; + return *this; + } + + ~BufferRef() + { + if (handle) + oidnReleaseBuffer(handle); + } + + OIDNBuffer getHandle() const + { + return handle; + } + + operator bool() const + { + return handle != nullptr; + } + + // Maps a region of the buffer to host memory. + // If byteSize is 0, the maximum available amount of memory will be mapped. + void* map(Access access = Access::ReadWrite, size_t byteOffset = 0, size_t byteSize = 0) + { + return oidnMapBuffer(handle, (OIDNAccess)access, byteOffset, byteSize); + } + + // Unmaps a region of the buffer. + // mappedPtr must be a pointer returned by a previous call to map. + void unmap(void* mappedPtr) + { + oidnUnmapBuffer(handle, mappedPtr); + } + }; + + // -------------------------------------------------------------------------- + // Filter + // -------------------------------------------------------------------------- + + // Progress monitor callback function + typedef bool (*ProgressMonitorFunction)(void* userPtr, double n); + + // Filter object with automatic reference counting + class FilterRef + { + private: + OIDNFilter handle; + + public: + FilterRef() : handle(nullptr) {} + FilterRef(OIDNFilter handle) : handle(handle) {} + + FilterRef(const FilterRef& other) : handle(other.handle) + { + if (handle) + oidnRetainFilter(handle); + } + + FilterRef(FilterRef&& other) : handle(other.handle) + { + other.handle = nullptr; + } + + FilterRef& operator =(const FilterRef& other) + { + if (&other != this) + { + if (other.handle) + oidnRetainFilter(other.handle); + if (handle) + oidnReleaseFilter(handle); + handle = other.handle; + } + return *this; + } + + FilterRef& operator =(FilterRef&& other) + { + std::swap(handle, other.handle); + return *this; + } + + FilterRef& operator =(OIDNFilter other) + { + if (other) + oidnRetainFilter(other); + if (handle) + oidnReleaseFilter(handle); + handle = other; + return *this; + } + + ~FilterRef() + { + if (handle) + oidnReleaseFilter(handle); + } + + OIDNFilter getHandle() const + { + return handle; + } + + operator bool() const + { + return handle != nullptr; + } + + // Sets an image parameter of the filter (stored in a buffer). + void setImage(const char* name, + const BufferRef& buffer, Format format, + size_t width, size_t height, + size_t byteOffset = 0, + size_t bytePixelStride = 0, size_t byteRowStride = 0) + { + oidnSetFilterImage(handle, name, + buffer.getHandle(), (OIDNFormat)format, + width, height, + byteOffset, + bytePixelStride, byteRowStride); + } + + // Sets an image parameter of the filter (owned by the user). + void setImage(const char* name, + void* ptr, Format format, + size_t width, size_t height, + size_t byteOffset = 0, + size_t bytePixelStride = 0, size_t byteRowStride = 0) + { + oidnSetSharedFilterImage(handle, name, + ptr, (OIDNFormat)format, + width, height, + byteOffset, + bytePixelStride, byteRowStride); + } + + // Sets a boolean parameter of the filter. + void set(const char* name, bool value) + { + oidnSetFilter1b(handle, name, value); + } + + // Sets an integer parameter of the filter. + void set(const char* name, int value) + { + oidnSetFilter1i(handle, name, value); + } + + // Sets a float parameter of the filter. + void set(const char* name, float value) + { + oidnSetFilter1f(handle, name, value); + } + + // Gets a parameter of the filter. + template + T get(const char* name); + + // Sets the progress monitor callback function of the filter. + void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr = nullptr) + { + oidnSetFilterProgressMonitorFunction(handle, (OIDNProgressMonitorFunction)func, userPtr); + } + + // Commits all previous changes to the filter. + void commit() + { + oidnCommitFilter(handle); + } + + // Executes the filter. + void execute() + { + oidnExecuteFilter(handle); + } + }; + + // Gets a boolean parameter of the filter. + template<> + inline bool FilterRef::get(const char* name) + { + return oidnGetFilter1b(handle, name); + } + + // Gets an integer parameter of the filter. + template<> + inline int FilterRef::get(const char* name) + { + return oidnGetFilter1i(handle, name); + } + + // Gets a float parameter of the filter. + template<> + inline float FilterRef::get(const char* name) + { + return oidnGetFilter1f(handle, name); + } + + // -------------------------------------------------------------------------- + // Device + // -------------------------------------------------------------------------- + + // Device types + enum class DeviceType + { + Default = OIDN_DEVICE_TYPE_DEFAULT, // select device automatically + + CPU = OIDN_DEVICE_TYPE_CPU, // CPU device + }; + + // Error codes + enum class Error + { + None = OIDN_ERROR_NONE, // no error occurred + Unknown = OIDN_ERROR_UNKNOWN, // an unknown error occurred + InvalidArgument = OIDN_ERROR_INVALID_ARGUMENT, // an invalid argument was specified + InvalidOperation = OIDN_ERROR_INVALID_OPERATION, // the operation is not allowed + OutOfMemory = OIDN_ERROR_OUT_OF_MEMORY, // not enough memory to execute the operation + UnsupportedHardware = OIDN_ERROR_UNSUPPORTED_HARDWARE, // the hardware (e.g. CPU) is not supported + Cancelled = OIDN_ERROR_CANCELLED, // the operation was cancelled by the user + }; + + // Error callback function + typedef void (*ErrorFunction)(void* userPtr, Error code, const char* message); + + // Device object with automatic reference counting + class DeviceRef + { + private: + OIDNDevice handle; + + public: + DeviceRef() : handle(nullptr) {} + DeviceRef(OIDNDevice handle) : handle(handle) {} + + DeviceRef(const DeviceRef& other) : handle(other.handle) + { + if (handle) + oidnRetainDevice(handle); + } + + DeviceRef(DeviceRef&& other) : handle(other.handle) + { + other.handle = nullptr; + } + + DeviceRef& operator =(const DeviceRef& other) + { + if (&other != this) + { + if (other.handle) + oidnRetainDevice(other.handle); + if (handle) + oidnReleaseDevice(handle); + handle = other.handle; + } + return *this; + } + + DeviceRef& operator =(DeviceRef&& other) + { + std::swap(handle, other.handle); + return *this; + } + + DeviceRef& operator =(OIDNDevice other) + { + if (other) + oidnRetainDevice(other); + if (handle) + oidnReleaseDevice(handle); + handle = other; + return *this; + } + + ~DeviceRef() + { + if (handle) + oidnReleaseDevice(handle); + } + + OIDNDevice getHandle() const + { + return handle; + } + + operator bool() const + { + return handle != nullptr; + } + + // Sets a boolean parameter of the device. + void set(const char* name, bool value) + { + oidnSetDevice1b(handle, name, value); + } + + // Sets an integer parameter of the device. + void set(const char* name, int value) + { + oidnSetDevice1i(handle, name, value); + } + + // Gets a parameter of the device. + template + T get(const char* name); + + // Sets the error callback function of the device. + void setErrorFunction(ErrorFunction func, void* userPtr = nullptr) + { + oidnSetDeviceErrorFunction(handle, (OIDNErrorFunction)func, userPtr); + } + + // Returns the first unqueried error code and clears the stored error. + // Can be called for a null device as well to check why a device creation failed. + Error getError() + { + return (Error)oidnGetDeviceError(handle, nullptr); + } + + // Returns the first unqueried error code and string message, and clears the stored error. + // Can be called for a null device as well to check why a device creation failed. + Error getError(const char*& outMessage) + { + return (Error)oidnGetDeviceError(handle, &outMessage); + } + + // Commits all previous changes to the device. + // Must be called before first using the device (e.g. creating filters). + void commit() + { + oidnCommitDevice(handle); + } + + // Creates a new buffer (data allocated and owned by the device). + BufferRef newBuffer(size_t byteSize) + { + return oidnNewBuffer(handle, byteSize); + } + + // Creates a new shared buffer (data allocated and owned by the user). + BufferRef newBuffer(void* ptr, size_t byteSize) + { + return oidnNewSharedBuffer(handle, ptr, byteSize); + } + + // Creates a new filter of the specified type (e.g. "RT"). + FilterRef newFilter(const char* type) + { + return oidnNewFilter(handle, type); + } + }; + + // Gets a boolean parameter of the device. + template<> + inline bool DeviceRef::get(const char* name) + { + return oidnGetDevice1b(handle, name); + } + + // Gets an integer parameter of the device (e.g. "version"). + template<> + inline int DeviceRef::get(const char* name) + { + return oidnGetDevice1i(handle, name); + } + + // Creates a new device. + inline DeviceRef newDevice(DeviceType type = DeviceType::Default) + { + return DeviceRef(oidnNewDevice((OIDNDeviceType)type)); + } + +} // namespace oidn diff --git a/thirdparty/oidn/include/OpenImageDenoise/version.h b/thirdparty/oidn/include/OpenImageDenoise/version.h new file mode 100644 index 000000000000..66b347c992b7 --- /dev/null +++ b/thirdparty/oidn/include/OpenImageDenoise/version.h @@ -0,0 +1,23 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#define OIDN_VERSION_MAJOR 1 +#define OIDN_VERSION_MINOR 1 +#define OIDN_VERSION_PATCH 0 +#define OIDN_VERSION 10100 +#define OIDN_VERSION_STRING "1.1.0" diff --git a/thirdparty/oidn/mkl-dnn/LICENSE b/thirdparty/oidn/mkl-dnn/LICENSE new file mode 100644 index 000000000000..d13f7b7ca06d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/LICENSE @@ -0,0 +1,214 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ============================================================================ + + Intel MKL-DNN includes components with separate copyright + notices and license terms. + + XByak, 3-clause BSD license + Copyright (c) 2007 MITSUNARI Shigeo + See full copyright notice and license text in src/cpu/xbyak/COPYRIGHT + + gtest, 3-clause BSD license + Copyright 2008, Google Inc. + See full copyright notice and license text in tests/gtests/gtest/LICENSE diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn.h b/thirdparty/oidn/mkl-dnn/include/mkldnn.h new file mode 100644 index 000000000000..9b649949223e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn.h @@ -0,0 +1,1771 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_H +#define MKLDNN_H + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +/* All symbols shall be internal unless marked as MKLDNN_API */ +#if defined _WIN32 || defined __CYGWIN__ +# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport) +# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport) +#else +# if __GNUC__ >= 4 +# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default"))) +# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default"))) +# else +# define MKLDNN_HELPER_DLL_IMPORT +# define MKLDNN_HELPER_DLL_EXPORT +# endif +#endif + +#ifdef MKLDNN_DLL +# ifdef MKLDNN_DLL_EXPORTS +# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT +# else +# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT +# endif +#else +# define MKLDNN_API +#endif + +#if defined (__GNUC__) +# define MKLDNN_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +# define MKLDNN_DEPRECATED __declspec(deprecated) +#else +# define MKLDNN_DEPRECATED +#endif + +#include "mkldnn_types.h" +#include "mkldnn_version.h" +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +#ifdef __cplusplus +extern "C" { +#endif + +/** @addtogroup c_api C API + * @{ */ + +/** @addtogroup c_api_primitive Primitive operations + * @{ */ + +/** @addtogroup c_api_primitive_common Common primitive operations + * @{ */ + +/** Creates a primitive descriptor @p iterator for given @p op_desc, @p attr, + * @p engine, and optionally a hint primitive descriptor from forward + * propagation (required for backward propagation). Pass @c NULL for forward + * propagation. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create( + mkldnn_primitive_desc_iterator_t *iterator, + const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine, + const_mkldnn_primitive_desc_t hint_forward_primitive_desc); + +/** Iterates over primitive descriptors. Returns #mkldnn_iterator_ends if no + * more primitive descriptors are available. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next( + mkldnn_primitive_desc_iterator_t iterator); + +/** Fetches the current primitive descriptor. + * + * @note + * The user should delete the fetched primitive descriptor using + * mkldnn_primitive_desc_destroy() once it is no longer needed. */ +mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch( + const_mkldnn_primitive_desc_iterator_t iterator); + +/** Deletes a primitive descriptor @p iterator */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy( + mkldnn_primitive_desc_iterator_t iterator); + +/** Creates a @p primitive_desc using @p op_desc, @p attr, @p engine, and + * optionally a hint primitive descriptor from forward propagation. The call is + * equivalent to creating a primitive descriptor iterator, immediately fetching + * a primitive descriptor, and then destroying the iterator. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_create( + mkldnn_primitive_desc_t *primitive_desc, + const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine, + const_mkldnn_primitive_desc_t hint_forward_primitive_desc); + +/** Makes a copy of a @p primitive_desc. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone( + mkldnn_primitive_desc_t *primitive_desc, + const_mkldnn_primitive_desc_t existing_primitive_desc); + +/** Returns a constant reference to the attribute of a @p primitive_desc. + * + * @warning + * The user should not destroy the obtained @p attr. + * + * @warning + * The lifetime of an @p attr is the same as that of a @p primitive_desc, + * so it is illegal to use the @p attr once @p primitive_desc has been + * destroyed. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_get_attr( + const_mkldnn_primitive_desc_t primitive_desc, + const_mkldnn_primitive_attr_t *attr); + +/** Deletes a @p primitive_desc. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy( + mkldnn_primitive_desc_t primitive_desc); + +/** Queries primitive descriptor + * + * One of the most typical use cases is to query a convolution primitive + * descriptor created with source, weights, and destination formats equal + * to #mkldnn_format_tag_any about the corresponding memory descriptors + * (@p what equals #mkldnn_query_src_md, #mkldnn_query_weights_md, and + * #mkldnn_query_dst_md respectively) to be able to prepare memory and + * create reorders if required. + * + * Another quite typical use case is to query an operation primitive + * descriptor for a workspace (@p what equals #mkldnn_query_workspace_md). + * The returned status #mkldnn_not_required indicates that a workspace is + * not required. + * + * A few other possibilities: + * - query an operation primitive descriptor for the underlying operation + * descriptor (#mkldnn_query_convolution_d, #mkldnn_query_eltwise_d, + * #mkldnn_query_rnn_d, etc.) + * - query an operation primitive descriptor for the implementation + * information string (#mkldnn_query_impl_info_str) + * - query an operation primitive descriptor for the number of inputs and + * outputs (#mkldnn_query_num_of_inputs_s32 and + * #mkldnn_query_num_of_outputs_s32 respectively) + * + * @sa mkldnn_query_t for more options + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query( + const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, + int index, void *result); + +/** Queries primitive descriptor for memory descriptor + * + * @returns NULL in case of any error. + * + * This is just a specialized version of mkldnn_primitive_desc_query + * used for convenience. + */ +const mkldnn_memory_desc_t MKLDNN_API *mkldnn_primitive_desc_query_md( + const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, + int index); + +/** Queries primitive descriptor for signed 32bit int + * + * @returns 0 in case of any error (in particular if the queried entity is + * not of type int32_t). Note that 0 might also be the actual returned + * value. + * + * This is just a specialized version of mkldnn_primitive_desc_query + * used for convenience. + */ +int MKLDNN_API mkldnn_primitive_desc_query_s32( + const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, + int index); + +/** Creates a @p primitive using a @p primitive_desc descriptor. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_create( + mkldnn_primitive_t *primitive, + const_mkldnn_primitive_desc_t primitive_desc); + +/** Executes a @p primitive using a @p stream, and @p nargs arguments + * @p args. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_execute( + const_mkldnn_primitive_t primitive, mkldnn_stream_t stream, + int nargs, const mkldnn_exec_arg_t *args); + +/** Retrieves a reference to the @p primitive_desc descriptor of given @p + * primitive. + * + * @warning + * The returned object must not be destroyed by the user. The @c const + * qualifier of the returned object prevents such attempts. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc( + const_mkldnn_primitive_t primitive, + const_mkldnn_primitive_desc_t *primitive_desc); + +/** Deletes a @p primitive. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy( + mkldnn_primitive_t primitive); + +/** @} */ + +/** @addtogroup c_api_attributes Attributes + * An extension for controlling primitive behavior. + * @{ */ + +/** Creates an empty (default) @p attr attribute. All the parameters are set to + * default values. + * + * An empty attribute is used in primitive descriptor creation whenever it + * is not passed explicitly, e.g. in mkldnn_primitive_desc_create. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create( + mkldnn_primitive_attr_t *attr); + +/** Makes a copy of an @p existing_attr. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone( + mkldnn_primitive_attr_t *attr, + const_mkldnn_primitive_attr_t existing_attr); + +/** Deletes an @p attr. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy( + mkldnn_primitive_attr_t attr); + +/** Returns the scratchpad @p mode set in the attribute @p attr */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_scratchpad_mode( + const_mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t *mode); + +/** Sets scratchpad @p mode. + * + * The possible values are: #mkldnn_scratchpad_mode_library (default) and + * #mkldnn_scratchpad_mode_user. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_scratchpad_mode( + mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t mode); + +/** Returns @p count, correspondence scale @p mask, and a pointer to a constant + * floating point array of output @p scales for given @p attr, previously set + * by mkldnn_primitive_attr_set_output_scales. + * + * @warning + * The @p scales array points to the internal @p attr field, so the user + * should not modify or destroy @p scales. + * + * @warning + * The lifetime of @p scales is the same as that of the @p attr to which it + * belongs, so it is illegal to use @p scales after @p attr is destroyed. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales( + const_mkldnn_primitive_attr_t attr, mkldnn_dim_t *count, int *mask, + const float **scales); + +/** Sets output @p scales for primitive operations. The number of elements @p + * count and correspondence scale @p mask are stored for future use. + * + * The @p mask argument defines the correspondence between the output tensor + * dimensions and the @p scales array. Set the i-th bit of @p mask to 1 to use a + * dedicated scaling factor for each slice of the output tensor over the i-th + * dimension. Set @p mask to 0 to use a common scaling factor for the whole + * output tensor. + * + * @note + * The dimension order is always native and does not depend on the actual + * layout used. Examples: + * - 2D dimensional data the order of dimensions is always: (n, c) + * - 4D dimensional data the order is always: (n, c, h, w) + * - 5D dimensional weights the order is always: (g, oc, ic, kh, kw) + * + * Example usage: + * @code + * int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params + * float scales[oc] = { ... }; // unique output scales per output channel + * int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ... + * + * mkldnn_convolution_desc_t cd; // create & configure convolution op_desc + * + * mkldnn_primitive_attr_t attr; + * mkldnn_primitive_attr_create(&attr); // create default attributes + * mkldnn_primitive_attr_set_output_scales(attr, oc, 1 << oc_dim, scales); + * + * mkldnn_primitive_desc_t cpd; + * mkldnn_primitive_desc_create(&cpd, &cd, attr, NULL); + * @endcode + * + * @note + * There is no way to check that @p count corresponds to @p mask until an + * actual primitive descriptor is created, so it is the user's + * responsibility to set proper values. The following formula must hold: + * + * \f[count = \prod\limits_{d \in mask} output.dims[d]\f] + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales( + mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask, + const float *scales); + +/** Returns @p post_ops for given @p attr. + * + * @warning + * @p post_ops points to the internal @p attr field, so the user should not + * modify or destroy @p post_ops. Also, the lifetime of @p post_ops is the + * same as that of the @p attr it belongs to, so it is illegal to use @p + * post_ops after @p attr has been destroyed. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops( + const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops); + +/** Sets configured @p post_ops to an attribute @p attr for future use (when + * primitive descriptor is being created). + * + * @note + * At this point in time, there is no way to check whether the primitive + * descriptor does or does not support a given sequence of post operations. + * Therefore the user should handle an error that might occur at the + * mkldnn_primitive_desc_create call. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops( + mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops); + +/** @addtogroup c_api_attributes_post_ops Sequence of post operations + * An extension for performing extra operations after a base operation. + * @{ */ + +/** Creates an empty sequence of post operations @p post_ops. */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops); + +/** Deletes a @p post_ops sequence. */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops); + +/** Returns the @p length of post operations for given @p post_ops. */ +int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops); + +/** Returns the type of post operation with index @p index in given + * @p post_ops. In case of error, returns #mkldnn_undefined_primitive. */ +mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind( + const_mkldnn_post_ops_t post_ops, int index); + +/** Appends accumulation (sum) post operation to the @p post_ops. Prior to + * accumulating the result, the previous value would be multiplied by @p scale. + * + * The kind of this post operation is #mkldnn_sum. + * + * This feature might improve performance for cases like residual learning + * blocks, where the result of convolution is accumulated to the previously + * computed activations. The parameter @p scale might be extreme for the + * integer-based computations when the result and previous activations have + * different logical scaling factors. + * + * In the simplest case when the accumulation is the only post operation, the + * computations would be: + * dst[] <- scale * dst[] + op(...) // instead of dst[] <- op(...) + * + * @note + * This post operation (as well as all the others) disregards the original + * layout of the destination; that is, the layout of the original + * destination is expected to be the same as the layout of the stored + * destination. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum( + mkldnn_post_ops_t post_ops, float scale); + +/** Gets the parameters of the accumulation (sum) post operation with index + * @p index in the sequence of @p post_ops. + * + * @note + * If index @p index would not correspond to the accumulation post + * operation, the function returns #mkldnn_invalid_arguments. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum( + const_mkldnn_post_ops_t post_ops, int index, float *scale); + +/** Appends eltwise post operation to the @p post_ops with given parameters + * @p kind, @p alpha, and @p beta (@sa mkldnn_eltwise_forward_desc_init and + * mkldnn_eltwise_desc_t). + * + * The kind of this post operation is #mkldnn_eltwise. + * + * In the simplest case when the eltwise is the only post operation, the + * computations would be: + * dst[] <- scale * eltwise_op ( op(...) ) // instead of dst[] <- op(...) + * where eltwise_op is configured with the given parameters. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise( + mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg, + float alpha, float beta); + +/** Gets the eltwise parameters of the post operation with index @p index in + * the sequence of @p post_ops. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise( + const_mkldnn_post_ops_t post_ops, int index, float *scale, + mkldnn_alg_kind_t *alg, float *alpha, float *beta); + +/** @} */ + +/** @} */ + +/** @addtogroup c_api_memory Memory + * A primitive to describe and store data. + * + * The library supports various data types and formats. Memory hierarchy + * consists of three levels of abstraction: + * 1. **Memory descriptor** -- engine agnostic logical description of data + * (number of dimensions, dimensions themselves, and data type), and + * optionally the format/layout that describes the physical representation + * of data in memory. If the format is not known yet, one can pass + * #mkldnn_format_tag_any. This approach is used to allow compute-intensive + * primitives to specify the most appropriate format on their own with + * users required to reorder the data if the incoming format doesn't match + * the primitive's selection. Memory descriptor can be initialized with + * mkldnn_memory_desc_init_by_tag() or mkldnn_memory_desc_init_by_strides() + * functions, or by directly filling the mkldnn_memory_desc_t structure. + * The latter requires deep knowledge of how the physical data + * representation is mapped to the structure. + * The @ref understanding_memory_formats topic should shed some light on + * that. + * For the fully defined memory descriptors (i.e. where the format kind is + * not equal to #mkldnn_format_kind_any) a user can the size, using the + * mkldnn_memory_desc_get_size() function. As described in + * @ref understanding_memory_formats, the size of data sometimes cannot + * be computed as the product of dimensions times the size of the data + * type. So users are encouraged to use this function for better code + * portability. + * Two memory descriptors can be compared with mkldnn_memory_desc_equal(). + * The comparison is especially useful when checking whether a primitive + * requires reorder from the user's data format to the primitive's format. + * 2. **Memory** -- an engine-specific object that handles the data and its + * description (a memory descriptor). For CPU enigne, the data handle is + * simply a pointer to @c void. The data handle can be queried using + * mkldnn_memory_get_data_handle() and set using + * mkldnn_memory_set_data_handle(). The latter function always sets the + * memory in the padding region to zero, which is the invariant maintained + * by all the primitives in Intel MKL-DNN. + * See @ref understanding_memory_formats for more details. + * A memory can be created using mkldnn_memory_create() function. + * A memory can also be queried for the underlying memory descriptor and + * engine using mkldnn_memory_get_memory_desc() and + * mkldnn_memory_get_engine() functions. + * + * Along with ordinary memory with all dimensions being positive, Intel + * MKL-DNN supports *zero-volume* memory with one or more dimensions set to + * zero. This is to support the NumPy\* convention. + * If a *zero-volume* memory is passed to a primitive, the primitive does + * not perform any computations on this memory. For example: + * - Convolution with `(0 batch, 3 input channels, 13 height, 13 width)` + * source and `(16 output channels, 3 inputs, channel, 3 height, 3 width)` + * weights would produce `(0 batch, 16 output channels, 11 height, 11 width)` + * destination (assuming strides are `1` and paddings are zero) and perform + * zero multiply-add operations. + * - Concatenation of three memories of shapes `(3, 4, 13, 13)`, + * `(3, 0, 13, 13)`, and `(3, 1, 13, 13)` along the second axis would produce + * the output of the shape `(3, 5, 13, 13)`, effectively ignoring the second + * input (however, if the user created a concatenation primitive descriptor + * with three inputs they should also provide all three memories to the + * concatenation primitive, including the one with zero second dimension). + * - However, Intel MKL-DNN would return an error when attempting to create a + * convolution with *zero-volume* memory passed for weights because such a + * convolution is not well-defined: + * ~~~ + * dst(1, 16, 11, 11) <-- src(1, 0, 13, 13) (*) wei(16, 0, 3, 3) + * ~~~ + * Should the values in the destination be zeroes or just not accessed at + * all? Moreover, backward pass w.r.t. weights in such cases is also not + * well-defined. + * + * Data handle of *zero-volume* memory is never accessed and hence can be + * unset (NULL in case of CPU engine). + * + * @sa @ref understanding_memory_formats + * @{ */ + +/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p + * data_type, and @p strides. + * + * The @p strides might be NULL, which means the order of physical dimensions + * is the same as the order of logical ones. + * + * @note The logical order of dimensions is defined by a primitive that + * consumes the memory. + */ +mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_strides( + mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, + mkldnn_data_type_t data_type, const mkldnn_dims_t strides); + +/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p + * data_type, and format @p tag. + * + * @p tag can be #mkldnn_format_tag_any, which allows a primitive to define + * the appropriate memory format. In this case, the @p format_kind would be set + * to #mkldnn_format_kind_any */ +mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_tag( + mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, + mkldnn_data_type_t data_type, mkldnn_format_tag_t tag); + +/** Initializes a @p memory_desc for a given @p parent_memory_desc, with + * @p dims sizes and @p offsets. May fail if layout used does not allow + * obtain desired submemory. In this case consider using `extract` or `insert` + * primitive */ +mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_submemory( + mkldnn_memory_desc_t *memory_desc, + const mkldnn_memory_desc_t *parent_memory_desc, + const mkldnn_dims_t dims, const mkldnn_dims_t offsets); + +/** Compares two memory descriptors. + * @return 1 if the descriptors are the same. + * @return 0 if the descriptors are different. + * + * Use this function to identify whether a reorder is required between the + * two memories */ +int MKLDNN_API mkldnn_memory_desc_equal( + const mkldnn_memory_desc_t *lhs, + const mkldnn_memory_desc_t *rhs); + +/** Returns the size (in bytes) that is required for given @p memory_desc */ +size_t MKLDNN_API mkldnn_memory_desc_get_size( + const mkldnn_memory_desc_t *memory_desc); + +/** Creates a memory for given @p memory_desc and @p engine. Also sets handle + * to @p native_handle. + * The @p native_handle can: + * - point to the user allocated memory, i.e. valid handle. In this case the + * library doesn't own allocated memory. + * - be MKLDNN_NATIVE_HANDLE_ALLOCATE to ask the library to allocate and + * attach memory. In this case the library owns allocated memory. + * - be MKLDNN_NATIVE_HANDLE_NONE to create mkldnn_memory w/o attached memory. + */ +mkldnn_status_t MKLDNN_API mkldnn_memory_create(mkldnn_memory_t *memory, + const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine, + void *native_handle); + +/** Returns a @p memory_desc associated with @p memory. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_get_memory_desc( + const_mkldnn_memory_t memory, + const mkldnn_memory_desc_t **memory_desc); + +/** Returns an @p engine associated with @p memory. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_get_engine( + const_mkldnn_memory_t memory, mkldnn_engine_t *engine); + +/** For a @p memory, returns the data @p handle. + * + * For the CPU engine, the data handle is a pointer to the actual data. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle( + const_mkldnn_memory_t memory, void **handle); + +/** For a @p memory, sets the data @p handle. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle( + mkldnn_memory_t memory, void *handle); + +/** Deletes a @p memory. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_destroy(mkldnn_memory_t memory); + +/** @} */ + +/** @addtogroup c_api_reorder Reorder + * A primitive to copy data between memory formats. + * @{ */ + +/** Initializes a @p reorder_primitive_desc using the description of the source + * (@p src_engine and @p src_md) and destination (@p dst_engine and @p dst_md) + * memory, and an @p attr attribute. + * + * Inputs: + * - input (#mkldnn_query_src_md, 0) + * + * Outputs: + * - output (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create( + mkldnn_primitive_desc_t *reorder_primitive_desc, + mkldnn_engine_t src_engine, const mkldnn_memory_desc_t *src_md, + mkldnn_engine_t dst_engine, const mkldnn_memory_desc_t *dst_md, + const_mkldnn_primitive_attr_t attr); + +/** @} */ + +/** @addtogroup c_api_concat Concat + * A primitive to concatenate data by arbitrary dimension. + * @{ */ + +/** Creates out-of-place @p concat_primitive_desc for concatenation of @p n + * inputs by @p concat_dimension with resulting @p output_desc memory + * descriptor. @p output_desc can be NULL or specified with the + * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory + * format would be chosen automatically. + * + * Inputs: + * - input 0 (#mkldnn_query_src_md, 0) + * - input 1 (#mkldnn_query_src_md, 1) + * - ... + * - input @p n - 1 (#mkldnn_query_src_md, @p n - 1) + * + * Outputs: + * - output (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create( + mkldnn_primitive_desc_t *concat_primitive_desc, + const mkldnn_memory_desc_t *dst_md, + int n, int concat_dimension, + const mkldnn_memory_desc_t *src_mds, + const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine); + +/** @} */ + +/** @addtogroup c_api_sum Sum + * A primitive to sum data. + * @{ */ + +/** Creates out-of-place @p sum_primitive_desc for sum of @p n + * inputs multiplied by scale with resulting @p output_desc memory + * descriptor. @p output_desc can be NULL or specified with the + * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory + * format would be chosen automatically. + * + * Inputs: + * - src 0 (#mkldnn_query_src_md, 0) + * - src 1 (#mkldnn_query_src_md, 1) + * - ... + * - src @p n - 1 (#mkldnn_query_src_md, @p n - 1) + * + * Outputs: + * - output (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create( + mkldnn_primitive_desc_t *sum_primitive_desc, + const mkldnn_memory_desc_t *dst_mds, + int n, const float *scales, + const mkldnn_memory_desc_t *src_mds, + const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine); + +/** @} */ + +/** @addtogroup c_api_convolution Convolution + * A primitive to compute convolution using different algorithms. + * + * \f[dst[n][oc][oh][ow] = + * \sum_{kw=0}^{KW}\sum_{kh=0}^{KH}\sum_{ic=0}^{IC} + * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw] + * \cdot weights[g][oc][ic][kh][kw] + * + bias[g][oc],\f] + * + * where size of output spatial domain is given by + * \f$ OH = \left\lfloor{\frac{IH - KH + p_l[0] + p_r[0]}{s_h}} + * \right\rfloor + 1\f$, + * \f$ OW = \left\lfloor{\frac{IW - KW + p_l[1] + p_r[1]}{s_w}} + * \right\rfloor + 1\f$, + * + * and summation is carried over input channels \f$ic\f$ in + * group \f$g\f$, and \f$s_h, s_w\f$ are @p strides and + * \f$p_l, p_r\f$ are @p padding_l and @p padding_r. + * @{ */ + +/** Initializes a convolution descriptor @p conv_desc for forward propagation + * using @p prop_kind (possible values are #mkldnn_forward_training and + * #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, @p + * padding_l, @p padding_r, and @p padding_kind. In order to create a + * convolution without bias, @p bias_desc should either be @c NULL or point to + * a descriptor with memory format kind equal to #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated convolution descriptor @p conv_desc for forward + * propagation using @p prop_kind (possible values are #mkldnn_forward_training + * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, + * @p dilates, @p padding_l, @p padding_r, and @p padding_kind. + * In order to create a dilated convolution without bias, @p bias_desc + * should either be @c NULL or point to a descriptor with memory format kind + * equals #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a convolution descriptor @p conv_desc for backward propagation + * with respect to data using @p alg_kind, memory descriptors, @p strides, @p + * padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated convolution descriptor @p conv_desc for backward + * propagation with respect to data using @p alg_kind, memory descriptors, @p + * strides, @p dilates @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a convolution descriptor @p conv_desc for backward propagation + * with respect to weights using @p alg_kind, memory descriptors, @p strides, + * @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a convolution descriptor @p conv_desc for backward propagation + * with respect to weights using @p alg_kind, memory descriptors, @p strides, + * @p dilates @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API +mkldnn_dilated_convolution_backward_weights_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** @} */ + +/** @addtogroup c_api_deconvolution Deconvolution + * A primitive to compute deconvolution using different algorithms. + * + * @{ */ + + +/** Initializes a deconvolution descriptor @p deconv_desc for forward + * propagation using @p prop_kind (possible values are #mkldnn_forward_training + * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, + * @p padding_l, @p padding_r, and @p padding_kind. In order to create a + * deconvolution without bias, @p bias_desc should either be @c NULL or point to + * a descriptor with memory format kind equals #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated deconvolution descriptor @p deconv_desc for forward + * propagation using @p prop_kind (possible values are #mkldnn_forward_training + * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, + * @p dilates, @p padding_l, @p padding_r, and @p padding_kind. In order to + * create a dilated deconvolution without bias, @p bias_desc should either be + * @c NULL or point to a descriptor with memory format kind equal + * #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a deconvolution descriptor @p conv_desc for backward propagation + * with respect to data using @p alg_kind, memory descriptors, @p strides, @p + * padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated deconvolution descriptor @p conv_desc for backward + * propagation with respect to data using @p alg_kind, memory descriptors, @p + * strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a deconvolution descriptor @p conv_desc for backward propagation + * with respect to weights using @p alg_kind, memory descriptors, @p strides, + * @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated deconvolution descriptor @p conv_desc for backward + * propagation with respect to weights using @p alg_kind, memory descriptors, + * @p strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** @} */ + +/** @addtogroup c_api_shuffle Shuffle + * A primitive to shuffle data along the axis. + * @{ */ + +/** Initializes a @p shuffle_desc for forward propagation using @p prop_kind, + * memory descriptor @p data_desc, @p axis, and @p group_size. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * + */ +mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init( + mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *data_desc, int axis, + mkldnn_dim_t group_size); + +/** Initializes a @p shuffle_desc for backward propagation using memory + * descriptor @p diff_data_desc, @p axis, and @p group_size. + * + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + * + */ +mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init( + mkldnn_shuffle_desc_t *shuffle_desc, + const mkldnn_memory_desc_t *diff_data_desc, int axis, + mkldnn_dim_t group_size); + +/** @} */ + +/** @addtogroup c_api_eltwise Eltwise + * A primitive to compute element-wise operations like parametric rectifier + * linear unit (ReLU). + * + * Both forward and backward passes support in-place operation; that is, src + * and dst point to the same memory for forward pass, and diff_dst and diff_src + * point to the same memory for backward pass. + * + * @warning Because the original src is required for backward pass, in-place + * forward pass in general cannot be applied during training. However, for some + * kinds of element-wise operations (namely ReLU with alpha parameter equals 0), + * dst and src can be interchangeable for the backward pass, which enables + * performing in-place forward even for training. + * + * @{ */ + +/** Initializes an @p eltwise_desc for forward propagation using @p prop_kind + * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference), + * @p alg_kind algorithm, memory descriptor @p data_desc, @p alpha, and + * @p beta parameters. + * + * @sa mkldnn_eltwise_desc_t for details. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init( + mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, + float alpha, float beta); + +/** Initializes an @p eltwise_desc for backward propagation using @p alg_kind + * algorithm memory descriptors @p diff_data_desc and @p data_desc, and the + * @p alpha and @p beta parameters. + * + * @sa mkldnn_eltwise_desc_t for details. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init( + mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_data_desc, + const mkldnn_memory_desc_t *data_desc, float alpha, float beta); + +/** @} */ + +/** @addtogroup c_api_softmax Softmax + * A primitive to perform softmax. + * + * \f[dst[u][c][in] = + * \frac{\exp(src[ou][c][in]) - \max\limits_{c}(src[ou][c][in])} + * {\sum\limits_{c}\{\exp(src[ou][c][in]) + * - \max\limits_{c}(src[ou][c][in])\}},\f] + * + * where \f$ou, iu\f$ are outer and inner sizes repectively, defined + * by @p data_desc.dims and @p softmax_axis. + * @{ */ + +/** Initializes a @p softmax_desc for forward propagation using @p prop_kind + * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference) + * and memory descriptor @p data_desc. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init( + mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *data_desc, int softmax_axis); + +/** Initializes a @p softmax_desc for backward propagation using memory + * descriptors @p diff_desc and @p data_desc. + * + * Inputs: + * - dst (#mkldnn_query_dst_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init( + mkldnn_softmax_desc_t *softmax_desc, + const mkldnn_memory_desc_t *diff_desc, + const mkldnn_memory_desc_t *data_desc, int softmax_axis); + +/** @} */ + +/** @addtogroup c_api_pooling Pooling + * A primitive to perform max or average pooling. + * + * Max pooling: + * \f[dst[n][oc][oh][ow] = + * \max\limits_{kw,kh} + * (src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw]),\f] + * + * Average pooling: + * \f[dst[n][oc][oh][ow] = + * \frac{1}{KW \cdot KH}\sum\limits_{kw,kh} + * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw],\f] + * + * where \f$p_l, p_r\f$ are @p padding_l and @p padding_r respectively, and + * output spatial dimensions are calculated similarly to how they are done in + * convolution. + * + * During training, max pooling requires a workspace on forward + * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes to + * save indices where maximum was found. The workspace layout is opaque, and + * the indices cannot be restored from it. However, one can use backward + * pooling to perform up-sampling (used in some detection topologies). + * + * @{ */ + +/** Initializes a pooling descriptor @p pool_desc for forward propagation using + * @p prop_kind (possible values are #mkldnn_forward_training and + * #mkldnn_forward_inference), @p alg_kind, memory descriptors, and pooling + * parameters in the spatial domain: @p strides, @p kernel sizes, @p padding_l, + * @p padding_r, and @p padding_kind. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if @p alg_kind = #mkldnn_pooling_max and + * @p prop_kind = #mkldnn_forward_training + */ +mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init( + mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a pooling descriptor @p pool_desc for backward propagation + * using @p alg_kind, memory descriptors, and pooling parameters in the spatial + * domain: @p strides, @p kernel sizes, @p padding_l, @p padding_r, and @p + * padding_kind. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if @p alg_kind = #mkldnn_pooling_max + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init( + mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** @} */ + +/** @addtogroup c_api_lrn LRN + * A primitive to perform local response normalization (LRN) across or within + * channels. + * + * LRN accross channels: + * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}} + * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2} + * (src[n][c+i][h][w])^2\right\}^{-\beta} + * src[n][c][h][w],\f] + * + * LRN within channels: + * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}} + * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2} + * (src[n][c][h+i][w+i])^2\right\}^{-\beta} + * src[n][c][h][w],\f] + * + * where \f$n_{l}\f$ is the @p local_size. + * + * During training, LRN might or might not require a workspace on forward + * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes. The + * behavior is implementation specific. Optimized implementations typically + * require a workspace and use it to save some intermediate results from the + * forward pass that accelerate computations on the backward pass. + * + * To check whether a workspace is required, query the LRN primitive descriptor + * for the workspace (#mkldnn_query_workspace_md). Success indicates that the + * workspace is required and its description will be returned. + * @sa mkldnn_primitive_desc_query and mkldnn_primitive_desc_query_pd + * + * @{ */ + +/** Initializes an @p lrn_desc for forward propagation using @p prop_kind + * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference), + * @p alg_kind, memory descriptor @p data_desc, and regularization + * parameters @p local_size, @p alpha, @p beta, and @p k. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if the underlying implementation requires + */ +mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init( + mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, + mkldnn_dim_t local_size, float alpha, float beta, float k); + +/** Initializes an @p lrn_desc for backward propagation using @p alg_kind, + * memory descriptors @p data_desc and @p diff_data_desc, and regularization + * parameters @p local_size, @p alpha, @p beta, and @p k. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if the underlying implementation requires + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init( + mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_data_desc, + const mkldnn_memory_desc_t *data_desc, mkldnn_dim_t local_size, + float alpha, float beta, float k); + +/** @} */ + +/** @addtogroup c_api_batch_normalization Batch Normalization + * A primitive to perform batch normalization. + * + * \f[dst[n][c][h][w] = \gamma[c] \frac{src[n][c][h][w] - \mu[c]} + * {\sqrt{\sigma[c] + eps}} + \beta[c],\f] + * + * where \f$\gamma[c], \beta[c]\f$ are weights and bias for a channel and, + * + * \f$\mu[c] = \frac{1}{NHW} \sum\limits_{whn} src[n][c][h][w]\f$, + * \f$\sigma[c] = \frac{1}{NHW} \sum\limits_{whn} + * (src[n][c][h][w] - \mu[c])^2\f$, + * + * and @c eps is a constant to improve numerical stability. + * + * Both forward and backward passes support in-place operation; that is, src + * and dst point to the same memory for forward pass, and diff_dst and diff_src + * point to the same memory for backward pass. + * + * Batch normalization supports different flavors controlled by + * mkldnn_batch_normalization_desc_t. For example, batch normalization can + * compute the mean and variance on its own or take them as inputs. It can + * either perform scaling and shifting using gamma and beta parameters or not. + * Optionally it can also perform a fused ReLU, which in case of training would + * also require a workspace. + * + * @sa mkldnn_batch_normalization_desc_t + * @{ */ + +/** Initializes a batch normalization descriptor @p bnrm_desc for forward + * propagation using @p prop_kind (possible values are + * #mkldnn_forward_training and #mkldnn_forward_inference), memory descriptor + * @p data_desc, normalization parameter @p epsilon, and @p flags set using bit + * flags of type mkldnn_batch_normalization_desc_t. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - mean (#mkldnn_query_src_md, 1), + * if #mkldnn_use_global_stats bit-flags is set in @p flags + * - variance (#mkldnn_query_src_md, 2), + * if #mkldnn_use_global_stats bit-flags is set in @p flags + * - scale_and_shift (#mkldnn_query_weights_md, 0), + * if #mkldnn_use_scaleshift bit-flags is set in @p flags + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * - mean (#mkldnn_query_dst_md, 1), + * if #mkldnn_use_global_stats bit-flags is not set in @p flags + * @p prop_kind = #mkldnn_forward_training + * - variance (#mkldnn_query_dst_md, 2), + * if #mkldnn_use_global_stats bit-flags is not set in @p flags + * and @p prop_kind = #mkldnn_forward_training + * - workspace (#mkldnn_query_workspace_md, 0), + * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags + * and @p prop_kind = #mkldnn_forward_training + * + * @note In-place operation is supported; that is, dst points to the same memory + * as src. + * + * @sa mkldnn_batch_normalization_desc_t + */ +mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init( + mkldnn_batch_normalization_desc_t *bnrm_desc, + mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, + float epsilon, unsigned flags); + +/** Initializes a batch normalization descriptor @p bnrm_desc for backward + * propagation with respect to data and scale-shift parameters using memory + * descriptors @p data_desc and @p diff_data_desc, normalization parameter + * @p epsilon, and @p flags set using bit flags of type + * mkldnn_batch_normalization_desc_t. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - mean (#mkldnn_query_src_md, 1) + * - variance (#mkldnn_query_src_md, 2) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - scale_and_shift (#mkldnn_query_weights_md, 0), + * if #mkldnn_use_scaleshift bit-flags is set in @p flags + * - workspace (#mkldnn_query_workspace_md, 0), + * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + * - diff_scale_and_shift (#mkldnn_query_diff_weights_md, 0), + * if #mkldnn_use_scaleshift bit-flags is set in @p flags + * and @p prop_kind = #mkldnn_backward + * + * @note in-place operation is supported, + * i.e. diff_src points to the same memory as diff_dst. + * + * @sa mkldnn_batch_normalization_desc_t + */ +mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init( + mkldnn_batch_normalization_desc_t *bnrm_desc, + mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *diff_data_desc, + const mkldnn_memory_desc_t *data_desc, + float epsilon, unsigned flags); + +/** @} */ + +/** @addtogroup c_api_inner_product Inner product + * A primitive to compute an inner product. + * + * Inner product layer is also known as fully connected layer. + * With spatial dimension: + * + * \f[dst[n][oc] = \sum\limits_{ic, kh, kw} + * src[n][ic][kh][kw] \cdot weights[oc][ic][kh][kw] + * + bias[oc]\f] + * @{ */ + +/** Initializes an inner product descriptor @p ip_desc for forward propagation + * using @p prop_kind (possible values are #mkldnn_forward_training and + * #mkldnn_forward_inference) and memory descriptors. In order to create an + * inner product without bias, @p bias_desc should be either @c NULL or a + * pointer to a descriptor with memory format kind equals + * #mkldnn_format_kind_undef. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init( + mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc); + +/** Initializes an inner product descriptor @p ip_desc for backward propagation + * with respect to data using memory descriptors. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init( + mkldnn_inner_product_desc_t *ip_desc, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc); + +/** Initializes an inner product descriptor @p ip_desc for backward propagation + * with respect to weights using memory descriptors. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init( + mkldnn_inner_product_desc_t *ip_desc, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc); + +/** @} */ + +/** @addtogroup c_api_rnn RNN + * A primitive to compute the common recurrent layer. + * @todo add additional description for the group + * @{ */ + +/** + * Initializes a recurrent cell descriptor @p rnn_cell_desc + * using @p rnn_cell_desc, @p kind (possible values are + * #mkldnn_vanilla_rnn, #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru, and + * #mkldnn_gru_linear_before_reset), + * @p f (possible values are #mkldnn_eltwise_relu and + * #mkldnn_eltwise_tanh), @p flags, @p alpha, and @p clipping. + */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init( + mkldnn_rnn_cell_desc_t *rnn_cell_desc, + mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f, + unsigned int flags, float alpha, float clipping); + +/** Returns the number of gates of a particular @p rnn_cell_desc. */ +int MKLDNN_API mkldnn_rnn_cell_get_gates_count( + const mkldnn_rnn_cell_desc_t *rnn_cell_desc); + +/** Returns the number of states of a particular @p rnn_cell_desc. */ +int MKLDNN_API mkldnn_rnn_cell_get_states_count( + const mkldnn_rnn_cell_desc_t *rnn_cell_desc); + +/** Sets quantization @p scale and @p shift for RNN data tensors. + * For performance reasons, low precision configuration of RNN primitive + * expects input activations to have unsigned int8 data type. Scale and shift + * used to quantize floating point data to unsigned integer must be passed to + * RNN primitive using attributes. + * Example usage: + * @code + * // rnn parameters + * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; + * // activations quantization parameters + * float scale = ..., shift = ..; + * + * mkldnn_primitive_attr_t rnn_attr; + * // create default attributes + * mkldnn_primitive_attr_create(&rnn_attr); + * + * // set scale and shift for int8 quantization of activation + * mkldnn_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift); + * + * // create & configure rnn op_desc + * mkldnn_rnn_desc_t rnn_d; + * mkldnn_primitive_desc_t rnn_pd; + * mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL); + * @endcode + * @note + * Quantization scale and shift are common for src_layer, src_iter, + * dst_iter and dst_layer. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_data_qparams( + mkldnn_primitive_attr_t attr, const float scale, const float shift); + +/** Sets quantization scales @p weights_scales for RNN weights tensors. + * Low precision configuration of RNN primitive expects input weights to have + * signed int8 data type. Scales used to quantize floating point data + * to signed integer must be passed to RNN primitive using attributes. + * The @p mask argument defines correspondence between output tensor dimensions + * and the @p weights_scales array. Set i-th bit of @p mask to 1 to use + * dedicated scaling factor for each slice of the output tensor over i-th + * dimension. Set @p mask to 0 to use common scaling factor for the whole output + * tensor. Example usage: + * @code + * // rnn parameters + * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; + * // unique output scales per output channel + * float weights_scales[dic * n_gates] = { ... }; + * // mask that specifies last two dimensions of ldigo format + * int mask = 0x3; + * + * mkldnn_primitive_attr_t attr; + * // create default attributes + * mkldnn_primitive_attr_create(&attr); + * + * // set output channel-wise weights scales + * mkldnn_primitive_attr_set_rnn_weights_qparams(attr, dic * n_gates, mask, + * weights_scales); + * + * // create & configure rnn op_desc + * mkldnn_rnn_desc_t rnn_d; + * mkldnn_primitive_desc_t rnn_pd; + * mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL); + * @endcode + * @note + * The dimension order is always native and does not depend on the actual + * layout used. For example, 5 dimensional weights always have + * (l, d, i, g, o) logical dimension ordering. + * @note + * Quantization sales are common for weights_layer and weights_iteration + * @note + * There is no way to check that @p count corresponds to @p mask until an + * actual primitive descriptor is created, so it is user's responsibility + * to set proper values. The following formula must be held: + * + * \f[count = \prod\limits_{d \in mask} output.dims[d]\f] + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_weights_qparams ( + mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask, + const float *weights_scales); + +/** Initializes a rnn descriptor @p rnn_desc for forward propagation + * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors. + * @note If @p prop_kind equals #mkldnn_forward_training, you must query a + * workspace memory descriptor before creating the primitive. + * + * @p src_iter_desc, @p bias_desc, and @p dst_iter_desc are allowed to either be + * @c NULL or point to a zero memory descriptor, which would indicate that the + * RNN primitive should not use them. + * + * @note All memory descriptors except @p src_iter_desc are allowed to be + * initialized with #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src_layer (#mkldnn_query_src_md, 0) + * - src_iter (#mkldnn_query_src_md, 1), if used + * - weights_layer (#mkldnn_query_weights_md, 0) + * - weights_iter (#mkldnn_query_weights_md, 1) + * - bias (#mkldnn_query_weights_md, 2), if used + * + * Outputs: + * - dst_layer (#mkldnn_query_dst_md, 0) + * - dst_iter (#mkldnn_query_dst_md, 1), if used + * - workspace (#mkldnn_query_workspace_md, 0), + * if @p prop_kind equals #mkldnn_forward_training + */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init( + mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_rnn_cell_desc_t *rnn_cell_desc, + const mkldnn_rnn_direction_t direction, + const mkldnn_memory_desc_t *src_layer_desc, + const mkldnn_memory_desc_t *src_iter_desc, + const mkldnn_memory_desc_t *weights_layer_desc, + const mkldnn_memory_desc_t *weights_iter_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_layer_desc, + const mkldnn_memory_desc_t *dst_iter_desc); + +/** Initializes a rnn descriptor @p rnn_desc for backward propagation + * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors. + * + * @note All memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * @p src_iter_desc (simultaneously with @p diff_src_iter_desc), + * @p bias_desc (simultaneously with @p diff_bias_desc), and + * @p dst_iter_desc (simultaneously with @p diff_src_iter_desc) are allowed to + * either be @c NULL or point to a zero memory descriptor, which would indicate + * that the RNN primitive should not use them. + * + * Inputs: + * - src_layer (#mkldnn_query_src_md, 0) + * - src_iter (#mkldnn_query_src_md, 1), if used + * - weights_layer (#mkldnn_query_weights_md, 0) + * - weights_iter (#mkldnn_query_weights_md, 1) + * - bias (#mkldnn_query_weights_md, 2), if used + * - dst_layer (#mkldnn_query_dst_md, 0) + * - dst_iter (#mkldnn_query_dst_md, 1), if used + * - diff_dst_layer (#mkldnn_query_diff_dst_md, 0) + * - diff_dst_iter (#mkldnn_query_diff_dst_md, 1), if used + * - workspace (#mkldnn_query_workspace_md, 0) + * + * Outputs: + * - diff_src_layer (#mkldnn_query_diff_src_md, 0) + * - diff_src_iter (#mkldnn_query_diff_src_md, 1), if used + * - diff_weights_layer (#mkldnn_query_diff_weights_md, 0) + * - diff_weights_iter (#mkldnn_query_diff_weights_md, 1) + * - diff_bias (#mkldnn_query_diff_weights_md, 2), if used + */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init( + mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_rnn_cell_desc_t *rnn_cell_desc, + const mkldnn_rnn_direction_t direction, + const mkldnn_memory_desc_t *src_layer_desc, + const mkldnn_memory_desc_t *src_iter_desc, + const mkldnn_memory_desc_t *weights_layer_desc, + const mkldnn_memory_desc_t *weights_iter_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_layer_desc, + const mkldnn_memory_desc_t *dst_iter_desc, + const mkldnn_memory_desc_t *diff_src_layer_desc, + const mkldnn_memory_desc_t *diff_src_iter_desc, + const mkldnn_memory_desc_t *diff_weights_layer_desc, + const mkldnn_memory_desc_t *diff_weights_iter_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_layer, + const mkldnn_memory_desc_t *diff_dst_iter_desc); + +/** @} */ + +/** @} */ + +/** @addtogroup c_api_engine Engine operations + * @{ */ + +/** Returns the number of engines of a particular @p kind. */ +size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind); + +/** Creates an @p engine of particular @p kind and @p index. */ +mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine, + mkldnn_engine_kind_t kind, size_t index); + +/** Returns the kind of an @p engine. */ +mkldnn_status_t MKLDNN_API mkldnn_engine_get_kind(mkldnn_engine_t engine, + mkldnn_engine_kind_t *kind); + +/** Destroys an @p engine. */ +mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine); + +/** @} */ + +/** @addtogroup c_api_stream Execution stream operations + * @{ */ + +/** Creates an execution @p stream for @p engine and with @p flags. */ +mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream, + mkldnn_engine_t engine, unsigned flags); + +/** Destroys an execution @p stream. */ +mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream); + +/** @} */ + +/** @addtogroup c_api_service Service functions + * @{ */ + +/** Sets verbosity level (print information to stdout). + * Possible levels are: + * - 0 -- no verbose output (default) + * - 1 -- primitive information at execution + * - 2 -- primitive information at creation and execution + * + * @note + * Dumping information might affect performance. + * This setting overrides the MKLDNN_VERBOSE environment variable. */ +mkldnn_status_t MKLDNN_API mkldnn_set_verbose(int level); + +/** Enables or disables dumping of JIT-generated code. + * The enable parameter can be: + * - 0 -- disable + * - any other value -- enable + * + * @note + * This setting overrides the MKLDNN_JIT_DUMP environment variable. */ +mkldnn_status_t MKLDNN_API mkldnn_set_jit_dump(int enable); + +/** Gets library version information. + * Version information includes: + * - major -- major version number + * - minor -- minor version number + * - patch -- patch release number + * - hash -- git commit hash */ +const mkldnn_version_t MKLDNN_API *mkldnn_version(); + +/** @} */ + +/** @addtogroup c_api_blas BLAS functions + * A subset of Basic Linear ALgebra (BLAS) functions to perform + * matrix-matrix multiplication. + * @{ */ + +/** SGEMM performs a matrix-matrix multiplication operation defined as + * + * C := alpha*op( A )*op( B ) + beta*C + * + * where + * - op( X ) is one of op( X ) = X or op( X ) = X**T, + * - alpha and beta are scalars, + * - A, B and C are matrices, with op( A ) an m by k matrix, op( B ) a k by n matrix + * and C an m by n matrix. + * + * The matrices are assumed to be stored in column-major order (the elements + * in a matrix columns are contiguous in memory). + * + * @note + * The API is different from the standard BLAS routine + * because it returns mkldnn_status_t for error handling. + * XERBLA is not supported: no error message will be printed + * in case of incorrect parameters. */ +mkldnn_status_t MKLDNN_API mkldnn_sgemm( + const char *transa, const char *transb, + const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, + const float *alpha, const float *A, const mkldnn_dim_t *lda, + const float *B, const mkldnn_dim_t *ldb, + const float *beta, float *C, const mkldnn_dim_t *ldc); + +/** gemm_s8u8s32 and gemm_s8s8s32 perform a matrix-matrix multiplication + * operation and add the result to a scalar-matrix product. For the final + * result, a vector is added to each row or column of the output matrix. + * The operation is defined as: + * + * C := alpha*(op(A) + A_offset) * (op(B) + B_offset) + beta*C + C_offset + * + * where + * - op( X ) = X or op( X ) = X**T, + * - A_offset is an m-by-k matrix with every element equal to the value oa, + * - B_offset is an k-by-n matrix with every element equal to the value ob, + * - C_offset is an m-by-n matrix defined by the oc array, size len: + * - if offsetc = F: len must be at least 1 + * - if offsetc = C: len must be at least max(1, m) + * - if offsetc = R: len must be at least max(1, n) + * - alpha and beta are scalars, and A, B and C are matrices, with op( A ) + * an m-by-k matrix, op( B ) a k-by-n matrix and C an m-by-n matrix. + * + * The matrices are assumed to be stored in column-major order (the elements + * in a matrix columns are contiguous in memory). + * + * @note + * The API is different compared with the standard BLAS routine + * because it returns mkldnn_status_t for error handling. + * XERBLA is not supported: no error message will be printed + * in case of incorrect parameters. */ +mkldnn_status_t MKLDNN_API mkldnn_gemm_s8u8s32( + const char *transa, const char *transb, const char *offsetc, + const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, + const float *alpha, + const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao, + const uint8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo, + const float *beta, + int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co); + +mkldnn_status_t MKLDNN_API mkldnn_gemm_s8s8s32( + const char *transa, const char *transb, const char *offsetc, + const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, + const float *alpha, + const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao, + const int8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo, + const float *beta, + int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co); +/** @} */ + +/** @} */ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp b/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp new file mode 100644 index 000000000000..581400a0139f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp @@ -0,0 +1,2615 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_HPP +#define MKLDNN_HPP + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +#include +#include +#include +#include +#include +#include + +#include "mkldnn.h" +#endif + +namespace mkldnn { + +/// @addtogroup cpp_api C++ API +/// @{ + +/// @addtogroup cpp_api_utils Utils +/// @{ + +/// A class that provides the destructor for an Intel(R) MKL-DNN C handle +template class handle_traits {}; + +/// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base +/// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and +/// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class +/// can be passed by value. This class enables wrapping: +/// - Newly constructed handles. +/// @n In this case, the constructed handle uses reference counting provided +/// by @p std::shared_ptr with a proper deleter function specified through +/// the @p handle_traits class. +/// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for +/// example, through mkldnn_primitive_get_primitive_desc()). +/// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a +/// deleter because it is assumed that the handle wrapper for the original +/// object deletes the handle (this model is similar to @p std::weak_ptr). +template > class handle { +private: + std::shared_ptr::type> _data; + handle(const handle &&) = delete; + handle &operator=(const handle &&other) = delete; +protected: + bool operator==(const T other) const { return other == _data.get(); } + bool operator!=(const T other) const { return !(*this == other); } +public: + /// Constructs a C handle wrapper. + /// @param t The C handle to wrap. + /// @param weak A flag to specify whether to construct a weak wrapper. + handle(T t = 0, bool weak = false): _data(0) { + reset(t, weak); + } + + handle(const handle &other): _data(other._data) {} + handle &operator=(const handle &other) { + _data = other._data; + return *this; + } + /// Resets the value of a C handle. + /// @param t The new value of the C handle. + /// @param weak A flag to specify whether the wrapper should be weak. + void reset(T t, bool weak = false) { + auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); }; + _data.reset(t, weak ? dummy_destructor : traits::destructor); + } + + /// Returns the value of the underlying C handle. + T get() const { return _data.get(); } + + bool operator==(const handle &other) const { return other._data.get() == _data.get(); } + bool operator!=(const handle &other) const { return !(*this == other); } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_memory_destroy; +}; + +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_desc_destroy; +}; + +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_destroy; +}; + +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy; +}; +#endif + +struct memory; +struct primitive_desc; + +/// Base class for all computational primitives. +class primitive: public handle { + friend struct error; + friend struct stream; + using handle::handle; +public: + /// A proxy to C primitive kind enum + enum class kind { + undefined_primitive = mkldnn_undefined_primitive, + reorder = mkldnn_reorder, + concat = mkldnn_concat, + sum = mkldnn_sum, + convolution = mkldnn_convolution, + deconvolution = mkldnn_deconvolution, + shuffle = mkldnn_shuffle, + eltwise = mkldnn_eltwise, + softmax = mkldnn_softmax, + pooling = mkldnn_pooling, + lrn = mkldnn_lrn, + batch_normalization = mkldnn_batch_normalization, + inner_product = mkldnn_inner_product, + rnn = mkldnn_rnn, + }; + + primitive(const_mkldnn_primitive_desc_t c_pd); + primitive(const primitive_desc &pd); + + /// Returns the descriptor of the underlying C API primitive. + inline const_mkldnn_primitive_desc_t get_primitive_desc() const; + // TODO: use the C++ API wrapper structure. + + void execute(struct stream &astream, + const std::unordered_map &args) const; +}; + +inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) { + return static_cast(akind); +} +/// Intel(R) MKL-DNN exception class. +/// +/// This class captures the status returned by the failed C API function, error +/// message, and, optionally, handle of the primitive that caused the error. +struct error: public std::exception { + mkldnn_status_t status; + const char *message; + + /// Constructs an error instance. + /// + /// @param astatus The error status returned by the C API. + /// @param amessage The error message. + error(mkldnn_status_t astatus, const char *amessage) + : status(astatus), message(amessage) {} + + /// A convenience function for wrapping calls to the C API. Checks the + /// return status and throws an #error in case of failure. + /// + /// @param status The error status returned by the C API. + /// @param message The error message. + static void wrap_c_api(mkldnn_status_t status, const char *message) { + if (status != mkldnn_success) + throw error(status, message); + } +}; + +const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const { + const_mkldnn_primitive_desc_t pd; + error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd), + "could not get primitive descriptor by primitive"); + return pd; +} +/// @} + +/// @addtogroup cpp_api_enums Common data types and enumerations +/// A proxy to @ref c_api_types in @ref c_api. +/// +/// @{ + +enum scratchpad_mode { + scratchpad_mode_library = mkldnn_scratchpad_mode_library, + scratchpad_mode_user = mkldnn_scratchpad_mode_user, +}; + +inline mkldnn_scratchpad_mode_t convert_to_c(scratchpad_mode mode) { + return static_cast(mode); +} + +enum padding_kind { + zero = mkldnn_padding_zero +}; + +inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) { + return static_cast(kind); +} + +enum prop_kind { + forward_training = mkldnn_forward_training, + forward_scoring = mkldnn_forward_scoring, + forward_inference = mkldnn_forward_inference, + forward = mkldnn_forward, + backward = mkldnn_backward, + backward_data = mkldnn_backward_data, + backward_weights = mkldnn_backward_weights, + backward_bias = mkldnn_backward_bias +}; + +inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) { + return static_cast(kind); +} + +enum algorithm { + algorithm_undef = mkldnn_alg_kind_undef, + convolution_auto = mkldnn_convolution_auto, + convolution_direct = mkldnn_convolution_direct, + convolution_winograd = mkldnn_convolution_winograd, + deconvolution_direct = mkldnn_deconvolution_direct, + deconvolution_winograd = mkldnn_deconvolution_winograd, + eltwise_relu = mkldnn_eltwise_relu, + eltwise_tanh = mkldnn_eltwise_tanh, + eltwise_elu = mkldnn_eltwise_elu, + eltwise_square = mkldnn_eltwise_square, + eltwise_abs = mkldnn_eltwise_abs, + eltwise_sqrt = mkldnn_eltwise_sqrt, + eltwise_linear = mkldnn_eltwise_linear, + eltwise_bounded_relu = mkldnn_eltwise_bounded_relu, + eltwise_soft_relu = mkldnn_eltwise_soft_relu, + eltwise_logistic = mkldnn_eltwise_logistic, + lrn_across_channels = mkldnn_lrn_across_channels, + lrn_within_channel = mkldnn_lrn_within_channel, + pooling_max = mkldnn_pooling_max, + pooling_avg = mkldnn_pooling_avg, + pooling_avg_include_padding = mkldnn_pooling_avg_include_padding, + pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding, + vanilla_rnn = mkldnn_vanilla_rnn, + vanilla_lstm = mkldnn_vanilla_lstm, + vanilla_gru = mkldnn_vanilla_gru, + gru_linear_before_reset = mkldnn_gru_linear_before_reset +}; + +inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) { + return static_cast(aalgorithm); +} + +enum batch_normalization_flag { + use_global_stats = mkldnn_use_global_stats, + use_scale_shift = mkldnn_use_scaleshift, + fuse_bn_relu = mkldnn_fuse_bn_relu +}; + +inline mkldnn_batch_normalization_flag_t convert_to_c( + batch_normalization_flag aflag) { + return static_cast(aflag); +} + +enum rnn_direction { + unidirectional_left2right = mkldnn_unidirectional_left2right, + unidirectional_right2left = mkldnn_unidirectional_right2left, + unidirectional = mkldnn_unidirectional, + bidirectional_concat = mkldnn_bidirectional_concat, + bidirectional_sum = mkldnn_bidirectional_sum, +}; + +inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) { + return static_cast(adir); +} + +enum query { + undef = mkldnn_query_undef, + + query_engine = mkldnn_query_engine, + primitive_kind = mkldnn_query_primitive_kind, + + num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32, + num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32, + + time_estimate_f64 = mkldnn_query_time_estimate_f64, + memory_consumption_s64 = mkldnn_query_memory_consumption_s64, + + query_scratchpad_engine = mkldnn_query_scratchpad_engine, + + impl_info_str = mkldnn_query_impl_info_str, + + op_d = mkldnn_query_op_d, + convolution_d = mkldnn_query_convolution_d, + deconvolution_d = mkldnn_query_deconvolution_d, + shuffle_d = mkldnn_query_shuffle_d, + eltwise_d = mkldnn_query_eltwise_d, + softmax_d = mkldnn_query_softmax_d, + pooling_d = mkldnn_query_pooling_d, + lrn_d = mkldnn_query_lrn_d, + batch_normalization_d = mkldnn_query_batch_normalization_d, + inner_product_d = mkldnn_query_inner_product_d, + rnn_d = mkldnn_query_rnn_d, + + src_md = mkldnn_query_src_md, + diff_src_md = mkldnn_query_diff_src_md, + weights_md = mkldnn_query_weights_md, + diff_weights_md = mkldnn_query_diff_weights_md, + dst_md = mkldnn_query_dst_md, + diff_dst_md = mkldnn_query_diff_dst_md, + workspace_md = mkldnn_query_workspace_md, + scratchpad_md = mkldnn_query_scratchpad_md, +}; + +inline mkldnn_query_t convert_to_c(query aquery) { + return static_cast(aquery); +} + +/// @} + +/// @addtogroup cpp_api_attr Attributes +/// An extension for controlling primitive behavior. +/// +/// @sa @ref c_api_attributes in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_post_ops_destroy; +}; +#endif + +struct post_ops: public handle { + post_ops() { + mkldnn_post_ops_t result; + error::wrap_c_api(mkldnn_post_ops_create(&result), + "could not create post operation sequence"); + reset(result); + } + + int len() const { return mkldnn_post_ops_len(get()); } + + primitive::kind kind(int index) const { + error::wrap_c_api( + index < len() ? mkldnn_success : mkldnn_invalid_arguments, + "post_ops index is out of range"); + return static_cast(mkldnn_post_ops_get_kind(get(), + index)); + } + + void append_sum(float scale = 1.) { + error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale), + "could not append sum"); + } + + void get_params_sum(int index, float &scale) const { + error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale), + "could not get sum params"); + } + + void append_eltwise(float scale, algorithm alg, float alpha, + float beta) { + error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale, + convert_to_c(alg), alpha, beta), + "could not append eltwise"); + } + + void get_params_eltwise(int index, float &scale, algorithm &alg, + float &alpha, float &beta) const { + mkldnn_alg_kind_t c_alg; + error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index, + &scale, &c_alg, &alpha, &beta), + "could not get eltwise params"); + alg = static_cast(c_alg); + } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_attr_destroy; +}; +#endif + +struct primitive_attr: public handle { + primitive_attr() { + mkldnn_primitive_attr_t result; + error::wrap_c_api(mkldnn_primitive_attr_create(&result), + "could not create a primitive attr"); + reset(result); + } + + scratchpad_mode get_scratchpad_mode() const { + mkldnn_scratchpad_mode_t result; + error::wrap_c_api(mkldnn_primitive_attr_get_scratchpad_mode( + get(), &result), "could not get scratchpad mode"); + return scratchpad_mode(result); + } + + void set_scratchpad_mode(scratchpad_mode mode) { + error::wrap_c_api(mkldnn_primitive_attr_set_scratchpad_mode( + get(), mkldnn::convert_to_c(mode)), + "could not set scratchpad mode"); + } + + void get_output_scales(int &mask, std::vector &scales) const + { + mkldnn_dim_t count; + int c_mask; + const float *c_scales; + error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(), + &count, &c_mask, &c_scales), + "could not get int output scales"); + scales.resize(count); + + mask = c_mask; + for (mkldnn_dim_t c = 0; c < count; ++c) + scales[c] = c_scales[c]; + } + + void set_output_scales(int mask, const std::vector &scales) + { + error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(), + (mkldnn_dim_t)scales.size(), mask, &scales[0]), + "could not set int output scales"); + } + + const post_ops get_post_ops() const { + post_ops result; + const_mkldnn_post_ops_t c_result; + error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result), + "could not get post operation sequence"); + result.reset(const_cast(c_result), true); + return result; + } + + void set_post_ops(post_ops ops) { + error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()), + "could not set post operation sequence"); + } + + void set_rnn_data_qparams(const float scale, const float shift) + { + error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(), + scale, shift), "could not set rnn data int scale/shift"); + } + + void set_rnn_weights_qparams(int mask, const std::vector &scales) + { + error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(), + (int)scales.size(), mask, &scales[0]), + "could not set rnn weights int scales"); + } +}; + +/// @} + +/// @addtogroup cpp_api_engine Engine +/// Engine operations. +/// +/// @sa @ref c_api_engine in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_engine_destroy; +}; +#endif + +/// An execution engine. +struct engine: public handle { + friend class primitive; + // gcc bug??? using handle::handle; + + /// Kinds of engines. + enum kind { + /// An unspecified engine + any = mkldnn_any_engine, + /// CPU engine + cpu = mkldnn_cpu, + }; + + /// Returns the number of engines of a certain kind. + /// + /// @param akind The kind of engines to count. + + static size_t get_count(kind akind) { + return mkldnn_engine_get_count(convert_to_c(akind)); + } + + /// Constructs an engine. + /// + /// @param akind The kind of engine to construct. + /// @param index The index of the engine. Must be less than the value + /// returned by #get_count() for this particular kind of engine. + + engine(kind akind, size_t index) { + mkldnn_engine_t aengine; + error::wrap_c_api( + mkldnn_engine_create(&aengine, + convert_to_c(akind), index), + "could not create an engine"); + reset(aengine); + } + + explicit engine(const mkldnn_engine_t& aengine) + : handle(aengine, true) {} + + engine(const handle &pd) { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(pd.get(), + mkldnn::convert_to_c(query_engine), 0, &engine_q), + "could not get engine from primitive_desc"); + reset(engine_q, true); + } + + template + static engine query(const primitive_desc &pd) { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(pd.get(), + mkldnn::convert_to_c(query_engine), 0, &engine_q), + "could not get engine from primitive_desc"); + + return engine(engine_q); + } + +private: + static mkldnn_engine_kind_t convert_to_c(kind akind) { + return static_cast(akind); + } +}; + +/// @} + +/// @addtogroup cpp_api_stream Stream +/// Execution stream operations +/// +/// @sa @ref c_api_stream in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_stream_destroy; +}; +#endif + +struct stream: public handle { + using handle::handle; + + enum: unsigned { + default_flags = mkldnn_stream_default_flags, + }; + + /// Constructs a stream. + stream(const engine &aengine, + unsigned flags = static_cast(default_flags)) { + mkldnn_stream_t astream; + error::wrap_c_api(mkldnn_stream_create(&astream, aengine.get(), flags), + "could not create a stream"); + reset(astream); + } +}; + +/// @} + +/// @addtogroup cpp_api_memory_related Memory and memory related operations +/// @{ + +/// @addtogroup cpp_api_memory Memory +/// A primitive to describe and store data. +/// +/// For more information, refer to @ref c_api_memory in @ref c_api. +/// @{ + +/// Memory that describes the data. +struct memory: public handle { + public: + typedef mkldnn_dim_t dim; + typedef std::vector dims; + + template static void validate_dims(const std::vector &v) { + if (v.size() > MKLDNN_MAX_NDIMS) + throw error(mkldnn_invalid_arguments, "invalid dimensions"); + } + + /// Data type specification. See #mkldnn_data_type_t for a detailed + /// description. + enum data_type { + data_undef = mkldnn_data_type_undef, + f32 = mkldnn_f32, + s32 = mkldnn_s32, + s8 = mkldnn_s8, + u8 = mkldnn_u8, + }; + + /// Memory format tag specification. See #mkldnn_format_tag_t + /// for a detailed description. + enum format_tag { + format_tag_undef = mkldnn_format_tag_undef, + any = mkldnn_format_tag_any, + a = mkldnn_a, + ab = mkldnn_ab, + abc = mkldnn_abc, + abcd = mkldnn_abcd, + abcde = mkldnn_abcde, + abcdef = mkldnn_abcdef, + abdec = mkldnn_abdec, + acb = mkldnn_acb, + acbde = mkldnn_acbde, + acdb = mkldnn_acdb, + acdeb = mkldnn_acdeb, + ba = mkldnn_ba, + bac = mkldnn_bac, + bacd = mkldnn_bacd, + bcda = mkldnn_bcda, + cba = mkldnn_cba, + cdba = mkldnn_cdba, + cdeba = mkldnn_cdeba, + decab = mkldnn_decab, + Abc16a = mkldnn_Abc16a, + ABc16a16b = mkldnn_ABc16a16b, + aBc16b = mkldnn_aBc16b, + ABc16b16a = mkldnn_ABc16b16a, + Abc4a = mkldnn_Abc4a, + aBc4b = mkldnn_aBc4b, + ABc4b16a4b = mkldnn_ABc4b16a4b, + ABc4b4a = mkldnn_ABc4b4a, + ABc8a16b2a = mkldnn_ABc8a16b2a, + ABc8a8b = mkldnn_ABc8a8b, + aBc8b = mkldnn_aBc8b, + ABc8b16a2b = mkldnn_ABc8b16a2b, + ABc8b8a = mkldnn_ABc8b8a, + Abcd16a = mkldnn_Abcd16a, + ABcd16a16b = mkldnn_ABcd16a16b, + aBcd16b = mkldnn_aBcd16b, + ABcd16b16a = mkldnn_ABcd16b16a, + aBCd16b16c = mkldnn_aBCd16b16c, + aBCd16c16b = mkldnn_aBCd16c16b, + Abcd4a = mkldnn_Abcd4a, + aBcd4b = mkldnn_aBcd4b, + ABcd4b16a4b = mkldnn_ABcd4b16a4b, + ABcd4b4a = mkldnn_ABcd4b4a, + aBCd4c16b4c = mkldnn_aBCd4c16b4c, + aBCd4c4b = mkldnn_aBCd4c4b, + ABcd8a16b2a = mkldnn_ABcd8a16b2a, + ABcd8a8b = mkldnn_ABcd8a8b, + aBcd8b = mkldnn_aBcd8b, + ABcd8b16a2b = mkldnn_ABcd8b16a2b, + aBCd8b16c2b = mkldnn_aBCd8b16c2b, + ABcd8b8a = mkldnn_ABcd8b8a, + aBCd8b8c = mkldnn_aBCd8b8c, + aBCd8c16b2c = mkldnn_aBCd8c16b2c, + aBCd8c8b = mkldnn_aBCd8c8b, + Abcde16a = mkldnn_Abcde16a, + ABcde16a16b = mkldnn_ABcde16a16b, + aBcde16b = mkldnn_aBcde16b, + ABcde16b16a = mkldnn_ABcde16b16a, + aBCde16b16c = mkldnn_aBCde16b16c, + aBCde16c16b = mkldnn_aBCde16c16b, + aBCde2c8b4c = mkldnn_aBCde2c8b4c, + Abcde4a = mkldnn_Abcde4a, + aBcde4b = mkldnn_aBcde4b, + ABcde4b4a = mkldnn_ABcde4b4a, + aBCde4b4c = mkldnn_aBCde4b4c, + aBCde4c16b4c = mkldnn_aBCde4c16b4c, + aBCde4c4b = mkldnn_aBCde4c4b, + Abcde8a = mkldnn_Abcde8a, + ABcde8a8b = mkldnn_ABcde8a8b, + aBcde8b = mkldnn_aBcde8b, + ABcde8b16a2b = mkldnn_ABcde8b16a2b, + aBCde8b16c2b = mkldnn_aBCde8b16c2b, + ABcde8b8a = mkldnn_ABcde8b8a, + aBCde8b8c = mkldnn_aBCde8b8c, + aBCde8c16b2c = mkldnn_aBCde8c16b2c, + aBCde8c8b = mkldnn_aBCde8c8b, + aBcdef16b = mkldnn_aBcdef16b, + aBCdef16b16c = mkldnn_aBCdef16b16c, + aBCdef16c16b = mkldnn_aBCdef16c16b, + aBcdef4b = mkldnn_aBcdef4b, + aBCdef4c4b = mkldnn_aBCdef4c4b, + aBCdef8b8c = mkldnn_aBCdef8b8c, + aBCdef8c16b2c = mkldnn_aBCdef8c16b2c, + aBCdef8c8b = mkldnn_aBCdef8c8b, + aBdc16b = mkldnn_aBdc16b, + aBdc4b = mkldnn_aBdc4b, + aBdc8b = mkldnn_aBdc8b, + aBdec16b = mkldnn_aBdec16b, + aBdec4b = mkldnn_aBdec4b, + aBdec8b = mkldnn_aBdec8b, + aBdefc16b = mkldnn_aBdefc16b, + aBdefc4b = mkldnn_aBdefc4b, + aBdefc8b = mkldnn_aBdefc8b, + Acb16a = mkldnn_Acb16a, + Acb4a = mkldnn_Acb4a, + Acb8a = mkldnn_Acb8a, + aCBd16b16c = mkldnn_aCBd16b16c, + aCBde16b16c = mkldnn_aCBde16b16c, + Acdb16a = mkldnn_Acdb16a, + Acdb4a = mkldnn_Acdb4a, + Acdb8a = mkldnn_Acdb8a, + Acdeb16a = mkldnn_Acdeb16a, + Acdeb4a = mkldnn_Acdeb4a, + Acdeb8a = mkldnn_Acdeb8a, + BAc16a16b = mkldnn_BAc16a16b, + BAcd16a16b = mkldnn_BAcd16a16b, + format_tag_last = mkldnn_format_tag_last, + + x = mkldnn_x, + nc = mkldnn_nc, + cn = mkldnn_cn, + ncw = mkldnn_ncw, + nwc = mkldnn_nwc, + nchw = mkldnn_nchw, + nhwc = mkldnn_nhwc, + chwn = mkldnn_chwn, + ncdhw = mkldnn_ncdhw, + ndhwc = mkldnn_ndhwc, + oi = mkldnn_oi, + io = mkldnn_io, + oiw = mkldnn_oiw, + wio = mkldnn_wio, + oihw = mkldnn_oihw, + hwio = mkldnn_hwio, + ihwo = mkldnn_ihwo, + iohw = mkldnn_iohw, + oidhw = mkldnn_oidhw, + dhwio = mkldnn_dhwio, + goiw = mkldnn_goiw, + goihw = mkldnn_goihw, + hwigo = mkldnn_hwigo, + giohw = mkldnn_giohw, + goidhw = mkldnn_goidhw, + tnc = mkldnn_tnc, + ntc = mkldnn_ntc, + ldsnc = mkldnn_ldsnc, + ldigo = mkldnn_ldigo, + ldgoi = mkldnn_ldgoi, + ldgo = mkldnn_ldgo, + nCdhw16c = mkldnn_nCdhw16c, + nCdhw4c = mkldnn_nCdhw4c, + nCdhw8c = mkldnn_nCdhw8c, + nChw16c = mkldnn_nChw16c, + nChw4c = mkldnn_nChw4c, + nChw8c = mkldnn_nChw8c, + nCw16c = mkldnn_nCw16c, + nCw4c = mkldnn_nCw4c, + nCw8c = mkldnn_nCw8c, + IOw16o16i = mkldnn_IOw16o16i, + OIw16i16o = mkldnn_OIw16i16o, + OIw16o16i = mkldnn_OIw16o16i, + Oiw16o = mkldnn_Oiw16o, + OIw4i16o4i = mkldnn_OIw4i16o4i, + OIw4i4o = mkldnn_OIw4i4o, + Oiw4o = mkldnn_Oiw4o, + OIw8i16o2i = mkldnn_OIw8i16o2i, + OIw8i8o = mkldnn_OIw8i8o, + OIw8o16i2o = mkldnn_OIw8o16i2o, + OIw8o8i = mkldnn_OIw8o8i, + Owi16o = mkldnn_Owi16o, + Owi4o = mkldnn_Owi4o, + Owi8o = mkldnn_Owi8o, + IOhw16o16i = mkldnn_IOhw16o16i, + Ohwi16o = mkldnn_Ohwi16o, + Ohwi4o = mkldnn_Ohwi4o, + Ohwi8o = mkldnn_Ohwi8o, + OIhw16i16o = mkldnn_OIhw16i16o, + OIhw16o16i = mkldnn_OIhw16o16i, + Oihw16o = mkldnn_Oihw16o, + OIhw4i16o4i = mkldnn_OIhw4i16o4i, + OIhw4i4o = mkldnn_OIhw4i4o, + Oihw4o = mkldnn_Oihw4o, + OIhw8i16o2i = mkldnn_OIhw8i16o2i, + OIhw8i8o = mkldnn_OIhw8i8o, + OIhw8o16i2o = mkldnn_OIhw8o16i2o, + OIhw8o8i = mkldnn_OIhw8o8i, + Odhwi16o = mkldnn_Odhwi16o, + Odhwi4o = mkldnn_Odhwi4o, + Odhwi8o = mkldnn_Odhwi8o, + OIdhw16i16o = mkldnn_OIdhw16i16o, + OIdhw16o16i = mkldnn_OIdhw16o16i, + Oidhw16o = mkldnn_Oidhw16o, + OIdhw4i4o = mkldnn_OIdhw4i4o, + Oidhw4o = mkldnn_Oidhw4o, + OIdhw8i16o2i = mkldnn_OIdhw8i16o2i, + OIdhw8i8o = mkldnn_OIdhw8i8o, + OIdhw8o8i = mkldnn_OIdhw8o8i, + gIOw16o16i = mkldnn_gIOw16o16i, + gOIw16i16o = mkldnn_gOIw16i16o, + gOIw16o16i = mkldnn_gOIw16o16i, + gOiw16o = mkldnn_gOiw16o, + gOIw4i16o4i = mkldnn_gOIw4i16o4i, + gOIw4i4o = mkldnn_gOIw4i4o, + gOiw4o = mkldnn_gOiw4o, + gOIw8i16o2i = mkldnn_gOIw8i16o2i, + gOIw8i8o = mkldnn_gOIw8i8o, + gOIw8o16i2o = mkldnn_gOIw8o16i2o, + gOIw8o8i = mkldnn_gOIw8o8i, + gOwi16o = mkldnn_gOwi16o, + gOwi4o = mkldnn_gOwi4o, + gOwi8o = mkldnn_gOwi8o, + gIOhw16o16i = mkldnn_gIOhw16o16i, + gOhwi16o = mkldnn_gOhwi16o, + gOhwi4o = mkldnn_gOhwi4o, + gOhwi8o = mkldnn_gOhwi8o, + Goihw16g = mkldnn_Goihw16g, + gOIhw16i16o = mkldnn_gOIhw16i16o, + gOIhw16o16i = mkldnn_gOIhw16o16i, + gOihw16o = mkldnn_gOihw16o, + gOIhw2i8o4i = mkldnn_gOIhw2i8o4i, + gOIhw4i16o4i = mkldnn_gOIhw4i16o4i, + gOIhw4i4o = mkldnn_gOIhw4i4o, + gOIhw4o4i = mkldnn_gOIhw4o4i, + gOihw4o = mkldnn_gOihw4o, + Goihw8g = mkldnn_Goihw8g, + gOIhw8i16o2i = mkldnn_gOIhw8i16o2i, + gOIhw8i8o = mkldnn_gOIhw8i8o, + gOIhw8o16i2o = mkldnn_gOIhw8o16i2o, + gOIhw8o8i = mkldnn_gOIhw8o8i, + gOdhwi16o = mkldnn_gOdhwi16o, + gOdhwi4o = mkldnn_gOdhwi4o, + gOdhwi8o = mkldnn_gOdhwi8o, + gOIdhw16i16o = mkldnn_gOIdhw16i16o, + gOIdhw16o16i = mkldnn_gOIdhw16o16i, + gOidhw16o = mkldnn_gOidhw16o, + gOIdhw4i4o = mkldnn_gOIdhw4i4o, + gOidhw4o = mkldnn_gOidhw4o, + gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i, + gOIdhw8i8o = mkldnn_gOIdhw8i8o, + gOIdhw8o8i = mkldnn_gOIdhw8o8i, + }; + + /// A memory descriptor. + struct desc { + friend struct memory; + /// The underlying C API data structure. + mkldnn_memory_desc_t data; + + /// Constructs a zero memory descriptor + desc(): data() {} + + /// Constructs a memory descriptor. + /// + /// @param adims Data dimensions + /// @param adata_type Data precision/type. + /// @param aformat Data layout format tag. + desc(const dims &adims, data_type adata_type, + format_tag aformat) { + validate_dims(adims); + error::wrap_c_api(mkldnn_memory_desc_init_by_tag(&data, (int)adims.size(), + adims.size() == 0 ? nullptr : &adims[0], + convert_to_c(adata_type), convert_to_c(aformat)), + "could not initialize a memory descriptor"); + } + + /// Constructs a memory descriptor from a C API data structure. + /// + /// @param adata A C API #mkldnn_memory_desc_t structure. + desc(const mkldnn_memory_desc_t &adata): data(adata) {} + + /// Constructs a sub-memory descriptor + // + /// @param adims Sizes of a sub-memory + /// @param offsets Offsets of a sub-memory + desc submemory_desc(const dims &adims, const dims &offsets) { + mkldnn_memory_desc_t sub_md; + error::wrap_c_api(mkldnn_memory_desc_init_submemory(&sub_md, + &data, &adims[0], &offsets[0]), + "could not initialize a sub-memory"); + return desc(sub_md); + } + + /// Returns the number of bytes required to allocate the memory described + /// including the padding area. + size_t get_size() const { return mkldnn_memory_desc_get_size(&data); } + + bool operator==(const desc &other) const { + return mkldnn_memory_desc_equal(&data, &other.data) != 0; + } + + bool operator!=(const desc &other) const { return !operator==(other); } + }; + + /// Constructs a memory. + /// + /// @param md Memory descriptor. + /// @param aengine Engine. + /// @param ahandle Native handle. + memory(const desc &md, const engine &aengine, void *ahandle) { + mkldnn_memory_t result; + error::wrap_c_api(mkldnn_memory_create(&result, &md.data, + aengine.get(), ahandle), "could not create a memory"); + reset(result); + } + + /// Constructs a memory. + /// + /// @param md Memory descriptor. + /// @param aengine Engine. + memory(const desc &md, const engine &aengine) + : memory(md, aengine, MKLDNN_NATIVE_HANDLE_ALLOCATE) {} + + /// Returns the descriptor of the memory. + desc get_desc() const { + const mkldnn_memory_desc_t *cdesc; + error::wrap_c_api(mkldnn_memory_get_memory_desc(get(), &cdesc), + "could not get memory descriptor from a memory"); + return desc(*cdesc); + } + + /// Returns the engine of the memory. + engine get_engine() const { + mkldnn_engine_t engine_q; + error::wrap_c_api(mkldnn_memory_get_engine(get(), &engine_q), + "could not get engine from a memory"); + return engine(engine_q); + } + + /// Returns a handle of the data contained in the memory. + /// + /// On the CPU engine, this is a pointer to the allocated memory. + void *get_data_handle() const { + void *handle; + error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle), + "could not get native handle"); + return handle; + } + + void set_data_handle(void *handle) const { + error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle), + "could not set native handle"); + } + + // Must go away or be private: + static mkldnn_data_type_t convert_to_c(data_type adata_type) { + return static_cast(adata_type); + } + static mkldnn_format_tag_t convert_to_c(format_tag aformat) { + return static_cast(aformat); + } +}; + +inline bool operator==(mkldnn_data_type_t a, memory::data_type b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) { + return !(a == b); +} +inline bool operator==(memory::data_type a, mkldnn_data_type_t b) { + return b == a; +} +inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) { + return !(a == b); +} + +inline bool operator==(mkldnn_format_tag_t a, memory::format_tag b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(mkldnn_format_tag_t a, memory::format_tag b) { + return !(a == b); +} +inline bool operator==(memory::format_tag a, mkldnn_format_tag_t b) { + return b == a; +} +inline bool operator!=(memory::format_tag a, mkldnn_format_tag_t b) { + return !(a == b); +} + +/// @} + +/// @addtogroup cpp_api_reorder Reorder +/// A primitive to copy data between memory formats. +/// +/// @sa @ref c_api_reorder in @ref c_api +/// @{ + +struct reorder : public primitive { + struct primitive_desc : public handle { + primitive_desc(const engine &src_engine, const memory::desc &src_md, + const engine &dst_engine, const memory::desc &dst_md, + const primitive_attr &aattr) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src_engine.get(), &src_md.data, + dst_engine.get(), &dst_md.data, aattr.get()), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const engine &src_engine, const memory::desc &src_md, + const engine &dst_engine, const memory::desc &dst_md) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src_engine.get(), &src_md.data, + dst_engine.get(), &dst_md.data, nullptr), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const memory &src, const memory &dst, + const primitive_attr &aattr) { + mkldnn_primitive_desc_t result; + auto src_md = src.get_desc(); + auto dst_md = dst.get_desc(); + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src.get_engine().get(), &src_md.data, + dst.get_engine().get(), &dst_md.data, aattr.get()), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const memory &src, const memory &dst) { + mkldnn_primitive_desc_t result; + auto src_md = src.get_desc(); + auto dst_md = dst.get_desc(); + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src.get_engine().get(), &src_md.data, + dst.get_engine().get(), &dst_md.data, nullptr), + "could not create a reorder primitive descriptor"); + reset(result); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine scratchpad_engine() { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(get(), + mkldnn::convert_to_c(query_scratchpad_engine), 0, &engine_q), + "could not get scratchpad engine from reorder primitive_desc"); + + return engine(engine_q); + } + + engine get_engine() { return engine::query(*this); } + }; + + reorder(const primitive_desc &pd): primitive(pd.get()) {} + + reorder(const memory &src, const memory &dst): + primitive(primitive_desc(src, dst).get()) {} + + void execute(stream astream, memory &src, memory &dst) { + primitive::execute(astream, + {{MKLDNN_ARG_FROM, src}, {MKLDNN_ARG_TO, dst}}); + } +}; + +/// @} + +/// @addtogroup cpp_api_concat Concat +/// A primitive to concatenate data by arbitrary dimension. +/// +/// @sa @ref c_api_concat in @ref c_api +/// @{ + +struct concat : public primitive { + struct primitive_desc : public handle { + std::vector cpp_to_c( + const std::vector &srcs) { + std::vector c_api_srcs; + c_api_srcs.reserve(srcs.size()); + for (const auto &s : srcs) c_api_srcs.push_back(s.data); + return c_api_srcs; + } + + primitive_desc(const memory::desc &dst, int concat_dimension, + const std::vector &srcs, const engine &aengine) { + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_concat_primitive_desc_create( + &result, &dst.data, (int)c_api_srcs.size(), + concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), + "could not create a concat primitive descriptor"); + reset(result); + } + + primitive_desc(int concat_dimension, + const std::vector &srcs, const engine &aengine) { + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_concat_primitive_desc_create( + &result, nullptr, (int)c_api_srcs.size(), + concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), + "could not create a concat primitive descriptor"); + reset(result); + } + + memory::desc dst_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(dst_md), 0); + error::wrap_c_api( + cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, + "could not get a dst memory descriptor"); + return memory::desc(*cdesc); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine get_engine() { return engine::query(*this); } + }; + + concat(const primitive_desc &pd): primitive(pd.get()) {} +}; + +/// @} + +/// @addtogroup cpp_api_sum Sum +/// A primitive to sum data. +/// +/// @sa @ref c_api_sum in @ref c_api +/// @{ + +struct sum : public primitive { + struct primitive_desc : public handle { + std::vector cpp_to_c( + const std::vector &srcs) { + std::vector c_api_srcs; + c_api_srcs.reserve(srcs.size()); + for (const auto &s : srcs) c_api_srcs.push_back(s.data); + return c_api_srcs; + } + + primitive_desc(const memory::desc &dst, + const std::vector &scales, + const std::vector &srcs, const engine &aengine) { + error::wrap_c_api(scales.size() == srcs.size() + ? mkldnn_success : mkldnn_invalid_arguments, + "number of scales not equal to number of srcs"); + + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_sum_primitive_desc_create( + &result, &dst.data, (int)c_api_srcs.size(), + &scales[0], &c_api_srcs[0], nullptr, aengine.get()), + "could not create a sum primitive descriptor"); + reset(result); + } + + primitive_desc(const std::vector &scales, + const std::vector &srcs, const engine &aengine) { + error::wrap_c_api(scales.size() == srcs.size() + ? mkldnn_success : mkldnn_invalid_arguments, + "number of scales not equal to number of srcs"); + + auto c_api_srcs = cpp_to_c(srcs); + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_sum_primitive_desc_create(&result, + nullptr, (int)c_api_srcs.size(), &scales[0], + &c_api_srcs[0], nullptr, aengine.get()), + "could not create a sum primitive descriptor"); + reset(result); + } + + memory::desc dst_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(dst_md), 0); + error::wrap_c_api( + cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, + "could not get a dst memory descriptor"); + return memory::desc(*cdesc); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine get_engine() { return engine::query(*this); } + }; + + sum(const primitive_desc &pd): primitive(pd.get()) {} +}; + +/// @} + +/// @} + +/// @addtogroup cpp_api_primitives Primitives +/// @{ + +/// @addtogroup cpp_api_primitive_descriptors Primitive descriptors +/// @{ + +/// A base class for all primitive descriptors. +struct primitive_desc : public handle { + primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr, + const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) { + mkldnn_primitive_desc_iterator_t iterator = nullptr; + mkldnn_status_t status = mkldnn_primitive_desc_iterator_create( + &iterator, desc, attr ? attr->get() : nullptr, e.get(), + hint_fwd_pd); + error::wrap_c_api(status, + "could not create a primitive descriptor iterator"); + pd_iterator.reset(iterator); + fetch_impl(); + } + + engine get_engine() { return engine::query(*this); } + + primitive_attr get_primitive_attr() const { + const_mkldnn_primitive_attr_t const_cattr; + error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr), + "could not get attributes"); + mkldnn_primitive_attr_t cattr; + error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr), + "could not clone attributes"); + + primitive_attr attr; + attr.reset(cattr); + return attr; + } + + /// Returns implementation name + const char *impl_info_str() const { + const char *res; + error::wrap_c_api(mkldnn_primitive_desc_query(get(), + mkldnn_query_impl_info_str, 0, &res), + "could not query implementation info string"); + return res; + } + + /// Queries the memory::dim value (same as int64_t) + memory::dim query_s64(query q) const { + memory::dim res; + mkldnn_status_t status = mkldnn_primitive_desc_query(get(), + mkldnn::convert_to_c(q), 0, &res); + return status == mkldnn_success ? res : 0; + } + + /// Advances the next implementation for the given op descriptor. + /// + /// Returns: + /// - @c true on success + /// - @c false if the last implementation reached, and + /// the primitive descriptor itself is kept unchanged + bool next_impl() { + mkldnn_status_t status = mkldnn_primitive_desc_iterator_next( + pd_iterator.get()); + if (status == mkldnn_iterator_ends) return false; + error::wrap_c_api(status, "primitive descriptor iterator next failed"); + + fetch_impl(); + return true; + } + + /// Queries and returns requested memory descriptor. + memory::desc query_md(query what, int idx = 0) const { + std::vector valid_q{src_md, diff_src_md, weights_md, + diff_weights_md, dst_md, diff_dst_md, workspace_md, scratchpad_md}; + if (!std::any_of(valid_q.cbegin(), valid_q.cend(), + [=](query q) { return what == q; })) + throw error(mkldnn_invalid_arguments, "invalid memory query"); + + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(what), idx); + if (cdesc == nullptr) return memory::desc(); + + return memory::desc(*cdesc); + } + + // register specialized queries, e.g. src_desc() +# define REG_QUERY_MD(name, what, idx) \ + memory::desc name ## _desc() const { return query_md(what ## _md, idx); } + + private: + handle pd_iterator; + void fetch_impl() { + mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch( + pd_iterator.get()); + error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error, + "could not fetch a primitive descriptor from the iterator"); + reset(pd); + } +}; + +/// @} + +/// @addtogroup cpp_api_convolution Convolution +/// A primitive to compute convolution using different algorithms. +/// +/// @sa @ref c_api_convolution in @ref c_api +/// @{ + +struct convolution_forward: public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &dilates[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &dilates[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated convolution forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct convolution_backward_data : public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward data descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct convolution_backward_weights : public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} +// +/// @addtogroup cpp_api_deconvolution Deconvolution +/// A primitive to compute deconvolution using different algorithms. +/// +/// @sa @ref c_api_deconvolution in @ref c_api +/// @{ + +struct deconvolution_forward: public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], + &padding_r[0], mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], + &padding_r[0], mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct deconvolution_backward_data : public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward data descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct deconvolution_backward_weights : public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward weights descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_lrn LRN +/// A primitive to perform local response normalization (LRN) across or within +/// channels. +/// +/// @sa @ref c_api_lrn in @ref c_api +/// @{ + +struct lrn_forward : public primitive { + struct desc { + mkldnn_lrn_desc_t data; + + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, memory::dim local_size, + float alpha, float beta, float k = 1.f) { + error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, local_size, alpha, beta, k), + "could not create a lrn forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + lrn_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct lrn_backward : public primitive { + struct desc { + mkldnn_lrn_desc_t data; + + desc(algorithm aalgorithm, const memory::desc &data_desc, + const memory::desc &diff_data_desc, memory::dim local_size, + float alpha, float beta, float k = 1.f) { + error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data, + convert_to_c(aalgorithm), &diff_data_desc.data, + &data_desc.data, local_size, alpha, beta, k), + "could not create a lrn backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const lrn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const lrn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + lrn_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_pooling Pooling +/// A primitive to perform max or average pooling. +/// +/// @sa @ref c_api_pooling in @ref c_api +/// @{ + +struct pooling_forward : public primitive { + struct desc { + mkldnn_pooling_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims kernel, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(kernel); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, &dst_desc.data, + &strides[0], &kernel[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not init a forward pooling descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + pooling_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct pooling_backward : public primitive { + struct desc { + mkldnn_pooling_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, + const memory::dims &strides, + const memory::dims &kernel, + const memory::dims &padding_l, + const memory::dims &padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(kernel); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data, + convert_to_c(aalgorithm), + &diff_src_desc.data, &diff_dst_desc.data, + &strides[0], &kernel[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not init a backward pooling descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const pooling_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const pooling_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + pooling_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_eltwise Eltwise +/// A primitive to compute element-wise operations like parametric rectifier +/// linear unit (ReLU). +/// +/// @sa @ref c_api_eltwise in @ref c_api +/// @{ + +struct eltwise_forward : public primitive { + struct desc { + mkldnn_eltwise_desc_t data; + template + desc(prop_kind aprop_kind, algorithm alg_kind, + const memory::desc &src_desc, T alpha = 0, T beta = 0) { + error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + mkldnn::convert_to_c(alg_kind), &src_desc.data, + static_cast(alpha), static_cast(beta)), + "could not create a eltwise forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + eltwise_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct eltwise_backward : public primitive { + struct desc { + mkldnn_eltwise_desc_t data; + + template + desc(algorithm alg_kind, const memory::desc &diff_data_desc, + const memory::desc &data_desc, T alpha = 0, T beta = 0) { + error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data, + mkldnn::convert_to_c(alg_kind), &diff_data_desc.data, + &data_desc.data, static_cast(alpha), + static_cast(beta)), + "could not create a eltwise backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const eltwise_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const eltwise_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + eltwise_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_softmax Softmax +/// A primitive to perform softmax. +/// +/// @sa @ref c_api_softmax in @ref c_api +/// @{ + +struct softmax_forward : public primitive { + struct desc { + mkldnn_softmax_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &data_desc, + int softmax_axis) { + error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &data_desc.data, + softmax_axis), + "could not create a softmax forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + softmax_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct softmax_backward : public primitive { + struct desc { + mkldnn_softmax_desc_t data; + desc(const memory::desc &diff_desc, const memory::desc &data_desc, + int softmax_axis) { + error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data, + &diff_desc.data, &data_desc.data, softmax_axis), + "could not init a backward softmax descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const softmax_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const softmax_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + softmax_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_batch_norm Batch normalization +/// A primitive to perform batch normalization. +/// +/// @sa @ref c_api_batch_normalization in @ref c_api +/// @{ + +struct batch_normalization_forward : public primitive { + struct desc { + mkldnn_batch_normalization_desc_t data; + template + desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, + unsigned flags) { + error::wrap_c_api( + mkldnn_batch_normalization_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + static_cast(epsilon), flags), + "could not create a batch normalization forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + + memory::desc mean_desc() const { return stat_desc(mean); } + memory::desc variance_desc() const { return stat_desc(var); } + + private: + enum { mean = 1, var = 2, }; + memory::desc stat_desc(int kind) const { + mkldnn_batch_normalization_desc_t *p; + error::wrap_c_api(mkldnn_primitive_desc_query( + get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p), + "could not get a batch-normalization descriptor"); + return query_md(p->flags & use_global_stats ? src_md : dst_md, kind); + } + }; + + batch_normalization_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct batch_normalization_backward : public primitive { + struct desc { + mkldnn_batch_normalization_desc_t data; + template + desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, + const memory::desc &data_desc, T epsilon, unsigned flags) { + error::wrap_c_api( + mkldnn_batch_normalization_backward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + &diff_data_desc.data, &data_desc.data, + static_cast(epsilon), flags), + "could not create a batch normalization backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const batch_normalization_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const batch_normalization_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(mean, src, 1); + REG_QUERY_MD(variance, src, 2); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + batch_normalization_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_inner_product Inner Product +/// A primitive to compute an inner product. +/// +/// @sa @ref c_api_inner_product in @ref c_api +/// @{ + +struct inner_product_forward: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + &weights_desc.data, &bias_desc.data, &dst_desc.data), + "could not create a inner product forward descriptor"); + } + + desc(prop_kind aprop_kind, const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + &weights_desc.data, nullptr, &dst_desc.data), + "could not create a inner product forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct inner_product_backward_data: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_data_desc_init(&data, + &diff_src_desc.data, &weights_desc.data, + &diff_dst_desc.data), + "could not create a inner product backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct inner_product_backward_weights: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_weights_desc_init( + &data, &src_desc.data, &diff_weights_desc.data, + &diff_bias_desc.data, &diff_dst_desc.data), + "could not create a inner product backward weights descriptor"); + } + desc(const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_weights_desc_init( + &data, &src_desc.data, &diff_weights_desc.data, + nullptr, &diff_dst_desc.data), + "could not create a inner product backward weights descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_rnn RNN +/// A primitive to compute common recurrent layer. +/// +/// @sa @ref c_api_rnn in @ref c_api +/// @{ + +struct rnn_cell { + struct desc { + mkldnn_rnn_cell_desc_t c_rnn_cell_; + + desc(algorithm kind, algorithm activation_f) { + error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_, + mkldnn::convert_to_c(kind), + mkldnn::convert_to_c(activation_f), 0U, 0, 0), + "could not init an rnn cell descriptor"); + } + desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {} + + operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; } + + algorithm get_cell_kind() const + { return algorithm(c_rnn_cell_.cell_kind); } + algorithm get_activation() const + { return algorithm(c_rnn_cell_.activation_kind); } + + float get_alpha() const { return c_rnn_cell_.alpha; } + void set_alpha(float alpha) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu; + c_rnn_cell_.alpha = alpha; + } + + float get_clipping() const { return c_rnn_cell_.clipping; } + void set_clipping(float clipping) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping; + c_rnn_cell_.clipping = clipping; + } + + int get_gates_count() const { + return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_); + } + int get_state_count() const { + return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_); + } + }; +}; + +struct rnn_forward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc + ) { + error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, &src_iter_desc.data, + &weights_layer_desc.data, &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, &dst_iter_desc.data), + "could not create an RNN forward descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src_layer, src, 0); + REG_QUERY_MD(src_iter, src, 1); + REG_QUERY_MD(weights_layer, weights, 0); + REG_QUERY_MD(weights_iter, weights, 1); + REG_QUERY_MD(bias, weights, 2); + REG_QUERY_MD(dst_layer, dst, 0); + REG_QUERY_MD(dst_iter, dst, 1); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + rnn_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct rnn_backward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc) { + error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, &src_iter_desc.data, + &weights_layer_desc.data, &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, &dst_iter_desc.data, + &diff_src_layer_desc.data, &diff_src_iter_desc.data, + &diff_weights_layer_desc.data, + &diff_weights_iter_desc.data, &diff_bias_desc.data, + &diff_dst_layer_desc.data, &diff_dst_iter_desc.data), + "could not create an RNN backward descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const rnn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const rnn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src_layer, src, 0); + REG_QUERY_MD(src_iter, src, 1); + REG_QUERY_MD(weights_layer, weights, 0); + REG_QUERY_MD(weights_iter, weights, 1); + REG_QUERY_MD(bias, weights, 2); + REG_QUERY_MD(dst_layer, dst, 0); + REG_QUERY_MD(dst_iter, dst, 1); + REG_QUERY_MD(workspace, workspace, 0); + + REG_QUERY_MD(diff_src_layer, diff_src, 0); + REG_QUERY_MD(diff_src_iter, diff_src, 1); + REG_QUERY_MD(diff_weights_layer, diff_weights, 0); + REG_QUERY_MD(diff_weights_iter, diff_weights, 1); + REG_QUERY_MD(diff_bias, diff_weights, 2); + REG_QUERY_MD(diff_dst_layer, diff_dst, 0); + REG_QUERY_MD(diff_dst_iter, diff_dst, 1); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + // With last iteration (with and without input src_iter) + rnn_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_shuffle Shuffle +/// A primitive to shuffle data along the axis. +/// +/// @sa @ref c_api_shuffle in @ref c_api +/// @{ + +struct shuffle_forward : public primitive { + struct desc { + mkldnn_shuffle_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &data_desc, + int axis, int group_size) { + error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &data_desc.data, + axis, group_size), + "could not create a shuffle forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + shuffle_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct shuffle_backward : public primitive { + struct desc { + mkldnn_shuffle_desc_t data; + desc(const memory::desc &diff_data_desc, int axis, int group_size) { + error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data, + &diff_data_desc.data, axis, group_size), + "could not create a shuffle backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const shuffle_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + shuffle_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @} Primitives + +/// @} C++ API + +#undef REG_QUERY_MD + +// implementation section +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +inline primitive::primitive(const_mkldnn_primitive_desc_t c_pd) { + mkldnn_primitive_t result; + error::wrap_c_api(mkldnn_primitive_create(&result, c_pd), + "could not create a primitive"); + reset(result); +} + +inline primitive::primitive(const primitive_desc &pd): primitive(pd.get()) {} + +inline void primitive::execute(stream &astream, + const std::unordered_map &args) const { + std::vector c_args; + c_args.reserve(args.size()); + for (const auto &a: args) + c_args.push_back({a.first, a.second.get()}); + + error::wrap_c_api(mkldnn_primitive_execute(get(), astream.get(), + (int)c_args.size(), c_args.data()), + "primitive execution fail"); +} +#endif // DOXYGEN_SHOULD_SKIP_THIS + +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h new file mode 100644 index 000000000000..f4dc2fdfa6ee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h @@ -0,0 +1,98 @@ +/******************************************************************************* +* Copyright 2018-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* DO NOT EDIT, AUTO-GENERATED */ + +#ifndef MKLDNN_DEBUG_H +#define MKLDNN_DEBUG_H + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +/* All symbols shall be internal unless marked as MKLDNN_API */ +#if defined _WIN32 || defined __CYGWIN__ +# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport) +# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport) +#else +# if __GNUC__ >= 4 +# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default"))) +# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default"))) +# else +# define MKLDNN_HELPER_DLL_IMPORT +# define MKLDNN_HELPER_DLL_EXPORT +# endif +#endif + +#ifdef MKLDNN_DLL +# ifdef MKLDNN_DLL_EXPORTS +# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT +# else +# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT +# endif +#else +# define MKLDNN_API +#endif + +#if defined (__GNUC__) +# define MKLDNN_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +# define MKLDNN_DEPRECATED __declspec(deprecated) +#else +# define MKLDNN_DEPRECATED +#endif + +#include "mkldnn_types.h" +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +#ifdef __cplusplus +extern "C" { +#endif + +const char MKLDNN_API *mkldnn_status2str(mkldnn_status_t v); +const char MKLDNN_API *mkldnn_dt2str(mkldnn_data_type_t v); +const char MKLDNN_API *mkldnn_fmt_kind2str(mkldnn_format_kind_t v); +const char MKLDNN_API *mkldnn_fmt_tag2str(mkldnn_format_tag_t v); +const char MKLDNN_API *mkldnn_prop_kind2str(mkldnn_prop_kind_t v); +const char MKLDNN_API *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v); +const char MKLDNN_API *mkldnn_alg_kind2str(mkldnn_alg_kind_t v); +const char MKLDNN_API *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v); + +/** Forms a format string for a given memory descriptor. + * + * The format is defined as: 'dt:[p|o|0]:fmt_kind:fmt:extra'. + * Here: + * - dt -- data type + * - p -- indicates there is non-trivial padding + * - o -- indicates there is non-trivial padding offset + * - 0 -- indicates there is non-trivial offset0 + * - fmt_kind -- format kind (blocked, wino, etc...) + * - fmt -- extended format string (format_kind specific) + * - extra -- shows extra fields (underspecified) + */ +int MKLDNN_API mkldnn_md2fmt_str(char *fmt_str, size_t fmt_str_len, + const mkldnn_memory_desc_t *md); + +/** Forms a dimension string for a given memory descriptor. + * + * The format is defined as: 'dim0xdim1x...xdimN + */ +int MKLDNN_API mkldnn_md2dim_str(char *dim_str, size_t dim_str_len, + const mkldnn_memory_desc_t *md); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h new file mode 100644 index 000000000000..1b6c35698205 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h @@ -0,0 +1,1415 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_TYPES_H +#define MKLDNN_TYPES_H + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +#include +#include +#endif + +/** @addtogroup c_api C API + * @{ + * + * @addtogroup c_api_types Types + * @{ + * + * @addtogroup c_api_types_generic Generic + * @{ */ + +/** Intel(R) MKL-DNN Version type */ +typedef struct { + int major; + int minor; + int patch; + const char *hash; +} mkldnn_version_t; + +/** Status values returned by Intel(R) MKL-DNN functions. */ +typedef enum { + /** The operation was successful */ + mkldnn_success = 0, + /** The operation failed due to an out-of-memory condition */ + mkldnn_out_of_memory = 1, + /** The operation failed and should be retried */ + mkldnn_try_again = 2, + /** The operation failed because of incorrect function arguments */ + mkldnn_invalid_arguments = 3, + /** The operation failed because a primitive was not ready for execution */ + mkldnn_not_ready = 4, + /** The operation failed because requested functionality is not implemented + */ + mkldnn_unimplemented = 5, + /** Primitive iterator passed over last primitive descriptor */ + mkldnn_iterator_ends = 6, + /** Primitive or engine failed on execution */ + mkldnn_runtime_error = 7, + /** Queried element is not required for given primitive */ + mkldnn_not_required = 8, +} mkldnn_status_t; + +/** Data type specification */ +typedef enum { + /** Undefined data type, used for empty memory descriptors. */ + mkldnn_data_type_undef = 0, + /** 32-bit/single-precision floating point. */ + mkldnn_f32 = 1, + /** 32-bit signed integer. */ + mkldnn_s32 = 2, + /** 8-bit signed integer. */ + mkldnn_s8 = 3, + /** 8-bit unsigned integer. */ + mkldnn_u8 = 4, +} mkldnn_data_type_t; + +/** Memory format kind */ +typedef enum { + /** Undefined memory format, used for empty memory descriptors. */ + mkldnn_format_kind_undef = 0, + /** Unspecified format. The primitive selects a format automatically. */ + mkldnn_format_kind_any, + /** A tensor in a generic format described by the stride and blocking + * values in each dimension. See #mkldnn_blocking_desc_t for more + * information. */ + mkldnn_blocked, + /** Weights format used in 8bit Winograd convolution */ + mkldnn_format_kind_wino, + /** Packed weights format used in RNN */ + mkldnn_format_kind_rnn_packed, +} mkldnn_format_kind_t; + +/** Memory format tag specification. + * + * Intel MKL-DNN formats describe physical data layout. The physical layout + * is described as a sequence of the dimensions as they are laid out in the + * memory (from the outer-most to the inner-most). Note that this order + * doesn't affect the logical order of the dimensions that is kept in the + * `dims` field of the mkldnn_memory_desc_t structure. The logical order of the + * dimensions is specified by the type of tensor. + * + * For example, CNN 5D tensor always has its logical dimensions in the order + * `(batch, channels, depth, height, width)`, while the physical layout might be + * #mkldnn_ncdhw or #mkldnn_ndhwc: + * + * ~~~cpp + * int batch = 2, channels = 16, depth = 13, height = 13, width = 13; + * + * int ndims = 5; // 5D tensor + * mkldnn_dims_t dims = {batch, channels, depth, height, width}; + * mkldnn_memory_desc_t data_in_ncdhw; + * mkldnn_memory_desc_init_by_tag( + * &data_in_ncdhw, 5, dims, mkldnn_f32, mkldnn_ncdhw); + * + * // note that in both cases dims passed are the same + * mkldnn_memory_desc_t data_in_ndhwc; + * mkldnn_memory_desc_init_by_tag( + * &data_in_ndhwc, 5, dims, mkldnn_f32, mkldnn_ndhwc); + * ~~~ + * + * The following notation applies to memory format names: + * - @c 'n' denotes the mini-batch dimension + * - @c 'c' denotes a channels dimension + * - When there are multiple channel dimensions (for example, in convolution + * weights tensor), @c 'i' and @c 'o' denote dimensions of input and output + * channels + * - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width + * respectively + * - Upper-case letters indicate that the data is laid out in blocks + * for a particular dimension. In such cases, the format name contains both + * upper- and lower-case letters for that dimension with a lower-case letter + * preceded by the block size. For example: @c 'mkldnn_nChw8c' describes a + * format where the outermost dimension is mini-batch, followed by the + * channel block number, followed by the spatial height and width, and + * finally followed by 8-element channel blocks. + * + * @note + * Channel designations can be different. For example, both the @c + * 'mkldnn_nc' and @c 'mkldnn_io' formats can be used to describe a 2D + * tensor. + * + * @sa @ref understanding_memory_formats + */ +typedef enum { + /** Undefined memory format tag */ + mkldnn_format_tag_undef = 0, + /** Undefined memory format tag. + * The primitive selects a format automatically. */ + mkldnn_format_tag_any, + + /* Semantic agnostic section */ + /* The physical order of dimensions is defined by the permutation of the + * characters, assuming that ab..z defines the natural order. + */ + + /* Plain formats */ + + mkldnn_a, + mkldnn_ab, + mkldnn_abc, + mkldnn_abcd, + mkldnn_abcde, + mkldnn_abcdef, + mkldnn_abdec, + mkldnn_acb, + mkldnn_acbde, + mkldnn_acdb, + mkldnn_acdeb, + mkldnn_ba, + mkldnn_bac, + mkldnn_bacd, + mkldnn_bcda, + mkldnn_cba, + mkldnn_cdba, + mkldnn_cdeba, + mkldnn_decab, + + /* Opaque blocked formats */ + + mkldnn_Abc16a, + mkldnn_ABc16a16b, + mkldnn_aBc16b, + mkldnn_ABc16b16a, + mkldnn_Abc4a, + mkldnn_aBc4b, + mkldnn_ABc4b16a4b, + mkldnn_ABc4b4a, + mkldnn_ABc8a16b2a, + mkldnn_ABc8a8b, + mkldnn_aBc8b, + mkldnn_ABc8b16a2b, + mkldnn_ABc8b8a, + mkldnn_Abcd16a, + mkldnn_ABcd16a16b, + mkldnn_aBcd16b, + mkldnn_ABcd16b16a, + mkldnn_aBCd16b16c, + mkldnn_aBCd16c16b, + mkldnn_Abcd4a, + mkldnn_aBcd4b, + mkldnn_ABcd4b16a4b, + mkldnn_ABcd4b4a, + mkldnn_aBCd4c16b4c, + mkldnn_aBCd4c4b, + mkldnn_ABcd8a16b2a, + mkldnn_ABcd8a8b, + mkldnn_aBcd8b, + mkldnn_ABcd8b16a2b, + mkldnn_aBCd8b16c2b, + mkldnn_ABcd8b8a, + mkldnn_aBCd8b8c, + mkldnn_aBCd8c16b2c, + mkldnn_aBCd8c8b, + mkldnn_Abcde16a, + mkldnn_ABcde16a16b, + mkldnn_aBcde16b, + mkldnn_ABcde16b16a, + mkldnn_aBCde16b16c, + mkldnn_aBCde16c16b, + mkldnn_aBCde2c8b4c, + mkldnn_Abcde4a, + mkldnn_aBcde4b, + mkldnn_ABcde4b4a, + mkldnn_aBCde4b4c, + mkldnn_aBCde4c16b4c, + mkldnn_aBCde4c4b, + mkldnn_Abcde8a, + mkldnn_ABcde8a8b, + mkldnn_aBcde8b, + mkldnn_ABcde8b16a2b, + mkldnn_aBCde8b16c2b, + mkldnn_ABcde8b8a, + mkldnn_aBCde8b8c, + mkldnn_aBCde8c16b2c, + mkldnn_aBCde8c8b, + mkldnn_aBcdef16b, + mkldnn_aBCdef16b16c, + mkldnn_aBCdef16c16b, + mkldnn_aBcdef4b, + mkldnn_aBCdef4c4b, + mkldnn_aBCdef8b8c, + mkldnn_aBCdef8c16b2c, + mkldnn_aBCdef8c8b, + mkldnn_aBdc16b, + mkldnn_aBdc4b, + mkldnn_aBdc8b, + mkldnn_aBdec16b, + mkldnn_aBdec4b, + mkldnn_aBdec8b, + mkldnn_aBdefc16b, + mkldnn_aBdefc4b, + mkldnn_aBdefc8b, + mkldnn_Acb16a, + mkldnn_Acb4a, + mkldnn_Acb8a, + mkldnn_aCBd16b16c, + mkldnn_aCBde16b16c, + mkldnn_Acdb16a, + mkldnn_Acdb4a, + mkldnn_Acdb8a, + mkldnn_Acdeb16a, + mkldnn_Acdeb4a, + mkldnn_Acdeb8a, + mkldnn_BAc16a16b, + mkldnn_BAcd16a16b, + + /** Just a sentinel, not real memory format tag. Must be changed after new + * format tag is added. */ + mkldnn_format_tag_last, + + /* Aliases */ + + mkldnn_x = mkldnn_a, + mkldnn_nc = mkldnn_ab, + mkldnn_cn = mkldnn_ba, + mkldnn_ncw = mkldnn_abc, + mkldnn_nwc = mkldnn_acb, + mkldnn_nchw = mkldnn_abcd, + mkldnn_nhwc = mkldnn_acdb, + mkldnn_chwn = mkldnn_bcda, + mkldnn_ncdhw = mkldnn_abcde, + mkldnn_ndhwc = mkldnn_acdeb, + + mkldnn_oi = mkldnn_ab, + mkldnn_io = mkldnn_ba, + mkldnn_oiw = mkldnn_abc, + mkldnn_wio = mkldnn_cba, + mkldnn_oihw = mkldnn_abcd, + mkldnn_hwio = mkldnn_cdba, + mkldnn_ihwo = mkldnn_bcda, + mkldnn_iohw = mkldnn_bacd, + mkldnn_oidhw = mkldnn_abcde, + mkldnn_dhwio = mkldnn_cdeba, + mkldnn_goiw = mkldnn_abcd, + mkldnn_goihw = mkldnn_abcde, + mkldnn_hwigo = mkldnn_decab, + mkldnn_giohw = mkldnn_acbde, + mkldnn_goidhw = mkldnn_abcdef, + + /** 3D RNN data tensor in the format (seq_length, batch, input channels). */ + mkldnn_tnc = mkldnn_abc, + /** 3D RNN data tensor in the format (batch, seq_length, input channels). */ + mkldnn_ntc = mkldnn_bac, + /** 5D RNN states tensor in the format (num_layers, num_directions, + * num_states, batch, state channels). */ + mkldnn_ldsnc = mkldnn_abcde, + /** 5D RNN weights tensor in the format (num_layers, num_directions, + * input_channels, num_gates, output_channels). + * + * - For LSTM cells, the gates order is input, forget, candidate + * and output gate. + * - For GRU cells, the gates order is update, reset and output gate. */ + mkldnn_ldigo = mkldnn_abcde, + /** 5D RNN weights tensor in the format (num_layers, num_directions, + * num_gates, output_channels, input_channels). + * + * - For LSTM cells, the gates order is input, forget, candidate + * and output gate. + * - For GRU cells, the gates order is update, reset and output gate. */ + mkldnn_ldgoi = mkldnn_abdec, + /** 4D RNN bias tensor in the format (num_layers, num_directions, + * num_gates, output_channels). + * + * - For LSTM cells, the gates order is input, forget, candidate + * and output gate. + * - For GRU cells, the gates order is update, reset and output gate. */ + mkldnn_ldgo = mkldnn_abcd, + + /* Opaque data types, are not to be used explicitly */ + + /* data */ + mkldnn_nCdhw16c = mkldnn_aBcde16b, + mkldnn_nCdhw4c = mkldnn_aBcde4b, + mkldnn_nCdhw8c = mkldnn_aBcde8b, + mkldnn_nChw16c = mkldnn_aBcd16b, + mkldnn_nChw4c = mkldnn_aBcd4b, + mkldnn_nChw8c = mkldnn_aBcd8b, + mkldnn_nCw16c = mkldnn_aBc16b, + mkldnn_nCw4c = mkldnn_aBc4b, + mkldnn_nCw8c = mkldnn_aBc8b, + + /* weights, 3D */ + mkldnn_IOw16o16i = mkldnn_BAc16a16b, + mkldnn_OIw16i16o = mkldnn_ABc16b16a, + mkldnn_OIw16o16i = mkldnn_ABc16a16b, + mkldnn_Oiw16o = mkldnn_Abc16a, + mkldnn_OIw4i16o4i = mkldnn_ABc4b16a4b, + mkldnn_OIw4i4o = mkldnn_ABc4b4a, + mkldnn_Oiw4o = mkldnn_Abc4a, + mkldnn_OIw8i16o2i = mkldnn_ABc8b16a2b, + mkldnn_OIw8i8o = mkldnn_ABc8b8a, + mkldnn_OIw8o16i2o = mkldnn_ABc8a16b2a, + mkldnn_OIw8o8i = mkldnn_ABc8a8b, + mkldnn_Owi16o = mkldnn_Acb16a, + mkldnn_Owi4o = mkldnn_Acb4a, + mkldnn_Owi8o = mkldnn_Acb8a, + + /* weights, 4D */ + mkldnn_IOhw16o16i = mkldnn_BAcd16a16b, + mkldnn_Ohwi16o = mkldnn_Acdb16a, + mkldnn_Ohwi4o = mkldnn_Acdb4a, + mkldnn_Ohwi8o = mkldnn_Acdb8a, + mkldnn_OIhw16i16o = mkldnn_ABcd16b16a, + mkldnn_OIhw16o16i = mkldnn_ABcd16a16b, + mkldnn_Oihw16o = mkldnn_Abcd16a, + mkldnn_OIhw4i16o4i = mkldnn_ABcd4b16a4b, + mkldnn_OIhw4i4o = mkldnn_ABcd4b4a, + mkldnn_Oihw4o = mkldnn_Abcd4a, + mkldnn_OIhw8i16o2i = mkldnn_ABcd8b16a2b, + mkldnn_OIhw8i8o = mkldnn_ABcd8b8a, + mkldnn_OIhw8o16i2o = mkldnn_ABcd8a16b2a, + mkldnn_OIhw8o8i = mkldnn_ABcd8a8b, + + /* weights, 5D */ + mkldnn_Odhwi16o = mkldnn_Acdeb16a, + mkldnn_Odhwi4o = mkldnn_Acdeb4a, + mkldnn_Odhwi8o = mkldnn_Acdeb8a, + mkldnn_OIdhw16i16o = mkldnn_ABcde16b16a, + mkldnn_OIdhw16o16i = mkldnn_ABcde16a16b, + mkldnn_Oidhw16o = mkldnn_Abcde16a, + mkldnn_OIdhw4i4o = mkldnn_ABcde4b4a, + mkldnn_Oidhw4o = mkldnn_Abcde4a, + mkldnn_OIdhw8i16o2i = mkldnn_ABcde8b16a2b, + mkldnn_OIdhw8i8o = mkldnn_ABcde8b8a, + mkldnn_OIdhw8o8i = mkldnn_ABcde8a8b, + + /* weights w/ groups, 3D */ + mkldnn_Goiw16g = mkldnn_Abcd16a, + mkldnn_gIOw16o16i = mkldnn_aCBd16b16c, + mkldnn_gOIw16i16o = mkldnn_aBCd16c16b, + mkldnn_gOIw16o16i = mkldnn_aBCd16b16c, + mkldnn_gOiw16o = mkldnn_aBcd16b, + mkldnn_gOIw4i16o4i = mkldnn_aBCd4c16b4c, + mkldnn_gOIw4i4o = mkldnn_aBCd4c4b, + mkldnn_gOiw4o = mkldnn_aBcd4b, + mkldnn_gOIw8i16o2i = mkldnn_aBCd8c16b2c, + mkldnn_gOIw8i8o = mkldnn_aBCd8c8b, + mkldnn_gOIw8o16i2o = mkldnn_aBCd8b16c2b, + mkldnn_gOIw8o8i = mkldnn_aBCd8b8c, + mkldnn_gOwi16o = mkldnn_aBdc16b, + mkldnn_gOwi4o = mkldnn_aBdc4b, + mkldnn_gOwi8o = mkldnn_aBdc8b, + + /* weights w/ groups, 4D */ + mkldnn_gIOhw16o16i = mkldnn_aCBde16b16c, + mkldnn_gOhwi16o = mkldnn_aBdec16b, + mkldnn_gOhwi4o = mkldnn_aBdec4b, + mkldnn_gOhwi8o = mkldnn_aBdec8b, + mkldnn_Goihw16g = mkldnn_Abcde16a, + mkldnn_gOIhw16i16o = mkldnn_aBCde16c16b, + mkldnn_gOIhw16o16i = mkldnn_aBCde16b16c, + mkldnn_gOihw16o = mkldnn_aBcde16b, + mkldnn_gOIhw2i8o4i = mkldnn_aBCde2c8b4c, + mkldnn_gOIhw4i16o4i = mkldnn_aBCde4c16b4c, + mkldnn_gOIhw4i4o = mkldnn_aBCde4c4b, + mkldnn_gOIhw4o4i = mkldnn_aBCde4b4c, + mkldnn_gOihw4o = mkldnn_aBcde4b, + mkldnn_Goihw8g = mkldnn_Abcde8a, + mkldnn_gOIhw8i16o2i = mkldnn_aBCde8c16b2c, + mkldnn_gOIhw8i8o = mkldnn_aBCde8c8b, + mkldnn_gOIhw8o16i2o = mkldnn_aBCde8b16c2b, + mkldnn_gOIhw8o8i = mkldnn_aBCde8b8c, + + /* weights w/ groups, 6D */ + mkldnn_gOdhwi16o = mkldnn_aBdefc16b, + mkldnn_gOdhwi4o = mkldnn_aBdefc4b, + mkldnn_gOdhwi8o = mkldnn_aBdefc8b, + mkldnn_gOIdhw16i16o = mkldnn_aBCdef16c16b, + mkldnn_gOIdhw16o16i = mkldnn_aBCdef16b16c, + mkldnn_gOidhw16o = mkldnn_aBcdef16b, + mkldnn_gOIdhw4i4o = mkldnn_aBCdef4c4b, + mkldnn_gOidhw4o = mkldnn_aBcdef4b, + mkldnn_gOIdhw8i16o2i = mkldnn_aBCdef8c16b2c, + mkldnn_gOIdhw8i8o = mkldnn_aBCdef8c8b, + mkldnn_gOIdhw8o8i = mkldnn_aBCdef8b8c, +} mkldnn_format_tag_t; + +/** Kinds of padding. Define how to interpret the data in padding regions. */ +typedef enum { + /** The data in padding regions is zero. */ + mkldnn_padding_zero, +} mkldnn_padding_kind_t; + +/** Kinds of propagation. */ +typedef enum { + /* TODO: suggest renames */ + /** Undefined propagation type. */ + mkldnn_prop_kind_undef = 0, + /** Forward data propagation (training mode). In this mode primitives + * perform computations necessary for subsequent backward propagation. */ + mkldnn_forward_training = 64, + /** Forward data propagation (inference mode). In this mode primitives + * perform only computations that are necessary for inference and omit + * computations that are necessary only for backward propagation. */ + mkldnn_forward_inference = 96, + /** Forward data propagation (alias for @c mkldnn_forward_inference) */ + mkldnn_forward_scoring = mkldnn_forward_inference, + /** Forward data propagation (alias for @c mkldnn_forward_training) */ + mkldnn_forward = mkldnn_forward_training, + /** Backward propagation (with respect to all parameters */ + mkldnn_backward = 128, + /** Backward data propagation */ + mkldnn_backward_data = 160, + /** Backward weights propagation */ + mkldnn_backward_weights = 192, + /** Backward bias propagation */ + mkldnn_backward_bias = 193, +} mkldnn_prop_kind_t; + +/** Kinds of primitives. Used to implement a way to extend the library with new + * primitives without changing the ABI. */ +typedef enum { + /** Undefined primitive (XXX: why do we have it?). */ + mkldnn_undefined_primitive, + /** A reorder primitive.*/ + mkldnn_reorder, + /** A shuffle primitive.*/ + mkldnn_shuffle, + /** A (out-of-place) concat primitive. */ + mkldnn_concat, + /** A sum primitive. */ + mkldnn_sum, + /** A convolution primitive. */ + mkldnn_convolution, + /** A deconvolution primitive. */ + mkldnn_deconvolution, + /** An element-wise primitive. */ + mkldnn_eltwise, + /** A Softmax primitive. */ + mkldnn_softmax, + /** A pooling primitive. */ + mkldnn_pooling, + /** An LRN primitive. */ + mkldnn_lrn, + /** An batch normalization primitive. */ + mkldnn_batch_normalization, + /** An inner product primitive. */ + mkldnn_inner_product, + /** A rnn primitive. */ + mkldnn_rnn, +} mkldnn_primitive_kind_t; + +/** Kinds of algorithms. */ +typedef enum { + mkldnn_alg_kind_undef, + /** Direct convolution */ + mkldnn_convolution_direct = 0x1, + /** Winograd convolution */ + mkldnn_convolution_winograd = 0x2, + /** Convolution algorithm(either direct or Winograd) is chosen just in time **/ + mkldnn_convolution_auto = 0x3, + /** Direct deconvolution */ + mkldnn_deconvolution_direct = 0xa, + /** Winograd deconvolution */ + mkldnn_deconvolution_winograd = 0xb, + /** Eltwise: ReLU */ + mkldnn_eltwise_relu = 0x1f, + /** Eltwise: hyperbolic tangent non-linearity (tanh) */ + mkldnn_eltwise_tanh = 0x2f, + /** Eltwise: parametric exponential linear unit (elu) */ + mkldnn_eltwise_elu = 0x3f, + /** Eltwise: square */ + mkldnn_eltwise_square = 0x4f, + /** Eltwise: abs */ + mkldnn_eltwise_abs = 0x5f, + /** Eltwise: square root */ + mkldnn_eltwise_sqrt = 0x6f, + /** Eltwise: linear */ + mkldnn_eltwise_linear = 0x7f, + /** Eltwise: bounded_relu */ + mkldnn_eltwise_bounded_relu = 0x8f, + /** Eltwise: soft_relu */ + mkldnn_eltwise_soft_relu = 0x9f, + /** Eltwise: logistic */ + mkldnn_eltwise_logistic = 0xaf, + /** Max pooling */ + mkldnn_pooling_max = 0x1ff, + /** Average pooling include padding */ + mkldnn_pooling_avg_include_padding = 0x2ff, + /** Average pooling exclude padding */ + mkldnn_pooling_avg_exclude_padding = 0x3ff, + mkldnn_pooling_avg = mkldnn_pooling_avg_exclude_padding, + /** Local response normalization (LRN) across multiple channels */ + mkldnn_lrn_across_channels = 0xaff, + /** LRN within a single channel */ + mkldnn_lrn_within_channel = 0xbff, + /** RNN cell */ + mkldnn_vanilla_rnn = 0x1fff, + /** LSTM cell */ + mkldnn_vanilla_lstm = 0x2fff, + /** GRU cell */ + mkldnn_vanilla_gru = 0x3fff, + /** GRU cell with linear before reset + * + * Modification of original GRU cell. Differs from #mkldnn_vanilla_gru + * in how the new memory gate is calculated: + * \f[ c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f] + * Primitive expects 4 biases on input: + * \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$ + * */ + mkldnn_gru_linear_before_reset = 0x4fff, +} mkldnn_alg_kind_t; + +/** Flags for batch-normalization primititve. */ +typedef enum { + /** Use global statistics + * + * If specified + * - on forward propagation use mean and variance provided by user (input) + * - on backward propagation reduces the amount of computations, since + * mean and variance are considered as constants + * + * If not specified: + * - on forward propagation mean and variance are computed and stored in + * output + * - on backward propagation compute full derivative wrt to data + */ + mkldnn_use_global_stats = 0x1U, + /** Use scale and shift parameters + * + * If specified: + * - on forward propagation use scale and shift (aka scale and bias) for + * the batch normalization results + * - on backward propagation (for prop_kind == #mkldnn_backward) compute + * diff wrt to scale and shift (hence one extra output used) + * + * If no specified: + * - on backward propagation prop_kind == #mkldnn_backward_data has the + * same behavior as prop_kind == #mkldnn_backward + */ + mkldnn_use_scaleshift = 0x2U, + /** Fuse with ReLU + * + * If specified: + * - on inference this option behaves the same as if the primitive were + * fused with ReLU via post ops API + * - on training primitive requires workspace (required to be able to + * perform backward pass) + */ + mkldnn_fuse_bn_relu = 0x4U, +} mkldnn_batch_normalization_flag_t; + +/** @} */ + +/** @addtogroup c_api_types_memory Memory + * @{ */ + +/** Maximum number of dimensions a tensor can have. Only restricts the amount + * of space used for the tensor description. Individual computational + * primitives may support only tensors of certain dimensions. */ +#define MKLDNN_MAX_NDIMS 12 + +/** A type to describe tensor dimension. */ +typedef int64_t mkldnn_dim_t; + +/** A type to describe tensor dimensions. */ +typedef mkldnn_dim_t mkldnn_dims_t[MKLDNN_MAX_NDIMS]; + +/** A type to describe strides within a tensor. */ +typedef mkldnn_dim_t mkldnn_strides_t[MKLDNN_MAX_NDIMS]; + +/** Generic description of blocked data layout for most memory formats. + * + * @sa @ref understanding_memory_formats */ +typedef struct { + /** The strides between the outermost blocks. + * In case of plain (non-blocked) formats the strides between dimensions. */ + mkldnn_dims_t strides; + /* Innermost section + * ASSUMPTION: the innermost blocks are always dense */ + /** The number of innermost blocks, e.g. 3 in case of `OIhw_4i16o4i_` */ + int inner_nblks; + /** The size of the blocks, e.g. `{4, 16, 4}` in case of `OIhw_4i16o4i` */ + mkldnn_dims_t inner_blks; + /** The logical indices of the blocks, e.g. `{1, 0, 1}` in case of + * `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim */ + mkldnn_dims_t inner_idxs; +} mkldnn_blocking_desc_t; + +typedef enum { + /** Undefined memory format, used for empty memory descriptors. */ + mkldnn_wino_undef = 0, + /** Tensors of weights for 2x3 winograd convolutions. */ + mkldnn_wino_wei_aaOIoi, + mkldnn_wino_wei_aaOio, + mkldnn_wino_wei_aaOBiOo, + /** Tensor of weights for 4x3 convolution. */ + mkldnn_wino_wei_OBaaIBOIio +} mkldnn_wino_memory_format_t; + +/** Description of tensor of weights for winograd 2x3 convolution. */ +typedef struct { + mkldnn_wino_memory_format_t wino_format; + int r; + int alpha; + int ic; + int oc; + int ic_block; + int oc_block; + int ic2_block; + int oc2_block; + float adj_scale; + size_t size; +} mkldnn_wino_desc_t; + +typedef enum { + mkldnn_packed_format_undef = 0, + mkldnn_ldigo_p, + mkldnn_ldgoi_p +} mkldnn_rnn_packed_memory_format_t; + +/* Maximum number of parts of RNN weights tensor that require separate + * computation. */ +#define MKLDNN_RNN_MAX_N_PARTS 4 + +/** Description of tensor of packed weights for rnn. */ +typedef struct { + mkldnn_rnn_packed_memory_format_t format; + int n_parts; + int n; + int parts[MKLDNN_RNN_MAX_N_PARTS]; + size_t part_pack_size[MKLDNN_RNN_MAX_N_PARTS]; + size_t offset_compensation; + size_t size; +} mkldnn_rnn_packed_desc_t; + +typedef enum { + mkldnn_memory_extra_flag_none = 0x0U, + /** Indicates the weights have an additional buffer, that depends on the + * @p compensation_mask. + * + * For instance, in 4D case with the compensation mask equals (1 << 0) + * the additional buffer would consist of OC values: + * O[oc : 0,OC] = + * -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) } + */ + mkldnn_memory_extra_flag_compensation_conv_s8s8 = 0x1U, + mkldnn_memory_extra_flag_scale_adjust = 0x2U, +} mkldnn_memory_extra_flags_t; + +/** Description of extra information stored in memory */ +typedef struct { + /** The flags contain arbitrary extra information, such as compensation. + * @sa mkldnn_memory_extra_flags_t */ + uint64_t flags; + /** Compensation mask */ + int compensation_mask; + /** Scale applied to the data */ + float scale_adjust; + /** For future backwards compatibility */ + char reserved[64]; +} mkldnn_memory_extra_desc_t; + +/** Memory descriptor. The description is based on a number of dimensions, + * dimensions themselves, plus information about elements type and memory + * format. Additionally, contains format-specific descriptions of the data + * layout. */ +typedef struct { + /** Number of dimensions */ + int ndims; + /** Dimensions in the following order: + * - CNN data tensors: mini-batch, channel, spatial + * ({N, C, [[D,] H,] W}) + * - CNN weight tensors: group (optional), output channel, input channel, + * spatial ({[G,] O, I, [[D,] H,] W}) + * - RNN data tensors: time, mini-batch, channels ({T, N, C}) + * or layers, directions, states, mini-batch, channels ({L, D, S, N, C}) + * - RNN weight tensor: layers, directions, input channel, gates, output channels + * ({L, D, I, G, O}). + * + * @note + * The order of dimensions does not depend on the memory format, so + * whether the data is laid out in #mkldnn_nchw or #mkldnn_nhwc + * the dims for 4D CN data tensor would be {N, C, H, W}. + */ + mkldnn_dims_t dims; + /** Data type of the tensor elements. */ + mkldnn_data_type_t data_type; + + /** Size of the data including padding in each dimension. */ + mkldnn_dims_t padded_dims; + /** Per-dimension offset from the padding to actual data, the top-level + * tensor with offsets applied must lie within the padding area. */ + mkldnn_dims_t padded_offsets; + + /** Offset from memory origin to the current block, non-zero only in + * a description of a memory sub-block. */ + mkldnn_dim_t offset0; + + /** Memory format kind. */ + mkldnn_format_kind_t format_kind; + union { + /** Description of the data layout for memory formats that use + * blocking. */ + mkldnn_blocking_desc_t blocking; + /** Tensor of weights for integer 8bit winograd convolution. */ + mkldnn_wino_desc_t wino_desc; + /** Tensor of packed weights for RNN. */ + mkldnn_rnn_packed_desc_t rnn_packed_desc; + /* ... other descriptions possible */ + } format_desc; + + mkldnn_memory_extra_desc_t extra; +} mkldnn_memory_desc_t; + +/** @struct mkldnn_memory + * An opaque structure to describe a memory. */ +struct mkldnn_memory; + +/** A memory handle. */ +typedef struct mkldnn_memory *mkldnn_memory_t; + +/** A constant memory handle. */ +typedef const struct mkldnn_memory *const_mkldnn_memory_t; + +#define MKLDNN_NATIVE_HANDLE_NONE (NULL) +#define MKLDNN_NATIVE_HANDLE_ALLOCATE ((void *)(size_t)-1) + +/** @} */ + +/** @addtogroup c_api_types_op_descs Operation descriptors + * @{*/ + +/** A pointer to any of the operation descriptors. */ +typedef void *mkldnn_op_desc_t; +/** A pointer to any of the operation descriptors (constant variant). */ +typedef const void *const_mkldnn_op_desc_t; + +/** A descriptor of a convolution operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_convolution. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward_data, + * #mkldnn_backward_weights, and #mkldnn_backward_bias. */ + mkldnn_prop_kind_t prop_kind; + /** The kind of the convolution algorithm. Possible values: + * #mkldnn_convolution_direct. */ + mkldnn_alg_kind_t alg_kind; + /** Source memory descriptor. */ + mkldnn_memory_desc_t src_desc; + /** Source gradient memory descriptor. */ + mkldnn_memory_desc_t diff_src_desc; + /** Weights memory descriptor. */ + mkldnn_memory_desc_t weights_desc; + /** Weights gradient memory descriptor. */ + mkldnn_memory_desc_t diff_weights_desc; + /** Bias memory descriptor. */ + mkldnn_memory_desc_t bias_desc; + /** Bias gradient memory descriptor. */ + mkldnn_memory_desc_t diff_bias_desc; + /** Destination memory descriptor. */ + mkldnn_memory_desc_t dst_desc; + /** Destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_dst_desc; + /** Convolution strides in each spatial dimension. */ + mkldnn_dims_t strides; + /** Convolution dilates in each spatial dimension. */ + mkldnn_dims_t dilates; + /** Padding in each spatial dimension. padding[0] is a padding in the + * beginning (@p padding_l), padding[1] is a padding in the end (@p + * padding_r). */ + mkldnn_dims_t padding[2]; + /** The kind of padding to use. */ + mkldnn_padding_kind_t padding_kind; + /** The accumulator data type. Initialized automatically. */ + mkldnn_data_type_t accum_data_type; +} mkldnn_convolution_desc_t; + +/** A descriptor of a deconvolution operation. */ +typedef mkldnn_convolution_desc_t mkldnn_deconvolution_desc_t; + +/** A descriptor of a shuffle operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_convolution. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, and #mkldnn_backward_data. */ + mkldnn_prop_kind_t prop_kind; + /** Source and destination memory descriptor, + * and source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** axis for shuffling. */ + int axis; + /** number of groups in group convolution */ + mkldnn_dim_t group_size; +} mkldnn_shuffle_desc_t; + +/** A descriptor of a element-wise operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_eltwise. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** The kind of eltwise algorithm. Possible values: #mkldnn_eltwise_relu, + * #mkldnn_eltwise_tanh, #mkldnn_eltwise_elu, #mkldnn_eltwise_square, + * #mkldnn_eltwise_abs, #mkldnn_eltwise_sqrt, #mkldnn_eltwise_linear, + * #mkldnn_eltwise_bounded_relu, #mkldnn_eltwise_soft_relu, and + * #mkldnn_eltwise_logistic. */ + mkldnn_alg_kind_t alg_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_data_desc; + /** Algorithm specific parameter. + * Accordance table: + * - #mkldnn_eltwise_relu: @p alpha -- negative slope, @p beta ignored + * - #mkldnn_eltwise_tanh: @p alpha and @p beta ignored + * - #mkldnn_eltwise_elu: @p alpha -- negative slope, @p beta ignored + * - #mkldnn_eltwise_square: @p alpha and @p beta ignored + * - #mkldnn_eltwise_abs: @p alpha and @p beta ignored + * - #mkldnn_eltwise_sqrt: @p alpha and @p beta ignored + * - #mkldnn_eltwise_linear: @p alpha -- scale, @p beta -- shift + * - #mkldnn_eltwise_bounded_relu: @p alpha -- upper bound, @p beta ignored + * - #mkldnn_eltwise_soft_relu: @p alpha and @p beta ignored + * - #mkldnn_eltwise_logistic: @p alpha and @p beta ignored + */ + float alpha, beta; +} mkldnn_eltwise_desc_t; + +/** A descriptor of a Softmax operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_softmax. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training and + * #mkldnn_forward_inference. */ + mkldnn_prop_kind_t prop_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and Destination of gradient memory descriptor. */ + mkldnn_memory_desc_t diff_desc; + /** The axis along which to perform the softmax. */ + int softmax_axis; +} mkldnn_softmax_desc_t; + +/** A descriptor of a pooling operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_pooling. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** The kind of pooling algorithm. Possible values: #mkldnn_pooling_max and + * #mkldnn_pooling_avg. */ + mkldnn_alg_kind_t alg_kind; + /** Source memory descriptor. */ + mkldnn_memory_desc_t src_desc; + /** Source gradient memory descriptor. */ + mkldnn_memory_desc_t diff_src_desc; + /** Destination memory descriptor. */ + mkldnn_memory_desc_t dst_desc; + /** Destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_dst_desc; + /** Pooling kernel strides for spatial dimensions. */ + mkldnn_dims_t strides; + /** Pooling kernel spatial dimensions. */ + mkldnn_dims_t kernel; + /** Padding in each spatial dimension. padding[0] is a padding in the + * beginning (@p padding_l), padding[1] is a padding in the end (@p + * padding_r). */ + mkldnn_dims_t padding[2]; + /** The kind of padding to use. */ + mkldnn_padding_kind_t padding_kind; + /** The accumulator data type. Initialized automatically. */ + mkldnn_data_type_t accum_data_type; +} mkldnn_pooling_desc_t; + +/** A descriptor of a Local Response Normalization (LRN) operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_lrn. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** LRN algorithm. Possible values: #mkldnn_lrn_within_channel and + * #mkldnn_lrn_across_channels. */ + mkldnn_alg_kind_t alg_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_data_desc; + /** The number of channels to sum over (for cross-channel LRN) or the side + * length of the square region to sum over (for within-channel LRN). */ + mkldnn_dim_t local_size; + /** LRN alpha parameter. */ + float lrn_alpha; + /** LRN beta parameter. */ + float lrn_beta; + /** LRN k parameter. */ + float lrn_k; +} mkldnn_lrn_desc_t; + +/** A descriptor of a Batch Normalization operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_batch_normalization. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_data_desc; + /** Scale and shift data and gradient memory descriptors. + * + * Scaleshift memory descriptor uses 2D #mkldnn_nc format[2,Channels]. 1-st + * dimension contains gamma parameter, 2-nd dimension contains beta + * parameter. */ + mkldnn_memory_desc_t data_scaleshift_desc; + mkldnn_memory_desc_t diff_data_scaleshift_desc; + /** Mean and variance data memory descriptors. + * + * Mean and variance memory descriptors use 1D #mkldnn_x format[Channels]. + */ + mkldnn_memory_desc_t mean_desc; + mkldnn_memory_desc_t variance_desc; + /** Batch normalization epsilon parameter. */ + float batch_norm_epsilon; + unsigned flags; +} mkldnn_batch_normalization_desc_t; + +/** A descriptor of an inner product operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_inner_product. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward_data, + * #mkldnn_backward_weights, and #mkldnn_backward_bias. */ + mkldnn_prop_kind_t prop_kind; + /** Source memory descriptor. */ + mkldnn_memory_desc_t src_desc; + /** Source gradient memory descriptor. */ + mkldnn_memory_desc_t diff_src_desc; + /** Weights memory descriptor. */ + mkldnn_memory_desc_t weights_desc; + /** Weights gradient memory descriptor. */ + mkldnn_memory_desc_t diff_weights_desc; + /** Bias memory descriptor. */ + mkldnn_memory_desc_t bias_desc; + /** Bias gradient memory descriptor. */ + mkldnn_memory_desc_t diff_bias_desc; + /** Destination memory descriptor. */ + mkldnn_memory_desc_t dst_desc; + /** Destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_dst_desc; + /** The accumulator data type. Initialized automatically. */ + mkldnn_data_type_t accum_data_type; +} mkldnn_inner_product_desc_t; + +/** Flags for RNN cell. */ +typedef enum { + mkldnn_rnn_cell_with_relu = 0x1U, + mkldnn_rnn_cell_with_clipping = 0x2U, +} mkldnn_rnn_cell_flags_t; + +typedef struct { + /** RNN cell kind. Must be one of #mkldnn_vanilla_rnn, + * #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru, + * or #mkldnn_gru_linear_before_reset. */ + mkldnn_alg_kind_t cell_kind; + /** Activation function used. Must be either #mkldnn_eltwise_relu or + * #mkldnn_eltwise_tanh. */ + mkldnn_alg_kind_t activation_kind; + /** RNN cell flags */ + unsigned int flags; + /** @c alpha is a negative slope parameter (used only if + * `(flags & #mkldnn_rnn_cell_with_relu) != 0`) */ + float alpha; + /** clipping parameter (used only if + * `(flags & #mkldnn_rnn_cell_with_clipping) != 0`) */ + float clipping; +} mkldnn_rnn_cell_desc_t; + +/** A direction of RNN primitive execution. */ +typedef enum { + /* Unidirectional execution of RNN primitive from left to right. */ + mkldnn_unidirectional_left2right, + /* Unidirectional execution of RNN primitive from right to left. */ + mkldnn_unidirectional_right2left, + /* Bidirectional execution of RNN primitive with concatenation of the + * results. */ + mkldnn_bidirectional_concat, + /* Bidirectional execution of RNN primitive with summation of the + * results. */ + mkldnn_bidirectional_sum, + mkldnn_unidirectional = mkldnn_unidirectional_left2right, +} mkldnn_rnn_direction_t; + +/** A descriptor for an RNN operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_rnn. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, and #mkldnn_backward. */ + mkldnn_prop_kind_t prop_kind; + /** The RNN cell desc. */ + mkldnn_rnn_cell_desc_t cell_desc; + /** The direction of RNN primitive execution. */ + mkldnn_rnn_direction_t direction; + /** Source layer memory descriptor. */ + mkldnn_memory_desc_t src_layer_desc; + /** Source iteration memory descriptor. */ + mkldnn_memory_desc_t src_iter_desc; + /** Weights layer memory descriptor. */ + mkldnn_memory_desc_t weights_layer_desc; + /** Weights iteration memory descriptor. */ + mkldnn_memory_desc_t weights_iter_desc; + /** Bias memory descriptor. */ + mkldnn_memory_desc_t bias_desc; + /** Destination layer memory descriptor. */ + mkldnn_memory_desc_t dst_layer_desc; + /** Destination iter memory descriptor. */ + mkldnn_memory_desc_t dst_iter_desc; + /** Source gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_src_layer_desc; + /** Source gradient iter memory descriptor. */ + mkldnn_memory_desc_t diff_src_iter_desc; + /** Weights gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_weights_layer_desc; + /** Weights gradient iter memory descriptor. */ + mkldnn_memory_desc_t diff_weights_iter_desc; + /** Bias gradient memory descriptor. */ + mkldnn_memory_desc_t diff_bias_desc; + /** Destination gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_dst_layer_desc; + /** Destination gradient iteration memory descriptor. */ + mkldnn_memory_desc_t diff_dst_iter_desc; +} mkldnn_rnn_desc_t; + +/** @} */ + +/** @addtogroup c_api_engine_types Engine + * @{ */ + +/** @brief Kinds of engines. */ +typedef enum { + /** An unspecified engine. */ + mkldnn_any_engine, + /** CPU engine. */ + mkldnn_cpu, +} mkldnn_engine_kind_t; + +/** @struct mkldnn_engine + * @brief An opaque structure to describe an engine. */ +struct mkldnn_engine; +/** @brief An engine handle. */ +typedef struct mkldnn_engine *mkldnn_engine_t; +#if 0 +/* FIXME: looks like this never happens */ +/** @brief A constant engine handle. */ +typedef const struct mkldnn_engine *const_mkldnn_engine_t; +#endif + +/** @} */ + +/** @addtogroup c_api_primitive_desc_iterators Primitive descriptor iterators + * @{ */ + +/** @struct mkldnn_primitive_desc_iterator + * @brief An opaque structure to describe a primitive descriptor iterator. */ +struct mkldnn_primitive_desc_iterator; + +/** @brief A primitive descriptor iterator handle. */ +typedef struct mkldnn_primitive_desc_iterator + *mkldnn_primitive_desc_iterator_t; + +/** @brief A constant primitive descriptor iterator handle. */ +typedef const struct mkldnn_primitive_desc_iterator + *const_mkldnn_primitive_desc_iterator_t; + +/** @} */ + +/** @addtogroup c_api_primitive_descs Primitive descriptors + * @{ */ + +/** @struct mkldnn_primitive_desc + * @brief An opaque structure to describe a primitive descriptor. */ +struct mkldnn_primitive_desc; + +/** @brief A primitive descriptor handle. */ +typedef struct mkldnn_primitive_desc *mkldnn_primitive_desc_t; + +/** @brief A constant primitive descriptor handle. */ +typedef const struct mkldnn_primitive_desc *const_mkldnn_primitive_desc_t; + +/** @} */ + +/** @addtogroup c_api_primitive_attr Primitive descriptor attributes + * @{ */ + +/** Scratchpad mode */ +typedef enum { + /** The library manages scratchpad (default) */ + mkldnn_scratchpad_mode_library, + /** A user shall query and provide the scratchpad memory to primitives */ + mkldnn_scratchpad_mode_user, +} mkldnn_scratchpad_mode_t; + +/** @struct mkldnn_primitive_attr + * @brief An opaque structure for primitive descriptor attributes. + * + * Attributes may contain: + * - output scales (to scale the result prior to storing it to the memory) + */ +struct mkldnn_primitive_attr; + +/** @brief A primitive descriptor attributes handle that controls primitive + * behavior. */ +typedef struct mkldnn_primitive_attr *mkldnn_primitive_attr_t; + +/** @brief A constant primitive descriptor attributes handle. */ +typedef const struct mkldnn_primitive_attr *const_mkldnn_primitive_attr_t; + +/** @struct mkldnn_post_ops + * @brief An opaque structure for a chain of post operations. + * + * mkldnn_post_ops can be used to perform some (trivial) operations like + * accumulation or eltwise after certain primitives like convolution. + * + * Post operations might be combined together, making a chain of post + * operations. For instance one can configure convolution followed by + * accumulation followed by eltwise. This might be especially beneficial + * for residual learning blocks. + * + * @warning + * Of course not all combinations are supported, so the user should handle + * errors accordingly. + * + * Supported post operations: + * - accumulation (base primitive: convolution) + * - eltwise (base primitive: convolution) + */ +struct mkldnn_post_ops; + +/** @brief A post operation chain handle. */ +typedef struct mkldnn_post_ops *mkldnn_post_ops_t; + +/** @brief A constant post operation chain handle. */ +typedef const struct mkldnn_post_ops *const_mkldnn_post_ops_t; + +/** @} */ + +/** @addtogroup c_api_types_primitive Primitive + * @{ */ + +/** @struct mkldnn_primitive + * An opaque structure to describe a primitive. */ +struct mkldnn_primitive; +/** A primitive handle. */ +typedef struct mkldnn_primitive *mkldnn_primitive_t; +/** A constant primitive handle. */ +typedef const struct mkldnn_primitive *const_mkldnn_primitive_t; + +/** @addtogroup c_api_types_arguments Argument indices + * @{ */ + +#define MKLDNN_ARG_SRC_0 1 +#define MKLDNN_ARG_SRC MKLDNN_ARG_SRC_0 +#define MKLDNN_ARG_SRC_LAYER MKLDNN_ARG_SRC_0 +#define MKLDNN_ARG_FROM MKLDNN_ARG_SRC_0 + +#define MKLDNN_ARG_SRC_1 2 +#define MKLDNN_ARG_SRC_ITER MKLDNN_ARG_SRC_1 + +#define MKLDNN_ARG_DST_0 17 +#define MKLDNN_ARG_DST MKLDNN_ARG_DST_0 +#define MKLDNN_ARG_TO MKLDNN_ARG_DST_0 +#define MKLDNN_ARG_DST_LAYER MKLDNN_ARG_DST_0 + +#define MKLDNN_ARG_DST_1 18 +#define MKLDNN_ARG_DST_ITER MKLDNN_ARG_DST_1 + +#define MKLDNN_ARG_WEIGHTS_0 33 +#define MKLDNN_ARG_WEIGHTS MKLDNN_ARG_WEIGHTS_0 +#define MKLDNN_ARG_SCALE_SHIFT MKLDNN_ARG_WEIGHTS_0 +#define MKLDNN_ARG_WEIGHTS_LAYER MKLDNN_ARG_WEIGHTS_0 + +#define MKLDNN_ARG_WEIGHTS_1 34 +#define MKLDNN_ARG_WEIGHTS_ITER MKLDNN_ARG_WEIGHTS_1 + +#define MKLDNN_ARG_BIAS 41 + +#define MKLDNN_ARG_MEAN 49 +#define MKLDNN_ARG_VARIANCE 50 + +#define MKLDNN_ARG_WORKSPACE 64 +#define MKLDNN_ARG_SCRATCHPAD 80 + +#define MKLDNN_ARG_DIFF_SRC_0 129 +#define MKLDNN_ARG_DIFF_SRC MKLDNN_ARG_DIFF_SRC_0 +#define MKLDNN_ARG_DIFF_SRC_LAYER MKLDNN_ARG_DIFF_SRC_0 + +#define MKLDNN_ARG_DIFF_SRC_1 130 +#define MKLDNN_ARG_DIFF_SRC_ITER MKLDNN_ARG_DIFF_SRC_1 + +#define MKLDNN_ARG_DIFF_DST_0 145 +#define MKLDNN_ARG_DIFF_DST MKLDNN_ARG_DIFF_DST_0 +#define MKLDNN_ARG_DIFF_DST_LAYER MKLDNN_ARG_DIFF_DST_0 + +#define MKLDNN_ARG_DIFF_DST_1 146 +#define MKLDNN_ARG_DIFF_DST_ITER MKLDNN_ARG_DIFF_DST_1 + +#define MKLDNN_ARG_DIFF_WEIGHTS_0 161 +#define MKLDNN_ARG_DIFF_WEIGHTS MKLDNN_ARG_DIFF_WEIGHTS_0 +#define MKLDNN_ARG_DIFF_SCALE_SHIFT MKLDNN_ARG_DIFF_WEIGHTS_0 +#define MKLDNN_ARG_DIFF_WEIGHTS_LAYER MKLDNN_ARG_DIFF_WEIGHTS_0 + +#define MKLDNN_ARG_DIFF_WEIGHTS_1 162 +#define MKLDNN_ARG_DIFF_WEIGHTS_ITER MKLDNN_ARG_DIFF_WEIGHTS_1 + +#define MKLDNN_ARG_DIFF_BIAS 169 + +#define MKLDNN_ARG_MULTIPLE_SRC 1024 +#define MKLDNN_ARG_MULTIPLE_DST 2048 + +/** @} */ + +/** An auxiliary structure to specify primitive's inputs/outputs at execution + * + * @warning + * With this API it's impossible to preserve constness of memory, so all + * memories are passed w/o const qualifier. However only memories with + * output semantics might be changed during the execution */ +typedef struct { + int arg; /**< An argument index, e.g. MKLDNN_ARG_SRC */ + mkldnn_memory_t memory; /**< Input/output memory */ +} mkldnn_exec_arg_t; + +/** @} */ + +/** @addtogroup c_api_types_query Queries + * @{ */ + +/** Primitive descriptor query specification + * + * For generic function mkldnn_primitive_desc_query(), the type of result must + * agree with the queried argument. The correspondence table: + * Query | type of result + * -------------------------------------------------------------- + * #mkldnn_query_engine | mkldnn_engine_t * + * #mkldnn_query_scratchpad_engine | mkldnn_engine_t * + * #mkldnn_query_primitive_kind | mkldnn_primitive_kind_t * + * *_s32 | int * + * *_s64 | mkldnn_dim_t * (same as int64_t *) + * *_f64 | double * + * *_str | const char ** + * #mkldnn_query_op_d | const_mkldnn_op_desc_t * + * *_md | const mkldnn_memory_desc_t ** + * *_${op}_d | const mkldnn_${op}_desc_t ** + * *_pd | const_mkldnn_primitive_desc_t * + * + * @note + * Rule of thumb: all opaque types and structures are returned by + * reference. All numbers are returned by value. + * + * @warning + * All returned references point to constant objects and are valid only + * during the lifetime of the queried primitive descriptor. Returned objects + * must not be destroyed by the user. If you need to keep the object longer + * than the lifetime of the queried primitive descriptor, use + * mkldnn_primitive_desc_clone() to make a copy. */ +typedef enum { + mkldnn_query_undef = 0, /**< no query */ + + mkldnn_query_engine, /**< execution engine */ + mkldnn_query_primitive_kind, /**< primitive kind */ + + mkldnn_query_num_of_inputs_s32, /**< number of inputs expected */ + mkldnn_query_num_of_outputs_s32, /**< number of outputs expected */ + + mkldnn_query_time_estimate_f64, /**< runtime estimation (seconds) */ + mkldnn_query_memory_consumption_s64, /**< memory consumption -- extra + (scratch) memory, additional to all + inputs and outputs memory (bytes) */ + + mkldnn_query_scratchpad_engine, /**< scratchpad engine -- engine to be used + for creating scratchpad memory */ + + mkldnn_query_impl_info_str, /**< implementation name */ + + /* memory and op descriptor section */ + mkldnn_query_some_d = 64, /**< stub */ + mkldnn_query_op_d, /**< op descriptor */ + mkldnn_query_convolution_d, /**< convolution descriptor */ + mkldnn_query_deconvolution_d, /**< deconvolution descriptor */ + mkldnn_query_shuffle_d, /**< shuffle descriptor */ + mkldnn_query_eltwise_d, /**< eltwise descriptor */ + mkldnn_query_softmax_d, /**< softmax descriptor */ + mkldnn_query_pooling_d, /**< pooling descriptor */ + mkldnn_query_lrn_d, /**< lrn descriptor */ + mkldnn_query_batch_normalization_d, /**< batch normalization descriptor */ + mkldnn_query_inner_product_d, /**< inner product descriptor */ + mkldnn_query_rnn_d, /**< rnn descriptor */ + + /* memory descriptor section */ + mkldnn_query_some_md = 128, /**< stub */ + mkldnn_query_src_md, /**< source memory desc */ + mkldnn_query_diff_src_md, /**< source gradient memory desc */ + mkldnn_query_weights_md, /**< weights memory descriptor desc */ + mkldnn_query_diff_weights_md, /**< weights grad. memory desc */ + mkldnn_query_dst_md, /**< destination memory desc */ + mkldnn_query_diff_dst_md, /**< destination grad. memory desc */ + mkldnn_query_workspace_md, /**< workspace memory desc */ + mkldnn_query_scratchpad_md, /**< scratchpad memory desc */ +} mkldnn_query_t; + +/** @} */ + +/** @addtogroup c_api_types_stream Execution stream + * @{ */ + +/** @brief Stream flags. */ +typedef enum { + /** A default stream configuration. */ + mkldnn_stream_default_flags = 0x0U, +} mkldnn_stream_flags_t; + +/** @struct mkldnn_stream + * An opaque structure to describe an execution stream. */ +struct mkldnn_stream; +/** An execution stream handle. */ +typedef struct mkldnn_stream *mkldnn_stream_t; +/** A constant execution stream handle. */ +typedef const struct mkldnn_stream *const_mkldnn_stream_t; + +/** @} */ +/** @} */ +/** @} */ + +#ifdef __cplusplus +} +#endif + + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h new file mode 100644 index 000000000000..a2713deccb3c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_VERSION_H +#define MKLDNN_VERSION_H + +/* Major version of MKL-DNN */ +#define MKLDNN_VERSION_MAJOR 0 + +/* Minor version of MKL-DNN */ +#define MKLDNN_VERSION_MINOR 90 + +/* Patch version of MKL-DNN */ +#define MKLDNN_VERSION_PATCH 0 + +/* Git Commit Hash of MKL-DNN */ +#define MKLDNN_VERSION_HASH "096bda1ca23324879f2df5a129e610e4405f775c" + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in new file mode 100644 index 000000000000..5ee0126188ba --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_VERSION_H +#define MKLDNN_VERSION_H + +/* Major version of MKL-DNN */ +#define MKLDNN_VERSION_MAJOR @MKLDNN_VERSION_MAJOR@ + +/* Minor version of MKL-DNN */ +#define MKLDNN_VERSION_MINOR @MKLDNN_VERSION_MINOR@ + +/* Patch version of MKL-DNN */ +#define MKLDNN_VERSION_PATCH @MKLDNN_VERSION_PATCH@ + +/* Git Commit Hash of MKL-DNN */ +#define MKLDNN_VERSION_HASH "@MKLDNN_VERSION_HASH@" + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp new file mode 100644 index 000000000000..1a51d8562bfa --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp @@ -0,0 +1,104 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) { + bool args_ok = true + && !any_null(bnrm_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward_data, backward) + && IMPLICATION(prop_kind & backward, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto bd = batch_normalization_desc_t(); + bd.primitive_kind = primitive_kind::batch_normalization; + bd.prop_kind = prop_kind; + + bd.data_desc = *data_desc; + bd.diff_data_desc = zero_md(); + if ( one_of(bd.prop_kind,backward_data, backward) ) + bd.diff_data_desc = *diff_data_desc; + + dims_t scaleshift_dims = { 2, data_desc->dims[1] }; + mkldnn_memory_desc_init_by_tag(&bd.data_scaleshift_desc, 2, + scaleshift_dims, data_type::f32, mkldnn_nc); + bd.diff_data_scaleshift_desc = zero_md(); + if (bd.prop_kind == backward) { + bd.diff_data_scaleshift_desc = bd.data_scaleshift_desc; + } + + dims_t stats_dims = { data_desc->dims[1] }; + mkldnn_memory_desc_init_by_tag(&bd.mean_desc, 1, stats_dims, + data_type::f32, mkldnn_x); + bd.variance_desc = bd.mean_desc; + bd.batch_norm_epsilon = epsilon; + + unsigned bnorm_flags = + mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu; + if ((~bnorm_flags & flags) != 0) return invalid_arguments; + + bd.flags = flags; + + bool consistency = true + && utils::one_of(bd.data_desc.ndims, 2, 4, 5); + if (bd.prop_kind == backward_data) + consistency = consistency + && utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5) + && array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims, + bd.diff_data_desc.ndims); + if (!consistency) return invalid_arguments; + + *bnrm_desc = bd; + return success; +} +} + +status_t mkldnn_batch_normalization_forward_desc_init( + batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, float epsilon, unsigned flags) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr, + epsilon, flags); +} + +status_t mkldnn_batch_normalization_backward_desc_init( + batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind, + const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc, + float epsilon, unsigned flags) { + if (!one_of(prop_kind, backward, backward_data)) + return invalid_arguments; + return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc, + epsilon, flags); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp new file mode 100644 index 000000000000..f61410b33c01 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp @@ -0,0 +1,240 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef BATCH_NORMALIZATION_PD_HPP +#define BATCH_NORMALIZATION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct batch_normalization_fwd_pd_t; + +struct batch_normalization_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::batch_normalization; + + batch_normalization_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + , stat_md_(desc_.mean_desc) + , scaleshift_md_(desc_.data_scaleshift_desc) + , ws_md_() + {} + + const batch_normalization_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::batch_normalization_d: + *(const batch_normalization_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common batch_normalization aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return desc_.data_desc.ndims; } + + bool stats_is_src() const { return desc_.flags & mkldnn_use_global_stats; } + bool use_scaleshift() const { return desc_.flags & mkldnn_use_scaleshift; } + bool use_global_stats() const + { return desc_.flags & mkldnn_use_global_stats; } + bool fuse_bn_relu() const { return desc_.flags & mkldnn_fuse_bn_relu; } + bool with_relu_post_op() const { + const auto &p = this->attr()->post_ops_; + return p.len_ == 1 && p.entry_[0].is_relu(true, true); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + bool is_bwd() const { return !this->is_fwd(); } + bool is_training() const + { return desc_.prop_kind == prop_kind::forward_training; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + +protected: + batch_normalization_desc_t desc_; + const batch_normalization_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + memory_desc_t stat_md_; + memory_desc_t scaleshift_md_; + + memory_desc_t ws_md_; + + void init_default_ws(size_t bits_per_element) { + const auto data_mdw = memory_desc_wrapper(data_md_); + + const dim_t data_nelems = data_mdw.nelems(true); + const dim_t bits_per_byte = 8; + const dims_t ws_sz = { (dim_t)utils::div_up( + data_nelems * bits_per_element, bits_per_byte) }; + mkldnn_memory_desc_init_by_tag(&ws_md_, 1, ws_sz, impl::data_type::u8, + format_tag::x); + } + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct batch_normalization_fwd_pd_t: public batch_normalization_pd_t { + typedef batch_normalization_fwd_pd_t base_class; + typedef batch_normalization_fwd_pd_t hint_class; + + batch_normalization_fwd_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) return arg_usage_t::input; + if (arg == MKLDNN_ARG_DST) return arg_usage_t::output; + + if (utils::one_of(arg, MKLDNN_ARG_MEAN, MKLDNN_ARG_VARIANCE)) { + if (stats_is_src()) return arg_usage_t::input; + if (!stats_is_src() && is_training()) return arg_usage_t::output; + return arg_usage_t::unused; + } + + if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE && is_training() && fuse_bn_relu()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override { + if (index == 0) return &data_md_; + if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_; + return nullptr; + } + + virtual const memory_desc_t *dst_md(int index = 0) const override { + if (index == 0) return &data_md_; + if (!stats_is_src() && is_training() && (index == 1 || index == 2)) + return &stat_md_; + return nullptr; + } + + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &scaleshift_md_ : nullptr; } + + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && is_training() && fuse_bn_relu() ? &ws_md_ : nullptr; } + + const memory_desc_t *stat_md() const + { return stats_is_src() ? src_md(1) : dst_md(1); } + + virtual int n_inputs() const override + { return 1 + 2 * stats_is_src() + use_scaleshift(); } + virtual int n_outputs() const override + { return 1 + (fuse_bn_relu() + 2 * (!stats_is_src())) * is_training(); } +}; + +struct batch_normalization_bwd_pd_t: public batch_normalization_pd_t { + typedef batch_normalization_bwd_pd_t base_class; + typedef batch_normalization_fwd_pd_t hint_class; + + batch_normalization_bwd_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + , diff_scaleshift_md_(desc_.diff_data_scaleshift_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_MEAN, + MKLDNN_ARG_VARIANCE, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE && fuse_bn_relu()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : index <= 2 ? &stat_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &scaleshift_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override + { return index == 0 ? &diff_scaleshift_md_ : nullptr; } + + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && fuse_bn_relu() ? &ws_md_ : nullptr; } + + const memory_desc_t *stat_md() const { return src_md(1); } + + virtual int n_inputs() const override + { return 4 + use_scaleshift() + fuse_bn_relu(); } + virtual int n_outputs() const override + { return 1 + (desc_.prop_kind == prop_kind::backward); } + +protected: + memory_desc_t diff_data_md_; + memory_desc_t diff_scaleshift_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp new file mode 100644 index 000000000000..3d43a0fbee74 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp @@ -0,0 +1,550 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TYPE_MAPPING_HPP +#define TYPE_MAPPING_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { + +// TODO: autogenerate this + +using dim_t = mkldnn_dim_t; +using dims_t = mkldnn_dims_t; +using stride_t = mkldnn_dim_t; +using strides_t = mkldnn_strides_t; + +using status_t = mkldnn_status_t; +namespace status { + const status_t success = mkldnn_success; + const status_t out_of_memory = mkldnn_out_of_memory; + const status_t try_again = mkldnn_try_again; + const status_t invalid_arguments = mkldnn_invalid_arguments; + const status_t not_ready = mkldnn_not_ready; + const status_t unimplemented = mkldnn_unimplemented; + const status_t iterator_ends = mkldnn_iterator_ends; + const status_t runtime_error = mkldnn_runtime_error; + const status_t not_required = mkldnn_not_required; +} + +using prop_kind_t = mkldnn_prop_kind_t; +namespace prop_kind { + const prop_kind_t undef = mkldnn_prop_kind_undef; + const prop_kind_t forward_training = mkldnn_forward_training; + const prop_kind_t forward_inference = mkldnn_forward_inference; + const prop_kind_t forward_scoring = mkldnn_forward_scoring; + const prop_kind_t forward = mkldnn_forward; + const prop_kind_t backward = mkldnn_backward; + const prop_kind_t backward_data = mkldnn_backward_data; + const prop_kind_t backward_weights = mkldnn_backward_weights; + const prop_kind_t backward_bias = mkldnn_backward_bias; +} + +using alg_kind_t = mkldnn_alg_kind_t; +namespace alg_kind { + const alg_kind_t undef = mkldnn_alg_kind_undef; + const alg_kind_t convolution_auto = mkldnn_convolution_auto; + const alg_kind_t convolution_direct = mkldnn_convolution_direct; + const alg_kind_t convolution_winograd = mkldnn_convolution_winograd; + const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct; + const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd; + const alg_kind_t eltwise_relu = mkldnn_eltwise_relu; + const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh; + const alg_kind_t eltwise_elu = mkldnn_eltwise_elu; + const alg_kind_t eltwise_square = mkldnn_eltwise_square; + const alg_kind_t eltwise_abs = mkldnn_eltwise_abs; + const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt; + const alg_kind_t eltwise_linear = mkldnn_eltwise_linear; + const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu; + const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu; + const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic; + const alg_kind_t pooling_max = mkldnn_pooling_max; + const alg_kind_t pooling_avg = mkldnn_pooling_avg; + const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding; + const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding; + const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels; + const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel; + const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn; + const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm; + const alg_kind_t vanilla_gru = mkldnn_vanilla_gru; + const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset; +} + +using data_type_t = mkldnn_data_type_t; +namespace data_type { + const data_type_t undef = mkldnn_data_type_undef; + const data_type_t f32 = mkldnn_f32; + const data_type_t s32 = mkldnn_s32; + const data_type_t s8 = mkldnn_s8; + const data_type_t u8 = mkldnn_u8; +} + +using scratchpad_mode_t = mkldnn_scratchpad_mode_t; +namespace scratchpad_mode { + const scratchpad_mode_t library = mkldnn_scratchpad_mode_library; + const scratchpad_mode_t user = mkldnn_scratchpad_mode_user; +} + +using rnn_packed_format_t = mkldnn_rnn_packed_memory_format_t; +namespace rnn_packed_format { + const rnn_packed_format_t undef = mkldnn_packed_format_undef; + const rnn_packed_format_t ldigo_p = mkldnn_ldigo_p; + const rnn_packed_format_t ldgoi_p = mkldnn_ldgoi_p; +} + +using format_kind_t = mkldnn_format_kind_t; +namespace format_kind { + const format_kind_t undef = mkldnn_format_kind_undef; + const format_kind_t any = mkldnn_format_kind_any; + const format_kind_t blocked = mkldnn_blocked; + const format_kind_t wino = mkldnn_format_kind_wino; + const format_kind_t rnn_packed = mkldnn_format_kind_rnn_packed; +} + +using format_tag_t = mkldnn_format_tag_t; +namespace format_tag { + const format_tag_t undef = mkldnn_format_tag_undef; + const format_tag_t any = mkldnn_format_tag_any; + const format_tag_t a = mkldnn_a; + const format_tag_t ab = mkldnn_ab; + const format_tag_t abc = mkldnn_abc; + const format_tag_t abcd = mkldnn_abcd; + const format_tag_t abcde = mkldnn_abcde; + const format_tag_t abcdef = mkldnn_abcdef; + const format_tag_t abdec = mkldnn_abdec; + const format_tag_t acb = mkldnn_acb; + const format_tag_t acbde = mkldnn_acbde; + const format_tag_t acdb = mkldnn_acdb; + const format_tag_t acdeb = mkldnn_acdeb; + const format_tag_t ba = mkldnn_ba; + const format_tag_t bac = mkldnn_bac; + const format_tag_t bacd = mkldnn_bacd; + const format_tag_t bcda = mkldnn_bcda; + const format_tag_t cba = mkldnn_cba; + const format_tag_t cdba = mkldnn_cdba; + const format_tag_t cdeba = mkldnn_cdeba; + const format_tag_t decab = mkldnn_decab; + const format_tag_t Abc16a = mkldnn_Abc16a; + const format_tag_t ABc16a16b = mkldnn_ABc16a16b; + const format_tag_t aBc16b = mkldnn_aBc16b; + const format_tag_t ABc16b16a = mkldnn_ABc16b16a; + const format_tag_t Abc4a = mkldnn_Abc4a; + const format_tag_t aBc4b = mkldnn_aBc4b; + const format_tag_t ABc4b16a4b = mkldnn_ABc4b16a4b; + const format_tag_t ABc4b4a = mkldnn_ABc4b4a; + const format_tag_t ABc8a16b2a = mkldnn_ABc8a16b2a; + const format_tag_t ABc8a8b = mkldnn_ABc8a8b; + const format_tag_t aBc8b = mkldnn_aBc8b; + const format_tag_t ABc8b16a2b = mkldnn_ABc8b16a2b; + const format_tag_t ABc8b8a = mkldnn_ABc8b8a; + const format_tag_t Abcd16a = mkldnn_Abcd16a; + const format_tag_t ABcd16a16b = mkldnn_ABcd16a16b; + const format_tag_t aBcd16b = mkldnn_aBcd16b; + const format_tag_t ABcd16b16a = mkldnn_ABcd16b16a; + const format_tag_t aBCd16b16c = mkldnn_aBCd16b16c; + const format_tag_t aBCd16c16b = mkldnn_aBCd16c16b; + const format_tag_t Abcd4a = mkldnn_Abcd4a; + const format_tag_t aBcd4b = mkldnn_aBcd4b; + const format_tag_t ABcd4b16a4b = mkldnn_ABcd4b16a4b; + const format_tag_t ABcd4b4a = mkldnn_ABcd4b4a; + const format_tag_t aBCd4c16b4c = mkldnn_aBCd4c16b4c; + const format_tag_t aBCd4c4b = mkldnn_aBCd4c4b; + const format_tag_t ABcd8a16b2a = mkldnn_ABcd8a16b2a; + const format_tag_t ABcd8a8b = mkldnn_ABcd8a8b; + const format_tag_t aBcd8b = mkldnn_aBcd8b; + const format_tag_t ABcd8b16a2b = mkldnn_ABcd8b16a2b; + const format_tag_t aBCd8b16c2b = mkldnn_aBCd8b16c2b; + const format_tag_t ABcd8b8a = mkldnn_ABcd8b8a; + const format_tag_t aBCd8b8c = mkldnn_aBCd8b8c; + const format_tag_t aBCd8c16b2c = mkldnn_aBCd8c16b2c; + const format_tag_t aBCd8c8b = mkldnn_aBCd8c8b; + const format_tag_t Abcde16a = mkldnn_Abcde16a; + const format_tag_t ABcde16a16b = mkldnn_ABcde16a16b; + const format_tag_t aBcde16b = mkldnn_aBcde16b; + const format_tag_t ABcde16b16a = mkldnn_ABcde16b16a; + const format_tag_t aBCde16b16c = mkldnn_aBCde16b16c; + const format_tag_t aBCde16c16b = mkldnn_aBCde16c16b; + const format_tag_t aBCde2c8b4c = mkldnn_aBCde2c8b4c; + const format_tag_t Abcde4a = mkldnn_Abcde4a; + const format_tag_t aBcde4b = mkldnn_aBcde4b; + const format_tag_t ABcde4b4a = mkldnn_ABcde4b4a; + const format_tag_t aBCde4b4c = mkldnn_aBCde4b4c; + const format_tag_t aBCde4c16b4c = mkldnn_aBCde4c16b4c; + const format_tag_t aBCde4c4b = mkldnn_aBCde4c4b; + const format_tag_t Abcde8a = mkldnn_Abcde8a; + const format_tag_t ABcde8a8b = mkldnn_ABcde8a8b; + const format_tag_t aBcde8b = mkldnn_aBcde8b; + const format_tag_t ABcde8b16a2b = mkldnn_ABcde8b16a2b; + const format_tag_t aBCde8b16c2b = mkldnn_aBCde8b16c2b; + const format_tag_t ABcde8b8a = mkldnn_ABcde8b8a; + const format_tag_t aBCde8b8c = mkldnn_aBCde8b8c; + const format_tag_t aBCde8c16b2c = mkldnn_aBCde8c16b2c; + const format_tag_t aBCde8c8b = mkldnn_aBCde8c8b; + const format_tag_t aBcdef16b = mkldnn_aBcdef16b; + const format_tag_t aBCdef16b16c = mkldnn_aBCdef16b16c; + const format_tag_t aBCdef16c16b = mkldnn_aBCdef16c16b; + const format_tag_t aBcdef4b = mkldnn_aBcdef4b; + const format_tag_t aBCdef4c4b = mkldnn_aBCdef4c4b; + const format_tag_t aBCdef8b8c = mkldnn_aBCdef8b8c; + const format_tag_t aBCdef8c16b2c = mkldnn_aBCdef8c16b2c; + const format_tag_t aBCdef8c8b = mkldnn_aBCdef8c8b; + const format_tag_t aBdc16b = mkldnn_aBdc16b; + const format_tag_t aBdc4b = mkldnn_aBdc4b; + const format_tag_t aBdc8b = mkldnn_aBdc8b; + const format_tag_t aBdec16b = mkldnn_aBdec16b; + const format_tag_t aBdec4b = mkldnn_aBdec4b; + const format_tag_t aBdec8b = mkldnn_aBdec8b; + const format_tag_t aBdefc16b = mkldnn_aBdefc16b; + const format_tag_t aBdefc4b = mkldnn_aBdefc4b; + const format_tag_t aBdefc8b = mkldnn_aBdefc8b; + const format_tag_t Acb16a = mkldnn_Acb16a; + const format_tag_t Acb4a = mkldnn_Acb4a; + const format_tag_t Acb8a = mkldnn_Acb8a; + const format_tag_t aCBd16b16c = mkldnn_aCBd16b16c; + const format_tag_t aCBde16b16c = mkldnn_aCBde16b16c; + const format_tag_t Acdb16a = mkldnn_Acdb16a; + const format_tag_t Acdb4a = mkldnn_Acdb4a; + const format_tag_t Acdb8a = mkldnn_Acdb8a; + const format_tag_t Acdeb16a = mkldnn_Acdeb16a; + const format_tag_t Acdeb4a = mkldnn_Acdeb4a; + const format_tag_t Acdeb8a = mkldnn_Acdeb8a; + const format_tag_t BAc16a16b = mkldnn_BAc16a16b; + const format_tag_t BAcd16a16b = mkldnn_BAcd16a16b; + const format_tag_t last = mkldnn_format_tag_last; + + const format_tag_t x = mkldnn_x; + const format_tag_t nc = mkldnn_nc; + const format_tag_t cn = mkldnn_cn; + const format_tag_t ncw = mkldnn_ncw; + const format_tag_t nwc = mkldnn_nwc; + const format_tag_t nchw = mkldnn_nchw; + const format_tag_t nhwc = mkldnn_nhwc; + const format_tag_t chwn = mkldnn_chwn; + const format_tag_t ncdhw = mkldnn_ncdhw; + const format_tag_t ndhwc = mkldnn_ndhwc; + const format_tag_t oi = mkldnn_oi; + const format_tag_t io = mkldnn_io; + const format_tag_t oiw = mkldnn_oiw; + const format_tag_t wio = mkldnn_wio; + const format_tag_t oihw = mkldnn_oihw; + const format_tag_t hwio = mkldnn_hwio; + const format_tag_t ihwo = mkldnn_ihwo; + const format_tag_t iohw = mkldnn_iohw; + const format_tag_t oidhw = mkldnn_oidhw; + const format_tag_t dhwio = mkldnn_dhwio; + const format_tag_t goiw = mkldnn_goiw; + const format_tag_t goihw = mkldnn_goihw; + const format_tag_t hwigo = mkldnn_hwigo; + const format_tag_t giohw = mkldnn_giohw; + const format_tag_t goidhw = mkldnn_goidhw; + const format_tag_t tnc = mkldnn_tnc; + const format_tag_t ntc = mkldnn_ntc; + const format_tag_t ldsnc = mkldnn_ldsnc; + const format_tag_t ldigo = mkldnn_ldigo; + const format_tag_t ldgoi = mkldnn_ldgoi; + const format_tag_t ldgo = mkldnn_ldgo; + const format_tag_t nCdhw16c = mkldnn_nCdhw16c; + const format_tag_t nCdhw4c = mkldnn_nCdhw4c; + const format_tag_t nCdhw8c = mkldnn_nCdhw8c; + const format_tag_t nChw16c = mkldnn_nChw16c; + const format_tag_t nChw4c = mkldnn_nChw4c; + const format_tag_t nChw8c = mkldnn_nChw8c; + const format_tag_t nCw16c = mkldnn_nCw16c; + const format_tag_t nCw4c = mkldnn_nCw4c; + const format_tag_t nCw8c = mkldnn_nCw8c; + const format_tag_t IOw16o16i = mkldnn_IOw16o16i; + const format_tag_t OIw16i16o = mkldnn_OIw16i16o; + const format_tag_t OIw16o16i = mkldnn_OIw16o16i; + const format_tag_t Oiw16o = mkldnn_Oiw16o; + const format_tag_t OIw4i16o4i = mkldnn_OIw4i16o4i; + const format_tag_t OIw4i4o = mkldnn_OIw4i4o; + const format_tag_t Oiw4o = mkldnn_Oiw4o; + const format_tag_t OIw8i16o2i = mkldnn_OIw8i16o2i; + const format_tag_t OIw8i8o = mkldnn_OIw8i8o; + const format_tag_t OIw8o16i2o = mkldnn_OIw8o16i2o; + const format_tag_t OIw8o8i = mkldnn_OIw8o8i; + const format_tag_t Owi16o = mkldnn_Owi16o; + const format_tag_t Owi4o = mkldnn_Owi4o; + const format_tag_t Owi8o = mkldnn_Owi8o; + const format_tag_t IOhw16o16i = mkldnn_IOhw16o16i; + const format_tag_t Ohwi16o = mkldnn_Ohwi16o; + const format_tag_t Ohwi4o = mkldnn_Ohwi4o; + const format_tag_t Ohwi8o = mkldnn_Ohwi8o; + const format_tag_t OIhw16i16o = mkldnn_OIhw16i16o; + const format_tag_t OIhw16o16i = mkldnn_OIhw16o16i; + const format_tag_t Oihw16o = mkldnn_Oihw16o; + const format_tag_t OIhw4i16o4i = mkldnn_OIhw4i16o4i; + const format_tag_t OIhw4i4o = mkldnn_OIhw4i4o; + const format_tag_t Oihw4o = mkldnn_Oihw4o; + const format_tag_t OIhw8i16o2i = mkldnn_OIhw8i16o2i; + const format_tag_t OIhw8i8o = mkldnn_OIhw8i8o; + const format_tag_t OIhw8o16i2o = mkldnn_OIhw8o16i2o; + const format_tag_t OIhw8o8i = mkldnn_OIhw8o8i; + const format_tag_t Odhwi16o = mkldnn_Odhwi16o; + const format_tag_t Odhwi4o = mkldnn_Odhwi4o; + const format_tag_t Odhwi8o = mkldnn_Odhwi8o; + const format_tag_t OIdhw16i16o = mkldnn_OIdhw16i16o; + const format_tag_t OIdhw16o16i = mkldnn_OIdhw16o16i; + const format_tag_t Oidhw16o = mkldnn_Oidhw16o; + const format_tag_t OIdhw4i4o = mkldnn_OIdhw4i4o; + const format_tag_t Oidhw4o = mkldnn_Oidhw4o; + const format_tag_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i; + const format_tag_t OIdhw8i8o = mkldnn_OIdhw8i8o; + const format_tag_t OIdhw8o8i = mkldnn_OIdhw8o8i; + const format_tag_t gIOw16o16i = mkldnn_gIOw16o16i; + const format_tag_t Goiw16g = mkldnn_Goiw16g; + const format_tag_t gOIw16i16o = mkldnn_gOIw16i16o; + const format_tag_t gOIw16o16i = mkldnn_gOIw16o16i; + const format_tag_t gOiw16o = mkldnn_gOiw16o; + const format_tag_t gOIw4i16o4i = mkldnn_gOIw4i16o4i; + const format_tag_t gOIw4i4o = mkldnn_gOIw4i4o; + const format_tag_t gOiw4o = mkldnn_gOiw4o; + const format_tag_t gOIw8i16o2i = mkldnn_gOIw8i16o2i; + const format_tag_t gOIw8i8o = mkldnn_gOIw8i8o; + const format_tag_t gOIw8o16i2o = mkldnn_gOIw8o16i2o; + const format_tag_t gOIw8o8i = mkldnn_gOIw8o8i; + const format_tag_t gOwi16o = mkldnn_gOwi16o; + const format_tag_t gOwi4o = mkldnn_gOwi4o; + const format_tag_t gOwi8o = mkldnn_gOwi8o; + const format_tag_t gIOhw16o16i = mkldnn_gIOhw16o16i; + const format_tag_t gOhwi16o = mkldnn_gOhwi16o; + const format_tag_t gOhwi4o = mkldnn_gOhwi4o; + const format_tag_t gOhwi8o = mkldnn_gOhwi8o; + const format_tag_t Goihw16g = mkldnn_Goihw16g; + const format_tag_t gOIhw16i16o = mkldnn_gOIhw16i16o; + const format_tag_t gOIhw16o16i = mkldnn_gOIhw16o16i; + const format_tag_t gOihw16o = mkldnn_gOihw16o; + const format_tag_t gOIhw2i8o4i = mkldnn_gOIhw2i8o4i; + const format_tag_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i; + const format_tag_t gOIhw4i4o = mkldnn_gOIhw4i4o; + const format_tag_t gOIhw4o4i = mkldnn_gOIhw4o4i; + const format_tag_t gOihw4o = mkldnn_gOihw4o; + const format_tag_t Goihw8g = mkldnn_Goihw8g; + const format_tag_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i; + const format_tag_t gOIhw8i8o = mkldnn_gOIhw8i8o; + const format_tag_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o; + const format_tag_t gOIhw8o8i = mkldnn_gOIhw8o8i; + const format_tag_t gOdhwi16o = mkldnn_gOdhwi16o; + const format_tag_t gOdhwi4o = mkldnn_gOdhwi4o; + const format_tag_t gOdhwi8o = mkldnn_gOdhwi8o; + const format_tag_t gOIdhw16i16o = mkldnn_gOIdhw16i16o; + const format_tag_t gOIdhw16o16i = mkldnn_gOIdhw16o16i; + const format_tag_t gOidhw16o = mkldnn_gOidhw16o; + const format_tag_t gOIdhw4i4o = mkldnn_gOIdhw4i4o; + const format_tag_t gOidhw4o = mkldnn_gOidhw4o; + const format_tag_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i; + const format_tag_t gOIdhw8i8o = mkldnn_gOIdhw8i8o; + const format_tag_t gOIdhw8o8i = mkldnn_gOIdhw8o8i; +} + +using memory_extra_flags_t = mkldnn_memory_extra_flags_t; +namespace memory_extra_flags { + const memory_extra_flags_t none = mkldnn_memory_extra_flag_none; + const memory_extra_flags_t compensation_conv_s8s8 = mkldnn_memory_extra_flag_compensation_conv_s8s8; + const memory_extra_flags_t scale_adjust = mkldnn_memory_extra_flag_scale_adjust; +} + +using padding_kind_t = mkldnn_padding_kind_t; +namespace padding_kind { + const padding_kind_t padding_zero = mkldnn_padding_zero; +} + +using engine_kind_t = mkldnn_engine_kind_t; +namespace engine_kind { + const engine_kind_t any_engine = mkldnn_any_engine; + const engine_kind_t cpu = mkldnn_cpu; +} + +using primitive_kind_t = mkldnn_primitive_kind_t; +namespace primitive_kind { + const primitive_kind_t undefined = mkldnn_undefined_primitive; + const primitive_kind_t reorder = mkldnn_reorder; + const primitive_kind_t concat = mkldnn_concat; + const primitive_kind_t sum = mkldnn_sum; + const primitive_kind_t convolution = mkldnn_convolution; + const primitive_kind_t deconvolution = mkldnn_deconvolution; + const primitive_kind_t shuffle = mkldnn_shuffle; + const primitive_kind_t eltwise = mkldnn_eltwise; + const primitive_kind_t softmax = mkldnn_softmax; + const primitive_kind_t pooling = mkldnn_pooling; + const primitive_kind_t lrn = mkldnn_lrn; + const primitive_kind_t batch_normalization = mkldnn_batch_normalization; + const primitive_kind_t inner_product = mkldnn_inner_product; + const primitive_kind_t rnn = mkldnn_rnn; +} + +using query_t = mkldnn_query_t; +namespace query { + const query_t undef = mkldnn_query_undef; + + const query_t engine = mkldnn_query_engine; + const query_t primitive_kind = mkldnn_query_primitive_kind; + + const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32; + const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32; + + const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64; + const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64; + + const query_t scratchpad_engine = mkldnn_query_scratchpad_engine; + + const query_t impl_info_str = mkldnn_query_impl_info_str; + + const query_t some_d = mkldnn_query_some_d; + const query_t op_d = mkldnn_query_op_d; + const query_t convolution_d = mkldnn_query_convolution_d; + const query_t deconvolution_d = mkldnn_query_deconvolution_d; + const query_t shuffle_d = mkldnn_query_shuffle_d; + const query_t eltwise_d = mkldnn_query_eltwise_d; + const query_t softmax_d = mkldnn_query_softmax_d; + const query_t pooling_d = mkldnn_query_pooling_d; + const query_t lrn_d = mkldnn_query_lrn_d; + const query_t batch_normalization_d = mkldnn_query_batch_normalization_d; + const query_t inner_product_d = mkldnn_query_inner_product_d; + const query_t rnn_d = mkldnn_query_rnn_d; + + const query_t some_md = mkldnn_query_some_md; + const query_t src_md = mkldnn_query_src_md; + const query_t diff_src_md = mkldnn_query_diff_src_md; + const query_t weights_md = mkldnn_query_weights_md; + const query_t diff_weights_md = mkldnn_query_diff_weights_md; + const query_t dst_md = mkldnn_query_dst_md; + const query_t diff_dst_md = mkldnn_query_diff_dst_md; + + const query_t workspace_md = mkldnn_query_workspace_md; + const query_t scratchpad_md = mkldnn_query_scratchpad_md; +} + +using blocking_desc_t = mkldnn_blocking_desc_t; +using rnn_packed_desc_t = mkldnn_rnn_packed_desc_t; +using wino_desc_t = mkldnn_wino_desc_t; +using memory_extra_desc_t = mkldnn_memory_extra_desc_t; +using memory_desc_t = mkldnn_memory_desc_t; +using convolution_desc_t = mkldnn_convolution_desc_t; +using deconvolution_desc_t = mkldnn_deconvolution_desc_t; +using shuffle_desc_t = mkldnn_shuffle_desc_t; +using pooling_desc_t = mkldnn_pooling_desc_t; +using eltwise_desc_t = mkldnn_eltwise_desc_t; +using softmax_desc_t = mkldnn_softmax_desc_t; +using lrn_desc_t = mkldnn_lrn_desc_t; +using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t; +using inner_product_desc_t = mkldnn_inner_product_desc_t; + +using rnn_direction_t = mkldnn_rnn_direction_t; +using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t; +using rnn_desc_t = mkldnn_rnn_desc_t; + +/* C op_desc_t, which eventually are just (void*) */ +using c_op_desc_t = mkldnn_op_desc_t; +using const_c_op_desc_t = const_mkldnn_op_desc_t; + +struct op_desc_t { + union { + primitive_kind_t kind; + convolution_desc_t convolution; + deconvolution_desc_t deconvolution; + shuffle_desc_t shuffle; + pooling_desc_t pooling; + eltwise_desc_t eltwise; + softmax_desc_t softmax; + lrn_desc_t lrn; + batch_normalization_desc_t batch_normalization; + inner_product_desc_t inner_product; + rnn_desc_t rnn; + }; + + op_desc_t(const primitive_kind_t &_): kind(_) {} + +# define DECL_CTOR_AND_CONVERTERS(c_type, name) \ + op_desc_t(const c_type &_): name(_) {} \ + static op_desc_t *convert_from_c(c_type *_) \ + { return reinterpret_cast(_); } \ + static const op_desc_t *convert_from_c(const c_type *_) \ + { return reinterpret_cast(_); } + + DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution); + DECL_CTOR_AND_CONVERTERS(shuffle_desc_t, shuffle); + DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling); + DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise); + DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax); + DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn); + DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization); + DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product); + DECL_CTOR_AND_CONVERTERS(rnn_desc_t, rnn); + +# undef DECL_CTOR_AND_CONVERTERS +}; + +using engine_t = mkldnn_engine; +using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator; +using primitive_desc_t = mkldnn_primitive_desc; +using primitive_attr_t = mkldnn_primitive_attr; +using post_ops_t = mkldnn_post_ops; +using memory_t = mkldnn_memory; +using primitive_t = mkldnn_primitive; + +using primitive_arg_index_t = int; + +using stream_flags_t = mkldnn_stream_flags_t; +namespace stream_flags { + const stream_flags_t default_flags = mkldnn_stream_default_flags; +} +using stream_t = mkldnn_stream; + +/* forward declaration of the internal primitive_desc types */ +struct batch_normalization_bwd_pd_t; +struct batch_normalization_fwd_pd_t; +struct batch_normalization_pd_t; +struct concat_pd_t; +struct convolution_bwd_data_pd_t; +struct convolution_bwd_weights_pd_t; +struct convolution_fwd_pd_t; +struct convolution_pd_t; +struct deconvolution_bwd_data_pd_t; +struct deconvolution_bwd_weights_pd_t; +struct deconvolution_fwd_pd_t; +struct deconvolution_pd_t; +struct eltwise_bwd_pd_t; +struct eltwise_fwd_pd_t; +struct eltwise_pd_t; +struct inner_product_bwd_data_pd_t; +struct inner_product_bwd_weights_pd_t; +struct inner_product_fwd_pd_t; +struct inner_product_pd_t; +struct lrn_bwd_pd_t; +struct lrn_fwd_pd_t; +struct lrn_pd_t; +struct pooling_bwd_pd_t; +struct pooling_fwd_pd_t; +struct pooling_pd_t; +struct reorder_pd_t; +struct rnn_bwd_pd_t; +struct rnn_fwd_pd_t; +struct rnn_pd_t; +struct shuffle_pd_t; +struct softmax_bwd_pd_t; +struct softmax_fwd_pd_t; +struct softmax_pd_t; +struct sum_pd_t; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat.cpp b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp new file mode 100644 index 000000000000..ed4c35c6e912 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp @@ -0,0 +1,86 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "concat_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd, + const memory_desc_t *dst_md, int n, int concat_dim, + const memory_desc_t *src_mds, + const primitive_attr_t *attr, + engine_t *engine) { + bool args_ok = !any_null(concat_pd, src_mds) && n > 0; + if (!args_ok) return invalid_arguments; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + const int ndims = src_mds[0].ndims; + const dims_t &dims = src_mds[0].dims; + const data_type_t dt = src_mds[0].data_type; + + int concat_dim_sz = dims[concat_dim]; + for (int i = 1; i < n; ++i) { + if (src_mds[i].ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (d == concat_dim) continue; + if (src_mds[i].dims[d] != dims[d]) + return invalid_arguments; + } + if (src_mds[i].data_type != dt) return invalid_arguments; + concat_dim_sz += src_mds[i].dims[concat_dim]; + } + + memory_desc_t dummy_dst_md; + if (dst_md) { + if (dst_md->ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (dst_md->dims[d] != + (d == concat_dim ? concat_dim_sz : dims[d])) + return invalid_arguments; + } + } else { + dummy_dst_md = src_mds[0]; + dummy_dst_md.dims[concat_dim] = concat_dim_sz; + dummy_dst_md.format_kind = format_kind::any; + dst_md = &dummy_dst_md; + } + + auto c_pd = reinterpret_cast(concat_pd); + + for (auto c = engine->get_concat_implementation_list(); *c; ++c) { + if ((*c)(c_pd, engine, attr, dst_md, n, concat_dim, src_mds) + == success) { + (*c_pd)->init_info(); + (*c_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp new file mode 100644 index 000000000000..29311927e2f3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp @@ -0,0 +1,211 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CONCAT_PD_HPP +#define CONCAT_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct concat_pd_t: public primitive_desc_t { + concat_pd_t(engine_t *engine, const primitive_attr_t *attr, + const memory_desc_t *dst_md, int n, int concat_dim, + const memory_desc_t *src_mds) + : primitive_desc_t(engine, attr, primitive_kind::concat) + , n_(n), concat_dim_(concat_dim), dst_md_(*dst_md) + { + src_mds_.reserve(n_); + for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]); + } + + concat_pd_t(const concat_pd_t &rhs) = default; + + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg >= MKLDNN_ARG_MULTIPLE_SRC + && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index < n_inputs() ? &src_mds_[index] : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return n_; } + virtual int n_outputs() const override { return 1; } + + int concat_dim() const { return concat_dim_; } + + const memory_desc_t *src_image_md(int index = 0) const + { return index < n_inputs() ? &src_image_mds_[index] : nullptr; } + +protected: + int n_, concat_dim_; + memory_desc_t dst_md_; + nstl::vector src_mds_; + + /* contains images of srcs in the dst memory (if possible) + * Lives here to simplify some implementations. An implementation might + * use this auxiliary array iff init() returned success */ + nstl::vector src_image_mds_; + +protected: + /* inits src_image_mds_ and dst_md_ in simple cases. The call may fail */ + status_t init() { + bool ok = true + && set_default_params() == status::success + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper i_d(&src_mds_[i]); + if (!i_d.is_blocking_desc() || i_d.is_additional_buffer()) + return status::unimplemented; + } + + const int ndims = dst_md_.ndims; + int current_concat_dim_offset = 0; + for (int i = 0; i < n_; ++i) { + const int dim = src_mds_[i].dims[concat_dim_]; + dims_t dims, offsets = {}; + utils::array_copy(dims, dst_md_.dims, ndims); + dims[concat_dim_] = dim; + offsets[concat_dim_] = current_concat_dim_offset; + + memory_desc_t src_img_d; + status_t status = mkldnn_memory_desc_init_submemory(&src_img_d, + &dst_md_, dims, offsets); + if (status != status::success) return status; + src_image_mds_.push_back(src_img_d); + current_concat_dim_offset += dim; + } + + return status::success; + } + + status_t set_default_params() { + if (dst_md_.format_kind != format_kind::any) + return status::success; + + const int ndims = dst_md_.ndims; + + /* The stupidest ever heuristics (but not the same as we had before): + * - Pick the first non-plain format; + * - If all formats are plain or it is not possible to create a + * blocked format for the output, pick the format of the plain input + * - If this fails as well, use plain layout (abcd...) + */ + status_t status = status::unimplemented; + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (src_d.is_blocking_desc() && !src_d.is_plain()) { + status = memory_desc_init_by_blocking_desc(dst_md_, + src_d.blocking_desc()); + if (status == status::success) break; + } + } + + if (status == status::success) { + /* check if we can create a sub-memory for the dst */ + bool desired_format_ok = true; + int current_concat_dim_offset = 0; + for (int i = 0; i < n_; ++i) { + const int dim = src_mds_[i].dims[concat_dim_]; + dims_t dims, offsets = {}; + utils::array_copy(dims, dst_md_.dims, ndims); + dims[concat_dim_] = dim; + offsets[concat_dim_] = current_concat_dim_offset; + + memory_desc_t src_img_d; + status_t status = mkldnn_memory_desc_init_submemory(&src_img_d, + &dst_md_, dims, offsets); + if (status != status::success) { + desired_format_ok = false; + break; + } + current_concat_dim_offset += dim; + } + + if (!desired_format_ok) + status = status::unimplemented; + } + + /* if no success so far, try using the format of the first plain input */ + if (status != status::success) { + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (src_d.is_blocking_desc() && src_d.is_plain()) { + status = memory_desc_init_by_blocking_desc(dst_md_, + memory_desc_wrapper(src_mds_[0]).blocking_desc()); + if (status == status::success) return status; + } + } + } + + /* the last line of defense: use plain abcd... format */ + if (status != status::success) + status = memory_desc_init_by_strides(dst_md_, nullptr); + + return status; + } +}; + +#define DECLARE_CONCAT_PD_t(impl_name, ...) \ + static status_t create(concat_pd_t **concat_pd, \ + engine_t *engine, const primitive_attr_t *attr, \ + const memory_desc_t *dst_md, int n, int concat_dim, \ + const memory_desc_t *src_mds) { \ + using namespace status; \ + auto _pd = new pd_t(engine, attr, dst_md, n, concat_dim, src_mds); \ + if (_pd == nullptr) return out_of_memory; \ + if (_pd->init() != success) { delete _pd; return unimplemented; } \ + return safe_ptr_assign(*concat_pd, _pd); \ + } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual const char *name() const override { return impl_name; } \ + +#define DECLARE_CONCAT_PD_T(impl_name, ...) \ + DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__) + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp new file mode 100644 index 000000000000..0c5c02bcd1cf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp @@ -0,0 +1,200 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace mkldnn { +namespace impl { +status_t conv_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides, + padding_l) + && one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) return invalid_arguments; + + if (padding_r == nullptr) padding_r = padding_l; + + auto cd = convolution_desc_t(); + cd.primitive_kind = primitive_kind::convolution; + cd.prop_kind = prop_kind; + cd.alg_kind = alg_kind; + + cd.diff_src_desc = cd.src_desc = zero_md(); + cd.diff_dst_desc = cd.dst_desc = zero_md(); + cd.diff_weights_desc = cd.weights_desc = zero_md(); + cd.diff_bias_desc = cd.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias = + bias_desc && bias_desc->format_kind != format_kind::undef; + const bool with_groups = weights_desc->ndims == src_desc->ndims + 1; + + (prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc; + (is_fwd ? cd.dst_desc : cd.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc) = + *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc) = + *bias_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(cd.strides, strides, sp_dims); + utils::array_copy(cd.padding[0], padding_l, sp_dims); + utils::array_copy(cd.padding[1], padding_r, sp_dims); + if (dilates) + utils::array_copy(cd.dilates, dilates, sp_dims); + else + utils::array_set(cd.dilates, 0, sp_dims); + + cd.padding_kind = padding_kind; + cd.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + const int g = with_groups ? weights_desc->dims[0] : 1; + const int bias_dim = prop_kind == backward_data + ? src_desc->dims[1] + : dst_desc->dims[1]; + + bool consistency = true + && memory_desc_wrapper(weights_desc).nelems() + && src_desc->ndims == dst_desc->ndims + && utils::one_of(src_desc->ndims, 3, 4, 5) + && utils::one_of(weights_desc->ndims, src_desc->ndims, + src_desc->ndims + 1) + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == bias_dim : true) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1] + && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0]; + for (int i = 2; i < src_desc->ndims; ++i) + { + int src = src_desc->dims[i]; + int ker = weights_desc->dims[with_groups + i]; + int dil = cd.dilates[i - 2]; + int pad_l = padding_l[i - 2]; + int pad_r = padding_r[i - 2]; + int str = strides[i - 2]; + int dst = dst_desc->dims[i]; + int ker_range = 1 + (ker - 1) * (dil + 1); + + if (str < 1) return invalid_arguments; + consistency = consistency + && dil >= 0 + && pad_l >= 0 + && pad_r + str > 0 + && (src - ker_range + pad_l + pad_r) / str + 1 == dst; + } + if (!consistency) return invalid_arguments; + + *conv_desc = cd; + return success; +} +} +} + +status_t mkldnn_convolution_forward_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_forward_desc_init( + convolution_desc_t *conv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_convolution_backward_data_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_backward_data_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_convolution_backward_weights_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, + nullptr, padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_backward_weights_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, + dilates, padding_l, padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp new file mode 100644 index 000000000000..9604e0acf5ac --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "utils.hpp" + +#include "convolution_pd.hpp" + +namespace mkldnn { +namespace impl { + +using namespace prop_kind; + +memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_data + ? &desc->diff_src_desc : &desc->src_desc; +} + +memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_weights_desc : &desc->weights_desc; +} + +memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_bias_desc : &desc->bias_desc; +} + +memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) { + return utils::one_of(desc->prop_kind, forward_inference, forward_training) + ? &desc->dst_desc : &desc->diff_dst_desc; +} + +const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_src_d(const_cast(desc)); } +const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_wei_d(const_cast(desc)); } +const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_bia_d(const_cast(desc)); } +const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_dst_d(const_cast(desc)); } + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp new file mode 100644 index 000000000000..b10c36db4909 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp @@ -0,0 +1,348 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CONVOLUTION_PD_HPP +#define CONVOLUTION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +status_t conv_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind); + +memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc); + +struct convolution_fwd_pd_t; + +struct convolution_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::convolution; + + convolution_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const convolution_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case pkind_traits::query_d: + *(const convolution_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common conv aux functions */ + + dim_t MB() const { return _src_md()->dims[0]; } + + dim_t IC() const { return _src_md()->dims[1]; } + dim_t OC() const { return _dst_md()->dims[1]; } + dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; } + + dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; } + dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; } + dim_t IW() const { return _src_md()->dims[ndims() - 1]; } + + dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; } + dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; } + dim_t OW() const { return _dst_md()->dims[ndims() - 1]; } + + dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; } + dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; } + dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } + dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } + dim_t KDW() const { return desc_.dilates[ndims() - 3]; } + + dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + int ndims() const { return _src_md()->ndims; } + + bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); } + bool with_groups() const { return _wei_md()->ndims == ndims() + 1; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*_src_md()); + const auto d_d = memory_desc_wrapper(*_dst_md()); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + +protected: + convolution_desc_t desc_; + const convolution_fwd_pd_t *hint_fwd_pd_; + + bool set_default_formats_common_template( + memory_desc_t &src_md, format_tag_t src_tag, + memory_desc_t &wei_md, format_tag_t wei_tag, + memory_desc_t &dst_md, format_tag_t dst_tag, + memory_desc_t &bia_md) { + using namespace format_tag; + +# define IS_OK(f) \ + do { if ((f) != status::success) return false; } while(0) + if (src_md.format_kind == format_kind::any + && !utils::one_of(src_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(src_md, src_tag)); + if (dst_md.format_kind == format_kind::any + && !utils::one_of(dst_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(dst_md, dst_tag)); + if (wei_md.format_kind == format_kind::any + && !utils::one_of(wei_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(wei_md, wei_tag)); + if (with_bias() && bia_md.format_kind == format_kind::any) + IS_OK(memory_desc_init_by_tag(bia_md, x)); +# undef IS_OK + + return true; + } + + bool set_default_alg_kind(alg_kind_t alg_kind) { + assert(utils::one_of(alg_kind, alg_kind::convolution_direct, + alg_kind::convolution_winograd)); + if (desc_.alg_kind == alg_kind::convolution_auto) + desc_.alg_kind = alg_kind; + return desc_.alg_kind == alg_kind; + } + + bool expect_data_types(data_type_t src_dt, data_type_t wei_dt, + data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const { + bool ok = true + && (src_dt == data_type::undef || _src_md()->data_type == src_dt) + && (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt) + && (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt) + && (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt); + if (with_bias() && bia_dt != data_type::undef) + ok = ok && _bia_md()->data_type == bia_dt; + return ok; + } + +private: + const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); } + const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); } + const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); } + const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); } +}; + +struct convolution_fwd_pd_t: public convolution_pd_t { + typedef convolution_fwd_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_fwd_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; + + bool set_default_formats_common(format_tag_t src_tag, + format_tag_t wei_tag, format_tag_t dst_tag) { + return set_default_formats_common_template(src_md_, src_tag, + weights_md_, wei_tag, dst_md_, dst_tag, bias_md_); + } +}; + +struct convolution_bwd_data_pd_t: public convolution_pd_t { + typedef convolution_bwd_data_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_bwd_data_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + + virtual bool support_bias() const { return false; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t diff_dst_md_; + + bool set_default_formats_common(format_tag_t diff_src_tag, + format_tag_t wei_tag, format_tag_t diff_dst_tag) { + return set_default_formats_common_template(diff_src_md_, diff_src_tag, + weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_); + } +}; + +struct convolution_bwd_weights_pd_t: public convolution_pd_t { + typedef convolution_bwd_weights_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_bwd_weights_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; + + bool set_default_formats_common(format_tag_t src_tag, + format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) { + return set_default_formats_common_template(src_md_, src_tag, + diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag, + diff_bias_md_); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp new file mode 100644 index 000000000000..98063c1c37ef --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp @@ -0,0 +1,188 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t deconv_desc_init(deconvolution_desc_t *deconv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides, + padding_l) + && one_of(alg_kind, deconvolution_direct, deconvolution_winograd) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) + return invalid_arguments; + + if (padding_r == nullptr) + padding_r = padding_l; + + auto dd = deconvolution_desc_t(); + dd.primitive_kind = primitive_kind::deconvolution; + dd.prop_kind = prop_kind; + dd.alg_kind = alg_kind; + + dd.diff_src_desc = dd.src_desc = zero_md(); + dd.diff_dst_desc = dd.dst_desc = zero_md(); + dd.diff_weights_desc = dd.weights_desc = zero_md(); + dd.diff_bias_desc = dd.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias + = bias_desc && bias_desc->format_kind != format_kind::undef; + const bool with_groups = weights_desc->ndims == src_desc->ndims + 1; + + (prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc; + (is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc) + = *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc) + = *bias_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(dd.strides, strides, sp_dims); + utils::array_copy(dd.padding[0], padding_l, sp_dims); + utils::array_copy(dd.padding[1], padding_r, sp_dims); + if (dilates) + utils::array_copy(dd.dilates, dilates, sp_dims); + else + utils::array_set(dd.dilates, 0, sp_dims); + + dd.padding_kind = padding_kind; + dd.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + const int g = with_groups ? weights_desc->dims[0] : 1; + bool consistency = true + && src_desc->ndims == dst_desc->ndims + && utils::one_of(src_desc->ndims, 3, 4, 5) + && utils::one_of(weights_desc->ndims, src_desc->ndims, + src_desc->ndims + 1) + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1] + && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0]; + for (int i = 2; i < src_desc->ndims; ++i) { + int src = src_desc->dims[i]; + int ker = weights_desc->dims[with_groups + i]; + int dil = dd.dilates[i - 2]; + int pad = padding_l[i - 2] + padding_r[i - 2]; + int str = strides[i - 2]; + int dst = dst_desc->dims[i]; + int ker_range = 1 + (ker - 1) * (dil + 1); + + consistency + = consistency && (dst - ker_range + pad) / str + 1 == src; + } + if (!consistency) + return invalid_arguments; + + *deconv_desc = dd; + return success; +} +} + +status_t mkldnn_deconvolution_forward_desc_init( + deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, nullptr, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_forward_desc_init( + deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, dilates, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_deconvolution_backward_data_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, nullptr, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_backward_data_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides,dilates, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_deconvolution_backward_weights_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc, + const dims_t strides, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_backward_weights_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc, + const dims_t strides, const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp new file mode 100644 index 000000000000..539e44bd9bb1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp @@ -0,0 +1,293 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DECONVOLUTION_PD_HPP +#define DECONVOLUTION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "convolution_pd.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct deconvolution_fwd_pd_t; + +struct deconvolution_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::deconvolution; + + deconvolution_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const deconvolution_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case pkind_traits::query_d: + *(const deconvolution_desc_t **)result = desc(); + break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */ + + dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; } + + dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; } + dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; } + dim_t G() const + { return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; } + + dim_t ID() const { + return ndims() >= 5 + ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t IH() const { + return ndims() >= 4 + ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t IW() const { + return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1]; + } + + dim_t OD() const { + return ndims() >= 5 + ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t OH() const { + return ndims() >= 4 + ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t OW() const { + return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1]; + } + + dim_t KD() const { + const int w_ndims = ndims() + with_groups(); + return ndims() >= 5 + ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1; + } + dim_t KH() const { + const int w_ndims = ndims() + with_groups(); + return ndims() >= 4 + ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1; + } + dim_t KW() const { + const int w_ndims = ndims() + with_groups(); + return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1]; + } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } + dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } + dim_t KDW() const { return desc_.dilates[ndims() - 3]; } + + dim_t padFront() const + { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const + { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const + { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const + { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + bool with_bias() const { + return + !memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero(); + } + + bool with_groups() const + { return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; } + + int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_)); + const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_)); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + +protected: + deconvolution_desc_t desc_; + const deconvolution_fwd_pd_t *hint_fwd_pd_; +}; + +struct deconvolution_fwd_pd_t: public deconvolution_pd_t { + typedef deconvolution_fwd_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_fwd_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; +}; + +struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t { + typedef deconvolution_bwd_data_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_bwd_data_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &weights_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t diff_dst_md_; +}; + +struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t { + typedef deconvolution_bwd_weights_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_bwd_weights_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp new file mode 100644 index 000000000000..f1708fca5226 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp @@ -0,0 +1,84 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, float alpha, float beta) { + bool args_ok = true + && !any_null(eltwise_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward_data) + && one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic) + && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto ed = eltwise_desc_t(); + ed.primitive_kind = primitive_kind::eltwise; + ed.prop_kind = prop_kind; + ed.alg_kind = alg_kind; + + ed.data_desc = *data_desc; + ed.diff_data_desc = + (ed.prop_kind == backward_data) ? *diff_data_desc : zero_md(); + + ed.alpha = alpha; + ed.beta = beta; + + bool consistency = true + && IMPLICATION(ed.prop_kind == backward_data, + array_cmp(ed.diff_data_desc.dims, ed.data_desc.dims, + ed.diff_data_desc.ndims)); + if (!consistency) return invalid_arguments; + + *eltwise_desc = ed; + return success; +} +} + +status_t mkldnn_eltwise_forward_desc_init(eltwise_desc_t *eltwise_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, float alpha, float beta) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return eltwise_desc_init(eltwise_desc, prop_kind, alg_kind, data_desc, + nullptr, alpha, beta); +} + +status_t mkldnn_eltwise_backward_desc_init(eltwise_desc_t *eltwise_desc, + alg_kind_t alg_kind, const memory_desc_t *diff_data_desc, + const memory_desc_t *data_desc, float alpha, float beta) { + return eltwise_desc_init(eltwise_desc, backward_data, alg_kind, data_desc, + diff_data_desc, alpha, beta); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp new file mode 100644 index 000000000000..9fd260fceea4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp @@ -0,0 +1,161 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ELTWISE_PD_HPP +#define ELTWISE_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct eltwise_fwd_pd_t; + +struct eltwise_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::eltwise; + + eltwise_pd_t(mkldnn::impl::engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const eltwise_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::eltwise_d: + *(const eltwise_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common eltwise aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + +protected: + eltwise_desc_t desc_; + const eltwise_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct eltwise_fwd_pd_t: public eltwise_pd_t { + typedef eltwise_fwd_pd_t base_class; + typedef eltwise_fwd_pd_t hint_class; + + eltwise_fwd_pd_t(mkldnn::impl::engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + bool is_zero_preserved() const + { return math::eltwise_fwd_preserves_zero(desc_.alg_kind); } +}; + +struct eltwise_bwd_pd_t: public eltwise_pd_t { + typedef eltwise_bwd_pd_t base_class; + typedef eltwise_fwd_pd_t hint_class; + + eltwise_bwd_pd_t(engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + + bool is_zero_preserved() const { return true; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.cpp b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp new file mode 100644 index 000000000000..3b3e25456d87 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp @@ -0,0 +1,75 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" +#include "engine.hpp" +#include "nstl.hpp" + +#include "c_types_map.hpp" +#include "../cpu/cpu_engine.hpp" + +namespace mkldnn { +namespace impl { + +engine_factory_t *engine_factories[] = { + &cpu::engine_factory, + nullptr, +}; + +static inline engine_factory_t *get_engine_factory(engine_kind_t kind) { + for (engine_factory_t **ef = engine_factories; *ef; ef++) + if ((*ef)->kind() == kind) + return *ef; + return nullptr; +} + +} +} + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +size_t mkldnn_engine_get_count(engine_kind_t kind) { + engine_factory_t *ef = get_engine_factory(kind); + return ef != nullptr ? ef->count() : 0; +} + +status_t mkldnn_engine_create(engine_t **engine, + engine_kind_t kind, size_t index) { + if (engine == nullptr) + return invalid_arguments; + + engine_factory_t *ef = get_engine_factory(kind); + if (ef == nullptr || index >= ef->count()) + return invalid_arguments; + + return ef->engine_create(engine, index); +} + +status_t mkldnn_engine_get_kind(engine_t *engine, engine_kind_t *kind) { + if (engine == nullptr) + return invalid_arguments; + *kind = engine->kind(); + return success; +} + +status_t mkldnn_engine_destroy(engine_t *engine) { + /* TODO: engine->dec_ref_count(); */ + delete engine; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.hpp b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp new file mode 100644 index 000000000000..8ac8a29de56e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp @@ -0,0 +1,119 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ENGINE_HPP +#define ENGINE_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive.hpp" +#include "utils.hpp" + +/** \brief An abstraction of an execution unit with shared resources + * + * Responsibilities: + * - Provide engine specific memory allocation + * - Provide engine specific primitive_desc_t creators + */ +struct mkldnn_engine: public mkldnn::impl::c_compatible { + mkldnn_engine(mkldnn::impl::engine_kind_t kind) + : kind_(kind) + {} + virtual ~mkldnn_engine() {} + + /** get kind of the current engine */ + virtual mkldnn::impl::engine_kind_t kind() const { return kind_; } + + /** allocate memory */ + virtual mkldnn::impl::status_t memory_create( + mkldnn::impl::memory_t **memory, + const mkldnn::impl::memory_desc_t *md, + void *handle) = 0; + + /** implementation section (typedefs) */ + + // TODO: remove engine? + typedef mkldnn::impl::status_t (*reorder_primitive_desc_create_f)( + mkldnn::impl::reorder_pd_t **reorder_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *src_engine, + const mkldnn::impl::memory_desc_t *src_md, + mkldnn::impl::engine_t *dst_engine, + const mkldnn::impl::memory_desc_t *dst_md); + + typedef mkldnn::impl::status_t (*concat_primitive_desc_create_f)( + mkldnn::impl::concat_pd_t **concat_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + const mkldnn::impl::memory_desc_t *dst_md, + int n, int concat_dim, + const mkldnn::impl::memory_desc_t *src_mds); + + typedef mkldnn::impl::status_t (*sum_primitive_desc_create_f)( + mkldnn::impl::sum_pd_t **sum_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + const mkldnn::impl::memory_desc_t *dst_md, + int n, const float *scales, + const mkldnn::impl::memory_desc_t *src_mds); + + typedef mkldnn::impl::status_t (*primitive_desc_create_f)( + mkldnn::impl::primitive_desc_t **, const mkldnn::impl::op_desc_t *, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *, const mkldnn::impl::primitive_desc_t *); + + /* implementation section */ + + /** return the list of reorder implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const reorder_primitive_desc_create_f* + get_reorder_implementation_list() const = 0; + + /** return the list of concat implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const concat_primitive_desc_create_f* + get_concat_implementation_list() const = 0; + + /** return the list of sum implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const sum_primitive_desc_create_f* + get_sum_implementation_list() const = 0; + + /** return the list of implementations. engine guarantees to return a + * NULL-terminated list */ + virtual const primitive_desc_create_f* get_implementation_list() const = 0; + +protected: + mkldnn::impl::engine_kind_t kind_; +}; + +namespace mkldnn { +namespace impl { + +struct engine_factory_t: public c_compatible { + virtual size_t count() const = 0; + virtual engine_kind_t kind() const = 0; + virtual status_t engine_create(engine_t **engine, size_t index) const = 0; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp new file mode 100644 index 000000000000..5a9f58cb1e85 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp @@ -0,0 +1,106 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) { + bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc); + if (!args_ok) return invalid_arguments; + + auto id = inner_product_desc_t(); + id.primitive_kind = primitive_kind::inner_product; + id.prop_kind = prop_kind; + + id.diff_src_desc = id.src_desc = zero_md(); + id.diff_dst_desc = id.dst_desc = zero_md(); + id.diff_weights_desc = id.weights_desc = zero_md(); + id.diff_bias_desc = id.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias = + bias_desc && bias_desc->format_kind != format_kind::undef; + + (prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc; + (is_fwd ? id.dst_desc : id.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc) = + *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc) = + *bias_desc; + + id.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + bool consistency = true + && memory_desc_wrapper(weights_desc).nelems() + && one_of(src_desc->ndims, 2, 3, 4, 5) + && dst_desc->ndims == 2 + && weights_desc->ndims == src_desc->ndims + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true) + && src_desc->dims[0] == dst_desc->dims[0] + && array_cmp(&src_desc->dims[1], &weights_desc->dims[1], + src_desc->ndims - 1) + && dst_desc->dims[1] == weights_desc->dims[0]; + if (!consistency) return invalid_arguments; + + *ip_desc = id; + return success; +} +} + +status_t mkldnn_inner_product_forward_desc_init(inner_product_desc_t *ip_desc, + prop_kind_t prop_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return ip_desc_init(ip_desc, prop_kind, src_desc, weights_desc, bias_desc, + dst_desc); +} + +status_t mkldnn_inner_product_backward_data_desc_init( + inner_product_desc_t *ip_desc, const memory_desc_t *diff_src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc) +{ + return ip_desc_init(ip_desc, backward_data, diff_src_desc, weights_desc, + nullptr, diff_dst_desc); +} + +status_t mkldnn_inner_product_backward_weights_desc_init( + inner_product_desc_t *ip_desc, const memory_desc_t *src_desc, + const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc) { + return ip_desc_init(ip_desc, backward_weights, src_desc, diff_weights_desc, + diff_bias_desc, diff_dst_desc); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp new file mode 100644 index 000000000000..091cf0f5d6c3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "utils.hpp" + +#include "inner_product_pd.hpp" + +namespace mkldnn { +namespace impl { + +using namespace prop_kind; + +memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_data + ? &desc->diff_src_desc : &desc->src_desc; +} + +memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_weights_desc : &desc->weights_desc; +} + +memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_bias_desc : &desc->bias_desc; +} + +memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) { + return utils::one_of(desc->prop_kind, forward_inference, forward_training) + ? &desc->dst_desc : &desc->diff_dst_desc; +} + +const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_src_d(const_cast(desc)); } +const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_wei_d(const_cast(desc)); } +const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_bia_d(const_cast(desc)); } +const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_dst_d(const_cast(desc)); } + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp new file mode 100644 index 000000000000..c426de632cab --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp @@ -0,0 +1,321 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef INNER_PRODUCT_PD_HPP +#define INNER_PRODUCT_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc); + +struct inner_product_fwd_pd_t; + +struct inner_product_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::inner_product; + + inner_product_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const inner_product_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::inner_product_d: + *(const inner_product_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common inner_product aux functions */ + + dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; } + dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; } + dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; } + + dim_t ID() const { + return ndims() >= 5 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t IH() const { + return ndims() >= 4 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t IW() const { + return ndims() >= 3 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t OD() const { + return ndims() >= 5 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t OH() const { + return ndims() >= 4 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t OW() const { + return ndims() >= 3 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t KD() const { + return ndims() >= 5 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t KH() const { + return ndims() >= 4 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t KW() const { + return ndims() >= 3 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t IC_total() const { + return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1], + ndims() - 1); + } + + dim_t IC_total_padded() const { + auto src_d = desc()->prop_kind == prop_kind::backward_data + ? memory_desc_wrapper(diff_src_md()) + : memory_desc_wrapper(src_md()); + assert(src_d.is_blocking_desc()); + if (!src_d.is_blocking_desc()) return -1; + return utils::array_product(src_d.padded_dims() + 1, ndims() - 1); + } + + int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; } + + bool with_bias() const + { return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_)); + const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_)); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + inner_product_desc_t desc_; + const inner_product_fwd_pd_t *hint_fwd_pd_; + + status_t template_set_default_params(memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t *bias_md) { + using namespace format_tag; + if (src_md.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, + utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw))); + } + if (dst_md.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_md, nc)); + if (weights_md.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md, + utils::pick(ndims() - 2, oi, oiw, oihw, oidhw))); + } + if (bias_md && bias_md->format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(*bias_md, x)); + return status::success; + } +}; + +struct inner_product_fwd_pd_t: public inner_product_pd_t { + typedef inner_product_fwd_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_fwd_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; + + status_t set_default_params() { + return template_set_default_params(src_md_, weights_md_, dst_md_, + &bias_md_); + } +}; + +struct inner_product_bwd_data_pd_t: public inner_product_pd_t { + typedef inner_product_bwd_data_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_bwd_data_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &weights_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t diff_dst_md_; + + status_t set_default_params() { + return template_set_default_params(diff_src_md_, weights_md_, + diff_dst_md_, nullptr); + } +}; + +struct inner_product_bwd_weights_pd_t: public inner_product_pd_t { + typedef inner_product_bwd_weights_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_bwd_weights_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; + + status_t set_default_params() { + return template_set_default_params(src_md_, diff_weights_md_, + diff_dst_md_, &diff_bias_md_); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp new file mode 100644 index 000000000000..fcf18b556f48 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t lrn_desc_init(lrn_desc_t *lrn_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, const memory_desc_t *diff_data_desc, + dim_t local_size, float alpha, float beta, float k) { + bool args_ok = true + && !any_null(lrn_desc, data_desc) + && one_of(alg_kind, lrn_within_channel, lrn_across_channels) + && one_of(prop_kind, forward_training, forward_inference, backward_data) + && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto ld = lrn_desc_t(); + ld.primitive_kind = primitive_kind::lrn; + ld.prop_kind = prop_kind; + ld.alg_kind = alg_kind; + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + + ld.data_desc = *data_desc; + if (!is_fwd) + ld.diff_data_desc = *diff_data_desc; + else + ld.diff_data_desc = zero_md(); + ld.local_size = local_size; + ld.lrn_alpha = alpha; + ld.lrn_beta = beta; + ld.lrn_k = k; + + bool consistency = true + && ld.data_desc.ndims == 4; + if (ld.prop_kind == backward_data) + consistency = consistency + && ld.diff_data_desc.ndims == 4 + && array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims, 4); + if (!consistency) return invalid_arguments; + + *lrn_desc = ld; + return success; +} +} + +status_t mkldnn_lrn_forward_desc_init(lrn_desc_t *lrn_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, dim_t local_size, float alpha, + float beta, float k) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr, + local_size, alpha, beta, k); +} + +status_t mkldnn_lrn_backward_desc_init(lrn_desc_t *lrn_desc, + alg_kind_t alg_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, dim_t local_size, float alpha, + float beta, float k) { + return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc, + diff_data_desc, local_size, alpha, beta, k); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp new file mode 100644 index 000000000000..90886e96562a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp @@ -0,0 +1,170 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef LRN_PD_HPP +#define LRN_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct lrn_fwd_pd_t; + +struct lrn_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::lrn; + + lrn_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + , ws_md_() + {} + + const lrn_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::lrn_d: + *(const lrn_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common lrn aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + lrn_desc_t desc_; + const lrn_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + memory_desc_t ws_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct lrn_fwd_pd_t: public lrn_pd_t { + typedef lrn_fwd_pd_t base_class; + typedef lrn_fwd_pd_t hint_class; + + lrn_fwd_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : lrn_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } +}; + +struct lrn_bwd_pd_t: public lrn_pd_t { + typedef lrn_bwd_pd_t base_class; + typedef lrn_fwd_pd_t hint_class; + + lrn_bwd_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : lrn_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override + { return 2 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp new file mode 100644 index 000000000000..3fddc0bd4565 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MATH_UTILS_HPP +#define MATH_UTILS_HPP + +#include +#include + +#include "utils.hpp" +#include "nstl.hpp" +#include "mkldnn_traits.hpp" + +#if defined(MKLDNN_X86_64) +#include "immintrin.h" +#endif + +namespace mkldnn { +namespace impl { +namespace math { + +/** rounds @p f to an integer according to the mxcsr register */ +inline int mxcsr_round(float f) { +#if defined(MKLDNN_X86_64) + return _mm_cvtss_si32(_mm_load_ss(&f)); +#else + return (int)nearbyintf(f); // optimism +#endif +} + +template +inline typename utils::enable_if::value, + typename utils::remove_reference::type>::type +saturate(const acc_t &x) { + return (typename utils::remove_reference::type)x; +} + +template +inline typename utils::enable_if::value, + typename utils::remove_reference::type>::type +saturate(const acc_t &x) { + acc_t v = x; + if (v < (acc_t)nstl::numeric_limits::lowest()) + v = (acc_t)nstl::numeric_limits::lowest(); + if (v > (acc_t)nstl::numeric_limits::max()) + v = (acc_t)nstl::numeric_limits::max(); + return (typename utils::remove_reference::type)v; +} + +template +double saturate(const double &x) { + double v = x; + if (v < (double)nstl::numeric_limits::lowest()) + v = (double)nstl::numeric_limits::lowest(); + if (v > (double)nstl::numeric_limits::max()) + v = (double)nstl::numeric_limits::max(); + return v; +} + +template <> inline int8_t saturate(const uint8_t &x) { + return x <= 127u ? x : 127; +} + +template <> inline uint8_t saturate(const int8_t &x) { + return x >= 0 ? x : 0; +} + +template +typename utils::enable_if::value, out_t>::type +out_round(float v) { return (out_t)mxcsr_round(v); } + +template +typename utils::enable_if::value, out_t>::type +out_round(double v) { return (out_t)mxcsr_round((float)v); } + +template +typename utils::enable_if::value, out_t>::type +out_round(float v) { return v; } + +inline int gcd(int a, int b) { + a = impl::nstl::abs(a); + b = impl::nstl::abs(b); + if (a < b) { int x = a; a = b; b = x; } + + if (b == 0) return a; + + int r; + while ((r = a % b) != 0) { a = b; b = r; } + + return b; +} + +template +inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; } + +/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */ +inline int ilog2q(size_t v) { + if (v == 0) + return -1; + + int p = 0; +# define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0) + CP(32); CP(16); CP(8); CP(4); CP(2); CP(1); +# undef CP + return p; +} + +template ::type> +inline U one_m_square(T x) { + return (U)(1 - x) * (1 + x); +} + +template ::type> +inline U x_m_square(T x) { + return (U)(1 - x) * x; +} + +/* activation */ +template ::type> +inline U relu_fwd(T s, A alpha) { + return s > 0 ? s : (U)(s * alpha); +} +template ::type> +inline U relu_bwd(T dd, T s, A alpha) { + return s > 0 ? dd : (U)(dd * alpha); +} + +template ::type> +inline U tanh_fwd(T s) { + const float e = tanhf((float) s); + return (U)e; +} + +template ::type> +inline U tanh_bwd(T dd, T s) { + const float e = tanh_fwd((float) s); + return (U)(dd * (1 - e) * (1 + e)); +} + +template ::type> +inline U elu_fwd(T s, A alpha) { + return s > 0 ? s : (U)(alpha * (::expm1f((float)s))); +} +template ::type> + inline U elu_bwd(T dd, T s, A alpha) { + return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s))); +} + +template ::type> +inline U square_fwd(T s) { + return s * s; +} + +template ::type> +inline U square_bwd(T dd, T s) { + return dd * 2 * s; +} + +template ::type> +inline U abs_fwd(T s) { + return s > 0 ? s : -s; +} + +template ::type> +inline U abs_bwd(T dd, T s) { + return s > 0 ? dd : s < 0 ? -dd : 0; +} + +template ::type> +inline U sqrt_fwd(T s) { + return s > 0 ? (U)(::sqrtf((float)(s))) : 0; +} + +template ::type> +inline U sqrt_bwd(T dd, T s) { + return s > 0 + ? (U)(dd / (2 * ::sqrtf((float)(s)))) + : 0; +} + +template ::type> +inline U linear_fwd(T s, A alpha, A beta) { + return (U)(alpha * s + beta); +} + +template ::type> +inline U linear_bwd(T dd, T s, A alpha, A beta) { + (void) s; + (void) beta; + return (U)(dd * alpha); +} + +template ::type> +inline U bounded_relu_fwd(T s, A alpha) { + s = s > 0 ? s : 0; + return s > alpha ? (U)(alpha) : s; +} + +template ::type> +inline U bounded_relu_bwd(T dd, T s, A alpha) { + return dd * (0 < s && s < alpha ? 1 : 0); +} + +template ::type> +inline U soft_relu_fwd(T s) { + float max_logf = 8.872284e+01; //::logf(FLT_MAX) + return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s; +} + +template ::type> +inline U soft_relu_bwd(T dd, T s) { + return (U)(dd / (1 + ::expf((float)(-s)))); +} + +template ::type> +inline U logistic_fwd(T s) { + U v = (U)(::expf((float) -s)); + return 1 / (1 + v); +} + +template ::type> +inline U logistic_bwd(T dd, T s) { + U v = logistic_fwd(s); + return dd * v * (1 - v); +} + +inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) { + using namespace alg_kind; + using namespace utils; + const bool preserves_zero = true + && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic) + && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh)); + return preserves_zero; +} + +inline float get_bias(const char *bias, size_t offset, data_type_t data_type) +{ + if (!bias) + return 0.0f; + +#define CASE(dt) \ + case dt: return (float)((const prec_traits
::type *)bias)[offset] + + switch (data_type) { + CASE(data_type::s8); + CASE(data_type::u8); + CASE(data_type::s32); + CASE(data_type::f32); + default: assert(!"unimplemented"); + } + return 0; // never happens (should probably be a NaN) +#undef CASE +} + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp new file mode 100644 index 000000000000..cea849c96e1f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp @@ -0,0 +1,238 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::data_type; + +namespace { +bool memory_desc_sanity_check(int ndims,const dims_t dims, + data_type_t data_type, format_kind_t format_kind) { + if (ndims == 0) return true; + + bool ok = true + && dims != nullptr + && 0 < ndims && ndims <= MKLDNN_MAX_NDIMS + && one_of(data_type, f32, s32, s8, u8) + && format_kind != format_kind::undef; + if (!ok) return false; + for (int d = 0; d < ndims; ++d) + if (dims[d] < 0) return false; + + return true; +} + +bool memory_desc_sanity_check(const memory_desc_t *md) { + if (md == nullptr) return false; + return memory_desc_sanity_check(md->ndims, md->dims, md->data_type, + format_kind::any); +} +} + +status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims, + const dims_t dims, data_type_t data_type, format_tag_t tag) { + if (any_null(memory_desc)) return invalid_arguments; + if (ndims == 0 || tag == format_tag::undef) { + *memory_desc = types::zero_md(); + return success; + } + + format_kind_t format_kind = types::format_tag_to_kind(tag); + + /* memory_desc != 0 */ + bool args_ok = !any_null(memory_desc) + && memory_desc_sanity_check(ndims, dims, data_type, format_kind); + if (!args_ok) return invalid_arguments; + + auto md = memory_desc_t(); + md.ndims = ndims; + array_copy(md.dims, dims, ndims); + md.data_type = data_type; + array_copy(md.padded_dims, dims, ndims); + md.format_kind = format_kind; + + status_t status = success; + if (tag == format_tag::undef) { + status = invalid_arguments; + } else if (tag == format_tag::any) { + // nop + } else if (format_kind == format_kind::blocked) { + status = memory_desc_wrapper::compute_blocking(md, tag); + } else { + assert(!"unreachable"); + status = invalid_arguments; + } + + if (status == success) + *memory_desc = md; + + return status; +} + +status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc, + int ndims, const dims_t dims, data_type_t data_type, + const dims_t strides) { + if (any_null(memory_desc)) return invalid_arguments; + if (ndims == 0) { + *memory_desc = types::zero_md(); + return success; + } + + /* memory_desc != 0 */ + bool args_ok = !any_null(memory_desc) + && memory_desc_sanity_check(ndims, dims, data_type, format_kind::any); + if (!args_ok) return invalid_arguments; + + auto md = memory_desc_t(); + md.ndims = ndims; + array_copy(md.dims, dims, ndims); + md.data_type = data_type; + array_copy(md.padded_dims, dims, ndims); + md.format_kind = format_kind::blocked; + + dims_t default_strides = {0}; + if (strides == nullptr) { + default_strides[md.ndims - 1] = 1; + for (int d = md.ndims - 2; d >= 0; --d) + default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1]; + strides = default_strides; + } else { + /* TODO: add sanity check for the provided strides */ + } + + array_copy(md.format_desc.blocking.strides, strides, md.ndims); + + *memory_desc = md; + + return status::success; +} + +status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md, + const memory_desc_t *parent_md, const dims_t dims, + const dims_t offsets) { + if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md)) + return invalid_arguments; + + const memory_desc_wrapper src_d(parent_md); + + for (int d = 0; d < src_d.ndims(); ++d) { + if (dims[d] < 0 || offsets[d] < 0 + || (offsets[d] + dims[d] > src_d.dims()[d])) + return invalid_arguments; + } + + if (src_d.format_kind() != format_kind::blocked) + return unimplemented; + + dims_t blocks; + src_d.compute_blocks(blocks); + + memory_desc_t dst_d = *parent_md; + auto &dst_d_blk = dst_d.format_desc.blocking; + + /* TODO: put this into memory_desc_wrapper */ + for (int d = 0; d < src_d.ndims(); ++d) { + /* very limited functionality for now */ + const bool ok = true + && offsets[d] % blocks[d] == 0 /* [r1] */ + && src_d.padded_offsets()[d] == 0 + && (false + || dims[d] % blocks[d] == 0 + || dims[d] < blocks[d]); + if (!ok) + return unimplemented; + + const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d]; + + dst_d.dims[d] = dims[d]; + dst_d.padded_dims[d] = is_right_border + ? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d]; + dst_d.padded_offsets[d] = src_d.padded_offsets()[d]; + dst_d.offset0 += /* [r1] */ + offsets[d] / blocks[d] * dst_d_blk.strides[d]; + } + + *md = dst_d; + + return success; +} + +int mkldnn_memory_desc_equal(const memory_desc_t *lhs, + const memory_desc_t *rhs) { + if (lhs == rhs) return 1; + if (any_null(lhs, rhs)) return 0; + return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs); +} + +size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) { + if (md == nullptr) return 0; + return memory_desc_wrapper(*md).size(); +} + +status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md, + engine_t *engine, void *handle) { + if (any_null(memory, engine)) return invalid_arguments; + memory_desc_t z_md = types::zero_md(); + return engine->memory_create(memory, md ? md : &z_md, handle); +} + +status_t mkldnn_memory_get_memory_desc(const memory_t *memory, + const memory_desc_t **md) { + if (any_null(memory, md)) return invalid_arguments; + *md = memory->md(); + return success; +} + +status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) { + if (any_null(memory, engine)) return invalid_arguments; + *engine = memory->engine(); + return success; +} + +status_t mkldnn_memory_get_data_handle(const memory_t *memory, + void **handle) { + if (any_null(handle)) + return invalid_arguments; + if (memory == nullptr) { + *handle = nullptr; + return success; + } + return memory->get_data_handle(handle); +} + +status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) { + if (any_null(memory)) return invalid_arguments; + return memory->set_data_handle(handle); +} + +status_t mkldnn_memory_destroy(memory_t *memory) { + delete memory; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp new file mode 100644 index 000000000000..03dfee01ff72 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp @@ -0,0 +1,63 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_HPP +#define MEMORY_HPP + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" + +struct mkldnn_memory: public mkldnn::impl::c_compatible { + mkldnn_memory(mkldnn::impl::engine_t *engine, + const mkldnn::impl::memory_desc_t *md) + : engine_(engine), md_(*md) {} + virtual ~mkldnn_memory() {} + + /** allocates/initializes memory */ + virtual mkldnn::impl::status_t init() = 0; + + /** returns memory's engine */ + mkldnn::impl::engine_t *engine() const { return engine_; } + /** returns memory's description */ + const mkldnn::impl::memory_desc_t *md() const { return &md_; } + + /** returns data handle */ + virtual mkldnn::impl::status_t get_data_handle(void **handle) const = 0; + + /** sets data handle */ + virtual mkldnn::impl::status_t set_data_handle(void *handle) = 0; + + /** zeros padding */ + virtual mkldnn::impl::status_t zero_pad() const + { return mkldnn::impl::status::success; } + +protected: + mkldnn::impl::engine_t *engine_; + const mkldnn::impl::memory_desc_t md_; + +private: + mkldnn_memory() = delete; + mkldnn_memory(const mkldnn_memory &) = delete; + mkldnn_memory(mkldnn_memory &&) = delete; + mkldnn_memory &operator=(const mkldnn_memory &) = delete; + mkldnn_memory &operator=(mkldnn_memory &&) = delete; +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp new file mode 100644 index 000000000000..8a99be33f3ee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp @@ -0,0 +1,212 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +status_t fill_blocked(memory_desc_t &md, + std::initializer_list perm, + std::initializer_list inner_blks, + std::initializer_list inner_idxs) { + const bool ok = true + && perm.size() == (size_t)md.ndims + && inner_blks.size() == inner_idxs.size(); + if (!ok) return status::invalid_arguments; + + md.offset0 = 0; + + blocking_desc_t &blk = md.format_desc.blocking; + + dim_t block_size = 1; + dims_t blocks = {0}; + utils::array_set(blocks, 1, md.ndims); + + blk.inner_nblks = (int)inner_blks.size(); + + int iblk = 0; + for (const auto &b: inner_idxs) + blk.inner_idxs[iblk++] = b; + + iblk = 0; + for (const auto &b: inner_blks) { + int dim = blk.inner_idxs[iblk]; + block_size *= b; + blocks[dim] *= b; + blk.inner_blks[iblk++] = b; + } + + utils::array_set(md.padded_offsets, 0, md.ndims); + for (int d = 0; d < md.ndims; ++d) + md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); + + dim_t stride = block_size; + // if only we use C++14, the initializer_list would have rbegin()/rend()... + for (int d = 0; d < md.ndims; ++d) + stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d]; + + for (const auto &d: perm) { + if (md.padded_dims[d] == 0) { + blk.strides[d] = 1; + continue; + } + stride /= md.padded_dims[d] / blocks[d]; + blk.strides[d] = stride; + } + + assert(stride == block_size); + + return status::success; +} + +status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc, + format_tag_t tag) +{ + using namespace format_tag; + + if (memory_desc.ndims == 0) return status::invalid_arguments; + +# define C(tag, ... /* perm, inner_blks, inner_idxs */) \ + case tag: return fill_blocked(memory_desc, __VA_ARGS__) + + switch (tag) { + C(a, {0}, {}, {}); + C(ab, {0, 1}, {}, {}); + C(abc, {0, 1, 2}, {}, {}); + C(abcd, {0, 1, 2, 3}, {}, {}); + C(abcde, {0, 1, 2, 3, 4}, {}, {}); + C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {}); + C(abdec, {0, 1, 3, 4, 2}, {}, {}); + C(acb, {0, 2, 1}, {}, {}); + C(acbde, {0, 2, 1, 3, 4}, {}, {}); + C(acdb, {0, 2, 3, 1}, {}, {}); + C(acdeb, {0, 2, 3, 4, 1}, {}, {}); + C(ba, {1, 0}, {}, {}); + C(bac, {1, 0, 2}, {}, {}); + C(bacd, {1, 0, 2, 3}, {}, {}); + C(bcda, {1, 2, 3, 0}, {}, {}); + C(cba, {2, 1, 0}, {}, {}); + C(cdba, {2, 3, 1, 0}, {}, {}); + C(cdeba, {2, 3, 4, 1, 0}, {}, {}); + C(decab, {3, 4, 2, 0, 1}, {}, {}); + + C(Abc4a, {0, 1, 2}, {4}, {0}); + C(aBc4b, {0, 1, 2}, {4}, {1}); + C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1}); + C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0}); + C(Abcd4a, {0, 1, 2, 3}, {4}, {0}); + C(aBcd4b, {0, 1, 2, 3}, {4}, {1}); + C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0}); + C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2}); + C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); + C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0}); + C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1}); + C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0}); + C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); + C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1}); + C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1}); + C(aBdc4b, {0, 1, 3, 2}, {4}, {1}); + C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1}); + C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1}); + C(Acb4a, {0, 2, 1}, {4}, {0}); + C(Acdb4a, {0, 2, 3, 1}, {4}, {0}); + C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0}); + + C(Abc16a, {0, 1, 2}, {16}, {0}); + C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1}); + C(aBc16b, {0, 1, 2}, {16}, {1}); + C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0}); + C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0}); + C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1}); + C(aBc8b, {0, 1, 2}, {8}, {1}); + C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1}); + C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0}); + C(Abcd16a, {0, 1, 2, 3}, {16}, {0}); + C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1}); + C(aBcd16b, {0, 1, 2, 3}, {16}, {1}); + C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0}); + C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2}); + C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1}); + C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1}); + C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0}); + C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1}); + C(aBcd8b, {0, 1, 2, 3}, {8}, {1}); + C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1}); + C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1}); + C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0}); + C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2}); + C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2}); + C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1}); + C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0}); + C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1}); + C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1}); + C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0}); + C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2}); + C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1}); + C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2}); + C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2}); + C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2}); + C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0}); + C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1}); + C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1}); + C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1}); + C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1}); + C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0}); + C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2}); + C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2}); + C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1}); + C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1}); + C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2}); + C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1}); + C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2}); + C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2}); + C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1}); + C(aBdc16b, {0, 1, 3, 2}, {16}, {1}); + C(aBdc8b, {0, 1, 3, 2}, {8}, {1}); + C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1}); + C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1}); + C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1}); + C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1}); + C(Acb16a, {0, 2, 1}, {16}, {0}); + C(Acb8a, {0, 2, 1}, {8}, {0}); + C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2}); + C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2}); + C(Acdb16a, {0, 2, 3, 1}, {16}, {0}); + C(Acdb8a, {0, 2, 3, 1}, {8}, {0}); + C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0}); + C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0}); + C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1}); + C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1}); + default: break; + } + +#undef C + + return status::invalid_arguments; +} + +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp new file mode 100644 index 000000000000..1758f9078a40 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp @@ -0,0 +1,400 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_DESC_WRAPPER_HPP +#define MEMORY_DESC_WRAPPER_HPP + +#include + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +/** thin wrapper class over \struct memory_desc_t which allows easy + * manipulations with underlying C structure, which is taken by reference */ +struct memory_desc_wrapper: public c_compatible { + const memory_desc_t *md_; + + /** constructor which takes a reference to a constant underlying C memory + * descriptor \param md */ + memory_desc_wrapper(const memory_desc_t *md): md_(md) {} + memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {} + + /* implementing attributes */ + int ndims() const { return md_->ndims; } + const dims_t &dims() const { return md_->dims; } + data_type_t data_type() const { return md_->data_type; } + + const dims_t &padded_dims() const { return md_->padded_dims; } + const dims_t &padded_offsets() const { return md_->padded_offsets; } + dim_t offset0() const { return md_->offset0; } + + format_kind_t format_kind() const { return md_->format_kind; } + + bool is_blocking_desc() const + { return format_kind() == format_kind::blocked; } + bool is_wino_desc() const + { return format_kind() == format_kind::wino; } + bool is_rnn_packed_desc() const + { return format_kind() == format_kind::rnn_packed; } + + const blocking_desc_t &blocking_desc() const { + assert(is_blocking_desc()); + return md_->format_desc.blocking; + } + const wino_desc_t &wino_desc() const { + assert(is_wino_desc()); + return md_->format_desc.wino_desc; + } + const rnn_packed_desc_t &rnn_packed_desc() const { + assert(is_rnn_packed_desc()); + return md_->format_desc.rnn_packed_desc; + } + + const memory_extra_desc_t &extra() const { return md_->extra; } + + /* some useful function */ + + /** returns the number of elements including padding if \param with_padding + * is true, and the number of data elements otherwise */ + dim_t nelems(bool with_padding = false) const { + if (is_zero()) return 0; + return utils::array_product( + with_padding ? padded_dims() : dims(), ndims()); + } + + /** returns true if memory descriptor is zero */ + bool is_zero() const { return ndims() == 0; } + + /** returns true if memory descriptor contains zero as one of its dim */ + bool has_zero_dim() const { return nelems() == 0; } + + /** return the size of data type (a shortcut) */ + size_t data_type_size() const + { return types::data_type_size(data_type()); } + + /** return the size of data type of additional buffer */ + size_t additional_buffer_data_size() const { + if (extra().flags & memory_extra_flags::compensation_conv_s8s8) + return sizeof(int32_t); + return 0; + } + + /** return true if memory format has additional buffer */ + bool is_additional_buffer() const { + return (extra().flags & memory_extra_flags::compensation_conv_s8s8); + } + + /** returns the size of additional buffer */ + size_t additional_buffer_size() const { + if (extra().flags & memory_extra_flags::compensation_conv_s8s8) { + int cmask = extra().compensation_mask; + assert(cmask == 1 || cmask == 3); + dim_t prod = 1; + for (int d = 0; d < ndims(); ++d) + if (cmask & (1<(max_size, + padded_dims()[d] / blocks[d] * bd.strides[d]); + + if (max_size == 1 && bd.inner_nblks != 0) { + max_size = utils::array_product(bd.inner_blks, bd.inner_nblks); + } + + return max_size * data_type_size() + additional_buffer_size(); + } + } + + /** returns true if data is dense in memory */ + bool is_dense(bool with_padding = false) const { + if (utils::one_of(format_kind(), format_kind::undef, format_kind::any)) + return false; + return nelems(with_padding) * data_type_size() == size(); + } + + /** returns true if memory desc is fully defined */ + bool is_defined() const { return format_kind() != format_kind::any; } + + /** returns true if the only (potentially) padded dim is \param dim */ + bool only_padded_dim(int dim) const { + for (int d = 0; d < ndims(); ++d) + if (d != dim && dims()[d] != padded_dims()[d]) + return false; + return true; + } + + /** returns true if memory desc has blocked layout and block dims are 1s */ + bool is_plain() const { + if (!is_blocking_desc()) return false; + return blocking_desc().inner_nblks == 0; + } + + /** returns overall block sizes */ + void compute_blocks(dims_t blocks) const { + if (!is_blocking_desc()) { + utils::array_set(blocks, 0, ndims()); + return; + } + + utils::array_set(blocks, 1, ndims()); + + const auto &bd = blocking_desc(); + for (int iblk = 0; iblk < bd.inner_nblks; ++iblk) + blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk]; + } + + /* comparison section */ + + bool operator==(const memory_desc_wrapper &rhs) const + { return *this->md_ == *rhs.md_; } + bool operator!=(const memory_desc_wrapper &rhs) const + { return !operator==(rhs); } + bool operator==(const memory_desc_t &rhs) const + { return operator==(memory_desc_wrapper(rhs)); } + bool operator!=(const memory_desc_t &rhs) const + { return !operator==(rhs); } + + /** returns true if data (w/o padding if with_padding == false and w/ + * padding otherwise) have the same physical structure, i.e. dimensions, + * strides, and blocked structure. Depending on with_data_type flag + * data_type is taken or not taken into account. dim_start allows to check + * similarity for the logical part of data [dim_start .. ndims()]. + * CAUTION: format kind any and undef are not similar to whatever, hence the + * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */ + /* TODO: revise */ + bool similar_to(const memory_desc_wrapper &rhs, + bool with_padding = true, bool with_data_type = true, + int dim_start = 0) const; + + /** returns true if one memory can be reordered to another */ + bool consistent_with(const memory_desc_wrapper &rhs) const; + + /** returns true if the memory desc corresponds to the given format tag and + * strides. + * @sa memory_desc_matches_tag */ + bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const { + return memory_desc_matches_tag(*md_, tag, strides); + } + + /** returns matching tag (or undef if match is not found) + * XXX: This is a workaround that eventually should go away! */ + template + format_tag_t matches_one_of_tag(Tags ...tags) const { + for (const auto tag: {tags...}) { + if (memory_desc_matches_tag(*md_, tag)) + return tag; + } + return format_tag::undef; + } + + /* offset section */ + + /** returns physical offset by logical one. logical offset is represented by + * an array \param pos. if \param is_pos_padded is true \param pos + * represents the position in already padded area */ + dim_t off_v(const dims_t pos, bool is_pos_padded = false) const { + assert(is_blocking_desc()); + const blocking_desc_t &blk = blocking_desc(); + + dims_t pos_copy = {0}; + for (int d = 0; d < ndims(); ++d) + pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]); + + dim_t phys_offset = offset0(); + + if (blk.inner_nblks > 0) { + dim_t blk_stride = 1; + for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) { + const int d = blk.inner_idxs[iblk]; + const dim_t p = pos_copy[d] % blk.inner_blks[iblk]; + + phys_offset += p * blk_stride; + + pos_copy[d] /= blk.inner_blks[iblk]; + + blk_stride *= blk.inner_blks[iblk]; + } + } + + for (int d = 0; d < ndims(); ++d) { + const dim_t p = pos_copy[d]; + phys_offset += p * blk.strides[d]; + } + + return phys_offset; + } + + /** returns physical offset by logical one. logical offset is represented by + * a scalar \param l_offset. if \param is_pos_padded is true, \param + * l_offset represents logical offset in already padded area */ + dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const { + assert(is_blocking_desc()); + dims_t pos; + for (int rd = 0; rd < ndims(); ++rd) { + const int d = ndims() - 1 - rd; + const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d]; + pos[d] = l_offset % cur_dim; + l_offset /= cur_dim; + } + return off_v(pos, is_pos_padded); + } + + /** returns physical offset by logical one. logical offset is represented by + * a tuple of indices (\param xn, ..., \param x1, \param x0) */ + template + dim_t off(Args... args) const { + assert(sizeof...(args) == ndims()); + dims_t pos = { args... }; + return off_v(pos, false); + } + + /** returns physical offset by logical one. logical offset is represented by + * a tuple of indices (\param xn, ..., \param x1, \param x0) in already + * padded area */ + template + dim_t off_padding(Args... args) const { + assert(sizeof...(args) == ndims()); + dims_t pos = { args... }; + return off_v(pos, true); + } + + /** returns physical offset by logical one. Logical offset is represented by + * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a + * user responsibility to adjust the result to get offset within blocks */ + template + dim_t blk_off(Args... args) const { + return _blk_off(args...); + } + + template + dim_t blk_off(T xn, Args... args) const { + return skip_first + ? blk_off(args...) + : blk_off(xn, args...); + } + + /* static functions section */ + /* TODO: replace with non-static, once md_ becomes non-const ref */ + + static status_t compute_blocking(memory_desc_t &memory_desc, + format_tag_t tag); + +private: + /* TODO: put logical_offset in utils */ + template + dim_t logical_offset(T x0) const { return x0; } + + template + dim_t logical_offset(T xn, Args... args) const { + const size_t n_args = sizeof...(args); + return xn * utils::array_product( + &dims()[ndims() - n_args]) + logical_offset(args...); + } + + template + dim_t _blk_off() const { return offset0(); } + + template + dim_t _blk_off(T xc, Args ...args) const { + assert(is_blocking_desc()); + constexpr int dc = ORIG_LEN - sizeof...(args) - 1; + return xc * blocking_desc().strides[dc] + + _blk_off(args...); + } +}; + +inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs, + bool with_padding, bool with_data_type, int dim_start) const { + using namespace utils; + + if (one_of(format_kind(), format_kind::undef, format_kind::any)) + return false; + if (is_wino_desc() || is_rnn_packed_desc()) + return false; + + const int ds = dim_start; + const auto &blk = blocking_desc(); + const auto &r_blk = rhs.blocking_desc(); + + return ndims() == rhs.ndims() + && dim_start <= ndims() /* guard */ + && format_kind() == rhs.format_kind() + && IMPLICATION(with_data_type, data_type() == rhs.data_type()) + && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds) + && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds) + && blk.inner_nblks == r_blk.inner_nblks + && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks) + && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks) + && IMPLICATION(with_padding, true + && array_cmp(padded_dims() + ds, rhs.padded_dims() + ds, + ndims() - ds) + && array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds, + ndims() - ds)); +} + +inline bool memory_desc_wrapper::consistent_with( + const memory_desc_wrapper &rhs) const { + if (ndims() == rhs.ndims()) { + for (int d = 0; d < ndims(); ++d) { + if (dims()[d] != rhs.dims()[d]) return false; + } + return true; + } else { + /* TODO: revise. + * is the following possible? + * [1, a, b] <--reorder--> [a, b] + * [a, 1, b] <--reorder--> [a, b] + * not, at least for now */ + return false; + } +} + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp new file mode 100644 index 000000000000..ec077b308c65 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp @@ -0,0 +1,295 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_TRACKING_HPP +#define MEMORY_TRACKING_HPP + +#include +#include + +#include "nstl.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace memory_tracking { + +/* Memory tracking capabilities + * + * The main purpose of this header file is to provide uniform way to register + * required memory for a scratchpad at a primitive descriptor creation time + * and then easily access it having only the base address of the scratchpad. + * + * Primitives might contain multiple disjoint parts that require temporary + * buffers (known as scratchpad) during their execution. A primitive descriptor + * should summarize all the needs into one single number -- the buffer size + * that would be requested from a user. At execution time, the corresponding + * primitive will receive a base pointer to a scratchpad. It then needs to + * provide each part of algorithm the corresponding piece of memory. Three main + * challenges here are: + * 1. Track correct offset (from the base scratchpad address) for each piece + * 2. Algorithm might require that different memory pieces to be aligned, so + * the scratchpad size is no more just a sum of size of the corresponding + * subparts. + * 3. While a primitive is responsible for its scratchpad, the implementation + * might use some other basic blocks (e.g. cpu_reducer) that also require + * scratchpad memory. So there should be a simple way of passing the + * information back and force between the main algorithm (a primitive) and + * auxiliary stuff that lives completely separately from it (e.g. reducer). + * + * To address these challenges this header file provides 3 structures: + * 1. registry_t -- the class the stores the information about requested + * memory. The information includes required size and desired + * alignment for each piece. This class is also responsible + * for computing the right offset to a given piece using the + * base pointer. + * This class is basically a ledger with all entries. + * Lives in primitive descriptors. + * + * 2. registrar_t -- the interface to a registry_t to book memory. Used at + * primitive descriptor creation time only. Contains a + * reference to the corresponding *mutable* registry. + * Always modifiable. + * Allows chaining (using prefixes). + * + * 3. grantor_t -- the interface to a registry_t to access memory. Used at + * primitive execution time only. Contains a reference to + * the corresponding *constant* registry and base pointer. + * Always constant. + * Allows chaining (using prefixes). + * + * Both registrar_t and grantor_t allow chaining with extra prefix provided. + * The feature is useful when a primitive offload a part of computations to + * some other primitives which require their own scratchpad space + * (e.g. reducer). Prefixes are used to avoid key collision in cases when + * multiple sub-primitive (e.g. multiple reducers) are used. + * + * A short example below demonstrates how to use aforementioned classes. In it + * the main primitive is convolution that uses scratchpad for keeping padded + * bias. It also needs a reducer, that needs its own space as well. + * + * ``` c++ + * struct reducer_t { + * static void init(registrar_t &scratchpad) { + * // preserve space for the reduction (one page aligned) + * scratchpad.book(key_space, sizeof(float) * 980 * 1024, 4096); + * } + * + * void exec(const grantor_t &scratchpad) { + * // get the pointer to preserved space. scratchpad came from + * // upper primitive (convolution in this example) + * auto space = scratchpad.get(key_reducer_space); + * + * space[:] += ...; + * } + * }; + * + * struct conv_t { + * struct pd_t { + * void init() { + * registrar_t scratchpad(scratchpad_registry_); + * + * // preserve a space for padded bias (using default alignment) + * scratchpad.book(key_conv_padded_bias, 128); + * + * // create a proxy registrar for the reducer All entries made + * // by reducer would live in convolution's registry, but would + * // have their own `prefix`, so no interference with conv's + * // buffers. + * registrar_t reducer_scratchpad(scratchpad, prefix_reducer); + * + * reducer_t::init(reducer_scratchpad); + * } + * + * registry_t scratchpad_registry_; + * } + * + * void exec() { + * // get the base pointer to a scratchpad memory from a user + * void *scratchpad_ptr = this->input(MKLDNN_MEM_SCRATCHPAD); + * + * // create a grantor to the scratchpad (and provide the base + * // pointer). + * grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr); + * + * // access the padded_bias (need only key name and the grantor) + * auto padded_bias = scratchpad.get(key_conv_padded_bias); + * + * // to give the `right` grantor to reducer we need to add the + * // corresponding prefix, so that reducer would be able to access + * // its keys. The call is very similar to the one in pd_t::init + * // with only difference in types: grantor_t vs registrar_t. + * grantor_t reducer_scratchpad(scratchpad, prefix_reducer); + * reducer->exec(reducer_scratchpad); + * } + * }; + * ``` + */ + + +/* namespace with common keys and prefixes */ +namespace names { +enum { + key_none = 0, + key_bnorm_tmp_mean, + key_bnorm_tmp_var, + key_bnorm_tmp_diff_ss, + key_bnorm_tmp_stats, + key_bnorm_reduction, + key_concat_iptrs, + key_concat_istrides, + key_concat_nelems, + key_concat_optrs, + key_conv_adjusted_scales, + key_conv_bia_reduction, + key_conv_gemm_col, + key_conv_gemm_imtr, + key_conv_int_dat_in_acc_dt, + key_conv_padded_bias, + key_conv_rtus_space, + key_conv_tr_diff_dst, + key_conv_tr_diff_dst_bctx, + key_conv_tr_src, + key_conv_tr_src_bctx, + key_conv_wei_reduction, + key_conv_wei_bia_reduction, + key_conv_wei_bia_reduction_bctx, + key_iprod_int_dat_in_acc_dt, + key_reducer_space, + key_reducer_space_bctx, + key_reorder_wino_plain, + key_reorder_wino_transform_space, + key_reorder_rnn_weights_quantization, + key_reorder_rnn_weights_reduction, + key_rnn_space, + key_rnn_ptrs_bia, + key_rnn_ptrs_wei_layer, + key_rnn_ptrs_wei_iter, + key_softmax_reduction, + key_wino_U, + key_wino_V, + key_wino_M, + key_barrier, +}; + +enum { + prefix_none = 0, + prefix_reducer_bia, + prefix_reducer_wei, +}; +} + +// level 0: 00 00 00 xxx +// level 1: 00 00 aa xxx +// level 2: 00 aa bb xxx +// level 3: aa bb cc xxx +// max # of levels: 3 + 1 (base_level) +// here: +// xxx : [1 .. MAX_KEY) : key +// aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3 + +using key_t = uint32_t; +enum { MAX_KEY = (1u << 10), MAX_PREFIX = (1u << 7), }; + +/// generates global key based on a prefix and a local key +inline key_t make_key(key_t prefix, key_t key) { return prefix + key; } + +/// generates global prefix based on the global parent and the local ones +inline key_t make_prefix(key_t parent_prefix, key_t prefix) +{ return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; } + +struct registrar_t; +struct grantor_t; + +struct registry_t { + void book(const key_t &key, size_t size, size_t alignment) { + if (size == 0) return; + assert(offset_map_.count(key) == 0); + + size = utils::rnd_up(size, minimal_alignment); + alignment = nstl::max(alignment, minimal_alignment); + offset_map_[key] = entry_t{size_, size, alignment}; + + size_ += size + alignment - minimal_alignment; + } + + void *get(const key_t &key, void *base_ptr) const { + if (base_ptr == nullptr) { assert(size() == 0); return nullptr; } + if (offset_map_.count(key) != 1) return nullptr; + + const auto &e = offset_map_.at(key); + base_ptr = utils::align_ptr(base_ptr, minimal_alignment); + char *ptr = (char *)base_ptr + e.offset; + return utils::align_ptr(ptr, e.alignment); + } + + size_t size() const + { return size_ > 0 ? size_ + minimal_alignment - 1 : 0; } + + registrar_t registrar(); + grantor_t grantor(void *base_ptr) const; + +protected: + enum { minimal_alignment = 64 }; + struct entry_t { size_t offset, size, alignment; }; + + std::unordered_map offset_map_; + size_t size_ = 0; +}; + +struct registrar_t { + enum { default_alignment = 64 }; + + registrar_t(registry_t ®istry): registry_(registry), prefix_(0) {} + registrar_t(registrar_t &parent, const key_t &prefix) + : registry_(parent.registry_) + , prefix_(make_prefix(parent.prefix_, prefix)) {} + + void book(const key_t &key, size_t size, + size_t alignment = default_alignment) + { registry_.book(make_key(prefix_, key), size, alignment); } + +protected: + registry_t ®istry_; + const key_t prefix_; +}; + +struct grantor_t { + grantor_t(const registry_t ®istry, void *base_ptr) + : registry_(registry), prefix_(0), base_ptr_(base_ptr) {} + grantor_t(const grantor_t &parent, const key_t &prefix) + : registry_(parent.registry_) + , prefix_(make_prefix(parent.prefix_, prefix)) + , base_ptr_(parent.base_ptr_) {} + + template T *get(const key_t &key) const + { return (T *)registry_.get(make_key(prefix_, key), base_ptr_); } + +protected: + const registry_t ®istry_; + const key_t prefix_; + void *base_ptr_; +}; + +inline registrar_t registry_t::registrar() { return registrar_t(*this); } +inline grantor_t registry_t::grantor(void *base_ptr) const +{ return grantor_t(*this, base_ptr); } + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp new file mode 100644 index 000000000000..2ef4a8fddca9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp @@ -0,0 +1,131 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include + +#include "mkldnn_debug.h" +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#define DPRINT(...) do { \ + int l = snprintf(str + written_len, str_len, __VA_ARGS__); \ + if (l < 0) return l; \ + if ((size_t)l >= str_len) return -1; \ + written_len += l; str_len -= l; \ +} while(0) + +int mkldnn_md2fmt_str(char *str, size_t str_len, + const mkldnn_memory_desc_t *mdesc) { + using namespace mkldnn::impl; + + if (str == nullptr || str_len <= 1u) + return -1; + + int written_len = 0; + + if (mdesc == nullptr) { + DPRINT("%s::%s::", + mkldnn_dt2str(data_type::undef), + mkldnn_fmt_kind2str(format_kind::undef)); + return written_len; + } + + memory_desc_wrapper md(mdesc); + + DPRINT("%s:", mkldnn_dt2str(md.data_type())); + + bool padded_dims = false, padded_offsets = false; + for (int d = 0; d < md.ndims(); ++d) { + if (md.dims()[d] != md.padded_dims()[d]) padded_dims = true; + if (md.padded_offsets()[d] != 0) padded_offsets = true; + } + bool offset0 = md.offset0(); + DPRINT("%s%s%s:", + padded_dims ? "p" : "", + padded_offsets ? "o" : "", + offset0 ? "0" : ""); + + DPRINT("%s:", mkldnn_fmt_kind2str(md.format_kind())); + + if (!md.is_blocking_desc()) { + /* TODO: extend */ + DPRINT("%s:", ""); + } else { + const auto &blk = md.blocking_desc(); + + dims_t blocks; + md.compute_blocks(blocks); + + char dim_chars[MKLDNN_MAX_NDIMS + 1]; + + bool plain = true; + for (int d = 0; d < md.ndims(); ++d) { + dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d; + if (blocks[d] != 1) plain = false; + } + + dims_t strides; + utils::array_copy(strides, blk.strides, md.ndims()); + utils::simultaneous_sort(strides, dim_chars, md.ndims(), + [](dim_t a, dim_t b) { return b - a; }); + + dim_chars[md.ndims()] = '\0'; + DPRINT("%s", dim_chars); + + if (!plain) { + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { + DPRINT("%d%c", (int)blk.inner_blks[iblk], + 'a' + (char)blk.inner_idxs[iblk]); + } + } + + DPRINT("%s", ":"); + } + + DPRINT("f%lx", (long)md.extra().flags); + + return written_len; +} + +int mkldnn_md2dim_str(char *str, size_t str_len, + const mkldnn_memory_desc_t *mdesc) { + using namespace mkldnn::impl; + + if (str == nullptr || str_len <= 1) + return -1; + + int written_len = 0; + + if (mdesc == nullptr || mdesc->ndims == 0) { + DPRINT("%s", ""); + return written_len; + } + + memory_desc_wrapper md(mdesc); + + for (int d = 0; d < md.ndims() - 1; ++d) + DPRINT("%" PRId64 "x", md.dims()[d]); + DPRINT("%" PRId64, md.dims()[md.ndims() - 1]); + + return written_len; +} + +#undef DPRINT diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp new file mode 100644 index 000000000000..16a8f7ea5e66 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp @@ -0,0 +1,365 @@ +/******************************************************************************* +* Copyright 2018-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* DO NOT EDIT, AUTO-GENERATED */ + +#include + +#include "mkldnn_debug.h" +#include "mkldnn_types.h" + +const char *mkldnn_status2str(mkldnn_status_t v) { + if (v == mkldnn_success) return "success"; + if (v == mkldnn_out_of_memory) return "out_of_memory"; + if (v == mkldnn_try_again) return "try_again"; + if (v == mkldnn_invalid_arguments) return "invalid_arguments"; + if (v == mkldnn_not_ready) return "not_ready"; + if (v == mkldnn_unimplemented) return "unimplemented"; + if (v == mkldnn_iterator_ends) return "iterator_ends"; + if (v == mkldnn_runtime_error) return "runtime_error"; + if (v == mkldnn_not_required) return "not_required"; + assert(!"unknown status"); + return "unknown status"; +} + +const char *mkldnn_dt2str(mkldnn_data_type_t v) { + if (v == mkldnn_data_type_undef) return "undef"; + if (v == mkldnn_f32) return "f32"; + if (v == mkldnn_s32) return "s32"; + if (v == mkldnn_s8) return "s8"; + if (v == mkldnn_u8) return "u8"; + assert(!"unknown dt"); + return "unknown dt"; +} + +const char *mkldnn_fmt_kind2str(mkldnn_format_kind_t v) { + if (v == mkldnn_format_kind_undef) return "undef"; + if (v == mkldnn_format_kind_any) return "any"; + if (v == mkldnn_blocked) return "blocked"; + if (v == mkldnn_format_kind_wino) return "wino"; + if (v == mkldnn_format_kind_rnn_packed) return "rnn_packed"; + assert(!"unknown fmt_kind"); + return "unknown fmt_kind"; +} + +const char *mkldnn_fmt_tag2str(mkldnn_format_tag_t v) { + if (v == mkldnn_format_tag_undef) return "undef"; + if (v == mkldnn_format_tag_any) return "format_tag_any"; + if (v == mkldnn_a) return "a"; + if (v == mkldnn_ab) return "ab"; + if (v == mkldnn_abc) return "abc"; + if (v == mkldnn_abcd) return "abcd"; + if (v == mkldnn_abcde) return "abcde"; + if (v == mkldnn_abcdef) return "abcdef"; + if (v == mkldnn_abdec) return "abdec"; + if (v == mkldnn_acb) return "acb"; + if (v == mkldnn_acbde) return "acbde"; + if (v == mkldnn_acdb) return "acdb"; + if (v == mkldnn_acdeb) return "acdeb"; + if (v == mkldnn_ba) return "ba"; + if (v == mkldnn_bac) return "bac"; + if (v == mkldnn_bacd) return "bacd"; + if (v == mkldnn_bcda) return "bcda"; + if (v == mkldnn_cba) return "cba"; + if (v == mkldnn_cdba) return "cdba"; + if (v == mkldnn_cdeba) return "cdeba"; + if (v == mkldnn_decab) return "decab"; + if (v == mkldnn_Abc16a) return "Abc16a"; + if (v == mkldnn_ABc16a16b) return "ABc16a16b"; + if (v == mkldnn_aBc16b) return "aBc16b"; + if (v == mkldnn_ABc16b16a) return "ABc16b16a"; + if (v == mkldnn_Abc4a) return "Abc4a"; + if (v == mkldnn_aBc4b) return "aBc4b"; + if (v == mkldnn_ABc4b16a4b) return "ABc4b16a4b"; + if (v == mkldnn_ABc4b4a) return "ABc4b4a"; + if (v == mkldnn_ABc8a16b2a) return "ABc8a16b2a"; + if (v == mkldnn_ABc8a8b) return "ABc8a8b"; + if (v == mkldnn_aBc8b) return "aBc8b"; + if (v == mkldnn_ABc8b16a2b) return "ABc8b16a2b"; + if (v == mkldnn_ABc8b8a) return "ABc8b8a"; + if (v == mkldnn_Abcd16a) return "Abcd16a"; + if (v == mkldnn_ABcd16a16b) return "ABcd16a16b"; + if (v == mkldnn_aBcd16b) return "aBcd16b"; + if (v == mkldnn_ABcd16b16a) return "ABcd16b16a"; + if (v == mkldnn_aBCd16b16c) return "aBCd16b16c"; + if (v == mkldnn_aBCd16c16b) return "aBCd16c16b"; + if (v == mkldnn_Abcd4a) return "Abcd4a"; + if (v == mkldnn_aBcd4b) return "aBcd4b"; + if (v == mkldnn_ABcd4b16a4b) return "ABcd4b16a4b"; + if (v == mkldnn_ABcd4b4a) return "ABcd4b4a"; + if (v == mkldnn_aBCd4c16b4c) return "aBCd4c16b4c"; + if (v == mkldnn_aBCd4c4b) return "aBCd4c4b"; + if (v == mkldnn_ABcd8a16b2a) return "ABcd8a16b2a"; + if (v == mkldnn_ABcd8a8b) return "ABcd8a8b"; + if (v == mkldnn_aBcd8b) return "aBcd8b"; + if (v == mkldnn_ABcd8b16a2b) return "ABcd8b16a2b"; + if (v == mkldnn_aBCd8b16c2b) return "aBCd8b16c2b"; + if (v == mkldnn_ABcd8b8a) return "ABcd8b8a"; + if (v == mkldnn_aBCd8b8c) return "aBCd8b8c"; + if (v == mkldnn_aBCd8c16b2c) return "aBCd8c16b2c"; + if (v == mkldnn_aBCd8c8b) return "aBCd8c8b"; + if (v == mkldnn_Abcde16a) return "Abcde16a"; + if (v == mkldnn_ABcde16a16b) return "ABcde16a16b"; + if (v == mkldnn_aBcde16b) return "aBcde16b"; + if (v == mkldnn_ABcde16b16a) return "ABcde16b16a"; + if (v == mkldnn_aBCde16b16c) return "aBCde16b16c"; + if (v == mkldnn_aBCde16c16b) return "aBCde16c16b"; + if (v == mkldnn_aBCde2c8b4c) return "aBCde2c8b4c"; + if (v == mkldnn_Abcde4a) return "Abcde4a"; + if (v == mkldnn_aBcde4b) return "aBcde4b"; + if (v == mkldnn_ABcde4b4a) return "ABcde4b4a"; + if (v == mkldnn_aBCde4b4c) return "aBCde4b4c"; + if (v == mkldnn_aBCde4c16b4c) return "aBCde4c16b4c"; + if (v == mkldnn_aBCde4c4b) return "aBCde4c4b"; + if (v == mkldnn_Abcde8a) return "Abcde8a"; + if (v == mkldnn_ABcde8a8b) return "ABcde8a8b"; + if (v == mkldnn_ABcde8b16a2b) return "ABcde8b16a2b"; + if (v == mkldnn_aBCde8b16c2b) return "aBCde8b16c2b"; + if (v == mkldnn_ABcde8b8a) return "ABcde8b8a"; + if (v == mkldnn_aBCde8b8c) return "aBCde8b8c"; + if (v == mkldnn_aBCde8c16b2c) return "aBCde8c16b2c"; + if (v == mkldnn_aBCde8c8b) return "aBCde8c8b"; + if (v == mkldnn_aBcdef16b) return "aBcdef16b"; + if (v == mkldnn_aBCdef16b16c) return "aBCdef16b16c"; + if (v == mkldnn_aBCdef16c16b) return "aBCdef16c16b"; + if (v == mkldnn_aBcdef4b) return "aBcdef4b"; + if (v == mkldnn_aBCdef4c4b) return "aBCdef4c4b"; + if (v == mkldnn_aBCdef8b8c) return "aBCdef8b8c"; + if (v == mkldnn_aBCdef8c16b2c) return "aBCdef8c16b2c"; + if (v == mkldnn_aBCdef8c8b) return "aBCdef8c8b"; + if (v == mkldnn_aBdc16b) return "aBdc16b"; + if (v == mkldnn_aBdc4b) return "aBdc4b"; + if (v == mkldnn_aBdc8b) return "aBdc8b"; + if (v == mkldnn_aBdec16b) return "aBdec16b"; + if (v == mkldnn_aBdec4b) return "aBdec4b"; + if (v == mkldnn_aBdec8b) return "aBdec8b"; + if (v == mkldnn_aBdefc16b) return "aBdefc16b"; + if (v == mkldnn_aBdefc4b) return "aBdefc4b"; + if (v == mkldnn_aBdefc8b) return "aBdefc8b"; + if (v == mkldnn_Acb16a) return "Acb16a"; + if (v == mkldnn_Acb4a) return "Acb4a"; + if (v == mkldnn_Acb8a) return "Acb8a"; + if (v == mkldnn_aCBd16b16c) return "aCBd16b16c"; + if (v == mkldnn_aCBde16b16c) return "aCBde16b16c"; + if (v == mkldnn_Acdb16a) return "Acdb16a"; + if (v == mkldnn_Acdb4a) return "Acdb4a"; + if (v == mkldnn_Acdb8a) return "Acdb8a"; + if (v == mkldnn_Acdeb16a) return "Acdeb16a"; + if (v == mkldnn_Acdeb4a) return "Acdeb4a"; + if (v == mkldnn_Acdeb8a) return "Acdeb8a"; + if (v == mkldnn_BAc16a16b) return "BAc16a16b"; + if (v == mkldnn_BAcd16a16b) return "BAcd16a16b"; + if (v == mkldnn_format_tag_last) return "format_tag_last"; + if (v == mkldnn_x) return "x"; + if (v == mkldnn_nc) return "nc"; + if (v == mkldnn_cn) return "cn"; + if (v == mkldnn_ncw) return "ncw"; + if (v == mkldnn_nwc) return "nwc"; + if (v == mkldnn_nchw) return "nchw"; + if (v == mkldnn_nhwc) return "nhwc"; + if (v == mkldnn_chwn) return "chwn"; + if (v == mkldnn_ncdhw) return "ncdhw"; + if (v == mkldnn_ndhwc) return "ndhwc"; + if (v == mkldnn_oi) return "oi"; + if (v == mkldnn_io) return "io"; + if (v == mkldnn_oiw) return "oiw"; + if (v == mkldnn_wio) return "wio"; + if (v == mkldnn_oihw) return "oihw"; + if (v == mkldnn_hwio) return "hwio"; + if (v == mkldnn_ihwo) return "ihwo"; + if (v == mkldnn_iohw) return "iohw"; + if (v == mkldnn_oidhw) return "oidhw"; + if (v == mkldnn_dhwio) return "dhwio"; + if (v == mkldnn_goiw) return "goiw"; + if (v == mkldnn_goihw) return "goihw"; + if (v == mkldnn_hwigo) return "hwigo"; + if (v == mkldnn_giohw) return "giohw"; + if (v == mkldnn_goidhw) return "goidhw"; + if (v == mkldnn_tnc) return "tnc"; + if (v == mkldnn_ntc) return "ntc"; + if (v == mkldnn_ldsnc) return "ldsnc"; + if (v == mkldnn_ldigo) return "ldigo"; + if (v == mkldnn_ldgoi) return "ldgoi"; + if (v == mkldnn_ldgo) return "ldgo"; + if (v == mkldnn_nCdhw16c) return "nCdhw16c"; + if (v == mkldnn_nCdhw4c) return "nCdhw4c"; + if (v == mkldnn_nCdhw8c) return "nCdhw8c"; + if (v == mkldnn_nChw16c) return "nChw16c"; + if (v == mkldnn_nChw4c) return "nChw4c"; + if (v == mkldnn_nChw8c) return "nChw8c"; + if (v == mkldnn_nCw16c) return "nCw16c"; + if (v == mkldnn_nCw4c) return "nCw4c"; + if (v == mkldnn_nCw8c) return "nCw8c"; + if (v == mkldnn_IOw16o16i) return "IOw16o16i"; + if (v == mkldnn_OIw16i16o) return "OIw16i16o"; + if (v == mkldnn_OIw16o16i) return "OIw16o16i"; + if (v == mkldnn_Oiw16o) return "Oiw16o"; + if (v == mkldnn_OIw4i16o4i) return "OIw4i16o4i"; + if (v == mkldnn_OIw4i4o) return "OIw4i4o"; + if (v == mkldnn_Oiw4o) return "Oiw4o"; + if (v == mkldnn_OIw8i16o2i) return "OIw8i16o2i"; + if (v == mkldnn_OIw8i8o) return "OIw8i8o"; + if (v == mkldnn_OIw8o16i2o) return "OIw8o16i2o"; + if (v == mkldnn_OIw8o8i) return "OIw8o8i"; + if (v == mkldnn_Owi16o) return "Owi16o"; + if (v == mkldnn_Owi4o) return "Owi4o"; + if (v == mkldnn_Owi8o) return "Owi8o"; + if (v == mkldnn_IOhw16o16i) return "IOhw16o16i"; + if (v == mkldnn_Ohwi16o) return "Ohwi16o"; + if (v == mkldnn_Ohwi4o) return "Ohwi4o"; + if (v == mkldnn_Ohwi8o) return "Ohwi8o"; + if (v == mkldnn_OIhw16i16o) return "OIhw16i16o"; + if (v == mkldnn_OIhw16o16i) return "OIhw16o16i"; + if (v == mkldnn_Oihw16o) return "Oihw16o"; + if (v == mkldnn_OIhw4i16o4i) return "OIhw4i16o4i"; + if (v == mkldnn_OIhw4i4o) return "OIhw4i4o"; + if (v == mkldnn_Oihw4o) return "Oihw4o"; + if (v == mkldnn_OIhw8i16o2i) return "OIhw8i16o2i"; + if (v == mkldnn_OIhw8i8o) return "OIhw8i8o"; + if (v == mkldnn_OIhw8o16i2o) return "OIhw8o16i2o"; + if (v == mkldnn_OIhw8o8i) return "OIhw8o8i"; + if (v == mkldnn_Odhwi16o) return "Odhwi16o"; + if (v == mkldnn_Odhwi4o) return "Odhwi4o"; + if (v == mkldnn_Odhwi8o) return "Odhwi8o"; + if (v == mkldnn_OIdhw16i16o) return "OIdhw16i16o"; + if (v == mkldnn_OIdhw16o16i) return "OIdhw16o16i"; + if (v == mkldnn_Oidhw16o) return "Oidhw16o"; + if (v == mkldnn_OIdhw4i4o) return "OIdhw4i4o"; + if (v == mkldnn_Oidhw4o) return "Oidhw4o"; + if (v == mkldnn_OIdhw8i16o2i) return "OIdhw8i16o2i"; + if (v == mkldnn_OIdhw8i8o) return "OIdhw8i8o"; + if (v == mkldnn_OIdhw8o8i) return "OIdhw8o8i"; + if (v == mkldnn_Goiw16g) return "Goiw16g"; + if (v == mkldnn_gIOw16o16i) return "gIOw16o16i"; + if (v == mkldnn_gOIw16i16o) return "gOIw16i16o"; + if (v == mkldnn_gOIw16o16i) return "gOIw16o16i"; + if (v == mkldnn_gOiw16o) return "gOiw16o"; + if (v == mkldnn_gOIw4i16o4i) return "gOIw4i16o4i"; + if (v == mkldnn_gOIw4i4o) return "gOIw4i4o"; + if (v == mkldnn_gOiw4o) return "gOiw4o"; + if (v == mkldnn_gOIw8i16o2i) return "gOIw8i16o2i"; + if (v == mkldnn_gOIw8i8o) return "gOIw8i8o"; + if (v == mkldnn_gOIw8o16i2o) return "gOIw8o16i2o"; + if (v == mkldnn_gOIw8o8i) return "gOIw8o8i"; + if (v == mkldnn_gOwi16o) return "gOwi16o"; + if (v == mkldnn_gOwi4o) return "gOwi4o"; + if (v == mkldnn_gOwi8o) return "gOwi8o"; + if (v == mkldnn_gIOhw16o16i) return "gIOhw16o16i"; + if (v == mkldnn_gOhwi16o) return "gOhwi16o"; + if (v == mkldnn_gOhwi4o) return "gOhwi4o"; + if (v == mkldnn_gOhwi8o) return "gOhwi8o"; + if (v == mkldnn_Goihw16g) return "Goihw16g"; + if (v == mkldnn_gOIhw16i16o) return "gOIhw16i16o"; + if (v == mkldnn_gOIhw16o16i) return "gOIhw16o16i"; + if (v == mkldnn_gOihw16o) return "gOihw16o"; + if (v == mkldnn_gOIhw2i8o4i) return "gOIhw2i8o4i"; + if (v == mkldnn_gOIhw4i16o4i) return "gOIhw4i16o4i"; + if (v == mkldnn_gOIhw4i4o) return "gOIhw4i4o"; + if (v == mkldnn_gOIhw4o4i) return "gOIhw4o4i"; + if (v == mkldnn_gOihw4o) return "gOihw4o"; + if (v == mkldnn_Goihw8g) return "Goihw8g"; + if (v == mkldnn_gOIhw8i16o2i) return "gOIhw8i16o2i"; + if (v == mkldnn_gOIhw8i8o) return "gOIhw8i8o"; + if (v == mkldnn_gOIhw8o16i2o) return "gOIhw8o16i2o"; + if (v == mkldnn_gOIhw8o8i) return "gOIhw8o8i"; + if (v == mkldnn_gOdhwi16o) return "gOdhwi16o"; + if (v == mkldnn_gOdhwi4o) return "gOdhwi4o"; + if (v == mkldnn_gOdhwi8o) return "gOdhwi8o"; + if (v == mkldnn_gOIdhw16i16o) return "gOIdhw16i16o"; + if (v == mkldnn_gOIdhw16o16i) return "gOIdhw16o16i"; + if (v == mkldnn_gOidhw16o) return "gOidhw16o"; + if (v == mkldnn_gOIdhw4i4o) return "gOIdhw4i4o"; + if (v == mkldnn_gOidhw4o) return "gOidhw4o"; + if (v == mkldnn_gOIdhw8i16o2i) return "gOIdhw8i16o2i"; + if (v == mkldnn_gOIdhw8i8o) return "gOIdhw8i8o"; + if (v == mkldnn_gOIdhw8o8i) return "gOIdhw8o8i"; + assert(!"unknown fmt_tag"); + return "unknown fmt_tag"; +} + +const char *mkldnn_prop_kind2str(mkldnn_prop_kind_t v) { + if (v == mkldnn_prop_kind_undef) return "undef"; + if (v == mkldnn_forward_training) return "forward_training"; + if (v == mkldnn_forward_inference) return "forward_inference"; + if (v == mkldnn_forward_scoring) return "forward_scoring"; + if (v == mkldnn_forward) return "forward"; + if (v == mkldnn_backward) return "backward"; + if (v == mkldnn_backward_data) return "backward_data"; + if (v == mkldnn_backward_weights) return "backward_weights"; + if (v == mkldnn_backward_bias) return "backward_bias"; + assert(!"unknown prop_kind"); + return "unknown prop_kind"; +} + +const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) { + if (v == mkldnn_undefined_primitive) return "undef"; + if (v == mkldnn_reorder) return "reorder"; + if (v == mkldnn_shuffle) return "shuffle"; + if (v == mkldnn_concat) return "concat"; + if (v == mkldnn_sum) return "sum"; + if (v == mkldnn_convolution) return "convolution"; + if (v == mkldnn_deconvolution) return "deconvolution"; + if (v == mkldnn_eltwise) return "eltwise"; + if (v == mkldnn_softmax) return "softmax"; + if (v == mkldnn_pooling) return "pooling"; + if (v == mkldnn_lrn) return "lrn"; + if (v == mkldnn_batch_normalization) return "batch_normalization"; + if (v == mkldnn_inner_product) return "inner_product"; + if (v == mkldnn_rnn) return "rnn"; + assert(!"unknown prim_kind"); + return "unknown prim_kind"; +} + +const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) { + if (v == mkldnn_alg_kind_undef) return "undef"; + if (v == mkldnn_convolution_direct) return "convolution_direct"; + if (v == mkldnn_convolution_winograd) return "convolution_winograd"; + if (v == mkldnn_convolution_auto) return "convolution_auto"; + if (v == mkldnn_deconvolution_direct) return "deconvolution_direct"; + if (v == mkldnn_deconvolution_winograd) return "deconvolution_winograd"; + if (v == mkldnn_eltwise_relu) return "eltwise_relu"; + if (v == mkldnn_eltwise_tanh) return "eltwise_tanh"; + if (v == mkldnn_eltwise_elu) return "eltwise_elu"; + if (v == mkldnn_eltwise_square) return "eltwise_square"; + if (v == mkldnn_eltwise_abs) return "eltwise_abs"; + if (v == mkldnn_eltwise_sqrt) return "eltwise_sqrt"; + if (v == mkldnn_eltwise_linear) return "eltwise_linear"; + if (v == mkldnn_eltwise_bounded_relu) return "eltwise_bounded_relu"; + if (v == mkldnn_eltwise_soft_relu) return "eltwise_soft_relu"; + if (v == mkldnn_eltwise_logistic) return "eltwise_logistic"; + if (v == mkldnn_pooling_max) return "pooling_max"; + if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding"; + if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding"; + if (v == mkldnn_pooling_avg) return "pooling_avg"; + if (v == mkldnn_lrn_across_channels) return "lrn_across_channels"; + if (v == mkldnn_lrn_within_channel) return "lrn_within_channel"; + if (v == mkldnn_vanilla_rnn) return "vanilla_rnn"; + if (v == mkldnn_vanilla_lstm) return "vanilla_lstm"; + if (v == mkldnn_vanilla_gru) return "vanilla_gru"; + if (v == mkldnn_gru_linear_before_reset) return "gru_linear_before_reset"; + assert(!"unknown alg_kind"); + return "unknown alg_kind"; +} + +const char *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v) { + if (v == mkldnn_unidirectional_left2right) return "unidirectional_left2right"; + if (v == mkldnn_unidirectional_right2left) return "unidirectional_right2left"; + if (v == mkldnn_bidirectional_concat) return "bidirectional_concat"; + if (v == mkldnn_bidirectional_sum) return "bidirectional_sum"; + if (v == mkldnn_unidirectional) return "unidirectional"; + assert(!"unknown rnn_direction"); + return "unknown rnn_direction"; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp new file mode 100644 index 000000000000..7e5789e2c355 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_THREAD_HPP +#define MKLDNN_THREAD_HPP + +#include "utils.hpp" +#include "z_magic.hpp" + +#define MKLDNN_THR_SEQ 0 +#define MKLDNN_THR_OMP 1 +#define MKLDNN_THR_TBB 2 + +/* Ideally this condition below should never happen (if the library is built + * using regular cmake). For the 3rd-party projects that build the library + * from the sources on their own try to guess the right threading... */ +#if !defined(MKLDNN_THR) +# define MKLDNN_THR MKLDNN_THR_TBB +#endif + +#if MKLDNN_THR == MKLDNN_THR_SEQ +#define MKLDNN_THR_SYNC 1 +inline int mkldnn_get_max_threads() { return 1; } +inline int mkldnn_get_num_threads() { return 1; } +inline int mkldnn_get_thread_num() { return 0; } +inline int mkldnn_in_parallel() { return 0; } +inline void mkldnn_thr_barrier() {} + +#define PRAGMA_OMP(...) + +#elif MKLDNN_THR == MKLDNN_THR_OMP +#include +#define MKLDNN_THR_SYNC 1 + +inline int mkldnn_get_max_threads() { return omp_get_max_threads(); } +inline int mkldnn_get_num_threads() { return omp_get_num_threads(); } +inline int mkldnn_get_thread_num() { return omp_get_thread_num(); } +inline int mkldnn_in_parallel() { return omp_in_parallel(); } +inline void mkldnn_thr_barrier() { +# pragma omp barrier +} + +#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__)) + +#elif MKLDNN_THR == MKLDNN_THR_TBB +#include "tbb/task_arena.h" +#include "tbb/parallel_for.h" +#define MKLDNN_THR_SYNC 0 + +inline int mkldnn_get_max_threads() +{ return tbb::this_task_arena::max_concurrency(); } +inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); } +inline int mkldnn_get_thread_num() +{ return tbb::this_task_arena::current_thread_index(); } +inline int mkldnn_in_parallel() { return 0; } +inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); } + +#define PRAGMA_OMP(...) + +#endif + +/* MSVC still supports omp 2.0 only */ +#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER) +# define collapse(x) +# define PRAGMA_OMP_SIMD(...) +#else +# define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__)) +#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER) + +namespace mkldnn { +namespace impl { + +inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; } + +template +inline void balance211(T n, U team, U tid, T &n_start, T &n_end) { + T n_min = 1; + T &n_my = n_end; + if (team <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else if (n_min == 1) { + // team = T1 + T2 + // n = T1*n1 + T2*n2 (n1 - n2 = 1) + T n1 = utils::div_up(n, (T)team); + T n2 = n1 - 1; + T T1 = n - n2 * (T)team; + n_my = (T)tid < T1 ? n1 : n2; + n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2; + } + + n_end += n_start; +} + +} // namespace impl +} // namespace mkldnn + +#include "mkldnn_thread_parallel_nd.hpp" + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp new file mode 100644 index 000000000000..50f9b29622eb --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp @@ -0,0 +1,277 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP +#define MKLDNN_THREAD_PARALLEL_ND_HPP + +/* This header must be included by mkldnn_thread.hpp only */ + +/* Functions: + * - parallel(nthr, f) - executes f in parallel using at most + * nthr threads. If nthr equals 0 + * mkldnn_get_max_threads() threads is + * used + * - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already + * created threads + * - parallel_nd(dims..., f) - creates a parallel section and then + * calls for_nd + * - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then + * calls for_nd (mostly for convenience) + */ + +namespace mkldnn { +namespace impl { + +/* general parallelization */ +template +void parallel(int nthr, F f) { + if (nthr == 0) nthr = mkldnn_get_max_threads(); +#if MKLDNN_THR == MKLDNN_THR_SEQ + assert(nthr == 1); + f(0, 1); +#elif MKLDNN_THR == MKLDNN_THR_OMP + if (nthr == 1) { f(0, 1); return; } +# pragma omp parallel num_threads(nthr) + f(mkldnn_get_thread_num(), mkldnn_get_num_threads()); +#elif MKLDNN_THR == MKLDNN_THR_TBB + if (nthr == 1) { f(0, 1); return; } + tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner()); +#endif +} + +/* for_nd section */ + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, F f) { + T0 start{0}, end{0}; + balance211(D0, nthr, ithr, start, end); + for (T0 d0 = start; d0 < end; ++d0) f(d0); +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) { + const size_t work_amount = (size_t)D0 * D1; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1); + utils::nd_iterator_step(d0, D0, d1, D1); + } +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); + } +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); + } +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, const T4 &D4, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3, d4); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + } +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, + d5, D5); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3, d4, d5); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + } +} + +// Skip a lambda function in the parameter pack. +template +constexpr size_t get_work_amount(const T &v) { return 1; } +template +constexpr size_t get_work_amount(const T &v, Args &&...args) +{ return (size_t)v * get_work_amount(utils::forward(args)...); } + +/* parallel_nd and parallel_nd_in_omp section */ + +#if MKLDNN_THR != MKLDNN_THR_TBB +template +void parallel_nd(Args &&...args) { +#if MKLDNN_THR == MKLDNN_THR_SEQ + for_nd(0, 1, utils::forward(args)...); +#elif MKLDNN_THR == MKLDNN_THR_OMP + const bool do_parallel = get_work_amount(utils::forward(args)...) > 1; +# pragma omp parallel if (do_parallel) + { + const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads(); + const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num(); + for_nd(ithr, nthr, utils::forward(args)...); + } +#endif +} +#else // MKLDNN_THR != MKLDNN_THR_TBB + +// gcc 4.8 has a bug with passing parameter pack to lambdas. +// So have to explicitly instantiate all the cases. + +template +void parallel_nd(const T0 &D0, F f) { + const size_t work_amount = (size_t)D0; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(T0(iwork)); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, F f) { + const size_t work_amount = (size_t)D0 * D1; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1); + utils::nd_iterator_step(d0, D0, d1, D1); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, + const T4 &D4, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3, d4); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, + const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, + d5, D5); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3, d4, d5); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + } + }, tbb::static_partitioner()); +} +#endif + +template +void parallel_nd_in_omp(Args &&...args) { +#if MKLDNN_THR == MKLDNN_THR_SEQ + for_nd(0, 1, utils::forward(args)...); +#elif MKLDNN_THR == MKLDNN_THR_OMP + for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(), + utils::forward(args)...); +#elif MKLDNN_THR == MKLDNN_THR_TBB + assert(!"unsupported parallel_nd_in_omp()"); +#endif +} + +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp new file mode 100644 index 000000000000..aa671a0b6e1d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp @@ -0,0 +1,77 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_TRAITS_HPP +#define MKLDNN_TRAITS_HPP + +#include +#include + +#include "mkldnn.h" +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +template struct prec_traits {}; /* ::type -> float */ +template struct data_traits {}; /* ::data_type -> f32 */ +template struct typesize_traits {}; /* ::data_type_size -> f32 */ +template struct pkind_traits {}; /* ::desc_type, ::query_d */ + +template <> struct prec_traits { typedef float type; }; +template <> struct prec_traits { typedef int32_t type; }; +template <> struct prec_traits { typedef int8_t type; }; +template <> struct prec_traits { typedef uint8_t type; }; + +template <> struct data_traits +{ static constexpr data_type_t data_type = data_type::f32; }; +template <> struct data_traits +{ static constexpr data_type_t data_type = data_type::s32; }; +template <> struct data_traits +{ static constexpr data_type_t data_type = data_type::s8; }; +template <> struct data_traits +{ static constexpr data_type_t data_type = data_type::u8; }; + +template <> struct typesize_traits<4> { typedef float type; }; +template <> struct typesize_traits<2> { typedef int16_t type; }; +template <> struct typesize_traits<1> { typedef uint8_t type; }; + +#define PKIND_TRAITS_INST(op) \ +template <> struct pkind_traits { \ + typedef CONCAT2(op, _desc_t) desc_type; \ + static constexpr query_t query_d = query::CONCAT2(op, _d); \ +} +PKIND_TRAITS_INST(convolution); +PKIND_TRAITS_INST(deconvolution); +PKIND_TRAITS_INST(shuffle); +PKIND_TRAITS_INST(eltwise); +PKIND_TRAITS_INST(softmax); +PKIND_TRAITS_INST(pooling); +PKIND_TRAITS_INST(lrn); +PKIND_TRAITS_INST(batch_normalization); +PKIND_TRAITS_INST(inner_product); +PKIND_TRAITS_INST(rnn); +#undef PKIND_TRAITS_INST + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp new file mode 100644 index 000000000000..f89ea999e264 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp @@ -0,0 +1,193 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef NSTL_HPP +#define NSTL_HPP + +#include +#include +#include + +#include +#include + +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +void *malloc(size_t size, int alignment); +void free(void *p); + +struct c_compatible { + enum { default_alignment = 64 }; + static void *operator new(size_t sz) { + return malloc(sz, default_alignment); + } + static void *operator new(size_t sz, void *p) { UNUSED(sz); return p; } + static void *operator new[](size_t sz) { + return malloc(sz, default_alignment); + } + static void operator delete(void *p) { free(p); } + static void operator delete[](void *p) { free(p); } +}; + +namespace nstl { + +template +inline const T abs(const T& a) { + return a >= 0 ? a : -a; +} + +template +inline const T& max(const T& a, const T& b) { + return a > b ? a : b; +} + +template +inline const T& min(const T& a, const T& b) { + return a < b ? a : b; +} + +template void swap(T& t1, T& t2) { + T tmp(t1); + t1 = t2; + t2 = tmp; +} + +// Rationale: MKL-DNN needs numeric limits implementation that does not +// generate dependencies on C++ run-time libraries. + +template struct numeric_limits; + +template<> struct numeric_limits { + static constexpr float lowest() { return -FLT_MAX; } + static constexpr float max() { return FLT_MAX; } +}; + +template<> struct numeric_limits { + static constexpr int lowest() { return INT32_MIN; } + static constexpr int max() { return INT32_MAX; } +}; + +template<> struct numeric_limits { + static constexpr int16_t lowest() { return INT16_MIN; } + static constexpr int16_t max() { return INT16_MAX; } +}; + +template<> struct numeric_limits { + static constexpr int8_t lowest() { return INT8_MIN; } + static constexpr int8_t max() { return INT8_MAX; } +}; + +template<> struct numeric_limits { + static constexpr uint8_t lowest() { return 0; } + static constexpr uint8_t max() { return UINT8_MAX; } +}; + +template struct is_integral +{ static constexpr bool value = false; }; +template<> struct is_integral { static constexpr bool value = true; }; +template<> struct is_integral { static constexpr bool value = true; }; +template<> struct is_integral { static constexpr bool value = true; }; +template<> struct is_integral { static constexpr bool value = true; }; + +template struct is_same +{ static constexpr bool value = false; }; +template struct is_same +{ static constexpr bool value = true; }; + +// Rationale: MKL-DNN needs container implementations that do not generate +// dependencies on C++ run-time libraries. +// +// Implementation philosophy: caller is responsible to check if the operation +// is valid. The only functions that have to return status are those that +// depend on memory allocation or similar operations. +// +// This means that e.g. an operator [] does not have to check for boundaries. +// The caller should have checked the boundaries. If it did not we crash and +// burn: this is a bug in MKL-DNN and throwing an exception would not have been +// recoverable. +// +// On the other hand, insert() or resize() or a similar operation needs to +// return a status because the outcome depends on factors external to the +// caller. The situation is probably also not recoverable also, but MKL-DNN +// needs to be nice and report "out of memory" to the users. + +enum nstl_status_t { + success = 0, + out_of_memory +}; + +template class vector: public c_compatible { +private: + std::vector _impl; +public: + typedef typename std::vector::iterator iterator; + typedef typename std::vector::const_iterator const_iterator; + typedef typename std::vector::size_type size_type; + vector() {} + vector(size_type n): _impl(n) {} + vector(size_type n, const T &value): _impl(n, value) {} + template + vector(input_iterator first, input_iterator last): _impl(first, last) {} + ~vector() {} + size_type size() const { return _impl.size(); } + T& operator[] (size_type i) { return _impl[i]; } + const T& operator[] (size_type i) const { return _impl[i]; } + iterator begin() { return _impl.begin(); } + const_iterator begin() const { return _impl.begin(); } + iterator end() { return _impl.end(); } + const_iterator end() const { return _impl.end(); } + template + nstl_status_t insert(iterator pos, input_iterator begin, input_iterator end) + { + _impl.insert(pos, begin, end); + return success; + } + void clear() { _impl.clear(); } + void push_back(const T& t) { _impl.push_back(t); } + void resize(size_type count) { _impl.resize(count); } + void reserve(size_type count) { _impl.reserve(count); } +}; + +template class map: public c_compatible { +private: + std::map _impl; +public: + typedef typename std::map::iterator iterator; + typedef typename std::map::const_iterator const_iterator; + typedef typename std::map::size_type size_type; + map() {} + ~map() {} + size_type size() const { return _impl.size(); } + T& operator[](const Key &k) { return _impl[k]; } + const T& operator[](const Key &k) const { return _impl[k]; } + iterator begin() { return _impl.begin(); } + const_iterator begin() const { return _impl.begin(); } + iterator end() { return _impl.end(); } + const_iterator end() const { return _impl.end(); } + template + void clear() { _impl.clear(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp new file mode 100644 index 000000000000..be96e654ff80 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp @@ -0,0 +1,114 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t pooling_desc_init(pooling_desc_t *pool_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t kernel, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(pool_desc, src_desc, dst_desc, strides, kernel, padding_l) + && one_of(alg_kind, pooling_max, + pooling_avg_include_padding, + pooling_avg_exclude_padding) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) return invalid_arguments; + + if (padding_r == nullptr) padding_r = padding_l; + + auto pd = pooling_desc_t(); + pd.primitive_kind = primitive_kind::pooling; + pd.prop_kind = prop_kind; + pd.alg_kind = alg_kind; + pd.src_desc.ndims = src_desc->ndims; + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + + pd.diff_src_desc = pd.src_desc = zero_md(); + pd.diff_dst_desc = pd.dst_desc = zero_md(); + + (is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc; + (is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(pd.strides, strides, sp_dims); + utils::array_copy(pd.kernel, kernel, sp_dims); + utils::array_copy(pd.padding[0], padding_l, sp_dims); + utils::array_copy(pd.padding[1], padding_r, sp_dims); + + pd.padding_kind = padding_kind; + if (one_of(alg_kind, pooling_max, pooling_avg_include_padding, + pooling_avg_exclude_padding)) { + pd.accum_data_type = types::default_accum_data_type( + src_desc->data_type, dst_desc->data_type); + } else { + pd.accum_data_type = dst_desc->data_type; + } + + bool consistency = true + && utils::one_of(src_desc->ndims, 4, 5) + && utils::one_of(dst_desc->ndims, 4, 5) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == dst_desc->dims[1]; + for (int i = 2; i < src_desc->ndims; ++i) + consistency = consistency && ( + (src_desc->dims[i] - kernel[i - 2] + padding_l[i - 2] + + padding_r[i - 2]) / strides[i - 2] + 1 + == dst_desc->dims[i]); + if (!consistency) return invalid_arguments; + + *pool_desc = pd; + return success; +} +} + +status_t mkldnn_pooling_forward_desc_init(pooling_desc_t *pool_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t kernel, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return pooling_desc_init(pool_desc, prop_kind, alg_kind, src_desc, + dst_desc, strides, kernel, padding_l, padding_r, padding_kind); +} + +status_t mkldnn_pooling_backward_desc_init(pooling_desc_t *pool_desc, + alg_kind_t alg_kind, const memory_desc_t *diff_src_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t kernel, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return pooling_desc_init(pool_desc, prop_kind::backward_data, alg_kind, + diff_src_desc, diff_dst_desc, strides, kernel, padding_l, + padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp new file mode 100644 index 000000000000..4c9c009412cc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp @@ -0,0 +1,238 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef POOLING_PD_HPP +#define POOLING_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +struct pooling_fwd_pd_t; + +struct pooling_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::pooling; + + pooling_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , ws_md_() + {} + + const pooling_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::pooling_d: + *(const pooling_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common pooling aux functions */ + + dim_t MB() const { return src_desc().dims[0]; } + dim_t C() const { return src_desc().dims[1]; } + + dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; } + dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; } + dim_t IW() const { return src_desc().dims[ndims() - 1]; } + + dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; } + dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; } + dim_t OW() const { return dst_desc().dims[ndims() - 1]; } + + dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; } + dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; } + dim_t KW() const { return desc_.kernel[ndims() - 3]; } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t padFront() const + { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const + { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const + { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const + { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + int ndims() const { return src_desc().ndims; } + bool is_3d() const { return ndims() == 5; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(src_desc()).has_zero_dim(); } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + pooling_desc_t desc_; + const pooling_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t ws_md_; + + void init_default_ws() { + ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md(); + ws_md_.data_type = indices_data_type(); + } + + data_type_t indices_data_type() const { + /* the simplest way to express 256... */ + const int u8_max = nstl::numeric_limits< + typename prec_traits::type>::max(); + return utils::array_product(desc()->kernel, ndims()) <= u8_max + ? data_type::u8 : data_type::s32; + } + +private: + const memory_desc_t &src_desc() const + { return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; } + const memory_desc_t &dst_desc() const + { return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; } +}; + +struct pooling_fwd_pd_t: public pooling_pd_t { + typedef pooling_fwd_pd_t base_class; + typedef pooling_fwd_pd_t hint_class; + + pooling_fwd_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } + +protected: + memory_desc_t src_md_; + memory_desc_t dst_md_; + + virtual status_t set_default_params() { + if (dst_md()->format_kind != format_kind::any) + return status::success; + + if (src_md()->format_kind != format_kind::blocked) + return status::unimplemented; + + return memory_desc_init_by_blocking_desc(dst_md_, + src_md_.format_desc.blocking); + } +}; + +struct pooling_bwd_pd_t: public pooling_pd_t { + typedef pooling_bwd_pd_t base_class; + typedef pooling_fwd_pd_t hint_class; + + pooling_bwd_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_DIFF_DST) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override + { return 1 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t diff_dst_md_; + + virtual status_t set_default_params() { + if (diff_src_md()->format_kind != format_kind::any) + return status::success; + + if (diff_dst_md()->format_kind != format_kind::blocked) + return status::unimplemented; + + return memory_desc_init_by_blocking_desc(diff_src_md_, + diff_dst_md_.format_desc.blocking); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp new file mode 100644 index 000000000000..fdf6522f6214 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "primitive.hpp" +#include "type_helpers.hpp" +#include "stream.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::primitive_kind; + +namespace { +// XXX: this is a huge hammer. This disables all and any msan checks on +// primitives outputs. +// +// A proper approach would be an implementation-specific unpoisoning. +void unpoison_outputs(const exec_args_t &args) { + for(const auto &arg: args) { + if (arg.second.is_const) continue; + auto *mem = arg.second.mem; + void *p; + mem->get_data_handle(&p); + size_t s = memory_desc_wrapper(*mem->md()).size(); + msan_unpoison(p, s); + } +} +} + +status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) { + if (primitive_desc) delete primitive_desc; + return success; +} + +status_t mkldnn_primitive_create(primitive_t **primitive, + const primitive_desc_t *primitive_desc) { + if (utils::any_null(primitive, primitive_desc)) + return invalid_arguments; + return primitive_desc->create_primitive(primitive); +} + +status_t mkldnn_primitive_execute(const primitive_t *primitive, + stream_t *stream, int nargs, const mkldnn_exec_arg_t *c_args) { + bool ok = true + && !utils::any_null(primitive, stream) + && primitive->engine() == stream->engine() + && IMPLICATION(nargs > 0, c_args != nullptr); + if (!ok) return invalid_arguments; + + exec_args_t args; + status_t status = cvt_primtive_args(primitive->pd(), nargs, c_args, args); + if (status != status::success) return status; + + exec_ctx_t ctx(stream, std::move(args)); + + if (mkldnn_verbose()->level) { + double ms = get_msec(); + status = primitive->execute(ctx); + ms = get_msec() - ms; + printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms); + fflush(0); + } else { + status = primitive->execute(ctx); + } + + if (msan_enabled) unpoison_outputs(ctx.args()); + + return status; +} + +status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive, + const primitive_desc_t **primitive_desc) { + if (utils::any_null(primitive, primitive_desc)) + return invalid_arguments; + return safe_ptr_assign(*primitive_desc, + primitive->pd()); +} + +status_t mkldnn_primitive_destroy(primitive_t *primitive) { + if (primitive != nullptr) + delete primitive; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp new file mode 100644 index 000000000000..3b506d6d1f62 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp @@ -0,0 +1,76 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_HPP +#define PRIMITIVE_HPP + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "primitive_exec_types.hpp" + +/** \brief A pure virtual primitive class + * + * Primitive contains links to its inputs & outputs, though it does not track + * their readiness on execution step. + * + * @remark @b Rational. + * Dependencies are essential through-out the whole MKL-DNN library, so it + * makes sense to include them on the very low level. On the other hand, + * tracking them should be a task for corresponding essence, like scheduler, + * stream or whatever. Primitive itself should know nothing about the + * environment it is running in. + * + * @note + * To make user experience better we should provide API which allows + * achieving the best (or good enough) performance when creating primitives + * in natural order: i.e. from bottom to top for forward pass and from top to + * bottom for backward pass. Please consider restriction [1] in Level 0. + */ +struct mkldnn_primitive: public mkldnn::impl::c_compatible { + mkldnn_primitive(const mkldnn::impl::primitive_desc_t *pd) + : pd_(pd->clone()) {} + virtual ~mkldnn_primitive() { delete pd_; } + + /** returns primitive's engine */ + mkldnn::impl::engine_t *engine() const { return pd_->engine(); } + /** returns primitive's inputs */ + const mkldnn::impl::primitive_desc_t *pd() const { return pd_; } + /** returns primitive's kind */ + mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); } + + /** executes primitive with execution context @p ctx */ + virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx) + const = 0; + +protected: + const mkldnn::impl::primitive_desc_t *pd_; + +private: + mkldnn_primitive() = delete; + mkldnn_primitive(const mkldnn_primitive &) = delete; + mkldnn_primitive(mkldnn_primitive &&) = delete; + mkldnn_primitive &operator=(const mkldnn_primitive &) = delete; + mkldnn_primitive &operator=(mkldnn_primitive &&) = delete; +}; + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp new file mode 100644 index 000000000000..9fd638842cd5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp @@ -0,0 +1,290 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_attr.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +namespace mkldnn { +namespace impl { + +status_t scales_t::set(dim_t count, int mask, const float *scales) { + cleanup(); + + count_ = count; + mask_ = mask; + + if (count_ == 1) { + scales_ = scales_buf_; + utils::array_set(scales_, scales[0], scales_buf_size); + } else { + scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64); + if (scales_ == nullptr) + return status::out_of_memory; + + for (dim_t c = 0; c < count_; ++c) + scales_[c] = scales[c]; + } + + return status::success; +} + +} +} + +status_t post_ops_t::append_sum(float scale) { + if (len_ == capacity) + return out_of_memory; + + entry_[len_].kind = primitive_kind::sum; + entry_[len_].sum.scale = scale; + + len_++; + + return success; +} + +status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha, + float beta) { + using namespace mkldnn::impl::alg_kind; + bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic); + if (!known_alg) + return invalid_arguments; + + if (len_ == capacity) + return out_of_memory; + + entry_[len_].kind = primitive_kind::eltwise; + entry_[len_].eltwise.scale = scale; + entry_[len_].eltwise.alg = alg; + entry_[len_].eltwise.alpha = alpha; + entry_[len_].eltwise.beta = beta; + + len_++; + + return success; +} + +status_t primitive_attr_t::set_scratchpad_mode( + scratchpad_mode_t scratchpad_mode) { + using namespace mkldnn::impl::scratchpad_mode; + + const bool ok = one_of(scratchpad_mode, library, user); + if (!ok) + return invalid_arguments; + + scratchpad_mode_ = scratchpad_mode; + return success; +} + +status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) { + this->post_ops_ = post_ops; + return success; +} + +/* Public C API */ + +status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) { + if (attr == nullptr) + return invalid_arguments; + + return safe_ptr_assign(*attr, + new mkldnn_primitive_attr); +} + +status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr, + const primitive_attr_t *existing_attr) { + if (any_null(attr, existing_attr)) + return invalid_arguments; + + return safe_ptr_assign(*attr, + existing_attr->clone()); +} + +status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) { + if (attr) + delete attr; + + return success; +} + +status_t mkldnn_primitive_attr_get_scratchpad_mode( + const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) { + if (any_null(attr, scratchpad_mode)) + return invalid_arguments; + + *scratchpad_mode = attr->scratchpad_mode_; + + return success; +} + +status_t mkldnn_primitive_attr_set_scratchpad_mode( + primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) { + if (any_null(attr)) + return invalid_arguments; + + return attr->set_scratchpad_mode(scratchpad_mode); +} + +status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr, + dim_t *count, int *mask, const float **scales) { + if (any_null(attr, count, mask, scales)) + return invalid_arguments; + + *count = attr->output_scales_.count_; + *mask = attr->output_scales_.mask_; + *scales = attr->output_scales_.scales_; + + return success; +} + +status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr, + dim_t count, int mask, const float *scales) { + bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->output_scales_.set(count, mask, scales); +} + +status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr, + const post_ops_t **post_ops) { + if (any_null(attr, post_ops)) + return invalid_arguments; + + *post_ops = &attr->post_ops_; + return success; +} + +status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr, + const post_ops_t *post_ops) { + if (any_null(attr, post_ops)) + return invalid_arguments; + + return attr->set_post_ops(*post_ops); +} + +status_t mkldnn_post_ops_create(post_ops_t **post_ops) { + if (post_ops == nullptr) + return invalid_arguments; + + return safe_ptr_assign(*post_ops, new mkldnn_post_ops); +} + +status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) { + if (post_ops) + delete post_ops; + + return success; +} + +int mkldnn_post_ops_len(const post_ops_t *post_ops) { + if (post_ops) + return post_ops->len_; + + return 0; +} + +primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops, + int index) { + bool ok = post_ops && 0 <= index && index < post_ops->len_; + if (!ok) + return primitive_kind::undefined; + + return post_ops->entry_[index].kind; +} + +status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_sum(scale); +} + +namespace { +bool simple_get_params_check(const post_ops_t *post_ops, int index, + primitive_kind_t kind) { + bool ok = true + && post_ops != nullptr + && 0 <= index + && index < post_ops->len_ + && post_ops->entry_[index].kind == kind; + return ok; +} +} + +status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index, + float *scale) { + bool ok = true + && simple_get_params_check(post_ops, index, primitive_kind::sum) + && !any_null(scale); + if (!ok) + return invalid_arguments; + + *scale = post_ops->entry_[index].sum.scale; + return success; +} + +status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale, + alg_kind_t kind, float alpha, float beta) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_eltwise(scale, kind, alpha, beta); +} + +status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops, + int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) { + bool ok = true + && simple_get_params_check(post_ops, index, primitive_kind::eltwise) + && !any_null(scale, alpha, beta); + if (!ok) + return invalid_arguments; + + const auto &e = post_ops->entry_[index].eltwise; + *scale = e.scale; + *alg = e.alg; + *alpha = e.alpha; + *beta = e.beta; + + return success; +} + +status_t mkldnn_primitive_attr_set_rnn_data_qparams( + primitive_attr_t *attr, const float scale, const float shift) { + if (attr == nullptr) + return invalid_arguments; + + return attr->rnn_data_qparams_.set(scale, shift); +} + +status_t mkldnn_primitive_attr_set_rnn_weights_qparams( + primitive_attr_t *attr, dim_t count, int mask, const float *scales) { + bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->rnn_weights_qparams_.set(count, mask, scales); +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp new file mode 100644 index 000000000000..e2130c7ab161 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp @@ -0,0 +1,183 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_ATTR_HPP +#define PRIMITIVE_ATTR_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct rnn_data_qparams_t : public c_compatible { + rnn_data_qparams_t() : scale_(1.), shift_(0.) {} + bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); } + + status_t set(float scale, float shift) { + scale_ = scale; + shift_ = shift; + return status::success; + } + + float scale_; + float shift_; +}; + +struct scales_t: public c_compatible { + scales_t(): count_(1), mask_(0), scales_(scales_buf_) + { set(1.); } + + scales_t(const scales_t &rhs): scales_t() + { set(rhs.count_, rhs.mask_, rhs.scales_); } + + ~scales_t() { cleanup(); } + + scales_t &operator=(const scales_t &rhs) { + if (&rhs == this) + return *this; + status_t status = set(rhs.count_, rhs.mask_, rhs.scales_); + assert(status == status::success); + (void)status; + return *this; + } + + bool has_default_values() const { + for (dim_t c = 0; c < count_; ++c) { + if(scales_[c] != 1.) return false; + } + return true; + } + + status_t set(dim_t count, int mask, const float *scales); + status_t set(float single_scale) { return this->set(1, 0, &single_scale); } + + dim_t count_; + int mask_; + float *scales_; + +private: + enum { scales_buf_size = 16 }; + float scales_buf_[scales_buf_size]; + + void cleanup() { + if (scales_ != scales_buf_ && scales_ != nullptr) + impl::free(scales_); + + count_ = 1; + mask_ = 0; + scales_ = scales_buf_; + } +}; + +} +} + +struct mkldnn_post_ops: public mkldnn::impl::c_compatible { + struct entry_t { + struct eltwise_t { + mkldnn::impl::alg_kind_t alg; + float scale, alpha, beta; + }; + + mkldnn::impl::primitive_kind_t kind; + union { + struct { float scale; } sum; + eltwise_t eltwise; + }; + + bool is_eltwise(bool require_scale_one = true) const { + using namespace mkldnn::impl; + return kind == primitive_kind::eltwise + && IMPLICATION(require_scale_one, eltwise.scale == 1.f); + } + + bool is_relu(bool require_scale_one = true, + bool require_nslope_zero = true) const { + using namespace mkldnn::impl; + return is_eltwise(require_scale_one) + && eltwise.alg == alg_kind::eltwise_relu + && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f); + } + + bool is_sum(bool require_scale_one = true) const { + using namespace mkldnn::impl; + return kind == primitive_kind::sum + && IMPLICATION(require_scale_one, sum.scale == 1.f); + } + }; + + mkldnn_post_ops(): len_(0) {} + + mkldnn::impl::status_t append_sum(float scale); + mkldnn::impl::status_t append_eltwise(float scale, + mkldnn::impl::alg_kind_t alg, float alpha, float beta); + + int find(mkldnn::impl::primitive_kind_t kind, int start = 0, + int stop = -1) const { + if (stop == -1) stop = len_; + stop = mkldnn::impl::nstl::min(stop, len_); + for (int idx = start; idx < stop; ++idx) + if (entry_[idx].kind == kind) return idx; + return -1; + } + + bool has_default_values() const { return len_ == 0; } + + bool contain(mkldnn::impl::primitive_kind_t kind, int index) const + { return find(kind, index, index + 1) == index; } + + enum { capacity = 4 }; + + int len_; + entry_t entry_[capacity]; +}; + +struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible { + mkldnn_primitive_attr() + : scratchpad_mode_(mkldnn::impl::scratchpad_mode::library) + {} + + mkldnn_primitive_attr *clone() const + { return new mkldnn_primitive_attr(*this); } + + /** Returns true if the attributes have default values. + * + * @note The scratchpad_mode_ is not take into account */ + bool has_default_values() const { + return true + && output_scales_.has_default_values() + && post_ops_.has_default_values() + && rnn_data_qparams_.has_default_values() + && rnn_weights_qparams_.has_default_values(); + } + + mkldnn::impl::status_t set_scratchpad_mode( + mkldnn::impl::scratchpad_mode_t scratchpad_mode); + mkldnn::impl::status_t set_post_ops( + const mkldnn::impl::post_ops_t &post_ops); + + mkldnn::impl::scratchpad_mode_t scratchpad_mode_; + mkldnn::impl::scales_t output_scales_; + mkldnn::impl::post_ops_t post_ops_; + mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_; + mkldnn::impl::scales_t rnn_weights_qparams_; +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp new file mode 100644 index 000000000000..723c41e05aec --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +status_t primitive_desc_t::query(query_t what, int idx, void *result) const { + auto safe_ret_md = [&](const memory_desc_t *_) { + if (_ == nullptr) return not_required; + *(const memory_desc_t **)result = _; + return success; + }; + + switch (what) { + case query::engine: *(engine_t**)result = engine(); break; + case query::primitive_kind: *(primitive_kind_t*)result = kind(); break; + + case query::scratchpad_engine: + *(engine_t**)result = scratchpad_engine(); break; + + case query::memory_consumption_s64: + *(dim_t *)result = scratchpad_size(scratchpad_mode::library); break; + + case query::op_d: + if (idx != 0 || op_desc() == nullptr) return invalid_arguments; + *(const_c_op_desc_t *)result + = static_cast(op_desc()); break; + + case query::src_md: return safe_ret_md(src_md(idx)); + case query::diff_src_md: return safe_ret_md(diff_src_md(idx)); + case query::dst_md: return safe_ret_md(dst_md(idx)); + case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx)); + case query::weights_md: return safe_ret_md(weights_md(idx)); + case query::diff_weights_md: return safe_ret_md(diff_weights_md(idx)); + case query::workspace_md: + if (idx != 0) return status::invalid_arguments; + return safe_ret_md(workspace_md(idx)); + case query::scratchpad_md: + if (idx != 0) return status::invalid_arguments; + return safe_ret_md(scratchpad_md(idx)); + + case query::num_of_inputs_s32: *(int*)result = n_inputs(); break; + case query::num_of_outputs_s32: *(int*)result = n_outputs(); break; + + case query::impl_info_str: *(const char **)result = name(); break; + + default: return unimplemented; + } + return success; +} + +status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc, + const primitive_attr_t **attr) { + if (utils::any_null(primitive_desc, attr)) + return invalid_arguments; + + *attr = primitive_desc->attr(); + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp new file mode 100644 index 000000000000..536dcfa1d0b2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp @@ -0,0 +1,174 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_DESC_HPP +#define PRIMITIVE_DESC_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "primitive_attr.hpp" +#include "verbose.hpp" + +struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible { + using md_t = mkldnn::impl::memory_desc_t; + + mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::primitive_kind_t kind) + : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; } + + mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, + mkldnn::impl::primitive_kind_t kind) + : engine_(engine), kind_(kind) { info_[0] = '\0'; } + + virtual mkldnn_primitive_desc *clone() const = 0; + virtual ~mkldnn_primitive_desc() {} + + const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; } + mkldnn::impl::engine_t *engine() const { return engine_; } + mkldnn::impl::primitive_kind_t kind() const { return kind_; } + + virtual void init_info() {} + const char *info() const { return info_; } + + mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() + { return scratchpad_registry_; } + const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const + { return scratchpad_registry_; } + virtual mkldnn::impl::engine_t *scratchpad_engine() const + { return engine_; } + + virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; } + + enum class arg_usage_t { unused, input, output }; + virtual arg_usage_t arg_usage( + mkldnn::impl::primitive_arg_index_t arg) const { + using mkldnn::impl::types::is_zero_md; + if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md())) + return arg_usage_t::output; + return arg_usage_t::unused; + } + +# define DECLARE_MD_STUB(stub) \ + virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \ + { return nullptr; } + + DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md); + DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md); + DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md); + DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md); + DECLARE_MD_STUB(workspace_md); +# undef DECLARE_MD_STUB + + const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const { + return idx == 0 ? &scratchpad_md_ : nullptr; + } + + virtual void init_scratchpad_md() { + auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user); + mkldnn::impl::dims_t dims = { size }; + mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims, + mkldnn::impl::data_type::u8, mkldnn_x); + } + + /** returns the scratchpad size for the given scratchpad mode. */ + mkldnn::impl::dim_t scratchpad_size( + mkldnn::impl::scratchpad_mode_t mode) const { + if (mode != attr_.scratchpad_mode_) return 0; + return scratchpad_registry().size(); + } + + virtual int n_inputs() const { return 0; } + virtual int n_outputs() const { return 0; } + + virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx, + void *result) const; + + virtual mkldnn::impl::status_t create_primitive( + mkldnn::impl::primitive_t **primitive) const = 0; + + virtual const char *name() const { return "mkldnn_primitive_desc"; } + + /* static magic */ + + template + static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd, + const mkldnn::impl::op_desc_t *adesc, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_desc_t *hint_fwd) { + using namespace mkldnn::impl; + using namespace mkldnn::impl::status; + using pd_op_desc_t = typename pkind_traits::desc_type; + if (adesc->kind != pd_t::base_pkind) return invalid_arguments; + assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true); + auto hint = + reinterpret_cast(hint_fwd); + auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + _pd->init_info(); + _pd->init_scratchpad_md(); + *pd = _pd; + return success; + } + +protected: + mkldnn::impl::engine_t *engine_; + mkldnn::impl::primitive_attr_t attr_; + mkldnn::impl::primitive_kind_t kind_; + + mkldnn::impl::memory_desc_t scratchpad_md_; + + char info_[MKLDNN_VERBOSE_BUF_LEN]; + + mkldnn::impl::memory_tracking::registry_t scratchpad_registry_; + +protected: + /** compares ws between fwd_pd and this (make sense to use for bwd_pd) + * Expectation: this already set workspace, and this workspace should + * exactly match the one from fwd_pd */ + bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const { + using namespace mkldnn::impl; + if (!workspace_md()) return true; // the impl lives fine w/o workspace + return fwd_pd && fwd_pd->workspace_md() + && *fwd_pd->workspace_md() == *workspace_md(); + } +}; + +#define DECLARE_COMMON_PD_t(impl_name, ...) \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual const char *name() const override { return impl_name; } +#define DECLARE_COMMON_PD_T(impl_name, ...) \ + DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__) + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp new file mode 100644 index 000000000000..43e5a31ef3f7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp @@ -0,0 +1,90 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "memory.hpp" +#include "primitive.hpp" +#include "primitive_exec_types.hpp" + +namespace mkldnn { +namespace impl { + +status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs, + const mkldnn_exec_arg_t *c_args, exec_args_t &args) { + using namespace status; + + if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments; + + int n_inputs = 0; + int n_outputs = 0; + + for (int i = 0; i < nargs; ++i) { + primitive_arg_index_t arg = c_args[i].arg; + auto *mem = c_args[i].memory; + + switch (pd->arg_usage(arg)) { + case primitive_desc_t::arg_usage_t::input: + if (args.count(arg) != 0) return invalid_arguments; + args[arg] = {mem, true}; + n_inputs++; + break; + case primitive_desc_t::arg_usage_t::output: + if (args.count(arg) != 0) return invalid_arguments; + args[arg] = {mem, false}; + n_outputs++; + break; + case primitive_desc_t::arg_usage_t::unused: + break; + } + } + + bool scratchpad_required = !types::is_zero_md(pd->scratchpad_md()); + + if (n_inputs != pd->n_inputs()) return invalid_arguments; + if (n_outputs != pd->n_outputs() + (scratchpad_required ? 1 : 0)) + return invalid_arguments; + + return success; +} + +const void *exec_ctx_t::input(primitive_arg_index_t arg) const { + if (args_.count(arg) != 1) return nullptr; + const auto ma = args_.at(arg); + assert(ma.is_const); + void *ptr; + status_t status = ma.mem->get_data_handle(&ptr); + assert(status == status::success); MAYBE_UNUSED(status); + return ptr; +} + +void *exec_ctx_t::output(primitive_arg_index_t arg) const { + if (args_.count(arg) != 1) return nullptr; + const auto ma = args_.at(arg); + assert(!ma.is_const); + void *ptr; + status_t status = ma.mem->get_data_handle(&ptr); + assert(status == status::success); MAYBE_UNUSED(status); + return ptr; +} + +const memory_t *exec_ctx_t::memory(primitive_arg_index_t arg) const { + assert(args_.count(arg) == 1); + const auto ma = args_.at(arg); + assert(!ma.is_const); + return ma.mem; +} + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp new file mode 100644 index 000000000000..0645891da7a9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_EXEC_TYPES_HPP +#define PRIMITIVE_EXEC_TYPES_HPP + +#include + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "memory.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct memory_arg_t { + memory_t *mem; + bool is_const; +}; + +using exec_args_t = std::unordered_map; + +status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs, + const mkldnn_exec_arg_t *c_args, exec_args_t &args); + +/** Primitive execution context (helps passing stream, memories, and events. */ +struct exec_ctx_t { + exec_ctx_t(const exec_ctx_t &) = default; + exec_ctx_t(exec_ctx_t &&) = default; + + exec_ctx_t(stream_t *stream): stream_(stream) {} + exec_ctx_t(stream_t *stream, exec_args_t &&args) + : stream_(stream) + , args_(std::move(args)) {} + + stream_t *stream() const { return stream_; } + const exec_args_t &args() const { return args_; } + + /* tentative solution... TODO: replace with functions return memory_t */ + const void *input(primitive_arg_index_t arg) const; + void *output(primitive_arg_index_t arg) const; + + const memory_t *memory(primitive_arg_index_t arg) const; + +private: + stream_t *stream_; + exec_args_t args_; +}; + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp new file mode 100644 index 000000000000..5a1cd7d379ac --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "primitive_iterator.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +status_t mkldnn_primitive_desc_iterator_create( + primitive_desc_iterator_t **iterator, const_c_op_desc_t c_op_desc, + const primitive_attr_t *attr, engine_t *engine, + const primitive_desc_t *hint_fwd_pd) { + const op_desc_t *op_desc = (const op_desc_t *)c_op_desc; + + auto it = new primitive_desc_iterator_t(engine, op_desc, attr, hint_fwd_pd); + if (it == nullptr) return out_of_memory; + + ++(*it); + if (*it == it->end()) { + delete it; + return unimplemented; + } + + *iterator = it; + return success; +} + +status_t mkldnn_primitive_desc_iterator_next( + primitive_desc_iterator_t *iterator) { + if (iterator == nullptr) return invalid_arguments; + ++(*iterator); + return *iterator == iterator->end() ? iterator_ends : success; +} + +primitive_desc_t *mkldnn_primitive_desc_iterator_fetch( + const primitive_desc_iterator_t *iterator) { + if (iterator == nullptr) return nullptr; + return *(*iterator); +} + +status_t mkldnn_primitive_desc_clone(primitive_desc_t **primitive_desc, + const primitive_desc_t *existing_primitive_desc) { + if (utils::any_null(primitive_desc, existing_primitive_desc)) + return invalid_arguments; + return safe_ptr_assign(*primitive_desc, + existing_primitive_desc->clone()); +} + +status_t mkldnn_primitive_desc_iterator_destroy( + primitive_desc_iterator_t *iterator) { + if (iterator != nullptr) + delete iterator; + return success; +} + +status_t mkldnn_primitive_desc_create(primitive_desc_t **primitive_desc, + const_c_op_desc_t c_op_desc, const primitive_attr_t *attr, + engine_t *engine, const primitive_desc_t *hint_fwd_pd) { + const op_desc_t *op_desc = (const op_desc_t *)c_op_desc; + + mkldnn_primitive_desc_iterator it(engine, op_desc, attr, hint_fwd_pd); + ++it; + if (it == it.end()) return unimplemented; + + return safe_ptr_assign(*primitive_desc, *it); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp new file mode 100644 index 000000000000..4e88ab3aa5d9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp @@ -0,0 +1,79 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +#ifndef PRIMITIVE_ITERATOR_HPP +#define PRIMITIVE_ITERATOR_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible { + using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f; + + mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc, + const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd) + : idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc) + , attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd) + , impl_list_(engine_->get_implementation_list()), last_idx_(0) + { + while (impl_list_[last_idx_] != nullptr) ++last_idx_; + } + ~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; } + + bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const + { return idx_ == rhs.idx_ && engine_ == rhs.engine_; } + bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const + { return !operator==(rhs); } + + mkldnn::impl::primitive_desc_iterator_t end() const + { return mkldnn_primitive_desc_iterator(engine_, last_idx_); } + + mkldnn::impl::primitive_desc_iterator_t &operator++() { + if (pd_) { delete pd_; pd_ = nullptr; } + while (++idx_ != last_idx_) { + auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_, + hint_fwd_pd_); + if (s == mkldnn::impl::status::success) break; + } + return *this; + } + + mkldnn::impl::primitive_desc_t *operator*() const { + if (*this == end() || pd_ == nullptr) return nullptr; + return pd_->clone(); + } + +protected: + int idx_; + mkldnn::impl::engine_t *engine_; + mkldnn::impl::primitive_desc_t *pd_; + const mkldnn::impl::op_desc_t *op_desc_; + const mkldnn::impl::primitive_attr_t attr_; + const mkldnn::impl::primitive_desc_t *hint_fwd_pd_; + const pd_create_f *impl_list_; + int last_idx_; + +private: + mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx) + : idx_(last_idx), engine_(engine), pd_(nullptr) + , op_desc_(nullptr), hint_fwd_pd_(nullptr) + , impl_list_(nullptr), last_idx_(last_idx) {} +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/query.cpp b/thirdparty/oidn/mkl-dnn/src/common/query.cpp new file mode 100644 index 000000000000..835cd7358156 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/query.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_primitive_desc_query(const primitive_desc_t *primitive_desc, + query_t what, int index, void *result) { + if (any_null(primitive_desc, result)) + return invalid_arguments; + + return primitive_desc->query(what, index, result); +} + +const memory_desc_t *mkldnn_primitive_desc_query_md( + const primitive_desc_t *primitive_desc, query_t what, int index) { + const memory_desc_t *res_md = nullptr; + bool args_ok = true + && primitive_desc != nullptr + && (what & query::some_md) == query::some_md + && what != query::some_md + && mkldnn_primitive_desc_query(primitive_desc, + what, index, &res_md) == success; + return args_ok ? res_md : nullptr; +} + +int mkldnn_primitive_desc_query_s32(const primitive_desc_t *primitive_desc, + query_t what, int index) { + int res_s32; + bool args_ok = primitive_desc != nullptr + && one_of(what, query::num_of_inputs_s32, query::num_of_outputs_s32) + && mkldnn_primitive_desc_query(primitive_desc, what, index, &res_s32) + == success; + return args_ok ? res_s32 : 0; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp new file mode 100644 index 000000000000..d11f1a0361f1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "reorder_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_reorder_primitive_desc_create( + primitive_desc_t **reorder_pd, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md, + const primitive_attr_t *attr) { + if (any_null(reorder_pd, src_engine, src_md, dst_engine, dst_md)) + return invalid_arguments; + + auto s_ek = src_engine->kind(); + auto d_ek = dst_engine->kind(); + if (!IMPLICATION(s_ek != d_ek, one_of(engine_kind::cpu, s_ek, d_ek))) + return invalid_arguments; + + auto r_pd = reinterpret_cast(reorder_pd); + auto s_mdw = memory_desc_wrapper(*src_md); + auto d_mdw = memory_desc_wrapper(*dst_md); + + if (!s_mdw.consistent_with(d_mdw)) + return invalid_arguments; + + auto e = (s_ek != engine_kind::cpu) ? src_engine : dst_engine; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + for (auto r = e->get_reorder_implementation_list(); *r; ++r) { + if ((*r)(r_pd, e, attr, src_engine, src_md, dst_engine, dst_md) + == success) { + (*r_pd)->init_info(); + (*r_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp new file mode 100644 index 000000000000..963cb0f58a32 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp @@ -0,0 +1,85 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REORDER_PD_HPP +#define REORDER_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "primitive_attr.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct reorder_pd_t: public primitive_desc_t { + reorder_pd_t(engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) + : primitive_desc_t(engine, attr, primitive_kind::reorder) + , src_engine_(src_engine) + , dst_engine_(dst_engine) + , scratchpad_engine_(nullptr) + , src_md_(*src_md) + , dst_md_(*dst_md) + {} + + virtual const op_desc_t *op_desc() const override { return nullptr; } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_FROM) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_TO) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + float alpha() const { return attr()->output_scales_.scales_[0]; } + float beta() const { + const int sum_idx = attr()->post_ops_.find(primitive_kind::sum); + return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale; + } + virtual mkldnn::impl::engine_t *scratchpad_engine() const override + { return scratchpad_engine_; } + +protected: + engine_t *src_engine_; + engine_t *dst_engine_; + engine_t *scratchpad_engine_; + + memory_desc_t src_md_; + memory_desc_t dst_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp new file mode 100644 index 000000000000..36967431a621 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp @@ -0,0 +1,400 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu/gemm/os_blas.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::types; +using namespace mkldnn::impl::utils; + +namespace { +memory_desc_t copy_maybe_null(const memory_desc_t *md) { + return md ? *md : zero_md(); +} + +rnn_desc_t zero_rnn_desc() { + auto rd = rnn_desc_t(); + rd.src_layer_desc = zero_md(); + rd.src_iter_desc = zero_md(); + rd.weights_layer_desc = zero_md(); + rd.weights_iter_desc = zero_md(); + rd.bias_desc = zero_md(); + rd.dst_layer_desc = zero_md(); + rd.dst_iter_desc = zero_md(); + rd.diff_src_layer_desc = zero_md(); + rd.diff_src_iter_desc = zero_md(); + rd.diff_weights_layer_desc = zero_md(); + rd.diff_weights_iter_desc = zero_md(); + rd.diff_bias_desc = zero_md(); + rd.diff_dst_layer_desc = zero_md(); + rd.diff_dst_iter_desc = zero_md(); + return rd; +} +} + +/* Public C Api */ + +status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc, + mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f, + unsigned int flags, float alpha, float clipping) { + using namespace mkldnn::impl::alg_kind; + + bool args_ok = true + && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru, + gru_linear_before_reset) + && IMPLICATION(cell_kind == vanilla_rnn, + one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic)); + if (!args_ok) + return invalid_arguments; + + auto rcd = mkldnn_rnn_cell_desc_t(); + + rcd.cell_kind = cell_kind; + rcd.activation_kind = act_f; + rcd.flags = flags; + rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0; + rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0; + + *rnn_cell_desc = rcd; + + return success; +} + +int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) { + switch (rnn_cell_desc->cell_kind) { + case mkldnn::impl::alg_kind::vanilla_rnn: return 1; + case mkldnn::impl::alg_kind::vanilla_gru: return 3; + case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3; + case mkldnn::impl::alg_kind::vanilla_lstm: return 4; + default: assert(!"unknown cell kind"); return 0; + } + return 0; +} + +int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) { + switch (rnn_cell_desc->cell_kind) { + case mkldnn::impl::alg_kind::vanilla_rnn: return 1; + case mkldnn::impl::alg_kind::vanilla_gru: return 1; + case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1; + case mkldnn::impl::alg_kind::vanilla_lstm: return 2; + default: assert(!"unknown cell kind"); return 0; + } + return 0; +} + +status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc, + prop_kind_t prop_kind, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + using namespace data_type; + data_type_t src_layer_dt = src_layer_desc->data_type; + data_type_t dst_layer_dt = dst_layer_desc->data_type; + data_type_t weights_iter_dt = weights_iter_desc->data_type; + data_type_t weights_layer_dt = weights_layer_desc->data_type; + + bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt, + weights_layer_dt) + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + +#if USE_MKL_PACKED_GEMM + bool is_u8u8u8 = src_layer_dt == u8 + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == u8) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == u8) + && one_of(dst_layer_dt, u8, f32) + && everyone_is(s8, weights_iter_dt, weights_layer_dt) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + + bool is_f32u8f32 = src_layer_dt == u8 + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == f32) + && one_of(dst_layer_dt, u8, f32) + && everyone_is(s8, weights_iter_dt, weights_layer_dt) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + + bool is_inference = prop_kind == prop_kind::forward_inference; + bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm; + + return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference)) + ? success + : unimplemented; +#else + return is_f32 ? success : unimplemented; +#endif +} + +status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc, + rnn_direction_t direction, int L, int D, int T, int N, int S, int G, + int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + bool args_ok; + + // * algorithm specific + args_ok = true + && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru, + DIC == SIC); + if (!args_ok) return invalid_arguments; + int extra_bias = + rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset; + + // * on num layers + args_ok = true + && L == weights_layer_desc->dims[0] + && L == weights_iter_desc->dims[0] + && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0]) + && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0]) + && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]); + if (!args_ok) return invalid_arguments; + + // * on num directions + args_ok = true + && D == weights_layer_desc->dims[1] + && D == weights_iter_desc->dims[1] + && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1]) + && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1]) + && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]); + if (!args_ok) return invalid_arguments; + + // * on num iterations + args_ok = true + && T == src_layer_desc->dims[0] + && T == dst_layer_desc->dims[0]; + if (!args_ok) return invalid_arguments; + + // * on mb + args_ok = true + && N == src_layer_desc->dims[1] + && N == dst_layer_desc->dims[1] + && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3]) + && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]); + if (!args_ok) return invalid_arguments; + + // * on num gates + args_ok = true + && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc) + && G == weights_layer_desc->dims[3] + && G == weights_iter_desc->dims[3] + && IMPLICATION(!is_zero_md(bias_desc), + G + extra_bias == bias_desc->dims[2]); + if (!args_ok) return invalid_arguments; + + // * on num states + args_ok = true + && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc) + && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2]) + && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]); + if (!args_ok) return invalid_arguments; + + // * on slc + args_ok = true + && SLC == weights_layer_desc->dims[2] + && SLC == src_layer_desc->dims[2]; + if (!args_ok) return invalid_arguments; + + // * on sic + args_ok = true + && SIC == weights_iter_desc->dims[2] + && IMPLICATION(!is_zero_md(src_iter_desc), + SIC == src_iter_desc->dims[4]); + if (!args_ok) return invalid_arguments; + + // * on dlc + int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1; + args_ok = true + && DLC == dlc_multiplier * DIC + && DLC == dst_layer_desc->dims[2]; + if (!args_ok) return invalid_arguments; + + // * on dic + args_ok = true + && DIC == weights_layer_desc->dims[4] + && DIC == weights_iter_desc->dims[4] + && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3]) + && IMPLICATION(!is_zero_md(dst_iter_desc), + DIC == dst_iter_desc->dims[4]); + if (!args_ok) return invalid_arguments; + + // * unrolling/fusion conditions + args_ok = true + && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC) + && IMPLICATION(T > 1, SIC == DIC); + if (!args_ok) return invalid_arguments; + + return success; +} + +status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, + prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, + const rnn_direction_t direction, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + bool args_ok = true && rnn_cell_desc != nullptr + && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, + dst_layer_desc); + if (!args_ok) return invalid_arguments; + + //check dimensions consistency + int L = weights_layer_desc->dims[0]; + int T = src_layer_desc->dims[0]; + int N = src_layer_desc->dims[1]; + const int D = one_of(direction, mkldnn_unidirectional_left2right, + mkldnn_unidirectional_right2left) ? + 1 : + 2; + int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); + int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); + int SLC = src_layer_desc->dims[2]; + int SIC = weights_iter_desc->dims[2]; + int DLC = dst_layer_desc->dims[2]; + int DIC = weights_layer_desc->dims[4]; + + CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, + weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, + dst_iter_desc)); + + CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind, + src_layer_desc, src_iter_desc, weights_layer_desc, + weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc)); + + // Create the descriptor + mkldnn_rnn_desc_t rd = zero_rnn_desc(); + + rd.primitive_kind = primitive_kind::rnn; + rd.prop_kind = prop_kind; + rd.cell_desc = *rnn_cell_desc; + rd.direction = direction; + rd.src_layer_desc = copy_maybe_null(src_layer_desc); + rd.src_iter_desc = copy_maybe_null(src_iter_desc); + rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); + rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); + rd.bias_desc = copy_maybe_null(bias_desc); + rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); + rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); + + *rnn_desc = rd; + + return success; +} + +status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, + prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, + const rnn_direction_t direction, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, + const memory_desc_t *diff_src_layer_desc, + const memory_desc_t *diff_src_iter_desc, + const memory_desc_t *diff_weights_layer_desc, + const memory_desc_t *diff_weights_iter_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_layer_desc, + const memory_desc_t *diff_dst_iter_desc) { + bool args_ok = true + && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, + dst_layer_desc, diff_src_layer_desc, + diff_weights_layer_desc, diff_weights_iter_desc, + diff_dst_layer_desc); + if (!args_ok) + return invalid_arguments; + + auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) { + return is_zero_md(a_md) == is_zero_md(b_md); + }; + + args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc) + && xnor_md(dst_iter_desc, diff_dst_iter_desc) + && xnor_md(src_iter_desc, diff_src_iter_desc); + if (!args_ok) + return invalid_arguments; + + //check dimensions consistency + int L = weights_layer_desc->dims[0]; + int T = src_layer_desc->dims[0]; + int N = src_layer_desc->dims[1]; + const int D = one_of(direction, mkldnn_unidirectional_left2right, + mkldnn_unidirectional_right2left) ? + 1 : + 2; + int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); + int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); + int SLC = src_layer_desc->dims[2]; + int SIC = weights_iter_desc->dims[2]; + int DLC = dst_layer_desc->dims[2]; + int DIC = weights_layer_desc->dims[4]; + + status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, + weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, + dst_iter_desc); + if (st != success) return st; + + st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc, + diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc, + diff_dst_layer_desc, diff_dst_iter_desc); + if (st != success) return st; + + mkldnn_rnn_desc_t rd = zero_rnn_desc(); + + rd.primitive_kind = primitive_kind::rnn; + rd.prop_kind = prop_kind; + rd.cell_desc = *rnn_cell_desc; + rd.direction = direction; + + rd.src_layer_desc = copy_maybe_null(src_layer_desc); + rd.src_iter_desc = copy_maybe_null(src_iter_desc); + rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); + rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); + rd.bias_desc = copy_maybe_null(bias_desc); + rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); + rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); + rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc); + rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc); + rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc); + rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc); + rd.diff_bias_desc = copy_maybe_null(diff_bias_desc); + rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc); + rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc); + + *rnn_desc = rd; + + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp new file mode 100644 index 000000000000..1ee2ba11143e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef RNN_PD_HPP +#define RNN_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +struct rnn_fwd_pd_t; + +struct rnn_pd_t : public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::rnn; + + rnn_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , src_layer_md_(desc_.src_layer_desc) + , src_iter_md_(desc_.src_iter_desc) + , weights_layer_md_(desc_.weights_layer_desc) + , weights_iter_md_(desc_.weights_iter_desc) + , bias_md_(desc_.bias_desc) + , dst_layer_md_(desc_.dst_layer_desc) + , dst_iter_md_(desc_.dst_iter_desc) + , ws_md_() + {} + + const rnn_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::rnn_d: *(const rnn_desc_t **)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + virtual const memory_desc_t *src_md(int index = 0) const override { + if (index == 0) return &src_layer_md_; + if (index == 1 && with_src_iter()) return &src_iter_md_; + return nullptr; + } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_layer_md_; + if (index == 1) return &weights_iter_md_; + if (index == 2 && with_bias()) return &bias_md_; + return nullptr; + } + virtual const memory_desc_t *dst_md(int index = 0) const override { + if (index == 0) return &dst_layer_md_; + if (index == 1 && with_dst_iter()) return &dst_iter_md_; + return nullptr; + } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + /* common pooling aux functions */ + + bool is_training() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::backward); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + dim_t T() const { return desc_.src_layer_desc.dims[0]; } + dim_t MB() const { return desc_.src_layer_desc.dims[1]; } + + dim_t L() const { return desc_.weights_layer_desc.dims[0]; } + dim_t D() const { return desc_.weights_layer_desc.dims[1]; } + + dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; } + + dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; } + dim_t G() const { return desc_.weights_layer_desc.dims[3]; } + dim_t DIC() const { return desc_.weights_layer_desc.dims[4]; } + + dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; } + + bool with_bias() const + { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); } + + bool with_src_iter() const + { return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); } + + bool with_dst_iter() const + { return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); } + + mkldnn::impl::alg_kind_t cell_kind() const + { return desc_.cell_desc.cell_kind; } + mkldnn::impl::alg_kind_t activation_kind() const + { return desc_.cell_desc.activation_kind; } + + bool is_lbr() const + { return cell_kind() == mkldnn_gru_linear_before_reset; } + + mkldnn_rnn_direction_t direction() const { return desc_.direction; } + +protected: + rnn_desc_t desc_; + const rnn_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t src_layer_md_; + memory_desc_t src_iter_md_; + memory_desc_t weights_layer_md_; + memory_desc_t weights_iter_md_; + memory_desc_t bias_md_; + memory_desc_t dst_layer_md_; + memory_desc_t dst_iter_md_; + + memory_desc_t ws_md_; +}; + +struct rnn_fwd_pd_t: public rnn_pd_t { + typedef rnn_fwd_pd_t base_class; + typedef rnn_fwd_pd_t hint_class; + + rnn_fwd_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : rnn_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC_LAYER) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_SRC_ITER && with_src_iter()) + return arg_usage_t::input; + + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER, + MKLDNN_ARG_WEIGHTS_ITER)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST_LAYER) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DST_ITER && with_dst_iter()) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && is_training()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual int n_inputs() const override + { return 3 + with_bias() + with_src_iter(); } + virtual int n_outputs() const override + { return 1 + with_dst_iter() + is_training(); } +}; + +struct rnn_bwd_pd_t : public rnn_pd_t { + typedef rnn_bwd_pd_t base_class; + typedef rnn_fwd_pd_t hint_class; + + rnn_bwd_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : rnn_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_layer_md_(desc_.diff_src_layer_desc) + , diff_src_iter_md_(desc_.diff_src_iter_desc) + , diff_weights_layer_md_(desc_.diff_weights_layer_desc) + , diff_weights_iter_md_(desc_.diff_weights_iter_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_layer_md_(desc_.diff_dst_layer_desc) + , diff_dst_iter_md_(desc_.diff_dst_iter_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC_LAYER, MKLDNN_ARG_DST_LAYER, + MKLDNN_ARG_DIFF_DST_LAYER)) + return arg_usage_t::input; + + if (with_src_iter()) { + if (arg == MKLDNN_ARG_SRC_ITER) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC_ITER) + return arg_usage_t::output; + } + + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER, + MKLDNN_ARG_WEIGHTS_ITER)) + return arg_usage_t::input; + + if (with_bias()) { + if (arg == MKLDNN_ARG_BIAS) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_BIAS) + return arg_usage_t::output; + } + + if (utils::one_of(arg, MKLDNN_ARG_DST_ITER, MKLDNN_ARG_DIFF_DST_ITER) + && with_dst_iter()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE) + return arg_usage_t::input; + + if (utils::one_of(arg, MKLDNN_ARG_DIFF_SRC_LAYER, + MKLDNN_ARG_DIFF_WEIGHTS_LAYER, + MKLDNN_ARG_DIFF_WEIGHTS_ITER)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override { + if (index == 0) return &diff_src_layer_md_; + if (index == 1 && with_src_iter()) return &diff_src_iter_md_; + return nullptr; + } + virtual const memory_desc_t *diff_weights_md( + int index = 0) const override { + if (index == 0) return &diff_weights_layer_md_; + if (index == 1) return &diff_weights_iter_md_; + if (index == 2 && with_bias()) return &diff_bias_md_; + return nullptr; + } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override { + if (index == 0) return &diff_dst_layer_md_; + if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_; + return nullptr; + } + + virtual int n_inputs() const override + { return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); } + virtual int n_outputs() const override + { return 3 + with_src_iter() + with_bias(); } + +protected: + memory_desc_t diff_src_layer_md_; + memory_desc_t diff_src_iter_md_; + memory_desc_t diff_weights_layer_md_; + memory_desc_t diff_weights_iter_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_layer_md_; + memory_desc_t diff_dst_iter_md_; +}; + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp new file mode 100644 index 000000000000..6bc14fc72a67 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp @@ -0,0 +1,112 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "scratchpad.hpp" + +namespace mkldnn { +namespace impl { + +/* Allocating memory buffers on a page boundary to reduce TLB/page misses */ +const size_t page_size = 2097152; + +/* + Implementation of the scratchpad_t interface that is compatible with + a concurrent execution +*/ +struct concurent_scratchpad_t : public scratchpad_t { + concurent_scratchpad_t(size_t size) { + size_ = size; + scratchpad_ = (char *) malloc(size, page_size); + assert(scratchpad_ != nullptr); + } + + ~concurent_scratchpad_t() { + free(scratchpad_); + } + + virtual char *get() const { + return scratchpad_; + } + +private: + char *scratchpad_; + size_t size_; +}; + +/* + Implementation of the scratchpad_t interface that uses a global + scratchpad +*/ + +struct global_scratchpad_t : public scratchpad_t { + global_scratchpad_t(size_t size) { + if (size > size_) { + if (scratchpad_ != nullptr) free(scratchpad_); + size_ = size; + scratchpad_ = (char *) malloc(size, page_size); + assert(scratchpad_ != nullptr); + } + reference_count_++; + } + + ~global_scratchpad_t() { + reference_count_--; + if (reference_count_ == 0) { + free(scratchpad_); + scratchpad_ = nullptr; + size_ = 0; + } + } + + virtual char *get() const { + return scratchpad_; + } + +private: + /* + Using thread-local here is unnecessary and even buggy! All threads + actually share the same scratchpad, which is created and queried only + on the main thread. If the scratchpad is queried on some thread other + than the one it was created on (e.g. the application calls the API from + multiple threads), thread-local causes a segfault because the scratchpad + is uninitialized on the current thread. + */ + /*thread_local*/ static char *scratchpad_; + /*thread_local*/ static size_t size_; + /*thread_local*/ static unsigned int reference_count_; +}; + +/*thread_local*/ char *global_scratchpad_t::scratchpad_ = nullptr; +/*thread_local*/ size_t global_scratchpad_t::size_ = 0; +/*thread_local*/ unsigned int global_scratchpad_t::reference_count_ = 0; + + +/* + Scratchpad creation routine +*/ +scratchpad_t *create_scratchpad(size_t size) { +#ifndef MKLDNN_ENABLE_CONCURRENT_EXEC + return new global_scratchpad_t(size); +#else + return new concurent_scratchpad_t(size); +#endif +} + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp new file mode 100644 index 000000000000..f7a246bc9926 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_SCRATCHPAD_HPP +#define COMMON_SCRATCHPAD_HPP + +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct scratchpad_t { + virtual ~scratchpad_t() {} + virtual char *get() const = 0; +}; + +scratchpad_t *create_scratchpad(size_t size); + +} +} +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp new file mode 100644 index 000000000000..e32e7352246d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp @@ -0,0 +1,72 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t shuffle_desc_init(shuffle_desc_t *shuffle_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, int axis, dim_t group_size) { + bool args_ok = true + && !any_null(shuffle_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward, backward_data) + && axis >= 0 && axis < data_desc->ndims + && group_size > 0 && group_size <= data_desc->dims[axis]; + if (!args_ok) return invalid_arguments; + + auto sd = shuffle_desc_t(); + sd.primitive_kind = primitive_kind::shuffle; + sd.prop_kind = prop_kind; + sd.data_desc = *data_desc; + sd.axis = axis; + sd.group_size = group_size; + + bool consistency = true + && sd.data_desc.dims[axis] % sd.group_size == 0; + if (!consistency) return invalid_arguments; + + *shuffle_desc = sd; + return success; +} +} + +status_t mkldnn_shuffle_forward_desc_init(shuffle_desc_t *shuffle_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, int axis, + dim_t group_size) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return shuffle_desc_init(shuffle_desc, prop_kind, data_desc, axis, + group_size); +} + +status_t mkldnn_shuffle_backward_desc_init(shuffle_desc_t *shuffle_desc, + const memory_desc_t *diff_data_desc, int axis, dim_t group_size) { + return shuffle_desc_init(shuffle_desc, backward_data, diff_data_desc, axis, + group_size); +} + +// vim: et ts=5 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp new file mode 100644 index 000000000000..cc5553fe7fe9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp @@ -0,0 +1,121 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SHUFFLE_PD_HPP +#define SHUFFLE_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct shuffle_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::shuffle; + + typedef shuffle_pd_t base_class; + typedef shuffle_pd_t hint_class; + + shuffle_pd_t(engine_t *engine, + const shuffle_desc_t *adesc, + const primitive_attr_t *attr, + const shuffle_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const shuffle_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::shuffle_d: + *(const shuffle_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (is_fwd()) { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + } else { + if (arg == MKLDNN_ARG_DIFF_DST) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + } + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 && is_fwd() ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 && is_fwd() ? &data_md_ : nullptr; } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 && !is_fwd() ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 && !is_fwd() ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + /* shuffle aux functions */ + + dim_t MB() const { return data_md()->dims[0]; } + dim_t C() const { return ndims() >= 2 ? data_md()->dims[1] : 1; } + dim_t D() const { return ndims() >= 5 ? data_md()->dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_md()->dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_md()->dims[ndims() - 1] : 1; } + + int ndims() const { return data_md()->ndims; } + + int axis() const { return desc_.axis; } + dim_t group_size() const { return desc_.group_size; } + dim_t axis_size() const { return data_md()->dims[axis()]; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + const memory_desc_t *data_md() const { return &data_md_; } + +protected: + shuffle_desc_t desc_; + const shuffle_pd_t *hint_fwd_pd_; + memory_desc_t data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp new file mode 100644 index 000000000000..82848e3d1f03 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t softmax_desc_init(softmax_desc_t *softmax_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, const memory_desc_t *diff_desc, int softmax_axis) { + bool args_ok = true + && !any_null(softmax_desc, data_desc) + && 0 <= softmax_axis + && softmax_axis < data_desc->ndims; + if (!args_ok) return invalid_arguments; + + auto sd = softmax_desc_t(); + sd.primitive_kind = primitive_kind::softmax; + sd.prop_kind = prop_kind; + + bool is_bwd = (sd.prop_kind == backward_data); + sd.data_desc = *data_desc; + sd.diff_desc = is_bwd ? *diff_desc : zero_md(); + sd.softmax_axis = softmax_axis; + + *softmax_desc = sd; + return success; +} +} + +status_t mkldnn_softmax_forward_desc_init(softmax_desc_t *softmax_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, + int softmax_axis) { + if (!one_of(prop_kind, forward_inference, forward_training)) + return invalid_arguments; + return softmax_desc_init(softmax_desc, prop_kind, data_desc, nullptr, softmax_axis); +} + +status_t mkldnn_softmax_backward_desc_init(softmax_desc_t *softmax_desc, + const memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc, + int softmax_axis) { + return softmax_desc_init(softmax_desc, prop_kind::backward_data, + data_desc, diff_desc, softmax_axis); +} +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp new file mode 100644 index 000000000000..8a16ce901ca4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp @@ -0,0 +1,161 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SOFTMAX_PD_HPP +#define SOFTMAX_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct softmax_fwd_pd_t; + +struct softmax_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::softmax; + + softmax_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const softmax_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::softmax_d: + *(const softmax_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common softmax aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + softmax_desc_t desc_; + const softmax_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct softmax_fwd_pd_t: public softmax_pd_t { + typedef softmax_fwd_pd_t base_class; + typedef softmax_fwd_pd_t hint_class; + + softmax_fwd_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : softmax_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } +}; + +struct softmax_bwd_pd_t: public softmax_pd_t { + typedef softmax_bwd_pd_t base_class; + typedef softmax_fwd_pd_t hint_class; + + softmax_bwd_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : softmax_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_DST, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual int n_inputs() const override + { return 2 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.cpp b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp new file mode 100644 index 000000000000..00af8935c027 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "stream.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +/* API */ + +status_t mkldnn_stream_create(stream_t **stream, engine_t *engine, + unsigned flags) { + bool args_ok = true + && !utils::any_null(stream, engine) + && flags == stream_flags::default_flags; + if (!args_ok) + return invalid_arguments; + + return safe_ptr_assign(*stream, new stream_t(engine, flags)); +} + +status_t mkldnn_stream_destroy(stream_t *stream) { + delete stream; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.hpp b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp new file mode 100644 index 000000000000..f010e5f6ed8f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp @@ -0,0 +1,44 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef STREAM_HPP +#define STREAM_HPP + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" + +struct mkldnn_stream: public mkldnn::impl::c_compatible { + mkldnn_stream(mkldnn::impl::engine_t *engine, unsigned flags) + : engine_(engine), flags_(flags) {} + virtual ~mkldnn_stream() {} + + /** returns stream's engine */ + mkldnn::impl::engine_t *engine() const { return engine_; } + + /** returns stream's kind */ + unsigned flags() const { return flags_; } + +protected: + mkldnn::impl::engine_t *engine_; + unsigned flags_; +}; + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum.cpp b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp new file mode 100644 index 000000000000..365663c0f86e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp @@ -0,0 +1,79 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "sum_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_sum_primitive_desc_create(primitive_desc_t **sum_pd, + const memory_desc_t *dst_md, int n, const float *scales, + const memory_desc_t *src_mds, const primitive_attr_t *attr, + engine_t *engine) { + bool args_ok = !any_null(sum_pd, src_mds, scales) && n > 0; + if (!args_ok) return invalid_arguments; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + const int ndims = src_mds[0].ndims; + const dims_t &dims = src_mds[0].dims; + const data_type_t dt = src_mds[0].data_type; + + for (int i = 1; i < n; ++i) { + if (src_mds[i].ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (src_mds[i].dims[d] != dims[d]) + return invalid_arguments; + } + if (src_mds[i].data_type != dt) return invalid_arguments; + } + + memory_desc_t dummy_dst_md; + if (dst_md) { + if (dst_md->ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (dst_md->dims[d] != dims[d]) + return invalid_arguments; + } + } else { + dummy_dst_md = src_mds[0]; + dummy_dst_md.format_kind = format_kind::any; + dst_md = &dummy_dst_md; + } + + auto s_pd = reinterpret_cast(sum_pd); + + for (auto s = engine->get_sum_implementation_list(); *s; ++s) { + if ((*s)(s_pd, engine, attr, dst_md, n, scales, src_mds) == success) { + (*s_pd)->init_info(); + (*s_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp new file mode 100644 index 000000000000..80254667df64 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp @@ -0,0 +1,143 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SUM_PD_HPP +#define SUM_PD_HPP + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct sum_pd_t: public primitive_desc_t { + sum_pd_t(engine_t *engine, const primitive_attr_t *attr, + const memory_desc_t *dst_md, int n, const float *scales, + const memory_desc_t *src_mds) + : primitive_desc_t(engine, attr, primitive_kind::sum) + , n_(n), dst_md_(*dst_md) + { + scales_.reserve(n_); + for (int i = 0; i < n_; ++i) scales_.push_back(scales[i]); + src_mds_.reserve(n_); + for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]); + } + + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg >= MKLDNN_ARG_MULTIPLE_SRC + && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index < n_inputs() ? &src_mds_[index] : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return n_; } + virtual int n_outputs() const override { return 1; } + + const float *scales() const { return &scales_[0]; } + +protected: + int n_; + nstl::vector scales_; + memory_desc_t dst_md_; + nstl::vector src_mds_; + +protected: + /* inits dst_md_ in simple cases. The call may fail. */ + status_t init() { + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(&src_mds_[i]); + if (!src_d.is_blocking_desc() || src_d.is_additional_buffer()) + return status::unimplemented; + } + bool ok = true + && set_default_params() == status::success + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + + status_t set_default_params() { + if (dst_md_.format_kind != format_kind::any) + return status::success; + + /* The stupidest ever heuristics (but not the same as we had before): + * - Pick the first non-plain format; + * - If all formats are plain, pick the format of the first input + */ + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (!src_d.is_plain() && src_d.is_blocking_desc()) { + return memory_desc_init_by_blocking_desc(dst_md_, + src_d.blocking_desc()); + } + } + + if (src_mds_[0].format_kind != format_kind::blocked) + return status::unimplemented; + + dst_md_ = src_mds_[0]; + + return status::success; + } +}; + +#define DECLARE_SUM_PD_t(impl_name, ...) \ + static status_t create(sum_pd_t **sum_pd, \ + engine_t *engine, const primitive_attr_t *attr, \ + const memory_desc_t *dst_md, int n, const float *scales, \ + const memory_desc_t *src_mds) { \ + using namespace status; \ + auto _pd = new pd_t(engine, attr, dst_md, n, scales, src_mds); \ + if (_pd == nullptr) return out_of_memory; \ + if (_pd->init() != success) { delete _pd; return unimplemented; } \ + return safe_ptr_assign(*sum_pd, _pd); \ + } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual const char *name() const override { return impl_name; } \ + +#define DECLARE_SUM_PD_T(impl_name, ...) \ + DECLARE_SUM_PD_t(impl_name, __VA_ARGS__) + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp new file mode 100644 index 000000000000..a408f45980b3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp @@ -0,0 +1,200 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TAG_TRAITS_HPP +#define TAG_TRAITS_HPP + +#include + +#include "c_types_map.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +enum class block_dim_t { + _, + _A, _B, + _AB, _BC, +}; + +enum class inner_blk_t { + _, + _4a, _4b, + _8a, _8b, + _16a, _16b, + + _4b4a, _4b4c, _4c4b, + _8a8b, _8b8a, _8b8c, _8c8b, + _16a16b, _16a4b, _16b16a, _16b4c, _16b16c, _16c16b, + + _2c8b4c, _8a16b2a, _4b16a4b, _8b16a2b, _8b16c2b, _4c16b4c, _8c16b2c, +}; + +/** returns the offset within the block for weights blocked over oc and ic */ +template +constexpr int AB_or_BC_blk_off(int x0, int x1) { + using ib = inner_blk_t; + static_assert(utils::one_of(f, ib::_4b4a, ib::_4b4c, ib::_4c4b, ib::_8a8b, + ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_16a16b, ib::_16a4b, + ib::_16b16a, ib::_16b4c, ib::_16b16c, ib::_16c16b, ib::_2c8b4c, + ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b, + ib::_4c16b4c, ib::_8c16b2c), + "unexpected inner_blk format"); + return false ? 0 + : (f == ib::_4b4c) ? 4 * x0 + x1 + : (f == ib::_4b4a || f == ib::_4c4b) ? 4 * x1 + x0 + : (f == ib::_8a8b || f == ib::_8b8c) ? 8 * x0 + x1 + : (f == ib::_8b8a || f == ib::_8c8b) ? 8 * x1 + x0 + : (f == ib::_16a16b || f == ib::_16b16c) ? 16 * x0 + x1 + : (f == ib::_16b16a || f == ib::_16c16b) ? 16 * x1 + x0 + : (f == ib::_16a4b || f == ib::_16b4c) ? 4 * x0 + x1 + : (f == ib::_8a16b2a || f == ib::_8b16c2b) ? (x0 / 2) * 32 + x1 * 2 + x0 % 2 + : (f == ib::_4b16a4b || f == ib::_4c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4 + : (f == ib::_8b16a2b || f == ib::_8c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2 + : (f == ib::_2c8b4c) ? (x1 / 4) * 32 + x0 * 4 + x1 % 4 + : INT_MIN; +} + +template struct inner_blk_traits { + using ib = inner_blk_t; +}; + +template struct tag_traits { + // block_dim_t block_dims; + // inner_blk_t inner_blks; + // int ndims; +}; + +#define DECL_TRAITS(_tag, _blk_fmt, _inner_blk, _ndims) \ +template <> struct tag_traits { \ + static constexpr block_dim_t block_dims = block_dim_t::_blk_fmt; \ + static constexpr inner_blk_t inner_blks = inner_blk_t::_inner_blk; \ + static constexpr int ndims = _ndims; \ +} + +DECL_TRAITS(a, _, _, 1); +DECL_TRAITS(ab, _, _, 2); +DECL_TRAITS(abc, _, _, 3); +DECL_TRAITS(abcd, _, _, 4); +DECL_TRAITS(abcde, _, _, 5); +DECL_TRAITS(abcdef, _, _, 6); +DECL_TRAITS(abdec, _, _, 5); +DECL_TRAITS(acb, _, _, 3); +DECL_TRAITS(acbde, _, _, 5); +DECL_TRAITS(acdb, _, _, 4); +DECL_TRAITS(acdeb, _, _, 5); +DECL_TRAITS(ba, _, _, 2); +DECL_TRAITS(bac, _, _, 3); +DECL_TRAITS(bacd, _, _, 4); +DECL_TRAITS(bcda, _, _, 4); +DECL_TRAITS(cba, _, _, 3); +DECL_TRAITS(cdba, _, _, 4); +DECL_TRAITS(cdeba, _, _, 5); +DECL_TRAITS(decab, _, _, 5); + +DECL_TRAITS(Abc4a, _A, _4a, 3); +DECL_TRAITS(aBc4b, _B, _4b, 3); +DECL_TRAITS(ABc4b16a4b, _AB, _4b16a4b, 3); +DECL_TRAITS(ABc4b4a, _AB, _4b4a, 3); +DECL_TRAITS(Abcd4a, _A, _4a, 4); +DECL_TRAITS(aBcd4b, _B, _4b, 4); +DECL_TRAITS(ABcd4b4a, _AB, _4b4a, 4); +DECL_TRAITS(aBCd4c16b4c, _BC, _4c16b4c, 4); +DECL_TRAITS(aBCd4c4b, _BC, _4c4b, 4); +DECL_TRAITS(Abcde4a, _A, _4a, 5); +DECL_TRAITS(aBcde4b, _B, _4b, 5); +DECL_TRAITS(ABcde4b4a, _AB, _4b4a, 5); +DECL_TRAITS(aBCde4c4b, _BC, _4c4b, 5); +DECL_TRAITS(aBcdef4b, _B, _4b, 6); +DECL_TRAITS(aBCdef4c4b, _BC, _4c4b, 6); +DECL_TRAITS(aBdc4b, _B, _4b, 4); +DECL_TRAITS(aBdec4b, _B, _4b, 5); +DECL_TRAITS(aBdefc4b, _B, _4b, 6); +DECL_TRAITS(Acb4a, _A, _4a, 3); +DECL_TRAITS(Acdb4a, _A, _4a, 4); +DECL_TRAITS(Acdeb4a, _A, _4a, 5); + +DECL_TRAITS(Abc16a, _A, _16a, 3); +DECL_TRAITS(ABc16a16b, _AB, _16a16b, 3); +DECL_TRAITS(aBc16b, _B, _16b, 3); +DECL_TRAITS(ABc16b16a, _AB, _16b16a, 3); +DECL_TRAITS(ABc8a16b2a, _AB, _8a16b2a, 3); +DECL_TRAITS(ABc8a8b, _AB, _8a8b, 3); +DECL_TRAITS(aBc8b, _B, _8b, 3); +DECL_TRAITS(ABc8b16a2b, _AB, _8b16a2b, 3); +DECL_TRAITS(ABc8b8a, _AB, _8b8a, 3); +DECL_TRAITS(Abcd16a, _A, _16a, 4); +DECL_TRAITS(ABcd16a16b, _AB, _16a16b, 4); +DECL_TRAITS(aBcd16b, _B, _16b, 4); +DECL_TRAITS(ABcd16b16a, _AB, _16b16a, 4); +DECL_TRAITS(aBCd16b16c, _BC, _16b16c, 4); +DECL_TRAITS(aBCd16c16b, _BC, _16c16b, 4); +DECL_TRAITS(ABcd4b16a4b, _AB, _4b16a4b, 4); +DECL_TRAITS(ABcd8a16b2a, _AB, _8a16b2a, 4); +DECL_TRAITS(ABcd8a8b, _AB, _8a8b, 4); +DECL_TRAITS(aBcd8b, _B, _8b, 4); +DECL_TRAITS(ABcd8b16a2b, _AB, _8b16a2b, 4); +DECL_TRAITS(aBCd8b16c2b, _BC, _8b16c2b, 4); +DECL_TRAITS(ABcd8b8a, _AB, _8b8a, 4); +DECL_TRAITS(aBCd8b8c, _BC, _8b8c, 4); +DECL_TRAITS(aBCd8c16b2c, _BC, _8c16b2c, 4); +DECL_TRAITS(aBCd8c8b, _BC, _8c8b, 4); +DECL_TRAITS(Abcde16a, _A, _16a, 5); +DECL_TRAITS(ABcde16a16b, _AB, _16a16b, 5); +DECL_TRAITS(aBcde16b, _B, _16b, 5); +DECL_TRAITS(ABcde16b16a, _AB, _16b16a, 5); +DECL_TRAITS(aBCde16b16c, _BC, _16b16c, 5); +DECL_TRAITS(aBCde16c16b, _BC, _16c16b, 5); +DECL_TRAITS(aBCde4c16b4c, _BC, _4c16b4c, 5); +DECL_TRAITS(Abcde8a, _A, _8a, 5); +DECL_TRAITS(ABcde8a8b, _AB, _8a8b, 5); +DECL_TRAITS(aBcde8b, _B, _8b, 5); +DECL_TRAITS(ABcde8b16a2b, _AB, _8b16a2b, 5); +DECL_TRAITS(aBCde8b16c2b, _BC, _8b16c2b, 5); +DECL_TRAITS(ABcde8b8a, _AB, _8b8a, 5); +DECL_TRAITS(aBCde8b8c, _BC, _8b8c, 5); +DECL_TRAITS(aBCde2c8b4c, _BC, _2c8b4c, 5); +DECL_TRAITS(aBCde8c16b2c, _BC, _8c16b2c, 5); +DECL_TRAITS(aBCde4b4c, _BC, _4b4c, 5); +DECL_TRAITS(aBCde8c8b, _BC, _8c8b, 5); +DECL_TRAITS(aBcdef16b, _B, _16b, 6); +DECL_TRAITS(aBCdef16b16c, _BC, _16b16c, 6); +DECL_TRAITS(aBCdef16c16b, _BC, _16c16b, 6); +DECL_TRAITS(aBCdef8b8c, _BC, _8b8c, 6); +DECL_TRAITS(aBCdef8c16b2c, _BC, _8c16b2c, 6); +DECL_TRAITS(aBCdef8c8b, _BC, _8c8b, 6); +DECL_TRAITS(aBdc16b, _B, _16b, 4); +DECL_TRAITS(aBdc8b, _B, _8b, 4); +DECL_TRAITS(aBdec16b, _B, _16b, 5); +DECL_TRAITS(aBdec8b, _B, _8b, 5); +DECL_TRAITS(aBdefc16b, _B, _16b, 6); +DECL_TRAITS(aBdefc8b, _B, _8b, 6); +DECL_TRAITS(Acb16a, _A, _16a, 3); +DECL_TRAITS(Acb8a, _A, _8a, 3); +DECL_TRAITS(aCBd16b16c, _BC, _16b16c, 4); +DECL_TRAITS(aCBde16b16c, _BC, _16b16c, 5); +DECL_TRAITS(Acdb16a, _A, _16a, 4); +DECL_TRAITS(Acdb8a, _A, _8a, 4); +DECL_TRAITS(Acdeb16a, _A, _16a, 5); +DECL_TRAITS(Acdeb8a, _A, _8a, 5); +DECL_TRAITS(BAc16a16b, _AB, _16a16b, 3); +DECL_TRAITS(BAcd16a16b, _AB, _16a16b, 4); + +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp new file mode 100644 index 000000000000..4f06368738cb --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp @@ -0,0 +1,348 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TYPE_HELPERS_HPP +#define TYPE_HELPERS_HPP + +#include +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "mkldnn_traits.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "math_utils.hpp" + +namespace mkldnn { +namespace impl { + +template +status_t safe_ptr_assign(T * &lhs, T* rhs) { + if (rhs == nullptr) return status::out_of_memory; + lhs = rhs; + return status::success; +} + +template struct is_subset +{ static constexpr bool value = false; }; +template struct is_subset +{ static constexpr bool value = true; }; +template struct is_subset::value, float>::type> +{ static constexpr bool value = true; }; +#define ISSPEC(t1, t2) template <> \ + struct is_subset { static constexpr bool value = true; } +ISSPEC(int16_t, int32_t); +ISSPEC(int8_t, int32_t); +ISSPEC(uint8_t, int32_t); +ISSPEC(int8_t, int16_t); +ISSPEC(uint8_t, int16_t); +#undef ISSPEC + +inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs); + +namespace types { + +inline size_t data_type_size(data_type_t data_type) { + using namespace data_type; + switch (data_type) { + case f32: return sizeof(prec_traits::type); + case s32: return sizeof(prec_traits::type); + case s8: return sizeof(prec_traits::type); + case u8: return sizeof(prec_traits::type); + case data_type::undef: + default: assert(!"unknown data_type"); + } + return 0; /* not supposed to be reachable */ +} + +inline format_kind_t format_tag_to_kind(format_tag_t tag) { + switch (tag) { + case format_tag::undef: return format_kind::undef; + case format_tag::any: return format_kind::any; + case format_tag::last: return format_kind::undef; + default: return format_kind::blocked; + } + + assert(!"unreachable"); + return format_kind::undef; +} + +inline bool memory_extra_desc_is_equal(const memory_extra_desc_t &lhs, + const memory_extra_desc_t &rhs) { + return true + && lhs.flags == rhs.flags + && IMPLICATION(lhs.flags & memory_extra_flags::compensation_conv_s8s8, + lhs.compensation_mask == rhs.compensation_mask) + && IMPLICATION(lhs.flags & memory_extra_flags::scale_adjust, + lhs.scale_adjust == rhs.scale_adjust); +} + +inline bool blocking_desc_is_equal(const blocking_desc_t &lhs, + const blocking_desc_t &rhs, int ndims = MKLDNN_MAX_NDIMS) { + using mkldnn::impl::utils::array_cmp; + return true + && lhs.inner_nblks == rhs.inner_nblks + && array_cmp(lhs.strides, rhs.strides, ndims) + && array_cmp(lhs.inner_blks, rhs.inner_blks, lhs.inner_nblks) + && array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks); +} + +inline bool wino_desc_is_equal(const wino_desc_t &lhs, + const wino_desc_t &rhs) { + return lhs.wino_format == rhs.wino_format + && lhs.alpha == rhs.alpha + && lhs.ic == rhs.ic + && lhs.oc == rhs.oc + && lhs.ic_block == rhs.ic_block + && lhs.oc_block == rhs.oc_block + && lhs.ic2_block == rhs.ic2_block + && lhs.oc2_block == rhs.oc2_block + && lhs.r == rhs.r; +} + +inline bool rnn_packed_desc_is_equal( + const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) { + bool ok = true + && lhs.format == rhs.format + && lhs.n_parts == rhs.n_parts + && lhs.offset_compensation == rhs.offset_compensation + && lhs.size == rhs.size + && lhs.n == rhs.n; + if (!ok) + return false; + + for (int i = 0; i < rhs.n_parts; i++) + ok = ok && lhs.parts[i] == rhs.parts[i]; + for (int i = 0; i < rhs.n_parts; i++) + ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i]; + return ok; +} + +inline memory_desc_t zero_md() { + auto zero = memory_desc_t(); + return zero; +} + +inline bool is_zero_md(const memory_desc_t *md) { + return md == nullptr || *md == zero_md(); +} + +inline data_type_t default_accum_data_type(data_type_t src_dt, + data_type_t dst_dt) { + using namespace utils; + using namespace data_type; + + if (one_of(f32, src_dt, dst_dt)) return f32; + if (one_of(s32, src_dt, dst_dt)) return s32; + + if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32; + + assert(!"unimplemented use-case: no default parameters available"); + return dst_dt; +} + +inline data_type_t default_accum_data_type(data_type_t src_dt, + data_type_t wei_dt, data_type_t dst_dt, prop_kind_t prop_kind) { + using namespace utils; + using namespace data_type; + using namespace prop_kind; + + /* prop_kind doesn't matter */ + if (everyone_is(f32, src_dt, wei_dt, dst_dt)) return f32; + + if (one_of(prop_kind, forward_training, forward_inference)) { + if ((src_dt == u8 || src_dt == s8) + && wei_dt == s8 && one_of(dst_dt, f32, s32, s8, u8)) + return s32; + } else if (prop_kind == backward_data) { + if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 && + one_of(dst_dt, s8, u8)) + return s32; + } + + assert(!"unimplemented use-case: no default parameters available"); + return dst_dt; +} + +} + +inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) { + using namespace mkldnn::impl::utils; + bool base_equal = true + && lhs.ndims == rhs.ndims + && array_cmp(lhs.dims, rhs.dims, lhs.ndims) + && lhs.data_type == rhs.data_type + && array_cmp(lhs.padded_dims, rhs.padded_dims, lhs.ndims) + && array_cmp(lhs.padded_offsets, rhs.padded_offsets, lhs.ndims) + && lhs.offset0 == rhs.offset0 + && lhs.format_kind == rhs.format_kind; + if (!base_equal) return false; + if (!types::memory_extra_desc_is_equal(lhs.extra, rhs.extra)) return false; + if (lhs.format_kind == format_kind::blocked) + return types::blocking_desc_is_equal(lhs.format_desc.blocking, + rhs.format_desc.blocking, lhs.ndims); + else if (lhs.format_kind == format_kind::wino) + return types::wino_desc_is_equal(lhs.format_desc.wino_desc, + rhs.format_desc.wino_desc); + else if (lhs.format_kind == format_kind::rnn_packed) + return types::rnn_packed_desc_is_equal(lhs.format_desc.rnn_packed_desc, + rhs.format_desc.rnn_packed_desc); + return true; +} + +inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) { + return !operator==(lhs, rhs); +} + +inline status_t memory_desc_init_by_strides(memory_desc_t &md, + const dims_t strides) { + return mkldnn_memory_desc_init_by_strides( + &md, md.ndims, md.dims, md.data_type, strides); +} + +inline status_t memory_desc_init_by_tag(memory_desc_t &md, format_tag_t tag, + const dims_t strides = nullptr) { + status_t status = mkldnn_memory_desc_init_by_tag( + &md, md.ndims, md.dims, md.data_type, tag); + if (status != status::success || strides == nullptr) + return status; + + /* TODO: add consistency check */ + + for (int d = 0; d < md.ndims; ++d) + md.format_desc.blocking.strides[d] = strides[d]; + + return status::success; +} + +/** inits memory descriptor based on logical dimensions kept in @p md, and the + * blocking structure @p blk. + * + * @note blk.strides represent the order only (from smaller to bigger) + * + * TODO: move md related functions to one single place + */ +inline status_t memory_desc_init_by_blocking_desc(memory_desc_t &md, + const blocking_desc_t &blk) { + dims_t blocks = {0}; + utils::array_set(blocks, 1, md.ndims); + dim_t block_size = 1; + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { + blocks[blk.inner_idxs[iblk]] *= blk.inner_blks[iblk]; + block_size *= blk.inner_blks[iblk]; + } + + for (int d = 0; d < md.ndims; ++d) { + md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); + md.padded_offsets[d] = 0; + } + md.offset0 = 0; + + md.format_kind = format_kind::blocked; + auto &mblk = md.format_desc.blocking; + mblk = blk; + + const int ndims = nstl::min(MKLDNN_MAX_NDIMS, md.ndims); // make GCC 5 happy + utils::array_copy(mblk.strides, blk.strides, ndims); + + int perm[MKLDNN_MAX_NDIMS]; + for (int d = 0; d < ndims; ++d) perm[d] = d; + + utils::simultaneous_sort(mblk.strides, perm, ndims, + [](stride_t a, stride_t b) { return b - a; }); + + dim_t stride = block_size; + for (int _d = ndims - 1; _d >= 0; --_d) { + const int d = perm[_d]; + md.format_desc.blocking.strides[d] = stride; + stride *= md.padded_dims[d] / blocks[d]; + } + + md.extra = utils::zero(); + + return status::success; +} + +/** returns true if memory desc @p md corresponds to the given format tag and + * strides. + * If strides are not passed (or passed as nullptr) the dense structure is + * assumed (i.e. the one that mkldnn_memory_desc_init_by_tag() returns). + * Strides might contain `0` value, indicating the stride must match the one + * that mkldnn_memory_desc_init_by_tag() returns. + * Strides might contain `-1` values, that would be ignored during the + * comparison. For instance, this can be used if a stride along minibatch + * doesn't matter. */ +inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag, + const dims_t strides = nullptr) { + if (md.format_kind != types::format_tag_to_kind(tag)) + return false; + + memory_desc_t md_gold; + status_t status = mkldnn_memory_desc_init_by_tag( + &md_gold, md.ndims, md.dims, md.data_type, tag); + if (status != status::success) return false; + + if (md.format_kind != format_kind::blocked) + return false; // unimplemented yet + + const auto &blk = md.format_desc.blocking; + const auto &blk_gold = md_gold.format_desc.blocking; + + using utils::array_cmp; + bool same_blocks = true + && blk.inner_nblks == blk_gold.inner_nblks + && array_cmp(blk.inner_blks, blk_gold.inner_blks, blk.inner_nblks) + && array_cmp(blk.inner_idxs, blk_gold.inner_idxs, blk.inner_nblks); + + if (!same_blocks) + return false; + + if (strides == nullptr) + return array_cmp(blk.strides, blk_gold.strides, md.ndims); + + for (int d = 0; d < md.ndims; ++d) { + dim_t stride = strides[d]; + if (stride == -1) continue; + if (stride == 0) stride = blk_gold.strides[d]; + if (blk.strides[d] != stride) return false; + } + + return true; +} + +/** returns matching tag (or undef if match is not found) + * XXX: This is a workaround that eventually should go away! */ +template +format_tag_t memory_desc_matches_one_of_tag(const memory_desc_t &md, + Tags ...tags) { + for (const auto tag: {tags...}) { + if (memory_desc_matches_tag(md, tag)) + return tag; + } + return format_tag::undef; +} + +} +} + +#include "memory_desc_wrapper.hpp" + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.cpp b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp new file mode 100644 index 000000000000..d23f4682dc32 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp @@ -0,0 +1,135 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#ifdef _WIN32 +#include +#include +#endif +#include +#include +#include + +#include "mkldnn.h" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +int getenv(const char *name, char *buffer, int buffer_size) { + if (name == NULL || buffer_size < 0 || (buffer == NULL && buffer_size > 0)) + return INT_MIN; + + int result = 0; + int term_zero_idx = 0; + size_t value_length = 0; + +#ifdef _WIN32 + value_length = GetEnvironmentVariable(name, buffer, buffer_size); +#else + const char *value = ::getenv(name); + value_length = value == NULL ? 0 : strlen(value); +#endif + + if (value_length > INT_MAX) + result = INT_MIN; + else { + int int_value_length = (int)value_length; + if (int_value_length >= buffer_size) { + result = -int_value_length; + } else { + term_zero_idx = int_value_length; + result = int_value_length; +#ifndef _WIN32 + strncpy(buffer, value, value_length); +#endif + } + } + + if (buffer != NULL) + buffer[term_zero_idx] = '\0'; + return result; +} + +int getenv_int(const char *name, int default_value) +{ + int value = default_value; + // # of digits in the longest 32-bit signed int + sign + terminating null + const int len = 12; + char value_str[len]; + if (getenv(name, value_str, len) > 0) + value = atoi(value_str); + return value; +} + +FILE *fopen(const char *filename, const char *mode) { +#ifdef _WIN32 + FILE *fp = NULL; + return ::fopen_s(&fp, filename, mode) ? NULL : fp; +#else + return ::fopen(filename, mode); +#endif +} + +void *malloc(size_t size, int alignment) { + void *ptr; + +#ifdef _WIN32 + ptr = _aligned_malloc(size, alignment); + int rc = ptr ? 0 : -1; +#else + int rc = ::posix_memalign(&ptr, alignment, size); +#endif + + return (rc == 0) ? ptr : 0; +} + +void free(void *p) { +#ifdef _WIN32 + _aligned_free(p); +#else + ::free(p); +#endif +} + +// Atomic operations +int32_t fetch_and_add(int32_t *dst, int32_t val) { +#ifdef _WIN32 + return InterlockedExchangeAdd(reinterpret_cast(dst), val); +#else + return __sync_fetch_and_add(dst, val); +#endif +} + +static int jit_dump_flag = 0; +static bool jit_dump_flag_initialized = false; +bool jit_dump_enabled() { + if (!jit_dump_flag_initialized) { + jit_dump_flag = getenv_int("MKLDNN_JIT_DUMP"); + jit_dump_flag_initialized = true; + } + return jit_dump_flag != 0; +} + +} +} + +mkldnn_status_t mkldnn_set_jit_dump(int enabled) { + using namespace mkldnn::impl::status; + mkldnn::impl::jit_dump_flag = enabled; + mkldnn::impl::jit_dump_flag_initialized = true; + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp new file mode 100644 index 000000000000..d5a8ec5139fc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp @@ -0,0 +1,370 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef UTILS_HPP +#define UTILS_HPP + +#include +#include +#include +#include +#include + +#if defined(__x86_64__) || defined(_M_X64) +#define MKLDNN_X86_64 +#endif + +#define MSAN_ENABLED 0 +#if defined(__has_feature) +#if __has_feature(memory_sanitizer) +#undef MSAN_ENABLED +#define MSAN_ENABLED 1 +#include +#endif +#endif + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +// Sanity check for 64 bits +static_assert(sizeof(void*) == 8, "Intel(R) MKL-DNN supports 64 bit only"); + +#define CHECK(f) do { \ + status_t status = f; \ + if (status != status::success) \ + return status; \ +} while (0) + +#define IMPLICATION(cause, effect) (!(cause) || !!(effect)) + +namespace utils { + +/* a bunch of std:: analogues to be compliant with any msvs version + * + * Rationale: msvs c++ (and even some c) headers contain special pragma that + * injects msvs-version check into object files in order to abi-mismatches + * during the static linking. This makes sense if e.g. std:: objects are passed + * through between application and library, which is not the case for mkl-dnn + * (since there is no any c++-rt dependent stuff, ideally...). */ + +/* SFINAE helper -- analogue to std::enable_if */ +template struct enable_if {}; +template struct enable_if { typedef T type; }; + +/* analogue std::conditional */ +template struct conditional {}; +template struct conditional +{ typedef T type; }; +template struct conditional +{ typedef F type; }; + +template struct conditional3 {}; +template +struct conditional3 { typedef T type; }; +template +struct conditional3 { typedef FT type; }; +template +struct conditional3 { typedef FF type; }; + +template struct conditional_v {}; +template struct conditional_v +{ static constexpr U value = t; }; +template struct conditional_v +{ static constexpr U value = f; }; + +template struct remove_reference { typedef T type; }; +template struct remove_reference { typedef T type; }; +template struct remove_reference { typedef T type; }; + +template +inline T&& forward(typename utils::remove_reference::type &t) +{ return static_cast(t); } +template +inline T&& forward(typename utils::remove_reference::type &&t) +{ return static_cast(t); } + +template +inline typename remove_reference::type zero() +{ auto zero = typename remove_reference::type(); return zero; } + +template +inline bool everyone_is(T val, P item) { return val == item; } +template +inline bool everyone_is(T val, P item, Args... item_others) { + return val == item && everyone_is(val, item_others...); +} + +template +constexpr bool one_of(T val, P item) { return val == item; } +template +constexpr bool one_of(T val, P item, Args... item_others) { + return val == item || one_of(val, item_others...); +} + +template +inline bool any_null(Args... ptrs) { return one_of(nullptr, ptrs...); } + +template +inline void array_copy(T *dst, const T *src, size_t size) { + for (size_t i = 0; i < size; ++i) dst[i] = src[i]; +} +template +inline bool array_cmp(const T *a1, const T *a2, size_t size) { + for (size_t i = 0; i < size; ++i) if (a1[i] != a2[i]) return false; + return true; +} +template +inline void array_set(T *arr, const U& val, size_t size) { + for (size_t i = 0; i < size; ++i) arr[i] = static_cast(val); +} + +namespace product_impl { +template struct int2type{}; + +template +constexpr int product_impl(const T *arr, int2type<0>) { return arr[0]; } + +template +inline T product_impl(const T *arr, int2type) { + return arr[0]*product_impl(arr+1, int2type()); } +} + +template +inline T array_product(const T *arr) { + return product_impl::product_impl(arr, product_impl::int2type()); +} + +template +inline R array_product(const T *arr, size_t size) { + R prod = 1; + for (size_t i = 0; i < size; ++i) prod *= arr[i]; + return prod; +} + +/** sorts an array of values using @p comparator. While sorting the array + * of value, the function permutes an array of @p keys accordingly. + * + * @note The arrays of @p keys can be omitted. In this case the function + * sorts the array of @vals only. + */ +template +inline void simultaneous_sort(T *vals, U *keys, size_t size, F comparator) { + if (size == 0) return; + + for (size_t i = 0; i < size - 1; ++i) { + bool swapped = false; + + for (size_t j = 0; j < size - i - 1; j++) { + if (comparator(vals[j], vals[j + 1]) > 0) { + nstl::swap(vals[j], vals[j + 1]); + if (keys) nstl::swap(keys[j], keys[j + 1]); + swapped = true; + } + } + + if (swapped == false) break; + } +} + +template +inline typename remove_reference::type div_up(const T a, const U b) { + assert(b); + return (a + b - 1) / b; +} + +template +inline typename remove_reference::type rnd_up(const T a, const U b) { + return div_up(a, b) * b; +} + +template +inline typename remove_reference::type rnd_dn(const T a, const U b) { + return (a / b) * b; +} + +template T *align_ptr(T *ptr, uintptr_t alignment) +{ return (T *)(((uintptr_t)ptr + alignment - 1) & ~(alignment - 1)); } + +template +inline U this_block_size(const T offset, const U max, const V block_size) { + assert(offset < max); + // TODO (Roma): can't use nstl::max() due to circular dependency... we + // need to fix this + const T block_boundary = offset + block_size; + if (block_boundary > max) + return max - offset; + else + return block_size; +} + +template +inline T nd_iterator_init(T start) { return start; } +template +inline T nd_iterator_init(T start, U &x, const W &X, Args &&... tuple) { + start = nd_iterator_init(start, utils::forward(tuple)...); + x = start % X; + return start / X; +} + +inline bool nd_iterator_step() { return true; } +template +inline bool nd_iterator_step(U &x, const W &X, Args &&... tuple) { + if (nd_iterator_step(utils::forward(tuple)...) ) { + x = (x + 1) % X; + return x == 0; + } + return false; +} + +template +inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X) +{ + U max_jump = end - cur; + U dim_jump = X - x; + if (dim_jump <= max_jump) { + x = 0; + cur += dim_jump; + return true; + } else { + cur += max_jump; + x += max_jump; + return false; + } +} +template +inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X, + Args &&... tuple) +{ + if (nd_iterator_jump(cur, end, utils::forward(tuple)...)) { + x = (x + 1) % X; + return x == 0; + } + return false; +} + +template +inline T pick(size_t i, const T &x0) { return x0; } +template +inline T pick(size_t i, const T &x0, Args &&... args) { + return i == 0 ? x0 : pick(i - 1, utils::forward(args)...); +} + +template +T pick_by_prop_kind(prop_kind_t prop_kind, const T &val_fwd_inference, + const T &val_fwd_training, const T &val_bwd_d, const T &val_bwd_w) { + switch (prop_kind) { + case prop_kind::forward_inference: return val_fwd_inference; + case prop_kind::forward_training: return val_fwd_training; + case prop_kind::backward_data: return val_bwd_d; + case prop_kind::backward_weights: return val_bwd_w; + default: assert(!"unsupported prop_kind"); + } + return T(); +} + +template +T pick_by_prop_kind(prop_kind_t prop_kind, + const T &val_fwd, const T &val_bwd_d, const T &val_bwd_w) +{ return pick_by_prop_kind(prop_kind, val_fwd, val_fwd, val_bwd_d, val_bwd_w); } + +template +struct array_offset_calculator { + template + array_offset_calculator(Telem *base, Targs... Fargs) : _dims{ Fargs... } + { + _base_ptr = base; + } + template + inline Telem &operator()(Targs... Fargs) + { + return *(_base_ptr + _offset(1, Fargs...)); + } + +private: + template + inline size_t _offset(size_t const dimension, size_t element) + { + return element; + } + + template + inline size_t _offset(size_t const dimension, size_t theta, size_t element) + { + return element + (_dims[dimension] * theta); + } + + template + inline size_t _offset(size_t const dimension, size_t theta, size_t element, + Targs... Fargs) + { + size_t t_prime = element + (_dims[dimension] * theta); + return _offset(dimension + 1, t_prime, Fargs...); + } + + Telem *_base_ptr; + const int _dims[Tdims]; +}; + +} + +int32_t fetch_and_add(int32_t *dst, int32_t val); +inline void yield_thread() {} + +// Reads an environment variable 'name' and stores its string value in the +// 'buffer' of 'buffer_size' bytes on success. +// +// - Returns the length of the environment variable string value (excluding +// the terminating 0) if it is set and its contents (including the terminating +// 0) can be stored in the 'buffer' without truncation. +// +// - Returns negated length of environment variable string value and writes +// "\0" to the buffer (if it is not NULL) if the 'buffer_size' is to small to +// store the value (including the terminating 0) without truncation. +// +// - Returns 0 and writes "\0" to the buffer (if not NULL) if the environment +// variable is not set. +// +// - Returns INT_MIN if the 'name' is NULL. +// +// - Returns INT_MIN if the 'buffer_size' is negative. +// +// - Returns INT_MIN if the 'buffer' is NULL and 'buffer_size' is greater than +// zero. Passing NULL 'buffer' with 'buffer_size' set to 0 can be used to +// retrieve the length of the environment variable value string. +// +int getenv(const char *name, char *buffer, int buffer_size); +// Reads an integer from the environment +int getenv_int(const char *name, int default_value = 0); +bool jit_dump_enabled(); +FILE *fopen(const char *filename, const char *mode); + +constexpr int msan_enabled = MSAN_ENABLED; +inline void msan_unpoison(void *ptr, size_t size) { +#if MSAN_ENABLED + __msan_unpoison(ptr, size); +#endif +} + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp new file mode 100644 index 000000000000..89a57772cf4c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp @@ -0,0 +1,665 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#ifndef _WIN32 +#include +#endif + +#include "mkldnn.h" +#include "mkldnn_version.h" +#include "c_types_map.hpp" +#include "verbose.hpp" +#include "cpu/cpu_isa_traits.hpp" + +#include "batch_normalization_pd.hpp" +#include "pooling_pd.hpp" +#include "concat_pd.hpp" +#include "reorder_pd.hpp" +#include "convolution_pd.hpp" +#include "rnn_pd.hpp" +#include "deconvolution_pd.hpp" +#include "shuffle_pd.hpp" +#include "eltwise_pd.hpp" +#include "softmax_pd.hpp" +#include "inner_product_pd.hpp" +#include "sum_pd.hpp" +#include "lrn_pd.hpp" + +/* MKL-DNN CPU ISA info */ +#define ISA_ANY "No instruction set specific optimizations" +#define SSE42 "Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2)" +#define AVX "Intel(R) Advanced Vector Extensions (Intel(R) AVX)" +#define AVX2 "Intel(R) Advanced Vector Extensions 2 (Intel(R) AVX2)" +#define AVX512_COMMON "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512)" +#define AVX512_CORE "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512BW, AVX512VL, and AVX512DQ extensions" +#define AVX512_CORE_VNNI "Intel(R) AVX512-Deep Learning Boost (Intel(R) " \ + "AVX512-DL Boost)" +#define AVX512_MIC "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512CD, AVX512ER, and AVX512PF extensions" +#define AVX512_MIC_4OPS "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512_4FMAPS and AVX512_4VNNIW extensions" + +namespace mkldnn { +namespace impl { + +static verbose_t verbose; +static bool initialized; +static bool version_printed = false; + +const verbose_t *mkldnn_verbose() { +#if !defined(DISABLE_VERBOSE) + if (!initialized) { + const int len = 2; + char val[len] = {0}; + if (getenv("MKLDNN_VERBOSE", val, len) == 1) + verbose.level = atoi(val); + initialized = true; + } + if (!version_printed && verbose.level > 0) { + printf("mkldnn_verbose,info," + "Intel(R) MKL-DNN v%d.%d.%d (Git Hash %s),%s\n", + mkldnn_version()->major, mkldnn_version()->minor, + mkldnn_version()->patch, mkldnn_version()->hash, + get_isa_info()); + version_printed = true; + } +#else + verbose.level = 0; +#endif + return &verbose; +} + +double get_msec() { +#ifdef _WIN32 + static LARGE_INTEGER frequency; + if (frequency.QuadPart == 0) + QueryPerformanceFrequency(&frequency); + LARGE_INTEGER now; + QueryPerformanceCounter(&now); + return 1e+3 * now.QuadPart / frequency.QuadPart; +#else + struct timeval time; + gettimeofday(&time, NULL); + return 1e+3 * time.tv_sec + 1e-3 * time.tv_usec; +#endif +} + +const char *get_isa_info() { + using namespace mkldnn::impl::cpu; + if (mayiuse(avx512_mic_4ops)) return AVX512_MIC_4OPS; + if (mayiuse(avx512_mic)) return AVX512_MIC; + if (mayiuse(avx512_core_vnni)) return AVX512_CORE_VNNI; + if (mayiuse(avx512_core)) return AVX512_CORE; + if (mayiuse(avx512_common)) return AVX512_COMMON; + if (mayiuse(avx2)) return AVX2; + if (mayiuse(avx)) return AVX; + if (mayiuse(sse42)) return SSE42; + return ISA_ANY; +} + +/* init_info section */ +namespace { +#if !defined(DISABLE_VERBOSE) +#define MKLDNN_VERBOSE_DAT_LEN 256 +#define MKLDNN_VERBOSE_AUX_LEN 384 +#define MKLDNN_VERBOSE_PRB_LEN 384 + +#define DECL_DAT_AUX_PRB_STRS() \ + int dat_written = 0, aux_written = 0, prb_written = 0; \ + MAYBE_UNUSED((dat_written * aux_written * prb_written)); \ + char dat_str[MKLDNN_VERBOSE_DAT_LEN] = {'\0'}; MAYBE_UNUSED(dat_str); \ + char aux_str[MKLDNN_VERBOSE_AUX_LEN] = {'\0'}; MAYBE_UNUSED(aux_str); \ + char prb_str[MKLDNN_VERBOSE_PRB_LEN] = {'\0'}; MAYBE_UNUSED(prb_str) + +#define DFMT "%" PRId64 + +void clear_buf(char *buf, int &written) { + /* TODO: do it better */ + buf[0] = '#'; + buf[1] = '\0'; + written = 1; +} + +#define DPRINT(buf, buf_len, written, ...) do { \ + int l = snprintf(buf + written, buf_len - written, __VA_ARGS__); \ + if (l < 0 || written + l > buf_len) { \ + clear_buf(buf, written); \ + } else { \ + written += l; \ + } \ +} while(0) + +// XXX: Outputs strings corresponding to memory formats used for data tensors. +void format_prb_desc_str(char *str, int len, const memory_desc_t *md) { + const auto dims = md->dims; + int written = 0; + if (md->ndims == 1) + DPRINT(str, len, written, + "x" DFMT, dims[0]); + else if (md->ndims == 2) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT, dims[0], dims[1]); + else if (md->ndims == 3) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "iw" DFMT, + dims[0], dims[1], dims[2]); + else if (md->ndims == 4) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "ih" DFMT "iw" DFMT, + dims[0], dims[1], dims[2], dims[3]); + else if (md->ndims == 5) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "id" DFMT "ih" DFMT "iw" DFMT, + dims[0], dims[1], dims[2], dims[3], dims[4]); + else + mkldnn_md2dim_str(str, len, md); +} + +void verbose_templ(char *buffer, mkldnn_primitive_kind_t prim_kind, + const char *impl_str, mkldnn_prop_kind_t prop_kind, + const char *data_str, const char *aux_str, const char *prb_str) { + MAYBE_UNUSED(verbose_templ); + int written = 0; + DPRINT(buffer, MKLDNN_VERBOSE_BUF_LEN, written, "%s,%s,%s,%s,%s,%s", + mkldnn_prim_kind2str(prim_kind), impl_str, + mkldnn_prop_kind2str(prop_kind), data_str, aux_str, prb_str); +} + +template static void init_info_bnorm(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "flags:%u", s->desc()->flags); + + format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_conv(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->desc()->prop_kind == prop_kind::backward_data + ? s->diff_src_md() : s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md() : s->weights_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bia + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md(1) : s->weights_md(1); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + if (1) { // dst + auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + if (s->ndims() == 5) { + if (s->with_groups()) + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT + "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->G(), s->IC(), s->OC(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + else + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_ic" DFMT "oc" DFMT + "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->IC(), s->OC(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + } else { + if (s->with_groups()) + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->G(), s->IC(), s->OC(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + else + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_ic" DFMT "oc" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->IC(), s->OC(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + } + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_shuffle(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + auto md = s->is_fwd() ? s->src_md() : s->diff_dst_md(); + + if (1) { // data + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "axis:%d group_size:" DFMT, s->axis(), s->group_size()); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, md); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_eltwise(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_iprod(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->desc()->prop_kind == prop_kind::backward_data + ? s->diff_src_md() : s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md() : s->weights_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bia + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md(1) : s->weights_md(1); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + if (1) { // dst + auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "oc" DFMT, s->MB(), s->IC_total(), s->OC()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_lrn(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_mem(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst + auto md = s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "num:%d", s->n_inputs()); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md()); + + verbose_templ(buffer, s->kind(), s->name(), prop_kind::undef, dat_str, + aux_str, prb_str); +} + +template static void init_info_pool(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->is_fwd() ? s->src_md() : s->diff_src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst + auto md = s->is_fwd() ? s->dst_md() : s->diff_dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // ws + auto md = s->workspace_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " ws_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + if (s->is_3d()) { + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "_" + "id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "pd" DFMT "_" + "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_" + "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT "", + s->MB(), s->C(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->padL()); + } else { + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "_" + "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_" + "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT, + s->MB(), s->C(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->padL()); + } + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_softmax(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_rnn(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src layer + auto md = s->is_fwd() ? s->src_md(0) : s->diff_src_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // src iter + auto md = s->is_fwd() ? s->src_md(1) : s->diff_src_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_iter_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei_layer + auto md = s->is_fwd() ? s->weights_md(0) : s->diff_weights_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei_iter + auto md = s->is_fwd() ? s->weights_md(1) : s->diff_weights_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bias + auto md = s->is_fwd() ? s->weights_md(2) : s->diff_weights_md(2); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bias_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst layer + auto md = s->is_fwd() ? s->dst_md(0) : s->diff_dst_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst iter + auto md = s->is_fwd() ? s->dst_md(1) : s->diff_dst_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_iter_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + alg_kind_t alg_kind = s->cell_kind(); + rnn_direction_t rnn_dir = s->direction(); + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s_%s", mkldnn_alg_kind2str(alg_kind), + mkldnn_rnn_direction2str(rnn_dir)); + + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "l" DFMT "t" DFMT "mb" DFMT + "sic" DFMT "slc" DFMT "dic" DFMT "dlc" DFMT, + s->L(), s->T(), s->MB(), + s->SIC(), s->SLC(), s->DIC(), s->DLC()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +#undef DPRINT + +#else // !defined(DISABLE_VERBOSE) + +#define DEFINE_STUB(name) \ + template \ + static void CONCAT2(init_info_, name)(pd_t *s, char *buffer) \ + { UNUSED(s); UNUSED(buffer); } + +DEFINE_STUB(bnorm); +DEFINE_STUB(conv); +DEFINE_STUB(eltwise); +DEFINE_STUB(iprod); +DEFINE_STUB(lrn); +DEFINE_STUB(mem); +DEFINE_STUB(pool); +DEFINE_STUB(softmax); +DEFINE_STUB(rnn); +DEFINE_STUB(shuffle); +#undef DEFINE_STUB + +#endif // !defined(DISABLE_VERBOSE) +} + +void init_info(batch_normalization_pd_t *s, char *b) +{ init_info_bnorm(s, b); } +void init_info(concat_pd_t *s, char *b) +{ init_info_mem(s, b); } +void init_info(convolution_pd_t *s, char *b) +{ init_info_conv(s, b); } +void init_info(deconvolution_pd_t *s, char *b) +{ init_info_conv(s, b); } +void init_info(eltwise_pd_t *s, char *b) +{ init_info_eltwise(s, b); } +void init_info(inner_product_pd_t *s, char *b) +{ init_info_iprod(s, b); } +void init_info(lrn_pd_t *s, char *b) +{ init_info_lrn(s, b); } +void init_info(pooling_pd_t *s, char *b) +{ init_info_pool(s, b); } +void init_info(reorder_pd_t *s, char *b) +{ init_info_mem(s, b); } +void init_info(rnn_pd_t *s, char *b) +{ init_info_rnn(s, b); } +void init_info(shuffle_pd_t *s, char *b) +{ init_info_shuffle(s, b); } +void init_info(softmax_pd_t *s, char *b) +{ init_info_softmax(s, b); } +void init_info(sum_pd_t *s, char *b) +{ init_info_mem(s, b); } + +} +} + +mkldnn_status_t mkldnn_set_verbose(int level) { + using namespace mkldnn::impl::status; + if (level < 0 || level > 2) return invalid_arguments; + mkldnn::impl::verbose.level = level; + mkldnn::impl::initialized = true; + return success; +} + +const mkldnn_version_t *mkldnn_version() { + static mkldnn_version_t ver = { + MKLDNN_VERSION_MAJOR, + MKLDNN_VERSION_MINOR, + MKLDNN_VERSION_PATCH, + MKLDNN_VERSION_HASH}; + return &ver; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp new file mode 100644 index 000000000000..e3049750cb16 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp @@ -0,0 +1,62 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef VERBOSE_HPP +#define VERBOSE_HPP + +#include +#include + +#include "mkldnn_debug.h" +#include "c_types_map.hpp" +#include "utils.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +struct verbose_t { + int level; +}; + +const verbose_t *mkldnn_verbose(); +double get_msec(); +const char *get_isa_info(); + +#if !defined(DISABLE_VERBOSE) +#define MKLDNN_VERBOSE_BUF_LEN 1024 +#else +#define MKLDNN_VERBOSE_BUF_LEN 1 +#endif + +void init_info(batch_normalization_pd_t *s, char *buffer); +void init_info(concat_pd_t *s, char *buffer); +void init_info(convolution_pd_t *s, char *buffer); +void init_info(deconvolution_pd_t *s, char *buffer); +void init_info(eltwise_pd_t *s, char *buffer); +void init_info(inner_product_pd_t *s, char *buffer); +void init_info(lrn_pd_t *s, char *buffer); +void init_info(pooling_pd_t *s, char *buffer); +void init_info(reorder_pd_t *s, char *buffer); +void init_info(rnn_pd_t *s, char *buffer); +void init_info(shuffle_pd_t *s, char *buffer); +void init_info(softmax_pd_t *s, char *buffer); +void init_info(sum_pd_t *s, char *buffer); + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp new file mode 100644 index 000000000000..520bd4710b3b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef Z_MAGIC_HPP +#define Z_MAGIC_HPP + +#define CHAIn2(a,b) a b +#define CHAIN2(a,b) CHAIn2(a,b) + +#define CONCAt2(a,b) a ## b +#define CONCAT2(a,b) CONCAt2(a,b) + +#define STRINGIFy(s) #s +#define STRINGIFY(s) STRINGIFy(s) + +#ifdef _MSC_VER +# define PRAGMA_MACRo(x) __pragma(x) +# define PRAGMA_MACRO(x) PRAGMA_MACRo(x) +#else +# define PRAGMA_MACRo(x) _Pragma(#x) +# define PRAGMA_MACRO(x) PRAGMA_MACRo(x) +#endif + +#define UNUSED(x) ((void)x) +#define MAYBE_UNUSED(x) UNUSED(x) + +#if defined(_WIN32) && !defined(__GNUC__) +#define __PRETTY_FUNCTION__ __FUNCSIG__ +#endif + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp new file mode 100644 index 000000000000..7cf7822d9029 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp @@ -0,0 +1,112 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "cpu_barrier.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace simple_barrier { + +void generate(jit_generator &code, Xbyak::Reg64 reg_ctx, + Xbyak::Reg64 reg_nthr) { +# define BAR_CTR_OFF offsetof(ctx_t, ctr) +# define BAR_SENSE_OFF offsetof(ctx_t, sense) + using namespace Xbyak; + + Xbyak::Reg64 reg_tmp = [&]() { + /* returns register which is neither reg_ctx nor reg_nthr */ + Xbyak::Reg64 regs[] = { util::rax, util::rbx, util::rcx }; + for (size_t i = 0; i < sizeof(regs) / sizeof(regs[0]); ++i) + if (!utils::one_of(regs[i], reg_ctx, reg_nthr)) + return regs[i]; + return regs[0]; /* should not happen */ + }(); + + Label barrier_exit_label, barrier_exit_restore_label, spin_label; + + code.cmp(reg_nthr, 1); + code.jbe(barrier_exit_label); + + code.push(reg_tmp); + + /* take and save current sense */ + code.mov(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]); + code.push(reg_tmp); + code.mov(reg_tmp, 1); + + if (mayiuse(avx512_mic)) { + code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]); + code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]); + } + + code.lock(); code.xadd(code.ptr[reg_ctx + BAR_CTR_OFF], reg_tmp); + code.add(reg_tmp, 1); + code.cmp(reg_tmp, reg_nthr); + code.pop(reg_tmp); /* restore previous sense */ + code.jne(spin_label); + + /* the last thread {{{ */ + code.mov(code.qword[reg_ctx + BAR_CTR_OFF], 0); // reset ctx + + // notify waiting threads + code.not_(reg_tmp); + code.mov(code.ptr[reg_ctx + BAR_SENSE_OFF], reg_tmp); + code.jmp(barrier_exit_restore_label); + /* }}} the last thread */ + + code.CodeGenerator::L(spin_label); + code.pause(); + code.cmp(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]); + code.je(spin_label); + + code.CodeGenerator::L(barrier_exit_restore_label); + code.pop(reg_tmp); + + code.CodeGenerator::L(barrier_exit_label); +# undef BAR_CTR_OFF +# undef BAR_SENSE_OFF +} + +/** jit barrier generator */ +struct jit_t: public jit_generator { + void (*barrier)(ctx_t *ctx, size_t nthr); + + jit_t() { + generate(*this, abi_param1, abi_param2); + ret(); + barrier = reinterpret_cast(const_cast( + this->getCode())); + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_t) +}; + +void barrier(ctx_t *ctx, int nthr) { + static jit_t j; /* XXX: constructed on load ... */ + j.barrier(ctx, nthr); +} + +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp new file mode 100644 index 000000000000..0f55e33aa84e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp @@ -0,0 +1,60 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_BARRIER_HPP +#define CPU_BARRIER_HPP + +#include + +#include "jit_generator.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace simple_barrier { + +STRUCT_ALIGN(64, +struct ctx_t { + enum { CACHE_LINE_SIZE = 64 }; + volatile size_t ctr; + char pad1[CACHE_LINE_SIZE - 1 * sizeof(size_t)]; + volatile size_t sense; + char pad2[CACHE_LINE_SIZE - 1 * sizeof(size_t)]; +}); + +inline void ctx_init(ctx_t *ctx) { *ctx = utils::zero(); } +void barrier(ctx_t *ctx, int nthr); + +/** injects actual barrier implementation into another jitted code + * @params: + * code -- jit_generator object where the barrier is to be injected + * reg_ctx -- read-only register with pointer to the barrier context + * reg_nnthr -- read-only register with the # of synchronizing threads + */ +void generate(jit_generator &code, Xbyak::Reg64 reg_ctx, + Xbyak::Reg64 reg_nthr); + +} + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp new file mode 100644 index 000000000000..1ed5ad57b9f0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp @@ -0,0 +1,40 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_BATCH_NORMALIZATION_PD_HPP +#define CPU_BATCH_NORMALIZATION_PD_HPP + +#include "batch_normalization_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_batch_normalization_fwd_pd_t: public batch_normalization_fwd_pd_t { + using batch_normalization_fwd_pd_t::batch_normalization_fwd_pd_t; +}; + +struct cpu_batch_normalization_bwd_pd_t: public batch_normalization_bwd_pd_t { + using batch_normalization_bwd_pd_t::batch_normalization_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp new file mode 100644 index 000000000000..b8d5c4fcaf6d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp @@ -0,0 +1,140 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "cpu_batch_normalization_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace bnorm_utils { + +void cache_balance(size_t working_set_size, dim_t C_blks, + dim_t &C_blks_per_iter, int64_t &iters) { + int nthrs = mkldnn_get_max_threads(); + int l3_size = get_cache_size(3, true) * nthrs / 2; + + C_blks_per_iter = l3_size / working_set_size; + + if (C_blks_per_iter == 0) + C_blks_per_iter = 1; + if (C_blks_per_iter > C_blks) + C_blks_per_iter = C_blks; + + iters = (C_blks + C_blks_per_iter - 1) / C_blks_per_iter; +} + +bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr, + int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr, + dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s, + dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e) { + if (nthr <= C_blks || !mkldnn_thr_syncable()) { + C_ithr = ithr; C_nthr = nthr; + N_ithr = 0; N_nthr = 1; + S_ithr = 0; S_nthr = 1; + N_s = 0; N_e = N; S_s = 0; S_e = SP; + balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e); + } else { + if (do_blocking) { + N_nthr = (int)nstl::min(N, nthr); + C_nthr = (int)nstl::min(C_blks, nthr / N_nthr); + S_nthr = (int)nstl::min(SP, nthr / (C_nthr * N_nthr)); + } else { + C_nthr = (int)math::gcd((dim_t)nthr, C_blks); + N_nthr = (int)nstl::min(N, nthr / C_nthr); + S_nthr = (int)nstl::min(SP, nthr / (C_nthr * N_nthr)); + } + + if (!spatial_thr_allowed) + S_nthr = 1; + + if (S_nthr < 1) S_nthr = 1; + if (ithr < C_nthr * N_nthr * S_nthr) { + N_ithr = (ithr / S_nthr) % N_nthr ; + C_ithr = ithr / (N_nthr * S_nthr); + S_ithr = ithr % S_nthr; + balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e); + balance211(N, N_nthr, N_ithr, N_s, N_e); + balance211(SP, S_nthr, S_ithr, S_s, S_e); + } else { + S_ithr = N_ithr = C_ithr = -ithr; + S_s = S_e = N_s = N_e = C_blk_s = C_blk_e = -1; + } + } + + // spatial_thr_allowed is meant to help maintain + // consistent decisions about spatial threading + // between mutiple invocations of this routine. + // It is caller's responsibility to check the + // return value and pass it as a flag to the + // next call if needed. + if (S_nthr == 1) + spatial_thr_allowed = false; + + return spatial_thr_allowed; +} + +bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w, + int data_size) { + if (!mkldnn_thr_syncable()) return false; + + dim_t nthr = mkldnn_get_max_threads(); + dim_t SP = bdesc->W() * bdesc->D() * bdesc->H(); + dim_t C_PADDED = memory_desc_wrapper(bdesc->src_md()) + .padded_dims()[1]; + assert(C_PADDED % simd_w == 0); + + size_t data = bdesc->MB() * C_PADDED * SP * data_size; + size_t l3_size_ = get_cache_size(3, true) * nthr / 2; + bool do_blocking = (data >= l3_size_ / 2 && l3_size_ > 0); + dim_t C_blks_per_iter{ 1 }, iters{ 1 }; + dim_t C_blks = C_PADDED / simd_w; + + if (do_blocking) { + int num_tensors = bdesc->is_fwd() ? 1 : 2; + size_t working_set_size + = (bdesc->MB() * SP * simd_w * data_size) * num_tensors; + cache_balance(working_set_size, C_blks, C_blks_per_iter, iters); + } + + // Spatial threading decision made in this function shall be consistent + // with thread_balance() behavior. + C_blks = do_blocking ? C_blks_per_iter : C_blks; + + if (nthr <= C_blks) return false; + + dim_t S_nthr = 1; + if (do_blocking) { + dim_t N_nthr = nstl::min(bdesc->MB(), nthr); + dim_t C_nthr = nstl::min(C_blks, nthr / N_nthr); + S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr)); + } else { + dim_t C_nthr = math::gcd(nthr, C_blks); + dim_t N_nthr = nstl::min(bdesc->MB(), nthr / C_nthr); + S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr)); + } + + return S_nthr > 1; +} + +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp new file mode 100644 index 000000000000..0daef0716c91 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp @@ -0,0 +1,43 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_BATCH_NORMALIZATION_UTILS_HPP +#define CPU_BATCH_NORMALIZATION_UTILS_HPP + +#include "batch_normalization_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace bnorm_utils { + +void cache_balance(size_t working_set_size, dim_t C_blks, + dim_t &C_blks_per_iter, int64_t &iters); + +bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr, + int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr, + dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s, + dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e); + +bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w, + int data_size); + +} +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp new file mode 100644 index 000000000000..b926491202ae --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu_engine.hpp" + +/* +#include "cpu/ref_concat.hpp" +#include "cpu/simple_concat.hpp" +*/ + +namespace mkldnn { +namespace impl { +namespace cpu { + +using cpd_create_f = mkldnn::impl::engine_t::concat_primitive_desc_create_f; + +namespace { +#define INSTANCE(...) __VA_ARGS__::pd_t::create +static const cpd_create_f cpu_concat_impl_list[] = { + /* + INSTANCE(simple_concat_t), + INSTANCE(simple_concat_t), + INSTANCE(simple_concat_t), + INSTANCE(simple_concat_t), + INSTANCE(ref_concat_t), + */ + nullptr, +}; +#undef INSTANCE +} + +const cpd_create_f *cpu_engine_t::get_concat_implementation_list() const { + return cpu_concat_impl_list; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp new file mode 100644 index 000000000000..0b01bcf1633a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp @@ -0,0 +1,41 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_CONCAT_PD_HPP +#define CPU_CONCAT_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "concat_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_concat_pd_t: public concat_pd_t { + using concat_pd_t::concat_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp new file mode 100644 index 000000000000..52a38a229472 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp @@ -0,0 +1,74 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_CONVOLUTION_PD_HPP +#define CPU_CONVOLUTION_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "convolution_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_convolution_fwd_pd_t: public convolution_fwd_pd_t { + using convolution_fwd_pd_t::convolution_fwd_pd_t; + + bool has_padded_dst() const { + memory_desc_wrapper dst_d(&dst_md_); + return OC() != dst_d.padded_dims()[1]; + } + + bool wants_padded_bias() const { + if (!with_bias()) return false; + return has_padded_dst(); + } + + bool wants_zero_pad_dst(bool jit_impl = true) const { + if (!has_padded_dst()) return false; + const auto &po = attr()->post_ops_; + int idx; + if ((idx = po.find(primitive_kind::eltwise)) == -1) return false; + return !math::eltwise_fwd_preserves_zero(po.entry_[idx].eltwise.alg, + jit_impl); + } +}; + +struct cpu_convolution_bwd_data_pd_t: public convolution_bwd_data_pd_t { + using convolution_bwd_data_pd_t::convolution_bwd_data_pd_t; +}; + +struct cpu_convolution_bwd_weights_pd_t: public convolution_bwd_weights_pd_t { + using convolution_bwd_weights_pd_t::convolution_bwd_weights_pd_t; + + bool wants_padded_bias() const { + if (!with_bias()) return false; + memory_desc_wrapper diff_dst_d(&diff_dst_md_); + return OC() != diff_dst_d.padded_dims()[1]; + } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp new file mode 100644 index 000000000000..164c8601d708 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_DECONVOLUTION_PD_HPP +#define CPU_DECONVOLUTION_PD_HPP + +#include + +#include "deconvolution_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_deconvolution_fwd_pd_t: public deconvolution_fwd_pd_t { + using deconvolution_fwd_pd_t::deconvolution_fwd_pd_t; +}; + +struct cpu_deconvolution_bwd_data_pd_t: public deconvolution_bwd_data_pd_t { + using deconvolution_bwd_data_pd_t::deconvolution_bwd_data_pd_t; +}; + +struct cpu_deconvolution_bwd_weights_pd_t: public deconvolution_bwd_weights_pd_t { + using deconvolution_bwd_weights_pd_t::deconvolution_bwd_weights_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp new file mode 100644 index 000000000000..c52f00026ecc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp @@ -0,0 +1,45 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ELTWISE_PD_HPP +#define CPU_ELTWISE_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "eltwise_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_eltwise_fwd_pd_t: public eltwise_fwd_pd_t { + using eltwise_fwd_pd_t::eltwise_fwd_pd_t; +}; + +struct cpu_eltwise_bwd_pd_t: public eltwise_bwd_pd_t { + using eltwise_bwd_pd_t::eltwise_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp new file mode 100644 index 000000000000..ce0a3667ad58 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp @@ -0,0 +1,324 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "type_helpers.hpp" +#include "verbose.hpp" + +#include "cpu_engine.hpp" +#include "cpu_memory.hpp" + +//#include "cpu/rnn/ref_rnn.hpp" + +//#include "cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp" +//#include "cpu/jit_avx512_common_1x1_convolution.hpp" +#include "cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp" +#include "cpu/jit_avx512_common_convolution_winograd.hpp" +//#include "cpu/jit_avx512_core_x8s8s32x_convolution.hpp" +#include "cpu/jit_avx512_common_convolution.hpp" +//#include "cpu/jit_avx2_1x1_convolution.hpp" +//#include "cpu/jit_sse42_1x1_convolution.hpp" +#include "cpu/jit_avx2_convolution.hpp" +#include "cpu/jit_sse42_convolution.hpp" +//#include "cpu/gemm_convolution.hpp" +//#include "cpu/gemm_x8s8s32x_convolution.hpp" +//#include "cpu/ref_convolution.hpp" +//#include "cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp" +//#include "cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp" +//#include "cpu/ref_deconvolution.hpp" +//#include "cpu/ref_shuffle.hpp" +//#include "cpu/jit_uni_eltwise.hpp" +//#include "cpu/ref_eltwise.hpp" +//#include "cpu/ref_softmax.hpp" +#include "cpu/jit_uni_pooling.hpp" +//#include "cpu/jit_uni_i8i8_pooling.hpp" +//#include "cpu/ref_pooling.hpp" +//#include "cpu/nchw_pooling.hpp" +//#include "cpu/nhwc_pooling.hpp" +//#include "cpu/jit_avx512_common_lrn.hpp" +//#include "cpu/jit_uni_lrn.hpp" +//#include "cpu/ref_lrn.hpp" +//#include "cpu/jit_uni_batch_normalization.hpp" +//#include "cpu/ref_batch_normalization.hpp" +//#include "cpu/ncsp_batch_normalization.hpp" +//#include "cpu/nspc_batch_normalization.hpp" +//#include "cpu/ref_inner_product.hpp" +//#include "cpu/gemm_inner_product.hpp" +//#include "cpu/gemm_x8s8s32x_inner_product.hpp" +//#include "cpu/jit_uni_dw_convolution.hpp" +//#include "cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp" +#include "cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +status_t cpu_engine_t::memory_create(memory_t **memory, + const memory_desc_t *md, void *handle) { + auto _memory = new cpu_memory_t(this, md, handle); + if (_memory == nullptr) + return status::out_of_memory; + + status_t status = _memory->init(); + if (status != status::success) { + delete _memory; + return status; + } + + return safe_ptr_assign(*memory, _memory); +} + +using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f; + +namespace { +using namespace mkldnn::impl::data_type; + +#define INSTANCE(...) &primitive_desc_t::create<__VA_ARGS__::pd_t> +static const pd_create_f cpu_impl_list[] = { + /* RNN */ + /* + INSTANCE(ref_rnn_fwd_f32_t), + INSTANCE(ref_rnn_fwd_u8s8_t), + INSTANCE(ref_rnn_bwd_f32_t), + */ + /* conv */ + /* + INSTANCE(jit_avx512_common_dw_convolution_fwd_t), + INSTANCE(jit_avx512_common_dw_convolution_bwd_data_t), + INSTANCE(jit_avx512_common_dw_convolution_bwd_weights_t), + INSTANCE(jit_avx512_common_1x1_convolution_fwd_f32_t), + INSTANCE(jit_avx512_common_1x1_convolution_bwd_data_f32_t), + INSTANCE(jit_avx512_common_1x1_convolution_bwd_weights_t), + */ + INSTANCE(jit_avx512_core_fp32_wino_conv_2x3_fwd_t), + INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_fwd_t), + //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t), + //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t), + INSTANCE(jit_avx512_common_convolution_winograd_fwd_t), + //INSTANCE(jit_avx512_common_convolution_winograd_bwd_data_t), + //INSTANCE(jit_avx512_common_convolution_winograd_bwd_weights_t), + INSTANCE(jit_avx512_common_convolution_fwd_t), + //INSTANCE(jit_avx512_common_convolution_bwd_data_t), + //INSTANCE(jit_avx512_common_convolution_bwd_weights_t), + /* + INSTANCE(jit_avx2_dw_convolution_fwd_t), + INSTANCE(jit_avx2_dw_convolution_bwd_data_t), + INSTANCE(jit_avx2_dw_convolution_bwd_weights_t), + INSTANCE(jit_avx2_1x1_convolution_fwd_t), + INSTANCE(jit_avx2_1x1_convolution_bwd_data_t), + INSTANCE(jit_avx2_1x1_convolution_bwd_weights_t), + INSTANCE(jit_sse42_dw_convolution_fwd_t), + INSTANCE(jit_sse42_dw_convolution_bwd_data_t), + INSTANCE(jit_sse42_dw_convolution_bwd_weights_t), + INSTANCE(jit_sse42_1x1_convolution_fwd_t), + */ + INSTANCE(jit_avx2_convolution_fwd_t), + //INSTANCE(jit_avx2_convolution_bwd_data_t), + //INSTANCE(jit_avx2_convolution_bwd_weights_t), + INSTANCE(jit_sse42_convolution_fwd_t), + /* + INSTANCE(gemm_convolution_fwd_t), + INSTANCE(gemm_convolution_bwd_data_t), + INSTANCE(gemm_convolution_bwd_weights_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_bwd_data_t), + INSTANCE(ref_convolution_bwd_weights_t), + */ + /* conv (int) */ + /* + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_bwd_data_t), + INSTANCE(ref_convolution_bwd_data_t), + INSTANCE(ref_convolution_bwd_data_t), + INSTANCE(ref_convolution_bwd_data_t), + */ + /* deconv */ + /* + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(ref_deconvolution_bwd_weights_t), + INSTANCE(ref_deconvolution_bwd_data_t), + INSTANCE(ref_deconvolution_fwd_t), + */ + /* shuffle */ + /* + INSTANCE(ref_shuffle_t<4>), // f32 or s32 + INSTANCE(ref_shuffle_t<1>), // s8 or u8 + */ + /* eltwise */ + /* + INSTANCE(jit_uni_eltwise_fwd_t), + INSTANCE(jit_uni_eltwise_bwd_t), + INSTANCE(jit_uni_eltwise_fwd_t), + INSTANCE(jit_uni_eltwise_bwd_t), + INSTANCE(jit_uni_eltwise_fwd_t), + INSTANCE(jit_uni_eltwise_bwd_t), + INSTANCE(ref_eltwise_fwd_t), + INSTANCE(ref_eltwise_bwd_t), + */ + /* eltwise (int) */ + /* + INSTANCE(ref_eltwise_fwd_t), + INSTANCE(ref_eltwise_fwd_t), + INSTANCE(ref_eltwise_fwd_t), + INSTANCE(ref_eltwise_bwd_t), + */ + /* softmax */ + /* + INSTANCE(ref_softmax_fwd_t), + INSTANCE(ref_softmax_bwd_t), + */ + /* pool */ + INSTANCE(jit_uni_pooling_fwd_t), + //INSTANCE(jit_uni_pooling_bwd_t), + INSTANCE(jit_uni_pooling_fwd_t), + //INSTANCE(jit_uni_pooling_bwd_t), + INSTANCE(jit_uni_pooling_fwd_t), + //INSTANCE(jit_uni_pooling_bwd_t), + /* + INSTANCE(nchw_pooling_fwd_t), + INSTANCE(nchw_pooling_bwd_t), + INSTANCE(nhwc_pooling_fwd_t), + INSTANCE(nhwc_pooling_bwd_t), + INSTANCE(ref_pooling_fwd_t), + INSTANCE(ref_pooling_bwd_t), + */ + /* pool (int) */ + /* + INSTANCE(jit_uni_i8i8_pooling_fwd_t), + INSTANCE(jit_uni_i8i8_pooling_fwd_t), + INSTANCE(ref_pooling_fwd_t), + INSTANCE(ref_pooling_fwd_t), + INSTANCE(ref_pooling_fwd_t), + INSTANCE(ref_pooling_bwd_t), + */ + /* lrn */ + /* + INSTANCE(jit_avx512_common_lrn_fwd_t), + INSTANCE(jit_avx512_common_lrn_bwd_t), + INSTANCE(jit_uni_lrn_fwd_t), + INSTANCE(jit_uni_lrn_bwd_t), + INSTANCE(jit_uni_lrn_fwd_t), + INSTANCE(ref_lrn_fwd_t), + INSTANCE(ref_lrn_bwd_t), + */ + /* batch normalization */ + /* + INSTANCE(jit_uni_batch_normalization_fwd_t), + INSTANCE(jit_uni_batch_normalization_bwd_t), + INSTANCE(jit_uni_batch_normalization_fwd_t), + INSTANCE(jit_uni_batch_normalization_bwd_t), + INSTANCE(jit_uni_batch_normalization_fwd_t), + INSTANCE(jit_uni_batch_normalization_bwd_t), + INSTANCE(ncsp_batch_normalization_fwd_t), + INSTANCE(ncsp_batch_normalization_bwd_t), + INSTANCE(nspc_batch_normalization_fwd_t), + INSTANCE(nspc_batch_normalization_bwd_t), + INSTANCE(ref_batch_normalization_fwd_t), + INSTANCE(ref_batch_normalization_bwd_t), + INSTANCE(ref_batch_normalization_fwd_t), + */ + /* inner product */ + /* + INSTANCE(gemm_inner_product_fwd_t), + INSTANCE(gemm_inner_product_bwd_data_t), + INSTANCE(gemm_inner_product_bwd_weights_t), + INSTANCE(ref_inner_product_fwd_t), + INSTANCE(ref_inner_product_bwd_data_t), + INSTANCE(ref_inner_product_bwd_weights_t), + */ + /* inner product (int) */ + /* + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(ref_inner_product_fwd_t), + INSTANCE(ref_inner_product_fwd_t), + INSTANCE(ref_inner_product_fwd_t), + INSTANCE(ref_inner_product_fwd_t), + */ + /* eol */ + nullptr, +}; +#undef INSTANCE +} + +const pd_create_f* cpu_engine_t::get_implementation_list() const { + return cpu_impl_list; +} + +cpu_engine_factory_t engine_factory; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp new file mode 100644 index 000000000000..e4c877ee05a8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp @@ -0,0 +1,70 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ENGINE_HPP +#define CPU_ENGINE_HPP + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "../common/engine.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +class cpu_engine_t: public engine_t { +public: + cpu_engine_t(): engine_t(engine_kind::cpu) {} + + /* implementation part */ + + virtual status_t memory_create(memory_t **memory, + const memory_desc_t *md, void *handle) override; + + virtual const concat_primitive_desc_create_f* + get_concat_implementation_list() const override; + virtual const reorder_primitive_desc_create_f* + get_reorder_implementation_list() const override; + virtual const sum_primitive_desc_create_f* + get_sum_implementation_list() const override; + virtual const primitive_desc_create_f* + get_implementation_list() const override; +}; + +class cpu_engine_factory_t: public engine_factory_t { +public: + virtual size_t count() const override { return 1; } + virtual engine_kind_t kind() const override { return engine_kind::cpu; } + virtual status_t engine_create(engine_t **engine, + size_t index) const override { + assert(index == 0); + *engine = new cpu_engine_t(); + return status::success; + }; +}; + +extern cpu_engine_factory_t engine_factory; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp new file mode 100644 index 000000000000..5880d3450cc7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp @@ -0,0 +1,84 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_INNER_PRODUCT_PD_HPP +#define CPU_INNER_PRODUCT_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "inner_product_pd.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { +inline bool dense_gemm_consitency_check(const memory_desc_wrapper &src_d, + const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) { + using namespace utils; + + auto strides_compatible = [&]() { + bool ok = true; + auto w_str = wei_d.blocking_desc().strides; + auto d_str = src_d.blocking_desc().strides; + for (int i = 1; i < src_d.ndims() - 1; i++) { + ok = ok && w_str[i] / d_str[i] == w_str[i + 1] / d_str[i + 1]; + } + return ok && one_of(w_str[1] / d_str[1], 1, wei_d.padded_dims()[0]); + }; + return true && src_d.is_blocking_desc() && wei_d.is_blocking_desc() + && src_d.ndims() == wei_d.ndims() + && src_d.blocking_desc().inner_nblks + == wei_d.blocking_desc().inner_nblks + && utils::one_of(src_d.blocking_desc().inner_nblks, 0, 1) + && array_cmp(src_d.blocking_desc().inner_blks, + wei_d.blocking_desc().inner_blks, + wei_d.blocking_desc().inner_nblks) + && array_cmp(src_d.blocking_desc().inner_idxs, + wei_d.blocking_desc().inner_idxs, + wei_d.blocking_desc().inner_nblks) + && strides_compatible() + && dst_d.matches_tag(format_tag::nc) + && src_d.only_padded_dim(1) + && wei_d.only_padded_dim(1) + && src_d.padded_dims()[1] == wei_d.padded_dims()[1] + && src_d.is_dense(true) + && dst_d.is_dense() + && wei_d.is_dense(true); +} +} + +struct cpu_inner_product_fwd_pd_t: public inner_product_fwd_pd_t { + using inner_product_fwd_pd_t::inner_product_fwd_pd_t; +}; + +struct cpu_inner_product_bwd_data_pd_t: public inner_product_bwd_data_pd_t { + using inner_product_bwd_data_pd_t::inner_product_bwd_data_pd_t; +}; + +struct cpu_inner_product_bwd_weights_pd_t: public inner_product_bwd_weights_pd_t { + using inner_product_bwd_weights_pd_t::inner_product_bwd_weights_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp new file mode 100644 index 000000000000..da6e9dac8ea7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp @@ -0,0 +1,151 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ISA_TRAITS_HPP +#define CPU_ISA_TRAITS_HPP + +#include + +#define XBYAK64 +#define XBYAK_NO_OP_NAMES +/* in order to make selinux happy memory that would be marked with X-bit should + * be obtained with mmap */ +#define XBYAK_USE_MMAP_ALLOCATOR +#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) +/* turn off `size_t to other-type implicit casting` warning + * currently we have a lot of jit-generated instructions that + * take uint32_t, but we pass size_t (e.g. due to using sizeof). + * FIXME: replace size_t parameters with the appropriate ones */ +#pragma warning (disable: 4267) +#endif +#include "xbyak/xbyak.h" +#include "xbyak/xbyak_util.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +typedef enum { + isa_any, + sse41, + sse42, + avx, + avx2, + avx512_common, + avx512_core, + avx512_core_vnni, + avx512_mic, + avx512_mic_4ops, +} cpu_isa_t; + +template struct cpu_isa_traits {}; /* ::vlen -> 32 (for avx2) */ + +template <> struct cpu_isa_traits { + typedef Xbyak::Xmm Vmm; + static constexpr int vlen_shift = 4; + static constexpr int vlen = 16; + static constexpr int n_vregs = 16; +}; +template <> struct cpu_isa_traits { + typedef Xbyak::Ymm Vmm; + static constexpr int vlen_shift = 5; + static constexpr int vlen = 32; + static constexpr int n_vregs = 16; +}; +template <> struct cpu_isa_traits: + public cpu_isa_traits {}; + +template <> struct cpu_isa_traits { + typedef Xbyak::Zmm Vmm; + static constexpr int vlen_shift = 6; + static constexpr int vlen = 64; + static constexpr int n_vregs = 32; +}; +template <> struct cpu_isa_traits: + public cpu_isa_traits {}; + +template <> struct cpu_isa_traits: + public cpu_isa_traits {}; + +template <> struct cpu_isa_traits: + public cpu_isa_traits {}; + +namespace { + +static Xbyak::util::Cpu cpu; +static inline bool mayiuse(const cpu_isa_t cpu_isa) { + using namespace Xbyak::util; + + switch (cpu_isa) { + case sse41: + case sse42: + // FIXME: SSE4.2 is actually NOT required + //return cpu.has(Cpu::tSSE42); + return cpu.has(Cpu::tSSE41); + case avx: + return cpu.has(Cpu::tAVX); + case avx2: + return cpu.has(Cpu::tAVX2); + case avx512_common: + return cpu.has(Cpu::tAVX512F); + case avx512_core: + return true + && cpu.has(Cpu::tAVX512F) + && cpu.has(Cpu::tAVX512BW) + && cpu.has(Cpu::tAVX512VL) + && cpu.has(Cpu::tAVX512DQ); + case avx512_core_vnni: + return true + && cpu.has(Cpu::tAVX512F) + && cpu.has(Cpu::tAVX512BW) + && cpu.has(Cpu::tAVX512VL) + && cpu.has(Cpu::tAVX512DQ) + && cpu.has(Cpu::tAVX512_VNNI); + case avx512_mic: + return true + && cpu.has(Cpu::tAVX512F) + && cpu.has(Cpu::tAVX512CD) + && cpu.has(Cpu::tAVX512ER) + && cpu.has(Cpu::tAVX512PF); + case avx512_mic_4ops: + return true + && mayiuse(avx512_mic) + && cpu.has(Cpu::tAVX512_4FMAPS) + && cpu.has(Cpu::tAVX512_4VNNIW); + case isa_any: + return true; + } + return false; +} +} + +/* whatever is required to generate string literals... */ +#include "z_magic.hpp" +#define JIT_IMPL_NAME_HELPER(prefix, isa, suffix_if_any) \ + (isa == sse42 ? prefix STRINGIFY(sse42) : \ + (isa == avx ? prefix STRINGIFY(avx) : \ + (isa == avx2 ? prefix STRINGIFY(avx2) : \ + (isa == avx512_common ? prefix STRINGIFY(avx512_common) : \ + (isa == avx512_core ? prefix STRINGIFY(avx512_core) : \ + (isa == avx512_mic ? prefix STRINGIFY(avx512_mic) : \ + (isa == avx512_mic_4ops ? prefix STRINGIFY(avx512_mic_4ops) : \ + prefix suffix_if_any))))))) + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp new file mode 100644 index 000000000000..49988f4c2dd4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp @@ -0,0 +1,42 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_LRN_PD_HPP +#define CPU_LRN_PD_HPP + +#include + +#include "lrn_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_lrn_fwd_pd_t: public lrn_fwd_pd_t { + using lrn_fwd_pd_t::lrn_fwd_pd_t; +}; + +struct cpu_lrn_bwd_pd_t: public lrn_bwd_pd_t { + using lrn_bwd_pd_t::lrn_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp new file mode 100644 index 000000000000..3c0624cf46d9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp @@ -0,0 +1,277 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn_traits.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl; +using namespace mkldnn::impl::data_type; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::format_tag; + +enum blk_kind_t { a, b, c, ab, ba, bc, cb }; + +template +void typed_zero_pad_blk( + const memory_desc_wrapper &m_d, typename prec_traits
::type *data) { + using data_t = typename prec_traits
::type; + const auto &dims = m_d.dims(); + const auto &pdims = m_d.padded_dims(); + const auto &blk = m_d.blocking_desc(); + auto dim_is_blocked = [&](int dim) { + for (int i = 0; i < blk.inner_nblks; i++) + if (blk.inner_idxs[i] == dim) + return true; + return false; + }; + bool A_blocked = dim_is_blocked(0), B_blocked = dim_is_blocked(1), + C_blocked = dim_is_blocked(2); + + assert(blk.inner_nblks < 4); + assert((A_blocked || B_blocked || C_blocked) || (A_blocked && B_blocked) + || (C_blocked && B_blocked)); + + const int a_tail_s = A_blocked ? dims[0] % blksize : 0; + const int b_tail_s = B_blocked ? dims[1] % blksize : 0; + const int c_tail_s = C_blocked ? dims[2] % blksize : 0; + assert(a_tail_s || b_tail_s || c_tail_s); + + const int A = A_blocked ? pdims[0] / blksize : dims[0]; + const int B = B_blocked ? pdims[1] / blksize : dims[1]; + const int C = C_blocked ? pdims[2] / blksize : dims[2]; + const int D = m_d.ndims() > 3 ? dims[3] : 1; + const int E = m_d.ndims() > 4 ? dims[4] : 1; + const int F = m_d.ndims() > 5 ? dims[5] : 1; + const int inner_blk = blk.inner_nblks == 3 ? blk.inner_blks[2] : 1; + + auto zeroize_tail = [&](data_t *d, const int tail_s) { + for (int b = tail_s; b < blksize; ++b) + d[b] = 0; + }; + auto zeroize_tail_inner = [&](data_t *d, const int tail_s) { + for (int b1 = 0; b1 < blksize; ++b1) + for (int b2 = tail_s; b2 < blksize; ++b2) + d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2 + + b1 % inner_blk] + = 0; + }; + auto zeroize_tail_outer = [&](data_t *d, const int tail_s) { + for (int b1 = tail_s; b1 < blksize; ++b1) + for (int b2 = 0; b2 < blksize; ++b2) + d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2 + + b1 % inner_blk] + = 0; + }; + + if (c_tail_s) { + parallel_nd(A, B, D, E, F, [&](int a, int b, int d, int e, int f) { + auto x = &data[m_d.blk_off(a, b, C - 1, d, e, f)]; + if (blk_kind == c) + zeroize_tail(x, c_tail_s); + else if (blk_kind == bc) + zeroize_tail_inner(x, c_tail_s); + else if (blk_kind == cb) + zeroize_tail_outer(x, c_tail_s); + }); + } + + if (b_tail_s) { + parallel_nd(A, C, D, E, F, [&](int a, int c, int d, int e, int f) { + auto x = &data[m_d.blk_off(a, B - 1, c, d, e, f)]; + if (blk_kind == b) + zeroize_tail(x, b_tail_s); + else if (blk_kind == ab || blk_kind == cb) + zeroize_tail_inner(x, b_tail_s); + else if (blk_kind == ba || blk_kind == bc) + zeroize_tail_outer(x, b_tail_s); + }); + } + + if (a_tail_s) { + parallel_nd(B, C, D, E, F, [&](int b, int c, int d, int e, int f) { + auto x = &data[m_d.blk_off(A - 1, b, c, d, e, f)]; + if (blk_kind == a) + zeroize_tail(x, a_tail_s); + else if (blk_kind == ba) + zeroize_tail_inner(x, a_tail_s); + else if (blk_kind == ab) + zeroize_tail_outer(x, a_tail_s); + }); + } +} + +/* + * all + */ +template +void typed_zero_pad_generic_blocked( + const memory_desc_wrapper &m_d, typename prec_traits
::type *data) { + const int ndims = m_d.ndims(); + const auto &dims = m_d.dims(); + const auto &pdims = m_d.padded_dims(); + + const ptrdiff_t nelems = (ptrdiff_t)m_d.nelems(true); + + /* [D_0] .. [D_k][D_k+1] .. [D_ndim - 1] + * | \ / + * | --------------------- + * has contiguous + * padding + * + * step <-- D_k+1 * ... * D_ndims-1 + * step_dim <-- k + */ + + ptrdiff_t step = 1; + int step_dim = ndims - 1; + for (; step_dim >= 0; --step_dim) { + if (dims[step_dim] != pdims[step_dim]) + break; + step *= dims[step_dim]; + } + + assert(step_dim >= 0 && "no zero padding is required"); + if (step_dim < 0) + return; + + parallel_nd(nelems / step, [&](ptrdiff_t e1) { + bool need_zero = false; + + ptrdiff_t idx = e1; + for (int d = step_dim; d >= 0; --d) { + if (idx % pdims[d] >= dims[d]) { + need_zero = true; + break; + } + idx /= pdims[d]; + } + + if (need_zero) { + for (ptrdiff_t e0 = 0; e0 < step; ++e0) + data[m_d.off_l(e1 * step + e0, true)] = 0; + } + }); +} + +template +status_t cpu_memory_t::typed_zero_pad() const { + const memory_desc_wrapper mdw(md()); + + if (mdw.format_kind() != format_kind::blocked) + return unimplemented; + + if (mdw.nelems(false) == mdw.nelems(true)) + return success; + + auto *data = (typename prec_traits
::type *)data_; + auto blk = mdw.blocking_desc(); + + auto get_blksize = [&](int ind) { + int blksize = 1; + for (int i = 0; i < blk.inner_nblks; i++) { + if (blk.inner_idxs[i] == ind) + blksize *= blk.inner_blks[i]; + } + return blksize; + }; + const int blksize = get_blksize(blk.inner_idxs[0]); + +# define CASE(blksize_, blk_kind) \ + do { \ + if (blksize == blksize_) { \ + typed_zero_pad_blk(mdw, data); \ + return success; \ + } \ + } while(0) + + switch (blk.inner_nblks) { + case 1: + if (blk.inner_idxs[0] == 0) { + CASE(4, a); + CASE(8, a); + CASE(16, a); + } else if (blk.inner_idxs[0] == 1) { + CASE(4, b); + CASE(8, b); + CASE(16, b); + } + break; + case 2: + case 3: + if (!IMPLICATION(blk.inner_nblks == 3, + blk.inner_idxs[0] == blk.inner_idxs[2])) + break; + + if (blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) { + CASE(4, ab); + CASE(8, ab); + CASE(16, ab); + } else if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0) { + CASE(4, ba); + CASE(8, ba); + CASE(16, ba); + } + if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) { + CASE(4, bc); + CASE(8, bc); + CASE(16, bc); + } else if (blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1) { + CASE(4, cb); + CASE(8, cb); + CASE(16, cb); + } + break; + default: break; + } + +# undef CASE + + // the last line of defence + typed_zero_pad_generic_blocked
(mdw, data); + return success; +} + +status_t cpu_memory_t::zero_pad() const { + memory_desc_wrapper mdw(md()); + const bool skip_zeroing = false + || data_ == nullptr + || mdw.is_zero() + || !mdw.is_blocking_desc(); + if (skip_zeroing) return success; + + switch (mdw.data_type()) { + case f32: return typed_zero_pad(); + case s32: return typed_zero_pad(); + case s8: return typed_zero_pad(); + case u8: return typed_zero_pad(); + default: assert(!"memory is undefined"); return unimplemented; + } + return unimplemented; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp new file mode 100644 index 000000000000..2c01bcc6af7d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_MEMORY_HPP +#define CPU_MEMORY_HPP + +#include + +#include "c_types_map.hpp" +#include "memory.hpp" +#include "memory_desc_wrapper.hpp" + +#include "cpu_engine.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_memory_t: public memory_t { + cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md, void *handle) + : memory_t(engine, md) + , own_data_(handle == MKLDNN_NATIVE_HANDLE_ALLOCATE) + , data_((char *)handle) {} + + cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md) + : cpu_memory_t(engine, md, nullptr) {} + + ~cpu_memory_t() { if (own_data_) free(data_); } + + virtual status_t init() override { + if (own_data_) { + data_ = nullptr; + const size_t size = memory_desc_wrapper(this->md()).size(); + if (size) { + data_ = (char *)malloc(size, 64); + if (data_ == nullptr) + return status::out_of_memory; + } + } + return zero_pad(); + } + + cpu_engine_t *engine() const { return (cpu_engine_t *)memory_t::engine(); } + + virtual status_t get_data_handle(void **handle) const override { + *handle = static_cast(data_); + return status::success; + } + + virtual mkldnn::impl::status_t set_data_handle(void *handle) override { + if (own_data_) { free(data_); own_data_ = false; } + data_ = static_cast(handle); + return zero_pad(); + } + + virtual mkldnn::impl::status_t zero_pad() const override; + +private: + bool own_data_; + char *data_; + + template + mkldnn::impl::status_t typed_zero_pad() const; + + cpu_memory_t(const cpu_memory_t &) = delete; + cpu_memory_t &operator=(const cpu_memory_t &) = delete; + cpu_memory_t &operator=(cpu_memory_t &&) = delete; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp new file mode 100644 index 000000000000..ac2daa415e18 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp @@ -0,0 +1,40 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_POOLING_PD_HPP +#define CPU_POOLING_PD_HPP + +#include "pooling_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_pooling_fwd_pd_t: public pooling_fwd_pd_t { + using pooling_fwd_pd_t::pooling_fwd_pd_t; +}; + +struct cpu_pooling_bwd_pd_t: public pooling_bwd_pd_t { + using pooling_bwd_pd_t::pooling_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp new file mode 100644 index 000000000000..56127f36c274 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp @@ -0,0 +1,83 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_PRIMITIVE_HPP +#define CPU_PRIMITIVE_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "primitive.hpp" +#include "scratchpad.hpp" + +#define CTX_IN_MEM(type, arg) static_cast(ctx.input(arg)) +#define CTX_OUT_MEM(type, arg) static_cast(ctx.output(arg)) + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_memory_t; + +struct cpu_primitive_t: public primitive_t { + cpu_primitive_t(const primitive_desc_t *pd, + bool use_global_scratchpad = false) + : primitive_t(pd) + , scratchpad_buffer_(nullptr) + , global_scratchpad_(nullptr) + { + const size_t scratchpad_size = + this->pd()->scratchpad_size(scratchpad_mode::library); + + if (scratchpad_size) { + if (use_global_scratchpad) + global_scratchpad_ = create_scratchpad(scratchpad_size); + else + scratchpad_buffer_ = malloc(scratchpad_size, 64); + } + } + + virtual ~cpu_primitive_t() { + delete global_scratchpad_; + free(scratchpad_buffer_); + } + +protected: + memory_tracking::grantor_t scratchpad(const exec_ctx_t &ctx) const { + void *ptr = nullptr; + if (pd()->attr()->scratchpad_mode_ == scratchpad_mode::user) { + ptr = CTX_OUT_MEM(void *, MKLDNN_ARG_SCRATCHPAD); + } else { + ptr = global_scratchpad_ + ? global_scratchpad_->get() : scratchpad_buffer_; + } + + return pd()->scratchpad_registry().grantor(ptr); + } + +private: + void *scratchpad_buffer_; + scratchpad_t *global_scratchpad_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp new file mode 100644 index 000000000000..1d41ac5ceab7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp @@ -0,0 +1,544 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn_thread.hpp" +#include "mkldnn_types.h" +#include "nstl.hpp" +#include "utils.hpp" + +#include "cpu_reducer.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +void reduce_balancer_t::balance() { + using namespace nstl; + using namespace utils; + + assert(nthr_ > 0 && job_size_ > 0 && njobs_ > 0 && reduction_size_ > 0); + + const int job_complexity = 1; + + const int min_njobs_per_group = max(1, njobs_ / nthr_); + const int max_njobs_per_group = max(1, + static_cast(max_buffer_size_ / (nthr_ * job_size_))); + + /* initial guess */ + int ngroups = min(njobs_ / min_njobs_per_group, nthr_); + int nthr_per_group = syncable_ ? min(nthr_ / ngroups, reduction_size_) : 1; + int njobs_per_group_ub = div_up(njobs_, ngroups); + + /* rough upper-bound estimation, will be fixed during brute force */ + size_t thread_complexity_ub = njobs_ * job_size_ * reduction_size_; + + /* brute force parameters for the best balance... */ + for (int c_njobs_per_group = min_njobs_per_group; + c_njobs_per_group < njobs_; ++c_njobs_per_group) { + /* current assumption */ + int c_ngroups = min(njobs_ / c_njobs_per_group, nthr_); + int c_nthr_per_group = syncable_ + ? min(nthr_ / c_ngroups, reduction_size_) : 1; + int c_njobs_per_group_ub = div_up(njobs_, c_ngroups); + + if (c_nthr_per_group > 1 && c_njobs_per_group_ub > max_njobs_per_group) + continue; + + int c_thread_reduction_ub = div_up(reduction_size_, c_nthr_per_group); + size_t c_group_size_ub = job_size_ * c_njobs_per_group_ub; + size_t c_thread_complexity_ub = c_group_size_ub * ( + job_complexity * c_thread_reduction_ub + + (c_nthr_per_group != 1)); + + if (c_thread_complexity_ub < thread_complexity_ub) { + ngroups = c_ngroups; + nthr_per_group = c_nthr_per_group; + njobs_per_group_ub = c_njobs_per_group_ub; + thread_complexity_ub = c_thread_complexity_ub; + } + } + + assert(njobs_per_group_ub <= max_njobs_per_group || nthr_per_group == 1); + assert(ngroups * nthr_per_group <= nthr_); + assert((size_t)njobs_per_group_ub * job_size_ * nthr_ <= max_buffer_size_ + || nthr_per_group == 1); /* no reduction buffer overflow */ + assert(IMPLICATION(!syncable_, nthr_per_group == 1)); + + ngroups_ = ngroups; + nthr_per_group_ = nthr_per_group; + njobs_per_group_ub_ = njobs_per_group_ub; +} + +/* reducer jit-ted driver */ + +using namespace Xbyak; + +template +struct reducer_2d_driver_t: public c_compatible { + typedef typename prec_traits::type data_t; + + reducer_2d_driver_t(int n_src, size_t src_ld, + size_t src_step, size_t dst_step, bool nullify_dst) + : n_src_(n_src), src_ld_(src_ld), src_step_(src_step) + , dst_step_(dst_step), nullify_dst_(nullify_dst), ker_(nullptr) {} + virtual ~reducer_2d_driver_t() {} + void operator()(data_t *dst, const data_t *srcs, size_t ny, size_t nx) + { assert(ker_); ker_(dst, srcs, ny, nx); } + +protected: + int n_src_; + size_t src_ld_, src_step_, dst_step_; + bool nullify_dst_; + void (*ker_)(data_t *dst, const data_t *srcs, size_t ny, size_t nx); +}; + +template +struct reducer_2d_driver_f_s_32_t: public reducer_2d_driver_t, + public jit_generator +{ + DECLARE_CPU_JIT_AUX_FUNCTIONS(reducer_2d_driver_f_s_32_t) + + /* cpu specific part */ + using Vmm = typename utils::conditional::type; + const AddressFrame &vmmword = (isa == avx2) ? yword : zword; + void uni_vadd(const Xmm& x1, const Xmm& x2, const Operand& op) + { if (data_type == data_type::f32) vaddps(x1, x2, op); + else vpaddd(x1, x2, op); } + void uni_add(const Xmm& x1, const Operand& op) + { if (data_type == data_type::f32) addss(x1, op); else paddd(x1, op); } + + const int vlen = cpu_isa_traits::vlen; + const int typesize + = sizeof(typename mkldnn::impl::prec_traits::type); + Xbyak::Reg64 reg_dst = abi_param1; + Xbyak::Reg64 reg_src = abi_param2; + Xbyak::Reg64 reg_ny = abi_param3; + Xbyak::Reg64 reg_nx = abi_param4; + + Xbyak::Reg64 reg_x = rax; + Xbyak::Reg64 reg_src_id = r10; + + reducer_2d_driver_f_s_32_t(int n_src, size_t src_ld, size_t src_step, + size_t dst_step, bool nullify_dst) + : reducer_2d_driver_t(n_src, src_ld, src_step, + dst_step, nullify_dst) + { generate(); } + + void nullify_dst(int nloads, int load_len) { + UNUSED(load_len); + for (int i = 0; i < nloads; ++i) + uni_vpxor(Vmm(i), Vmm(i), Vmm(i)); + /* prefetches[dst] ? */ + } + + void load_dst(int nloads, int load_len) { + for (int i = 0; i < nloads; ++i) { + if (load_len == typesize) + movd(Xmm(i), ptr[reg_dst + i * load_len]); + else if (load_len == vlen) + vmovups(Vmm(i), ptr[reg_dst + i * load_len]); + else + assert(!"unsupported"); + } + } + + void store_dst(int nloads, int load_len) { + for (int i = 0; i < nloads; ++i) { + if (load_len == typesize) + movd(ptr[reg_dst + i * load_len], Xmm(i)); + else if (load_len == vlen) + vmovups(ptr[reg_dst + i * load_len], Vmm(i)); + else + assert(!"unsupported"); + } + } + + void accumulate(int nloads, int load_len, size_t base_off) { + for (int i = 0; i < nloads; ++i) { + size_t off = base_off + i * load_len; + + if (load_len == typesize) + uni_add(Xmm(i), ptr[reg_src + off]); + else if (load_len == vlen) + uni_vadd(Vmm(i), Vmm(i), vmmword[reg_src + off]); + else + assert(!"unsupported"); + } + } + + void loop_x() { + const int nloads[] = {cpu_isa_traits::n_vregs, 1, 1}; + const int nbranches = sizeof(nloads) / sizeof(nloads[0]); + + const int load_len[nbranches] = {vlen, vlen, typesize}; + Label loop_x_label[nbranches + 1]; + + mov(reg_x, reg_nx); + + for (int id = 0; id < nbranches; ++id) { + L(loop_x_label[id]); + + cmp(reg_x, nloads[id] * load_len[id]); + jl(loop_x_label[id + 1], T_NEAR); + + if (this->nullify_dst_) + nullify_dst(nloads[id], load_len[id]); + else + load_dst(nloads[id], load_len[id]); + + if (nloads[id] > 1) { + Label loop_srcs; + mov(reg_src_id, this->n_src_); + L(loop_srcs); + + accumulate(nloads[id], load_len[id], 0); + add(reg_src, this->src_ld_ * typesize); + + dec(reg_src_id); + jnz(loop_srcs, T_NEAR); + + sub(reg_src, this->n_src_ * this->src_ld_ * typesize); + } else { + for (int src_id = 0; src_id < this->n_src_; ++src_id) { + const size_t base_off = src_id * this->src_ld_ * typesize; + accumulate(nloads[id], load_len[id], base_off); + } + } + + store_dst(nloads[id], load_len[id]); + + add(reg_src, nloads[id] * load_len[id]); + add(reg_dst, nloads[id] * load_len[id]); + + sub(reg_x, nloads[id] * load_len[id]); + + jmp(loop_x_label[id], T_NEAR); + } + + L(loop_x_label[nbranches]); + + /* restore address registers */ + sub(reg_src, reg_nx); + sub(reg_dst, reg_nx); + } + + void generate() { + assert(isa == avx2 || isa == avx512_common || isa == avx512_mic); + + preamble(); + + shl(reg_nx, 2); + + Label ny_loop; + L(ny_loop); + + loop_x(); + + add(reg_dst, this->dst_step_ * typesize); + add(reg_src, this->src_step_ * typesize); + + dec(reg_ny); + jnz(ny_loop, T_NEAR); + + postamble(); + this->ker_ = reinterpret_castker_)>( + const_cast(this->getCode())); + } +}; + +template +inline reducer_2d_driver_t *create_reduce_2d_drv(int n_src, + size_t src_ld, size_t src_step, size_t dst_step, bool nullify_dst) { + if (mayiuse(avx512_common)) + return new reducer_2d_driver_f_s_32_t(n_src, + src_ld, src_step, dst_step, nullify_dst); + else if (mayiuse(avx2)) + return new reducer_2d_driver_f_s_32_t(n_src, src_ld, + src_step, dst_step, nullify_dst); + assert(!"unimplemented"); + return nullptr; +} + +/* cpu_reducer_t */ + +template +void cpu_reducer_t::conf_t::init_scratchpad( + memory_tracking::registrar_t &scratchpad) const { + if (balancer_.nthr_per_group_ == 1) return; + + const size_t space_size = balancer_.ngroups_ + * (balancer_.nthr_per_group_ - 1) + * cpu_reducer_t::space_per_thread(balancer_); + scratchpad.book(key_reducer_space, sizeof(data_t) * space_size, PAGE_4K); + scratchpad.book(key_reducer_space_bctx, + sizeof(simple_barrier::ctx_t) * balancer_.ngroups_); +} + +template +cpu_reducer_t::cpu_reducer_t(const conf_t &conf) + : conf_(conf), drv_(nullptr) +{ + if (balancer().nthr_per_group_ == 1) return; + + drv_ = create_reduce_2d_drv(balancer().nthr_per_group_ - 1, + space_per_thread(balancer()), 0, 0, false); +} + +template +cpu_reducer_t::~cpu_reducer_t() { delete drv_; } + +template +typename cpu_reducer_t::data_t * +cpu_reducer_t::get_local_ptr(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const int id_in_grp = balancer().id_in_group(ithr); + + /* threads 0 from each group writes directly to the destination */ + if (id_in_grp == 0) + return dst + balancer().ithr_job_off(ithr) * balancer().job_size_; + + const int grp_id = balancer().group_id(ithr); + const int offset_factor = grp_id * (balancer().nthr_per_group_ - 1) + + (id_in_grp - 1); + + auto space = scratchpad.template get(key_reducer_space); + return space + offset_factor * space_per_thread(balancer()); +} + +template +void cpu_reducer_t::reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + +#ifdef SIMPLE_IMPL + if (balancer().id_in_group(ithr) != 0) + return; /* only threads 0 do the reduction */ + + const int njobs_in_grp = balancer().ithr_njobs(ithr); + data_t *d = get_local_ptr(ithr, dst, scratchpad); + for (int id_in_grp = 1; id_in_grp < balancer_.nthr_per_group_; ++id_in_grp) + { + const data_t *space = get_local_ptr(ithr + id_in_grp, dst, scratchpad); + for (size_t i = 0; i < (size_t)njobs_in_grp * balancer().job_size_; ++i) + d[i] += space[i]; + } +#else + using namespace utils; + + const int id_in_grp = balancer().id_in_group(ithr); + const int njobs_in_grp = balancer().ithr_njobs(ithr); + const size_t cl = 64 / sizeof(data_t); + + const size_t reduction_size = njobs_in_grp * balancer().job_size_; + size_t start{0}, end{0}; + balance211(div_up(reduction_size, cl), balancer().nthr_per_group_, + id_in_grp, start, end); + + if (start == end) return; + + data_t *d = get_local_ptr(ithr - id_in_grp, dst, scratchpad) + start * cl; + const data_t *space = get_local_ptr(ithr - id_in_grp + 1, dst, scratchpad) + + start * cl; + const size_t len = nstl::min(end * cl, reduction_size) - start * cl; + + (*drv_)(d, space, 1, len); +#endif +} + +template struct cpu_reducer_t; +template struct cpu_reducer_t; + +/* cpu_reducer_2d_t */ + +template +void cpu_reducer_2d_t::conf_t::init_scratchpad( + memory_tracking::registrar_t &scratchpad) const { + if (balancer_.nthr_per_group_ == 1) return; + + const size_t space_size = balancer_.ngroups_ * balancer_.nthr_per_group_ + * cpu_reducer_2d_t::space_per_thread(balancer_); + scratchpad.book(key_reducer_space, sizeof(data_t) * space_size); + scratchpad.book(key_reducer_space_bctx, + sizeof(simple_barrier::ctx_t) * balancer_.ngroups_); +} + +template +cpu_reducer_2d_t::cpu_reducer_2d_t(const conf_t &conf) + : conf_(conf), drv_(nullptr) +{ + if (balancer().nthr_per_group_ == 1) return; + + drv_ = create_reduce_2d_drv(balancer().nthr_per_group_, + space_per_thread(balancer()), conf_.job_size_x_, conf_.dst_x_, + true); +} + +template +cpu_reducer_2d_t::~cpu_reducer_2d_t() { delete drv_; } + +template +typename cpu_reducer_2d_t::data_t *cpu_reducer_2d_t:: +get_local_ptr(int ithr, const memory_tracking::grantor_t &scratchpad) const { + const int id_in_grp = balancer().id_in_group(ithr); + const int grp_id = balancer().group_id(ithr); + const int offset_factor = grp_id * balancer().nthr_per_group_ + id_in_grp; + auto space = scratchpad.template get(key_reducer_space); + return space + offset_factor * space_per_thread(balancer()); +} + +template +int cpu_reducer_2d_t::choose_x_blocking(int nx, int ny, + int nthr_per_grp) const { + // find x_blocking for better balance reducing work between threads + assert(conf_.x_block_ > 0 && nx > conf_.x_block_ + && nx % conf_.x_block_ == 0); + int x_blocking = nx / conf_.x_block_; + int min_x_blocking = + utils::div_up(x_blocking, nstl::max(1, nthr_per_grp / ny)); + while (true) { + if (x_blocking % 2 == 0 && x_blocking >= min_x_blocking * 2) + x_blocking /= 2; + else if (x_blocking % 3 == 0 && x_blocking >= min_x_blocking * 3) + x_blocking /= 3; + else + break; + } + if (x_blocking >= min_x_blocking * 4) x_blocking = 1; + x_blocking *= conf_.x_block_; + return x_blocking; +} + +template +void cpu_reducer_2d_t::reduce_block(const data_t* space_base, + data_t *dst, int job, int start_y, int start_x, + int ny_start, int nx_start, int ny_step, int nx_step) const { + data_t *d = dst + (start_y + ny_start) * conf_.dst_x_ + + start_x + nx_start; + const data_t *space = space_base + job * balancer().job_size_ + + ny_start * conf_.job_size_x_ + nx_start; +#ifdef SIMPLE_IMPL + for (int idg = 0; idg < balancer().nthr_per_group_; ++idg) { + const data_t *w = &space[idg * space_per_thread(balancer())]; + for (int y = 0; y < ny_step; ++y) + for (int x = 0; x < nx_step; ++x) { + d[y * conf_.dst_x_ + x] + = (idg == 0 ? 0 : d[y * conf_.dst_x_ + x]) + + w[y * conf_.job_size_x_ + x]; + } + } +#else + (*drv_)(d, space, ny_step, nx_step); +#endif +} + +template +void cpu_reducer_2d_t::reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + + const int id_in_grp = balancer().id_in_group(ithr); + const int njobs_in_grp = balancer().ithr_njobs(ithr); + const int njobs_x = utils::div_up(conf_.dst_x_, conf_.job_size_x_); + const int global_job_start = balancer().ithr_job_off(ithr); + + const data_t *space_base = get_local_ptr(ithr - id_in_grp, scratchpad); + + const int pr_grps = nstl::min(njobs_in_grp, balancer().nthr_per_group_); + const int pr_nthr_per_grp = balancer().nthr_per_group_ / pr_grps; + + if (id_in_grp >= pr_grps * pr_nthr_per_grp) + return; /* idle */ + + const int pr_my_grp = id_in_grp / pr_nthr_per_grp; + const int pr_my_id = id_in_grp % pr_nthr_per_grp; + + int pr_job_start{0}, pr_job_end{0}; + balance211(njobs_in_grp, pr_grps, pr_my_grp, pr_job_start, pr_job_end); + + for (int j = pr_job_start; j < pr_job_end; ++j) { + const int global_job = global_job_start + j; + const int j_y = global_job / njobs_x; + const int j_x = global_job % njobs_x; + const int start_y = j_y * conf_.job_size_y_; + const int start_x = j_x * conf_.job_size_x_; + const int ny = nstl::min(conf_.dst_y_ - start_y, conf_.job_size_y_); + const int nx = nstl::min(conf_.dst_x_ - start_x, conf_.job_size_x_); + int x_blocking = choose_x_blocking(nx, ny, pr_nthr_per_grp); + + int nxy_start{0}, nxy_end{0}; + balance211(ny * nx / x_blocking, pr_nthr_per_grp, pr_my_id, + nxy_start, nxy_end); + if (nxy_start == nxy_end) continue; + nxy_start *= x_blocking; + nxy_end *= x_blocking; + + int nxy = nxy_start; + if (nxy % nx != 0) { + int nx_step = nstl::min(nx - nxy % nx, nxy_end - nxy); + reduce_block(space_base, dst, j, start_y, start_x, + nxy / nx, nxy % nx, 1, nx_step); + nxy += nx_step; + } + if ((nxy_end - nxy) > nx) { + int ny_step = (nxy_end - nxy) / nx; + reduce_block(space_base, dst, j, start_y, start_x, + nxy / nx, nxy % nx, ny_step, nx); + nxy += nx * ny_step; + } + if ((nxy_end - nxy) > 0) { + reduce_block(space_base, dst, j, start_y, start_x, + nxy / nx, nxy % nx, 1, nxy_end - nxy); + } + } +} + +template struct cpu_reducer_2d_t; +template struct cpu_reducer_2d_t; + +/* accumulator section */ + +template +cpu_accumulator_1d_t::cpu_accumulator_1d_t(): drv_(nullptr) { + drv_ = create_reduce_2d_drv(1, 0, 0, 0, false); +} + +template +cpu_accumulator_1d_t::~cpu_accumulator_1d_t() { + delete drv_; +} + +template +void cpu_accumulator_1d_t::accumulate(data_t *dst, + const data_t *src, size_t size) { + (*drv_)(dst, src, 1, size); +} + +template struct cpu_accumulator_1d_t; +template struct cpu_accumulator_1d_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp new file mode 100644 index 000000000000..27f5939cd2a3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp @@ -0,0 +1,334 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REDUCER_HPP +#define CPU_REDUCER_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_types.h" +#include "nstl.hpp" +#include "type_helpers.hpp" + +#include "cpu_barrier.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +/** class to perform balancing over 3D array + * + * Conceptually the reduction happens according to the picture below: + * + * <--job_size-> + * +-----------+ +-----------+ +-----------+ ^ + * | | | | | | | + * | | | | | | | + * | 1 | | 2 | . . . | njobs | | reduction_size + * | | | | | | | + * | | | | | | | + * +-----------+ +-----------+ +-----------+ v + * + * | | | | | | | | | + * v v v v v v v v v + * ===================================================== vertical reduction + * + * +-----------+ +-----------+ . . . +-----------+ result + * + * In a simple case the result must be contiguous in memory. + * @class cpu_reducer_t is an implementation. + * + * Threads are divided into groups. The groups are independent of each other. + * Each group may work on several jobs (the distribution is not uniform, since + * njobs might be not a multiple of groups). Threads within a group work on + * different parts of the reduction dimension. Thread 0 in each group is called + * master (@sa reduce_balancer_t::master()). + * + * If threading driver does not allow sync between sub-group of threads (e.g. + * Intel(R) TBB) the # of thread per group is enforced to be 1. + */ +struct reduce_balancer_t { + reduce_balancer_t() { init(1, 1, 1, 1, 0); } /* trivial balance */ + reduce_balancer_t(int nthr, int job_size, int njobs, int reduction_size, + size_t max_buffer_size) + { init(nthr, job_size, njobs, reduction_size, max_buffer_size); } + + reduce_balancer_t &init(int nthr, int job_size, int njobs, + int reduction_size, size_t max_buffer_size) + { + syncable_ = mkldnn_thr_syncable(); + nthr_ = nthr; + job_size_ = job_size; + njobs_ = njobs; + reduction_size_ = reduction_size; + max_buffer_size_ = max_buffer_size; + balance(); + return *this; + } + + bool syncable_; + int nthr_; + int job_size_, njobs_, reduction_size_; + + int ngroups_; /** number of independent work (thread) groups */ + int nthr_per_group_; /** number of threads within a single work group */ + int njobs_per_group_ub_; /** the max # of jobs within a work group */ + + bool master(int ithr) const { return id_in_group(ithr) == 0; } + bool idle(int ithr) const { return ithr >= nthr_per_group_ * ngroups_; } + + int group_id(int ithr) const { return ithr / nthr_per_group_; } + int id_in_group(int ithr) const { return ithr % nthr_per_group_; } + + int grp_njobs(int grp) const { + if (grp >= ngroups_) return 0; + return njobs_ / ngroups_ + (grp < njobs_ % ngroups_); + } + int grp_job_off(int grp) const { + if (grp >= ngroups_) return njobs_; + return njobs_ / ngroups_ * grp + nstl::min(grp, njobs_ % ngroups_); + } + + int ithr_njobs(int ithr) const { return grp_njobs(group_id(ithr)); } + int ithr_job_off(int ithr) const { return grp_job_off(group_id(ithr)); } + +private: + size_t max_buffer_size_; + void balance(); +}; + +/** forward declaration of reduce driver */ +template struct reducer_2d_driver_t; + +/** class to perform a reduction over 3D array + * + * Balancing is based on @class reduce_balancer_t. + * Restrictions: the result of the reduction must be contiguous in memory. * + * The reduction happens according to the picture below (once more): + * + * <--job_size-> + * +-----------+ +-----------+ +-----------+ ^ + * | | | | | | | + * | | | | | | | + * | 1 | | 2 | . . . | njobs | | reduction_size + * | | | | | | | + * | | | | | | | + * +-----------+ +-----------+ +-----------+ v + * + * | | | | | | | | | + * v v v v v v v v v + * ===================================================== vertical reduction + * + * +-----------+ +-----------+ . . . +-----------+ (contiguous) result + * + * An example how work might be shared is shown below. + * + * In this example group 0 owns 2 (independent) jobs -- 2 big squares. + * The number of threads per group is also 2 (thread 0 of group 0 and thread 1 + * of group 0). Master threads (i.e. threads with id 0 in corresponding group) + * from each group put the partial result directly into destination memory, + * while all the other threads with-in the group use workspace (on the picture + * the only thread 1). Once intermediate results obtained each group reduces + * corresponding part (own jobs) to the destination memory. + * + * <------- group 0 -------> + * + * +-----------+ +-----------+ ^ + * | | | | | thread 0 of reduces to the dest-memory + * | | | | | group 0 +-----------+ +-----------+ + * |- - - - - -| |- - - - - -| X + * | | | | | thread 1 of reduces to workspace[tid=1]: + * | | | | | group 0 +-----------+ +-----------+ + * +-----------+ +-----------+ v + * | | | | | | + * v v v v v v + * ((barrier)) ============================= + * + * dest-memory: +-----------+ +-----------+ + */ +template +struct cpu_reducer_t { + typedef typename prec_traits::type data_t; + + struct conf_t { + conf_t() = default; + conf_t &init(const reduce_balancer_t &balancer) + { balancer_ = balancer; return *this; } + + void init_scratchpad(memory_tracking::registrar_t &scratchpad) const; + + reduce_balancer_t balancer_; + }; + + cpu_reducer_t(const conf_t &conf); + ~cpu_reducer_t(); + + /** initializes reducer. + * Must be called from a single thread prior to actual usage */ + void init(const memory_tracking::grantor_t &scratchpad) const { + if (balancer().nthr_per_group_ == 1) return; + + auto bctx = scratchpad.template get( + memory_tracking::names::key_reducer_space_bctx); + for (int i = 0; i < balancer().ngroups_; ++i) + simple_barrier::ctx_init(&bctx[i]); + } + + /** for given thread returns the pointer where to put partial results. + * Reduction destination @p dst must be provided as well (master threads + * from each group will use it for partial result to reduce memory + * pressure). + * + * @note: job offset is already applied by get_local_ptr(), which means all + * threads should start writing from the very beginning of returned + * address. + */ + data_t *get_local_ptr(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + + /** performs the reduction with built-in synchronization. */ + void reduce(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + + auto bctx = scratchpad.template get( + memory_tracking::names::key_reducer_space_bctx); + simple_barrier::barrier(&bctx[balancer().group_id(ithr)], + balancer().nthr_per_group_); + + reduce_nolock(ithr, dst, scratchpad); + } + + const reduce_balancer_t &balancer() const { return conf_.balancer_; } + +private: + static size_t space_per_thread(const reduce_balancer_t &balancer) + { return balancer.njobs_per_group_ub_ * balancer.job_size_; } + + /* The scratchpad is organized as follows: + * + * data_t space[nthr_][njobs_per_group_ub_][jobs_size_]; + * simple_barrier::ctx_t barriers[groups_]; */ + + const conf_t conf_; + reducer_2d_driver_t *drv_; + + void reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; +}; + +template +struct cpu_reducer_2d_t { + typedef typename prec_traits::type data_t; + + struct conf_t { + conf_t() = default; + conf_t &init(const reduce_balancer_t &balancer, int job_size_x, + int job_size_y, int x_block, int dst_x, int dst_y) { + balancer_ = balancer; + job_size_x_ = job_size_x; + job_size_y_ = job_size_y; + x_block_ = x_block; + dst_x_ = dst_x; + dst_y_ = dst_y; + return *this; + } + + void init_scratchpad(memory_tracking::registrar_t &scratchpad) const; + + reduce_balancer_t balancer_; + int job_size_x_, job_size_y_, x_block_, dst_x_, dst_y_; + }; + + cpu_reducer_2d_t(const conf_t &conf); + ~cpu_reducer_2d_t(); + + /** initializes reducer. + * Must be called from a single thread prior to actual usage */ + void init(const memory_tracking::grantor_t &scratchpad) const { + if (balancer().nthr_per_group_ == 1) return; + + auto bctx = scratchpad.template get( + memory_tracking::names::key_reducer_space_bctx); + for (int i = 0; i < balancer().ngroups_; ++i) + simple_barrier::ctx_init(&bctx[i]); + } + + /** for given thread returns the pointer where to put partial results */ + data_t *get_local_ptr(int ithr, + const memory_tracking::grantor_t &scratchpad) const; + + /** performs the reduction with built-in synchronization. */ + void reduce(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + + auto bctx = scratchpad.template get( + memory_tracking::names::key_reducer_space_bctx); + simple_barrier::barrier(&bctx[balancer().group_id(ithr)], + balancer().nthr_per_group_); + + reduce_nolock(ithr, dst, scratchpad); + } + + const reduce_balancer_t &balancer() const { return conf_.balancer_; } + +private: + static size_t space_per_thread(const reduce_balancer_t &balancer) + { return balancer.njobs_per_group_ub_ * balancer.job_size_; } + + /* The scratchpad is organized as follows: + * + * data_t space[nthr_][njobs_per_group_ub_][jobs_size_]; + * simple_barrier::ctx_t barriers[groups_]; */ + + const conf_t conf_; + reducer_2d_driver_t *drv_; + + int choose_x_blocking(int nx, int ny, int nthr_per_grp) const; + void reduce_block(const data_t* space_base, data_t *dst, + int job, int start_y, int start_x, + int ny_start, int nx_start, int ny_step, int nx_step) const; + void reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; +}; + +/** simple 1d accumulator: y[:] += x[:] */ +template +struct cpu_accumulator_1d_t { + typedef typename prec_traits::type data_t; + + cpu_accumulator_1d_t(); + ~cpu_accumulator_1d_t(); + void accumulate(data_t *dst, const data_t *src, size_t size); + + reducer_2d_driver_t *drv_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp new file mode 100644 index 000000000000..82be70353de9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp @@ -0,0 +1,262 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "cpu_engine.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reorder_pd.hpp" +#include "cpu_memory.hpp" +#include "type_helpers.hpp" + +#include "cpu/jit_uni_reorder.hpp" +#include "cpu/simple_reorder.hpp" +#include "cpu/wino_reorder.hpp" +#include "cpu/rnn/rnn_reorders.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using rpd_create_f = mkldnn::impl::engine_t::reorder_primitive_desc_create_f; + +namespace { +using namespace mkldnn::impl::data_type; +using namespace mkldnn::impl::format_tag; + +#define REG_SR(idt, ifmt, odt, ofmt, ...) \ + simple_reorder_t::pd_t::create + +#define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \ + REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep), \ + REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse) + +#define REG_SR_DIRECT_COPY(idt, odt) \ + REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy), \ + REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0) + +static const rpd_create_f cpu_reorder_impl_list[] = { + /* winograd */ + wino_reorder_t::pd_t::create, + //wino_reorder_t::pd_t::create, + + /* rnn reorders */ + rnn_data_reorder_t::pd_t::create, + rnn_weights_reorder_t::pd_t::create, + rnn_weights_reorder_t::pd_t::create, + + /* conv reorders w/ compensation */ + REG_SR(f32, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), + + /* regular reorders */ + +#if defined(__INTEL_COMPILER) || (defined(__GNUC__) && !defined(__clang__)) + /* Direct copy for icc which is faster than jitted code; + * Direct copy for gcc which might or might not be faster than jitted + * code, but still worth it because doesn't require jitting, i.e. much + * faster creation time. This is tentative solution and should be removed + * later (when we will cache jitted code?...). */ + REG_SR_DIRECT_COPY(f32, f32), +#endif + +#ifdef __INTEL_COMPILER + /* direct copy for icc, which is faster than jitted code */ + /* + REG_SR_DIRECT_COPY(f32, s32), + REG_SR_DIRECT_COPY(f32, s8), + REG_SR_DIRECT_COPY(f32, u8), + REG_SR_DIRECT_COPY(s32, f32), + REG_SR_DIRECT_COPY(s32, s32), + REG_SR_DIRECT_COPY(s32, s8), + REG_SR_DIRECT_COPY(s32, u8), + REG_SR_DIRECT_COPY(s8, f32), + REG_SR_DIRECT_COPY(s8, s32), + REG_SR_DIRECT_COPY(s8, s8), + REG_SR_DIRECT_COPY(s8, u8), + REG_SR_DIRECT_COPY(u8, f32), + REG_SR_DIRECT_COPY(u8, s32), + REG_SR_DIRECT_COPY(u8, s8), + REG_SR_DIRECT_COPY(u8, u8), + */ +#endif + + /* jit */ + jit_uni_reorder_create, + + /* fp32: flat <-> blocked with tail */ + /* + REG_SR_BIDIR(f32, any, f32, nCw4c), + REG_SR_BIDIR(f32, any, f32, nCw8c), + REG_SR_BIDIR(f32, any, f32, OIw4i4o), + REG_SR_BIDIR(f32, any, f32, OIw8i8o), + REG_SR_BIDIR(f32, any, f32, OIw8o8i), + REG_SR_BIDIR(f32, any, f32, gOIw4i4o), + REG_SR_BIDIR(f32, any, f32, gOIw8i8o), + REG_SR_BIDIR(f32, any, f32, gOIw8o8i), + + REG_SR_BIDIR(f32, any, f32, nCw16c), + REG_SR_BIDIR(f32, any, f32, OIw16o16i), + REG_SR_BIDIR(f32, any, f32, OIw16i16o), + REG_SR_BIDIR(f32, any, f32, IOw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIw16i16o), + REG_SR_BIDIR(f32, any, f32, gIOw16o16i), + + REG_SR_BIDIR(f32, any, f32, nChw4c), + REG_SR_BIDIR(f32, any, f32, nChw8c), + REG_SR_BIDIR(f32, any, f32, OIhw4i4o), + REG_SR_BIDIR(f32, any, f32, Ohwi8o), + + REG_SR_BIDIR(f32, any, f32, OIhw8i8o), + REG_SR_BIDIR(f32, any, f32, OIhw8o8i), + REG_SR_BIDIR(f32, any, f32, gOIhw4i4o), + REG_SR_BIDIR(f32, any, f32, gOIhw4o4i), + REG_SR_BIDIR(f32, any, f32, gOhwi8o), + REG_SR_BIDIR(f32, any, f32, gOIhw8i8o), + REG_SR_BIDIR(f32, any, f32, gOIhw8o8i), + + REG_SR_BIDIR(f32, any, f32, nChw16c), + REG_SR_BIDIR(f32, any, f32, Oihw4o), + REG_SR_BIDIR(f32, any, f32, Oihw16o), + REG_SR_BIDIR(f32, any, f32, Ohwi4o), + REG_SR_BIDIR(f32, any, f32, Ohwi16o), + REG_SR_BIDIR(f32, any, f32, OIhw16o16i), + REG_SR_BIDIR(f32, any, f32, OIhw16i16o), + REG_SR_BIDIR(f32, any, f32, IOhw16o16i), + REG_SR_BIDIR(f32, any, f32, gOihw4o), + REG_SR_BIDIR(f32, any, f32, gOihw16o), + REG_SR_BIDIR(f32, any, f32, gOhwi4o), + REG_SR_BIDIR(f32, any, f32, gOhwi16o), + REG_SR_BIDIR(f32, any, f32, gOIhw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIhw16i16o), + REG_SR_BIDIR(f32, any, f32, gIOhw16o16i), + + REG_SR_BIDIR(f32, any, f32, nCdhw4c), + REG_SR_BIDIR(f32, any, f32, nCdhw8c), + REG_SR_BIDIR(f32, any, f32, OIdhw4i4o), + REG_SR_BIDIR(f32, any, f32, Odhwi8o), + REG_SR_BIDIR(f32, any, f32, OIdhw8i8o), + REG_SR_BIDIR(f32, any, f32, OIdhw8o8i), + REG_SR_BIDIR(f32, any, f32, gOIdhw4i4o), + REG_SR_BIDIR(f32, any, f32, gOdhwi8o), + REG_SR_BIDIR(f32, any, f32, gOIdhw8i8o), + REG_SR_BIDIR(f32, any, f32, gOIdhw8o8i), + + REG_SR_BIDIR(f32, any, f32, nCdhw16c), + REG_SR_BIDIR(f32, any, f32, Oidhw4o), + REG_SR_BIDIR(f32, any, f32, Oidhw16o), + REG_SR_BIDIR(f32, any, f32, Odhwi16o), + REG_SR_BIDIR(f32, any, f32, OIdhw16o16i), + REG_SR_BIDIR(f32, any, f32, OIdhw16i16o), + REG_SR_BIDIR(f32, any, f32, gOidhw4o), + REG_SR_BIDIR(f32, any, f32, gOidhw16o), + REG_SR_BIDIR(f32, any, f32, gOdhwi16o), + REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o), + */ + + /* fp32: blocked <-> blocked with tail */ + REG_SR_BIDIR(f32, nCw8c, f32, nCw16c), + REG_SR_BIDIR(f32, nChw8c, f32, nChw16c), + REG_SR_BIDIR(f32, nCdhw8c, f32, nCdhw16c), + + /* int: flat <-> blocked with tail */ + /* + REG_SR_BIDIR(f32, any, s32, nChw16c), + REG_SR_BIDIR(f32, any, s8, nChw16c), + REG_SR_BIDIR(f32, any, u8, nChw16c), + REG_SR_BIDIR(s32, any, f32, nChw16c), + REG_SR_BIDIR(s32, any, s32, nChw16c), + REG_SR_BIDIR(s32, any, s8, nChw16c), + REG_SR_BIDIR(s32, any, u8, nChw16c), + REG_SR_BIDIR(s8, any, f32, nChw16c), + REG_SR_BIDIR(s8, any, s32, nChw16c), + REG_SR_BIDIR(s8, any, s8, nChw16c), + REG_SR_BIDIR(s8, any, u8, nChw16c), + REG_SR_BIDIR(u8, any, f32, nChw16c), + REG_SR_BIDIR(u8, any, s32, nChw16c), + REG_SR_BIDIR(u8, any, s8, nChw16c), + REG_SR_BIDIR(u8, any, u8, nChw16c), + + REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i), + REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i), + REG_SR_BIDIR(s8, any, f32, OIhw4i16o4i), + REG_SR_BIDIR(s8, any, s8, OIhw4i16o4i), + REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i), + REG_SR_BIDIR(s8, any, f32, gOIhw4i16o4i), + REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i), + REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i), + */ + + /* reference: the last line of defence */ + /* + REG_SR(f32, any, f32, any, fmt_order::any, spec::reference), + REG_SR(f32, any, s32, any, fmt_order::any, spec::reference), + REG_SR(f32, any, s8, any, fmt_order::any, spec::reference), + REG_SR(f32, any, u8, any, fmt_order::any, spec::reference), + + REG_SR(s32, any, f32, any, fmt_order::any, spec::reference), + REG_SR(s32, any, s32, any, fmt_order::any, spec::reference), + REG_SR(s32, any, s8, any, fmt_order::any, spec::reference), + REG_SR(s32, any, u8, any, fmt_order::any, spec::reference), + + REG_SR(s8, any, f32, any, fmt_order::any, spec::reference), + REG_SR(s8, any, s32, any, fmt_order::any, spec::reference), + REG_SR(s8, any, s8, any, fmt_order::any, spec::reference), + REG_SR(s8, any, u8, any, fmt_order::any, spec::reference), + + REG_SR(u8, any, f32, any, fmt_order::any, spec::reference), + REG_SR(u8, any, s32, any, fmt_order::any, spec::reference), + REG_SR(u8, any, u8, any, fmt_order::any, spec::reference), + REG_SR(u8, any, s8, any, fmt_order::any, spec::reference), + */ + + /* eol */ + nullptr, +}; +} + +const rpd_create_f *cpu_engine_t::get_reorder_implementation_list() const { + return cpu_reorder_impl_list; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp new file mode 100644 index 000000000000..1622eb68495b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp @@ -0,0 +1,48 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REORDER_PD_HPP +#define CPU_REORDER_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "reorder_pd.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_reorder_pd_t: public reorder_pd_t { + using reorder_pd_t::reorder_pd_t; + + status_t init() { + const auto &post_ops = attr()->post_ops_; + bool args_ok = IMPLICATION(post_ops.len_ != 0, post_ops.len_ == 1 + && post_ops.entry_[0].kind == primitive_kind::sum); + scratchpad_engine_ = src_engine_; + return args_ok ? status::success : status::unimplemented; + } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp new file mode 100644 index 000000000000..f16587b99f4c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp @@ -0,0 +1,41 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SHUFFLE_PD_HPP +#define CPU_SHUFFLE_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "shuffle_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_shuffle_pd_t: public shuffle_pd_t { + using shuffle_pd_t::shuffle_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp new file mode 100644 index 000000000000..3a39eab974b5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp @@ -0,0 +1,45 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SOFTMAX_PD_HPP +#define CPU_SOFTMAX_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "softmax_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_softmax_fwd_pd_t: public softmax_fwd_pd_t { + using softmax_fwd_pd_t::softmax_fwd_pd_t; +}; + +struct cpu_softmax_bwd_pd_t: public softmax_bwd_pd_t { + using softmax_bwd_pd_t::softmax_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp new file mode 100644 index 000000000000..1ab5d9f17435 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp @@ -0,0 +1,48 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu_engine.hpp" + +/* +#include "cpu/ref_sum.hpp" +#include "cpu/simple_sum.hpp" +*/ + +namespace mkldnn { +namespace impl { +namespace cpu { + +using spd_create_f = mkldnn::impl::engine_t::sum_primitive_desc_create_f; + +namespace { +#define INSTANCE(...) __VA_ARGS__::pd_t::create +static const spd_create_f cpu_sum_impl_list[] = { + /* + INSTANCE(simple_sum_t), + INSTANCE(ref_sum_t), + */ + nullptr, +}; +#undef INSTANCE +} + +const spd_create_f *cpu_engine_t::get_sum_implementation_list() const { + return cpu_sum_impl_list; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp new file mode 100644 index 000000000000..0965129f9b02 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp @@ -0,0 +1,39 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SUM_PD_HPP +#define CPU_SUM_PD_HPP + +#include "c_types_map.hpp" +#include "sum_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_sum_pd_t: public sum_pd_t { + using sum_pd_t::sum_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp new file mode 100644 index 000000000000..a9810dec28f2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp @@ -0,0 +1,372 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "gemm_utils_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace gemm_utils { +#define BM_NOCOPY_AVX 64 +#define BN_NOCOPY_AVX 48 +#define BK_NOCOPY_AVX 384 +#define BN_LARGE_NOCOPY_AVX 192 +#define BM_SMALL_NOCOPY_AVX 16 +#define BN_SMALL_NOCOPY_AVX 1 +#define BK_SMALL_NOCOPY_AVX 4 +// Determine number of threads for each dimension of a 3-D partitioning +// algorithm based on input parameters +// m/n/k - First/second/third parameter for GEMM +// nthrs - total available number of threads +// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension +// BM/BN/BK - blocking values +void calc_nthr_nocopy_avx(int m, int n, int k, + int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN, + int *BK) +{ + int nthr, nthr_m, nthr_n, nthr_k; + int MB, NB, KB; + + nthr = nthrs; + nthr_m = (m + BM_NOCOPY_AVX - 1) / BM_NOCOPY_AVX; + nthr_n = (n + BN_NOCOPY_AVX - 1) / BN_NOCOPY_AVX; + nthr_k = 1; + + // Partition along K dimension + // - if threading allows having barriers (e.g. OMP) + // - if there is not enough parallelism along M or N + if (mkldnn_thr_syncable()) { + int nthr_other = nthr_k = 1; + while ((nthr_m * nthr_n * nthr_other < nthr) + && (k / (nthr_other + 1) > BK_NOCOPY_AVX)) { + nthr_other++; + if ((nthr / nthr_other) * nthr_other > 0.9 * nthr) + nthr_k = nthr_other; + } + } + nthr /= nthr_k; + + if (nthr_m == 1) + nthr_n = nthr; + if (nthr_n == 1) + nthr_m = nthr; + + // Simple partition reduction + while (nthr_m * nthr_n > nthr) + if (nthr_m > nthr_n) + nthr_m--; + else + nthr_n--; + while (nthr_m * nthr_n < nthr) + if (nthr_m < nthr_n) + nthr_m++; + else + nthr_n++; + + if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) { + + if (nthr_m <= nthr_n) { + nthr_m = (int)sqrt((double)nthr); + if (nthr_m > (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX) + nthr_m = (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX; + nthr_n = nthr / nthr_m; + + while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) { + nthr_m--; + nthr_n = nthr / nthr_m; + } + } else { + nthr_n = (int)sqrt((double)nthr); + if (nthr_n > (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX) + nthr_n = (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX; + nthr_m = nthr / nthr_n; + + while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) { + nthr_n--; + nthr_m = nthr / nthr_n; + } + } + } + + MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX - 1; + MB -= MB % BM_SMALL_NOCOPY_AVX; + NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX - 1; + NB -= NB % BN_SMALL_NOCOPY_AVX; + KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX - 1; + KB -= KB % BK_SMALL_NOCOPY_AVX; + + if (MB * nthr_m > m) + nthr_m = (m + MB - 1) / MB; + if (NB * nthr_n > n) + nthr_n = (n + NB - 1) / NB; + if (KB * nthr_k > k) + nthr_k = (k + KB - 1) / KB; + + *nthrs_m = nthr_m; + *nthrs_n = nthr_n; + *nthrs_k = nthr_k; + + *BM = MB; + *BN = NB; + *BK = KB; +} +#undef BM_NOCOPY_AVX +#undef BN_NOCOPY_AVX +#undef BK_NOCOPY_AVX +#undef BN_LARGE_NOCOPY_AVX +#undef BM_SMALL_NOCOPY_AVX +#undef BN_SMALL_NOCOPY_AVX +#undef BK_SMALL_NOCOPY_AVX + +#define BM_NOCOPY_AVX512_COMMON 32 +#define BN_NOCOPY_AVX512_COMMON 64 +#define BK_NOCOPY_AVX512_COMMON 192 +#define BN_LARGE_NOCOPY_AVX512_COMMON 192 +#define BM_SMALL_NOCOPY_AVX512_COMMON 16 +#define BN_SMALL_NOCOPY_AVX512_COMMON 1 +#define BK_SMALL_NOCOPY_AVX512_COMMON 4 +// Determine number of threads for each dimension of a 3-D partitioning +// algorithm based on input parameters +// m/n/k - First/second/third parameter for GEMM +// nthrs - total available number of threads +// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension +// BM/BN/BK - blocking values +void calc_nthr_nocopy_avx512_common(int m, + int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, + int *BM, int *BN, int *BK) +{ + int nthr, nthr_m, nthr_n, nthr_k = 1; + int MB, NB, KB; + nthr = nthrs; + + int counter = 0; + float ratio_float = 1.; + int ratio = 1; + nthr = nthrs; + int nthr_m_gt_n; + + // Partition along K dimension + // - if threading allows having barriers (e.g. OMP) + // - if there is not enough parallelism along M or N + if (mkldnn_thr_syncable()) { + if (n <= 2 * BN_NOCOPY_AVX512_COMMON && + m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr) { + nthr_k = k / BK_NOCOPY_AVX512_COMMON; + if (nthr_k > nthr / 4) + nthr_k = nthr / 4; + if (nthr_k < 1) + nthr_k = 1; + + while ((nthr_k > 1) && (nthr % nthr_k)) { + nthr_k--; + } + nthr /= nthr_k; + } else { + nthr_k = 1; + } + } + nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON; + nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON; + + if (nthr_m < 1) + nthr_m = 1; + if (nthr_n < 1) + nthr_n = 1; + + nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0; + ratio_float = (float)nthr_m / nthr_n; + + if (nthr_m_gt_n) + ratio = (int)ratio_float; + else + ratio = (int)(1. / ratio_float); + + // scale down nthr_m and nthr_n if they are too large + while (nthr_m * nthr_n > 4 * nthr) { + nthr_m /= 2; + nthr_n /= 2; + } + + if (nthr_m < 1) + nthr_m = 1; + if (nthr_n < 1) + nthr_n = 1; + + // Simple partition reduction + counter = 0; + while (nthr_m * nthr_n > nthr) { + if (nthr_m > nthr_n) { + if (counter < ratio) + nthr_m--; + else { + nthr_n--; + counter = -1; + } + } else { + if (counter < ratio) + nthr_n--; + else { + nthr_m--; + counter = -1; + } + } + counter++; + } + + // Simple partition increment + counter = 0; + while (nthr_m * nthr_n < 0.95 * nthr) { + if (nthr_m > nthr_n) { + if (counter < ratio) + nthr_m++; + else { + nthr_n++; + counter = -1; + } + } else { + if (counter < ratio) + nthr_n++; + else { + nthr_m++; + counter = -1; + } + } + counter++; + } + + // if nothing works out, then this should work + if ((nthr_m * nthr_n > nthr)) { + + if (nthr_m <= nthr_n) { + nthr_m = (int)sqrt((double)nthr); + if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1) + / BM_SMALL_NOCOPY_AVX512_COMMON) + nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1) + / BM_SMALL_NOCOPY_AVX512_COMMON; + nthr_n = nthr / nthr_m; + + while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) { + nthr_m--; + nthr_n = nthr / nthr_m; + } + } else { + nthr_n = (int)sqrt((double)nthr); + if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1) + / BN_SMALL_NOCOPY_AVX512_COMMON) + nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1) + / BN_SMALL_NOCOPY_AVX512_COMMON; + nthr_m = nthr / nthr_n; + + while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) { + nthr_n--; + nthr_m = nthr / nthr_n; + } + } + } + + MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1; + MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON; + NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1; + NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON; + KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1; + KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON; + + if (MB * nthr_m > m) + nthr_m = (m + MB - 1) / MB; + if (NB * nthr_n > n) + nthr_n = (n + NB - 1) / NB; + if (KB * nthr_k > k) + nthr_k = (k + KB - 1) / KB; + + *nthrs_m = nthr_m; + *nthrs_n = nthr_n; + *nthrs_k = nthr_k; + + *BM = MB; + *BN = NB; + *BK = KB; +} +#undef BM_NOCOPY_AVX512_COMMON +#undef BN_NOCOPY_AVX512_COMMON +#undef BK_NOCOPY_AVX512_COMMON +#undef BN_LARGE_NOCOPY_AVX512_COMMON +#undef BM_SMALL_NOCOPY_AVX512_COMMON +#undef BN_SMALL_NOCOPY_AVX512_COMMON +#undef BK_SMALL_NOCOPY_AVX512_COMMON + +// Partition n values as equally as possible among nthr threads +// and set the offset (t_offset) and number of values (t_block) for ithr +// Assumption: 0 <= ithr < nthr +void partition_unit_diff( + int ithr, int nthr, int n, int *t_offset, int *t_block) +{ + int band = n / nthr; + if (band == 0) + band = 1; + int tail = n - band * nthr; + if (tail < 0) + tail = 0; + + if (ithr < tail) { + band++; + *t_offset = band * ithr; + *t_block = band; + } else { + *t_offset = band * ithr + tail; + *t_block = band; + } + + if (*t_offset >= n) { + *t_offset = 0; + *t_block = 0; + } + + if (*t_offset + *t_block > n) { + *t_block = n - *t_offset; + } +} + +// Sum the m*n values from p_src into p_dst, assuming the two-dimensional +// arrays have leading dimensions ld_src and ld_dst, respectively +template +void sum_two_matrices(int m, int n, + data_t * __restrict p_src, dim_t ld_src, + data_t * __restrict p_dst, dim_t ld_dst) +{ + int i, j; + for (j = 0; j < n; j++) { + for (i = 0; i < m; i++) { + p_dst[i + j * ld_dst] += p_src[i + j * ld_src]; + } + } +} + +template +void sum_two_matrices(int m, int n, + float * __restrict p_src, dim_t ld_src, + float * __restrict p_dst, dim_t ld_dst); + +template +void sum_two_matrices(int m, int n, + double * __restrict p_src, dim_t ld_src, + double * __restrict p_dst, dim_t ld_dst); +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp new file mode 100644 index 000000000000..3352298b4aa6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp @@ -0,0 +1,72 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_UTILS_HPP +#define GEMM_UTILS_HPP + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace gemm_utils { +// Alias for any dimension related variable. +typedef ptrdiff_t dim_t; + +template +struct gemm_traits {}; + +template +struct gemm_traits { + static constexpr int m = 8; + static constexpr int n = 6; + static constexpr int BM = 4032; + static constexpr int BN = isTransA ? 96 : 192; + static constexpr int BK = isTransB ? 96 : 512; +}; + +template +struct gemm_traits { + static constexpr int m = 16; + static constexpr int n = 6; + static constexpr int BM = 4032; + static constexpr int BN = isTransA ? 96 : 48; + static constexpr int BK = isTransB ? 96 : 256; +}; + +template +using unroll_factor = gemm_traits; + +template +void sum_two_matrices(int m, int n, + data_t * __restrict p_src, dim_t ld_src, + data_t * __restrict p_dst, dim_t ld_dst); + +void calc_nthr_nocopy_avx512_common(int m, + int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, + int *BM, int *BN, int *BK); + +void calc_nthr_nocopy_avx(int m, int n, int k, + int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN, + int *BK); + +void partition_unit_diff( + int ithr, int nthr, int n, int *t_offset, int *t_block); +}; + +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp new file mode 100644 index 000000000000..d7be43e39228 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp @@ -0,0 +1,2131 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "ref_gemm_f32.hpp" +#include "gemm_utils_f32.hpp" +#include "jit_avx512_common_gemm_f32.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define CACHE_LINE_SIZE 64 + +#define STACKSIZE get_size_of_abi_save_regs() +#ifdef _WIN32 +#define STACK_K_CAPACITY 32 +#else +#define STACK_K_CAPACITY 2048 +#endif +#define SIZE 4 +#define OFFSET 128 +#define BASE_SHIFT 2 +#define SECOND_FETCH unroll_n +#define UNROLL_M 48 +#define UNROLL_N 8 + +namespace avx512_common_gemm_f32 { +using namespace gemm_utils; + +struct xbyak_gemm : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm) + + xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false, + void *code_ptr = nullptr, + size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + { + using namespace Xbyak; + + enum { ver_avx512_core, ver_avx512_mic } ver = + mayiuse(avx512_core) ? ver_avx512_core : ver_avx512_mic; + + bool isBeta0 = (beta == 0.0); + bool isBetaN = (!isBeta0 && beta != 1.0); + + // various definitions for convenience + auto ARG_M = abi_param1; + auto ARG_N = abi_param2; + auto K = abi_param3; + auto ARG_ALPHA = abi_param4; +#ifdef _WIN32 + auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE]; + auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE]; + const auto stackOffset = OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE; + auto A = rsi; + auto LDA = rdi; +#else + auto ARG_A = r8; + auto ARG_LDA = r9; + const auto stackOffset = STACKSIZE; + auto A = ARG_A; + auto LDA = ARG_LDA; +#endif + auto ARG_B = ptr[rsp + 8 + stackOffset]; + auto ARG_LDB = ptr[rsp + 16 + stackOffset]; + auto ARG_BETA = ptr[rsp + 24 + stackOffset]; + auto ARG_C = ptr[rsp + 32 + stackOffset]; + auto ARG_LDC = ptr[rsp + 40 + stackOffset]; + auto ARG_BIAS = ptr[rsp + 48 + stackOffset]; + auto ARG_WS = ptr[rsp + 56 + stackOffset]; + + auto B = r11; + auto LDB = rbx; + auto LDC = r13; + auto LL = rax; + auto AO1 = abi_param2; + auto BO1 = abi_param4; + auto BO2 = rbp; + auto CO1 = r14; + auto CO2 = r15; + auto LDB3 = r10; + auto LDA4 = abi_param1; + auto AA = r12; + auto BIAS1 = abi_param1; + + auto M = qword[rsp + 0]; + auto N = qword[rsp + 8]; + auto FLAG = qword[rsp + 16]; + auto I = qword[rsp + 24]; + auto C = qword[rsp + 32]; + auto BIAS = qword[rsp + 40]; + auto ALPHA = qword[rsp + 48]; + auto BETA = qword[rsp + 64]; + auto ORIG_A = qword[rsp + 80]; + auto ORIG_SP = qword[rsp + 120]; + + auto ZSTRIDE = zmm4; + auto VALPHA = zmm6; + auto VBETA = zmm7; + auto VBIAS1 = zmm1; + auto VBIAS2 = zmm2; + auto VBIAS3 = zmm3; + + auto PREFETCHSIZEA = ver == ver_avx512_core ? 48 : 80; + auto PREFETCHSIZEB = 16; + + Zmm regs[] = { zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15, + zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23, zmm24, + zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31 }; + + // Function for packing if needed + auto do_pack = [&](int unroll_m) { + Label pack2, pack3, pack4, pack10; + + mov(BO1, A); + lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); + mov(LL, K); + sar(LL, 2); + jle(pack3, T_NEAR); + align(16); + + L(pack2); + if (!isTransA) { + for (int i = 0; i < 4; i++) { + vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); + if (unroll_m > 16) + vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]); + if (unroll_m > 32) + vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]); + add(BO1, LDA); + + vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE] + | k1, + zmm0); + if (unroll_m > 16) + vmovups(ptr[AO1 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE] + | k2, + zmm1); + if (unroll_m > 32) + vmovups(ptr[AO1 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE] + | k3, + zmm2); + } + } else { + for (int i = 0; i < 4; i++) { + kmovw(k4, k1); + vgatherqps(ymm5 | k4, + ptr[BO1 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO1 + LDA * 8]); + kshiftrw(k4, k1, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + vshuff64x2(zmm0, zmm5, zmm6, 0x44); + + if (unroll_m > 16) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k2); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k2, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + vshuff64x2(zmm1, zmm5, zmm6, 0x44); + } + + if (unroll_m > 32) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k3); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k3, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + vshuff64x2(zmm2, zmm5, zmm6, 0x44); + } + + vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE], + zmm0 | k1); + if (unroll_m > 16) + vmovups(ptr[AO1 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE], + zmm1 | k2); + if (unroll_m > 32) + vmovups(ptr[AO1 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE], + zmm2 | k3); + } + add(BO1, 4 * SIZE); + } + add(AO1, unroll_m * 4 * SIZE); + + sub(LL, 1); + jg(pack2, T_NEAR); + align(16); + + L(pack3); + mov(LL, K); + and_(LL, 3); + jle(pack10, T_NEAR); + align(16); + + L(pack4); + if (!isTransA) { + vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); + if (unroll_m > 16) + vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]); + if (unroll_m > 32) + vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]); + add(BO1, LDA); + } else { + kmovw(k4, k1); + vgatherqps(ymm5 | k4, ptr[BO1 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO1 + LDA * 8]); + kshiftrw(k4, k1, 8); + vgatherqps(ymm6 | k4, ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + vshuff64x2(zmm0, zmm5, zmm6, 0x44); + + if (unroll_m > 16) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k2); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k2, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + vshuff64x2(zmm1, zmm5, zmm6, 0x44); + } + + if (unroll_m > 32) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k3); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k3, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + vshuff64x2(zmm2, zmm5, zmm6, 0x44); + } + add(BO1, SIZE); + } + + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], + zmm0 | k1); + if (unroll_m > 16) + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 16 - OFFSET) * SIZE], + zmm1 | k2); + if (unroll_m > 32) + vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 16 - OFFSET) * SIZE], + zmm2 | k3); + + add(AO1, unroll_m * SIZE); + sub(LL, 1); + jg(pack4, T_NEAR); + align(16); + + L(pack10); + }; + + // Function to update C, covering masking and other considerations + auto update = [&](Zmm reg, bool useCO1, int offset, int mask, + bool useScale = false) { + vmulps(reg, reg, VALPHA); + if (!isBeta0) { + if (!useScale) { + switch (mask) { + case 0: + if (useCO1) + vmovups(zmm0, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0, ptr[CO2 + offset * SIZE]); + break; + case 1: + if (useCO1) + vmovups(zmm0 | k1 | T_z, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0 | k1 | T_z, ptr[CO2 + offset * SIZE]); + break; + case 2: + if (useCO1) + vmovups(zmm0 | k2 | T_z, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0 | k2 | T_z, ptr[CO2 + offset * SIZE]); + break; + case 3: + if (useCO1) + vmovups(zmm0 | k3 | T_z, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0 | k3 | T_z, ptr[CO2 + offset * SIZE]); + break; + } + } else { + switch (mask) { + case 0: + if (useCO1) + vmovups(zmm0, ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0, ptr[CO2 + LDC + offset * SIZE]); + break; + case 1: + if (useCO1) + vmovups(zmm0 | k1 | T_z, + ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0 | k1 | T_z, + ptr[CO2 + LDC + offset * SIZE]); + break; + case 2: + if (useCO1) + vmovups(zmm0 | k2 | T_z, + ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0 | k2 | T_z, + ptr[CO2 + LDC + offset * SIZE]); + break; + case 3: + if (useCO1) + vmovups(zmm0 | k3 | T_z, + ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0 | k3 | T_z, + ptr[CO2 + LDC + offset * SIZE]); + break; + } + } + if (!isBetaN) { + vaddps(zmm0, reg, zmm0); + } else { + vfmadd132ps(zmm0, reg, VBETA); + } + if (!useScale) { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0 | k1); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0 | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0 | k2); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0 | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0 | k3); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0 | k3); + break; + } + } else { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k1); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k2); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k3); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k3); + break; + } + } + } else { + if (!useScale) { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg); + else + vmovups(ptr[CO2 + offset * SIZE], reg); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg | k1); + else + vmovups(ptr[CO2 + offset * SIZE], reg | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg | k2); + else + vmovups(ptr[CO2 + offset * SIZE], reg | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg | k3); + else + vmovups(ptr[CO2 + offset * SIZE], reg | k3); + break; + } + } else { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k1); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k2); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k3); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k3); + break; + } + } + } + vpxorq(reg, reg, reg); + }; + + // Loop with unroll_n - 2 FMAs; called by innerkernel + auto fmaloop = [&](int unroll_m, int unroll_n, int iteration) { + for (int i = 2; i < unroll_n; i++) { + if (ver == ver_avx512_core) { + if (!isTransB) { + switch (i) { + case 2: + vbroadcastss( + zmm3, + ptr[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 3: + vbroadcastss( + zmm3, + ptr[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + case 4: + vbroadcastss(zmm3, + ptr[BO2 + (iteration - OFFSET) * SIZE]); + break; + case 5: + vbroadcastss( + zmm3, + ptr[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + break; + case 6: + vbroadcastss( + zmm3, + ptr[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 7: + vbroadcastss( + zmm3, + ptr[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + } + } else { + vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); + } + vfmadd231ps(regs[i], zmm3, zmm0); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm3, zmm1); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm3, zmm2); + } else { + if (!isTransB) { + switch (i) { + case 2: + vfmadd231ps(regs[i], zmm0, + zword_b[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 3: + vfmadd231ps(regs[i], zmm0, + zword_b[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + case 4: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + (iteration - OFFSET) * SIZE]); + break; + case 5: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + break; + case 6: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 7: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + } + } else { + vfmadd231ps( + regs[i], zmm0, zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO1 + (i - OFFSET) * SIZE]); + } + } + } + }; + + // Innerkernel; called by kernel + auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect, + bool isCopy, bool doCPrefetch, bool isUnmasked = true) { + for (int i = 0; i < 8; i++) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + i * unroll_m + 0 * 16 - OFFSET) + * SIZE]); + if (unroll_m >= 32) + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + i * unroll_m + 1 * 16 - OFFSET) + * SIZE]); + if (unroll_m >= 48) + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + i * unroll_m + 2 * 16 - OFFSET) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4 + (16 * 0 * SIZE)]); + if (unroll_m >= 32) + prefetcht0(ptr[AO1 + LDA4 + (16 * 1 * SIZE)]); + if (unroll_m >= 48) + prefetcht0(ptr[AO1 + LDA4 + (16 * 2 * SIZE)]); + } + + if (!isDirect) { + if (i != 0) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * i + 1 * 16 + - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * i + 1 * 16 + - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * i + 2 * 16 + - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * i + 2 * 16 + - OFFSET) + * SIZE]); + } + } + } + } else { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (ver == ver_avx512_core) { + if (!isTransB) { + vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + vfmadd231ps(regs[0], zmm3, zmm0); + if (unroll_m >= 32) + vfmadd231ps(regs[0 + 8], zmm3, zmm1); + if (unroll_m >= 48) + vfmadd231ps(regs[0 + 16], zmm3, zmm2); + } else { + if (!isTransB) { + vfmadd231ps(regs[0], zmm0, + zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[0 + 8], zmm1, + zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[0 + 16], zmm2, + zword_b[BO1 + (i - OFFSET) * SIZE]); + } else { + vfmadd231ps(regs[0], zmm0, + zword_b[BO1 + (0 - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[0 + 8], zmm1, + zword_b[BO1 + (0 - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[0 + 16], zmm2, + zword_b[BO1 + (0 - OFFSET) * SIZE]); + } + } + + if (unroll_n >= i + 1) { + if (!isTransB) { + switch (i) { + case 0: + prefetcht0( + ptr[BO1 + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 1: + prefetcht0(ptr[BO1 + LDB + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 2: + prefetcht0(ptr[BO1 + LDB * 2 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 3: + prefetcht0(ptr[BO1 + LDB3 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 4: + prefetcht0( + ptr[BO2 + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 5: + prefetcht0(ptr[BO2 + LDB + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 6: + prefetcht0(ptr[BO2 + LDB * 2 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 7: + prefetcht0(ptr[BO2 + LDB3 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + } + } + } + + if (unroll_n >= 2) { + if (ver == ver_avx512_core) { + if (!isTransB) { + vbroadcastss(zmm3, + ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(zmm3, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + vfmadd231ps(regs[1], zmm3, zmm0); + if (unroll_m >= 32) + vfmadd231ps(regs[1 + 8], zmm3, zmm1); + if (unroll_m >= 48) + vfmadd231ps(regs[1 + 16], zmm3, zmm2); + } else { + if (!isTransB) { + vfmadd231ps(regs[1], zmm0, + zword_b[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[1 + 8], zmm1, + zword_b[BO1 + LDB * 1 + + (i - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[1 + 16], zmm2, + zword_b[BO1 + LDB * 1 + + (i - OFFSET) * SIZE]); + } else { + vfmadd231ps(regs[1], zmm0, + zword_b[BO1 + (1 - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[1 + 8], zmm1, + zword_b[BO1 + (1 - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[1 + 16], zmm2, + zword_b[BO1 + (1 - OFFSET) * SIZE]); + } + } + } + + if (isCopy) { + if (isUnmasked || unroll_m > 16) { + vmovups(ptr[LDA4 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE], + zmm0); + } else { + vmovups(ptr[LDA4 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE], + zmm0 | k1); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE], + zmm1); + } else { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE], + zmm1 | k2); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(ptr[LDA4 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE], + zmm2); + } else { + vmovups(ptr[LDA4 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE], + zmm2 | k3); + } + } + if (i == 7) + sub(LDA4, -unroll_m * 8 * SIZE); + } + fmaloop(unroll_m, unroll_n, i); + + if (i == 1) { + if (doCPrefetch) { + if (ver == ver_avx512_core) + prefetchw(ptr[CO2 + 0 * 16 * SIZE]); + else + prefetcht0(ptr[CO2 + 0 * 16 * SIZE]); + } + } + if (i == 3) { + if (doCPrefetch && unroll_m >= 32) { + if (ver == ver_avx512_core) + prefetchw(ptr[CO2 + 1 * 16 * SIZE]); + else + prefetcht0(ptr[CO2 + 1 * 16 * SIZE]); + } + if (!isTransA) { + if (ver == ver_avx512_core) + prefetcht0(ptr[AA + 16 * 0 * SIZE]); + else + prefetcht2(ptr[AA + 16 * 0 * SIZE]); + } + } + if (i == 5) { + if (doCPrefetch) { + if (unroll_m >= 48) { + if (ver == ver_avx512_core) + prefetchw(ptr[CO2 + 2 * 16 * SIZE]); + else + prefetcht0(ptr[CO2 + 2 * 16 * SIZE]); + } + add(CO2, LDC); + } + if (!isTransA) { + if (unroll_m >= 32) { + if (ver == ver_avx512_core) + prefetcht0(ptr[AA + 16 * 1 * SIZE]); + else + prefetcht2(ptr[AA + 16 * 1 * SIZE]); + } + } + } + + if (isTransB) { + prefetcht0(ptr[BO1 + BO2]); + add(BO1, LDB); + } + } // end of for loop + + if (!isTransB) { + sub(BO1, -8 * SIZE); + if (unroll_n >= 4) + sub(BO2, -8 * SIZE); + } + if (!isTransA) { + if (unroll_m >= 48) { + if (ver == ver_avx512_core) + prefetcht0(ptr[AA + 16 * 2 * SIZE]); + else + prefetcht2(ptr[AA + 16 * 2 * SIZE]); + } + lea(AA, ptr[AA + LDA]); + } + + if (!isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * 8 + 1 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * 8 + 1 * 16 - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * 8 + 2 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * 8 + 2 * 16 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * 8 * SIZE); + } + + sub(LL, 1); + }; + + // Main kernel; does prefetching and calls innerkernel + // After calculating results in registers, writes back to C matrix by + // calling update + auto kernel = [&](int unroll_m, int unroll_n, bool isDirect, + bool isCopy, bool isUnmasked = true) { + if (!isDirect) { + lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); + } else { + mov(AO1, A); + } + + if (isCopy) { + lea(LDA4, ptr[rsp + 128 + OFFSET * SIZE]); + } else { + auto step = ver == ver_avx512_core ? 2 : 4; + lea(LDA4, ptr[LDA * step + (16 - 1 - OFFSET) * SIZE]); + } + + if (isTransB) { + lea(BO2, ptr[LDB * 4 + (16 / 2 - 1 - OFFSET) * SIZE]); + } + + if (!isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE]); + } + } + } + + Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18; + + mov(LL, K); + sar(LL, 3); + sub(LL, SECOND_FETCH); + jle(kernel13, T_NEAR); + align(16); + + L(kernel12); + innerkernel( + unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked); + jg(kernel12, T_NEAR); + align(16); + + L(kernel13); + lea(CO2, ptr[CO1 + (16 - 1) * SIZE]); + add(LL, unroll_n); + jle(kernel15, T_NEAR); + align(16); + + L(kernel14); + innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked); + jg(kernel14, T_NEAR); + align(16); + + L(kernel15); + mov(LL, K); + and_(LL, 7); + jle(kernel18, T_NEAR); + align(16); + + L(kernel16); + if (isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + for (int i = 0; i < unroll_n; i++) { + if (!isTransB) { + switch (i) { + case 0: + vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]); + break; + case 1: + vbroadcastss( + zmm3, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); + break; + case 2: + vbroadcastss( + zmm3, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); + break; + case 3: + vbroadcastss( + zmm3, ptr[BO1 + LDB3 + (0 - OFFSET) * SIZE]); + break; + case 4: + vbroadcastss(zmm3, ptr[BO2 + (0 - OFFSET) * SIZE]); + break; + case 5: + vbroadcastss( + zmm3, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); + break; + case 6: + vbroadcastss( + zmm3, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); + break; + case 7: + vbroadcastss( + zmm3, ptr[BO2 + LDB3 + (0 - OFFSET) * SIZE]); + break; + } + } else { + vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); + } + vfmadd231ps(regs[i], zmm3, zmm0); + if (unroll_m >= 32) { + vfmadd231ps(regs[i + 8], zmm3, zmm1); + } + if (unroll_m >= 48) { + vfmadd231ps(regs[i + 16], zmm3, zmm2); + } + } + + if (isCopy) { + if (isUnmasked || unroll_m > 16) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], + zmm0); + } else { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], + zmm0 | k1); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE], + zmm1); + } else { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE], + zmm1 | k2); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE], + zmm2); + } else { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE], + zmm2 | k3); + } + } + sub(LDA4, -unroll_m * SIZE); + } + + if (!isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * 1 + 1 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * 1 + 1 * 16 - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * 1 + 2 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * 1 + 2 * 16 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * SIZE); + } + + if (!isTransB) { + sub(BO1, -SIZE); + if (unroll_n >= 4) { + sub(BO2, -SIZE); + } + } else { + add(BO1, LDB); + } + + sub(LL, 1); + jg(kernel16, T_NEAR); + align(16); + + L(kernel18); + vbroadcastss(VALPHA, ALPHA); + + if (isBetaN) { + vbroadcastss(VBETA, BETA); + } + + // Write back the results; all beta cases need to be handled + if (hasBias) { + mov(BIAS1, BIAS); + if (isUnmasked || unroll_m > 16) + vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]); + else + vmovups(VBIAS1 | k1 | T_z, ptr[BIAS1 + 0 * SIZE]); + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) + vmovups(VBIAS2, ptr[BIAS1 + 16 * SIZE]); + else + vmovups(VBIAS2 | k2 | T_z, ptr[BIAS1 + 16 * SIZE]); + } + if (unroll_m >= 48) { + if (isUnmasked) + vmovups(VBIAS3, ptr[BIAS1 + 32 * SIZE]); + else + vmovups(VBIAS3 | k3 | T_z, ptr[BIAS1 + 32 * SIZE]); + } + } + + for (int i = 0; i < unroll_n; i++) { + bool useScale = i % 2 != 0; + bool useCO1 = i < 2; + if (i == 2) + lea(CO2, ptr[CO1 + LDC * 2]); + if (i == 4 || i == 6) + lea(CO2, ptr[CO2 + LDC * 2]); + if (hasBias) + vaddps(regs[i], VBIAS1, regs[i]); + if (isUnmasked || unroll_m > 16) { + update(regs[i], useCO1, 0, 0, useScale); + } else { + update(regs[i], useCO1, 0, 1, useScale); + } + if (unroll_m >= 32) { + if (hasBias) + vaddps(regs[i + 8], VBIAS2, regs[i + 8]); + if (isUnmasked || unroll_m > 32) { + update(regs[i + 8], useCO1, 16, 0, useScale); + } else { + update(regs[i + 8], useCO1, 16, 2, useScale); + } + } + if (unroll_m >= 48) { + if (hasBias) + vaddps(regs[i + 16], VBIAS3, regs[i + 16]); + if (isUnmasked) { + update(regs[i + 16], useCO1, 32, 0, useScale); + } else { + update(regs[i + 16], useCO1, 32, 3, useScale); + } + } + } + + switch (unroll_n) { + case 1: add(CO1, LDC); break; + case 2: lea(CO1, ptr[CO1 + LDC * 2]); break; + case 3: lea(CO1, ptr[CO2 + LDC * 1]); break; + case 4: lea(CO1, ptr[CO2 + LDC * 2]); break; + case 5: lea(CO1, ptr[CO2 + LDC * 1]); break; + case 6: lea(CO1, ptr[CO2 + LDC * 2]); break; + case 7: lea(CO1, ptr[CO2 + LDC * 1]); break; + case 8: lea(CO1, ptr[CO2 + LDC * 2]); break; + } + + // Compute next address of B + if (!isTransB) { + lea(rax, ptr[K * SIZE]); + switch (unroll_n) { + case 1: + add(BO1, LDB); + add(BO2, LDB); + break; + case 2: + lea(BO1, ptr[BO1 + LDB * 2]); + lea(BO2, ptr[BO2 + LDB * 2]); + break; + case 3: + lea(BO1, ptr[BO1 + LDB3]); + lea(BO2, ptr[BO2 + LDB3]); + break; + case 4: + lea(BO1, ptr[BO1 + LDB * 4]); + lea(BO2, ptr[BO2 + LDB * 4]); + break; + case 5: + lea(BO1, ptr[BO1 + LDB * 4]); + add(BO1, LDB); + lea(BO2, ptr[BO2 + LDB * 4]); + add(BO2, LDB); + break; + case 6: + lea(BO1, ptr[BO1 + LDB3 * 2]); + lea(BO2, ptr[BO2 + LDB3 * 2]); + break; + case 7: + lea(BO1, ptr[BO1 + LDB * 8]); + sub(BO1, LDB); + lea(BO2, ptr[BO2 + LDB * 8]); + sub(BO2, LDB); + break; + case 8: + lea(BO1, ptr[BO1 + LDB * 8]); + lea(BO2, ptr[BO2 + LDB * 8]); + break; + } + sub(BO1, rax); + sub(BO2, rax); + } else { + mov(rax, LDB); + imul(rax, K); + sub(BO1, rax); + add(BO1, unroll_n * SIZE); + } + }; + + // High-level subroutine; does packing if needed, then splits C matrix. + // Operates on chunks of 48 rows, 8 columns at a time (handling tail + // cases appropriately by doing 32 or 16 rows, and/or with masking, + // and/or fewer columns). + auto subloop = [&](int unroll_m) { + Label l_subloop_20x[8], l_subloop_mask_20x[8]; + Label l_subloop_30x[8], l_subloop_mask_30x[8]; + + Label subloop11, subloop11mask; + Label subloop30, subloop30mask; + Label subloop31, subloop31mask; + Label subloop96; + Label subloop98, subloop98mask; + Label subloop99; + + // Create mask + mov(BO1, rcx); + mov(rcx, M); + sub(rcx, unroll_m - 16); + mov(CO1, 16); + cmp(rcx, 16); + + cmovg(rcx, CO1); + mov(rax, 1); + sal(rax, cl); + sub(rax, 1); + mov(rcx, 0xffff); + + if (unroll_m == 16) { + kmovw(k1, eax); + } else if (unroll_m == 32) { + kmovw(k1, ecx); + kmovw(k2, eax); + } else { + kmovw(k1, ecx); + kmovw(k2, ecx); + kmovw(k3, eax); + } + mov(rcx, BO1); + + and_(rax, 0xffff); + cmp(rax, 0xffff); + jne(subloop96, T_NEAR); + + if (isTransA) { + do_pack(unroll_m); + } + + mov(CO1, C); + add(C, unroll_m * SIZE); + + mov(BO1, B); + if (!isTransB) { + lea(BO2, ptr[B + LDB * 4]); + } + + if (!isTransA) { + lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); + cmp(M, UNROLL_M); + jg(subloop98, T_NEAR); + + mov(AA, ORIG_A); + lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); + L(subloop98); + } + + mov(LL, N); + mov(I, LL); + if (!isTransA) { + // If N is too small, skip copy operation + cmp(LL, UNROLL_N * 3); + jle(subloop30, T_NEAR); + + // If A is not aligned to cache line + cmp(FLAG, 0); + je(subloop30, T_NEAR); + } else { + cmp(LL, UNROLL_N); + jl(l_subloop_20x[1], T_NEAR); + } + align(16); + + if (!isTransA) { + kernel(unroll_m, UNROLL_N, true, true); + } else { + kernel(unroll_m, UNROLL_N, false, false); + } + + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jl(l_subloop_20x[1], T_NEAR); + align(16); + + L(subloop11); + kernel(unroll_m, UNROLL_N, false, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop11, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_20x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_20x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, false, false); + jmp(subloop99, T_NEAR); + align(16); + } + + if (!isTransA) { + L(subloop30); + cmp(I, UNROLL_N); + jl(l_subloop_30x[1], T_NEAR); + align(16); + + L(subloop31); + kernel(unroll_m, UNROLL_N, true, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop31, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_30x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_30x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, true, false); + if (i < 7) + jmp(subloop99, T_NEAR); + align(16); + } + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop96); + if (isTransA) { + do_pack(unroll_m); + } + + mov(CO1, C); + add(C, unroll_m * SIZE); + mov(BO1, B); + if (!isTransB) { + lea(BO2, ptr[B + LDB * 4]); + } + + if (!isTransA) { + lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); + cmp(M, UNROLL_M); + jg(subloop98mask, T_NEAR); + mov(AA, ORIG_A); + lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); + L(subloop98mask); + } + + mov(LL, N); + mov(I, LL); + if (!isTransA) { + // If N is too small, skip copy operation + cmp(LL, UNROLL_N * 3); + jle(subloop30mask, T_NEAR); + + // If A is not aligned to cache line + cmp(FLAG, 0); + je(subloop30mask, T_NEAR); + } else { + cmp(LL, UNROLL_N); + jl(l_subloop_mask_20x[1], T_NEAR); + } + align(16); + + if (!isTransA) { + kernel(unroll_m, UNROLL_N, true, true, false); + } else { + kernel(unroll_m, UNROLL_N, false, false, false); + } + + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jl(l_subloop_mask_20x[1], T_NEAR); + align(16); + + L(subloop11mask); + kernel(unroll_m, UNROLL_N, false, false, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop11mask, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_mask_20x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_mask_20x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, false, false, false); + jmp(subloop99, T_NEAR); + align(16); + } + + if (!isTransA) { + L(subloop30mask); + cmp(I, UNROLL_N); + jl(l_subloop_mask_30x[1], T_NEAR); + align(16); + + L(subloop31mask); + kernel(unroll_m, UNROLL_N, true, false, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop31mask, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_mask_30x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_mask_30x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, true, false, false); + if (i < 7) + jmp(subloop99, T_NEAR); + align(16); + } + } + + L(subloop99); + // Compute address for A + if (!isTransA) { + add(A, unroll_m * SIZE); + } else { + mov(rax, LDA); + imul(rax, rax, unroll_m); + add(A, rax); + } + + // Compute next address of BIAS + if (hasBias) { + add(BIAS, unroll_m * SIZE); + } + }; + + preamble(); + + Label buffer_in_ws, buffer_allocated; + + // Get the registers + mov(B, ARG_B); + mov(LDB, ARG_LDB); + mov(r15, ARG_BETA); + mov(r12, ARG_C); + if (hasBias) + mov(r10, ARG_BIAS); + mov(LDC, ARG_LDC); + mov(rbp, rsp); + + vmovss(xmm0, ptr[ARG_ALPHA]); + vmovss(xmm1, ptr[r15]); + +#if _WIN32 + mov(A, ARG_A); + mov(LDA, ARG_LDA); +#endif + + cmp(K, STACK_K_CAPACITY); + jg(buffer_in_ws, T_NEAR); + + // Create buffer and align to 4kB page + lea(rax, ptr[K * SIZE]); + imul(rax, rax, 0x30); + add(rax, 256); + sub(rsp, rax); + and_(rsp, -PAGE_4K); + jmp(buffer_allocated, T_NEAR); + + L(buffer_in_ws); + mov(rsp, ARG_WS); + + L(buffer_allocated); + + mov(ORIG_SP, rbp); + mov(M, ARG_M); + mov(N, ARG_N); + mov(C, r12); + if (hasBias) + mov(BIAS, r10); + vmovss(ALPHA, xmm0); + vmovss(BETA, xmm1); + sub(A, -OFFSET * SIZE); + sub(B, -OFFSET * SIZE); + mov(ORIG_A, A); + sal(LDA, BASE_SHIFT); + sal(LDB, BASE_SHIFT); + sal(LDC, BASE_SHIFT); + lea(LDB3, ptr[LDB + LDB * 2]); + + if (isTransA) { + vpbroadcastq(zmm2, LDA); + vpxorq(ZSTRIDE, ZSTRIDE, ZSTRIDE); + mov(rax, -2); + kmovw(k4, eax); + + for (int i = 0; i < 6; i++) { + vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2); + kshiftlw(k4, k4, 1); + } + vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2); + } + + // Check A alignment and leading dimension; take copy-based path as + // needed + mov(rax, LDA); + or_(rax, A); + and_(rax, ver == ver_avx512_core ? 0x07 : 0x3f); + mov(FLAG, rax); + + for (int i = 8; i < 16; i++) { + for (int j = 0; j < 3; j++) { + vpxorq(Zmm(i + 8 * j), Zmm(i + 8 * j), Zmm(i + 8 * j)); + } + } + + Label main0, main1, main2, main999; + + cmp(M, 32); + jle(main0, T_NEAR); + align(16); + + L(main1); + subloop(48); + sub(M, UNROLL_M); + cmp(M, 32); + jg(main1, T_NEAR); + align(16); + + L(main0); + cmp(M, 16); + jle(main2, T_NEAR); + + subloop(32); + jmp(main999, T_NEAR); + align(16); + + L(main2); + cmp(M, 0); + jle(main999, T_NEAR); + subloop(16); + align(16); + + L(main999); + // Restore original stack + mov(rsp, ORIG_SP); + + vzeroupper(); + postamble(); + + ker_ = this->getCode(); + } + + typedef void (*ker_t)(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws); + + void operator()(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws) const + { + ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws); + } + +private: + ker_t ker_; +}; + +const xbyak_gemm *get_xbyak_gemm( + bool isTransA, bool isTransB, float beta, bool hasBias) { + auto beta_idx = [](float beta) { + return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2); + }; + + // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)] + static xbyak_gemm *kernel_table[2][2][2][3]; + static std::once_flag initialized; + std::call_once(initialized, [=]{ + for (bool isTransA: {false, true}) + for (bool isTransB: {false, true}) + for (bool hasBias: {false, true}) + for (float beta: {0.0f, 1.0f, 2.0f}) { + // nocopy sgemm with bias for beta != 0.0 is not supported + if (hasBias && beta != 0.0) + continue; + kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] = + new xbyak_gemm(isTransA, isTransB, beta, hasBias); + } + }); + + return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)]; +} + +void sgemm_nocopy_driver(const char *transa, + const char *transb, int m, int n, int k, const float *alpha, + const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta, + float *c, dim_t ldc, const float *bias, float *ws) +{ + bool isTransA = (*transa == 'T' || *transa == 't'); + bool isTransB = (*transb == 'T' || *transb == 't'); + + int Bm, sizeM, Bn, sizeN, Bk, sizeK; + + int i, j; + + if ((m <= 0) || (n <= 0)) + return; + + if ((k <= 0) || (alpha[0] == 0.)) { + + if (beta[0] == 0.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] = 0.0; + } else if (beta[0] != 1.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] *= beta[0]; + } + + return; + } + + assert(IMPLICATION(bias != nullptr, *beta == 0.0)); + + // XXX: this happens on every thread... + bool hasBias = (bias != nullptr); + auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias); + auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false); + auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false); + assert(ker_bn && ker_b1 && ker_b0); + + int BM = 4032, BN, BK; + if (mayiuse(avx512_core)) { + BN = isTransA ? 384 : 64; + BK = 384; + } else { + BN = isTransA ? 96 : 64; + BK = isTransB ? 96 : 192; + if (!isTransA && !isTransB) + BK = 128; + } + const float *curA, *curB, *curBias = nullptr; + float *curC; + + for (Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK >= BK * 2) + sizeK = BK; + else { + if (sizeK > BK) + sizeK = (sizeK + 1) / 2; + } + + for (Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM >= BM * 2) + sizeM = BM; + else { + if (sizeM > BM + BM / 2) + sizeM = (sizeM + 1) / 2; + } + + for (Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN >= BN * 2) + sizeN = BN; + else { + if (sizeN > BN + BN / 2) + sizeN = (sizeN + 1) / 2; + } + + if (!isTransA) { + curA = a + Bm + Bk * lda; + } else { + curA = a + Bk + Bm * lda; + } + if (!isTransB) { + curB = b + Bk + Bn * ldb; + } else { + curB = b + Bn + Bk * ldb; + } + curC = c + Bm + (size_t)Bn * ldc; + if (bias != nullptr) { + if (Bk == 0) { + curBias = bias + Bm; + } else { + curBias = nullptr; + } + } + if (Bk == 0) { + if (*beta == 0.0 && bias == nullptr) + (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + else + (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } else { + (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } + } + } + } +} + +} + +mkldnn_status_t jit_avx512_common_gemm_f32( + const char *transa, const char *transb, + const int *p_m, const int *p_n, const int *p_k, const float *p_alpha, + const float *A, const int *p_lda, const float *B, const int *p_ldb, + const float *p_beta, float *C, const int *p_ldc, const float *bias) +{ + using namespace mkldnn::impl::utils; + using namespace avx512_common_gemm_f32; + using namespace gemm_utils; + + if (*p_beta != 0 && bias) + return ref_gemm(transa, transb, p_m, p_n, p_k, + p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias); + + int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + + int m = *p_m; + int n = *p_n; + int k = *p_k; + dim_t lda = *p_lda; + dim_t ldb = *p_ldb; + dim_t ldc = *p_ldc; + float beta = *p_beta; + int MB, NB, KB; + + int nthr_m, nthr_n, nthr_k, nthr_mn; + + // Determine threading partitioning + calc_nthr_nocopy_avx512_common( + m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); + + // May not happen, but just in case + if (nthr < nthr_m * nthr_n * nthr_k) + nthr = nthr_m * nthr_n * nthr_k; + + nthr_mn = nthr_m * nthr_n; + + unsigned char * ompstatus_ = nullptr; + unsigned char volatile *ompstatus = nullptr; + + float *c_buffers = nullptr; + float *ws_buffers = nullptr; + + if (nthr_k > 1) { + ompstatus_ = (unsigned char *) malloc( + nthr * CACHE_LINE_SIZE, + CACHE_LINE_SIZE); + ompstatus = (unsigned char volatile *) ompstatus_; + assert(ompstatus); + + for (int i = 0; i < nthr; i++) + ompstatus[i * CACHE_LINE_SIZE] = 0; + + c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB + * sizeof(float), PAGE_4K); + } + + const size_t ws_elems_per_thr = (size_t)k * 48 + 64; + const size_t ws_size_per_thr + = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); + if (k > STACK_K_CAPACITY) { + ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K); + } + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int k_from, k_to, myK; + int cbase, ibase; + const float *myA, *myB, *myBias = nullptr; + float *myC = C, myBeta; + float *ws = ws_buffers ? + ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0; + dim_t ld = ldc; + + int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k); + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + k_from = KB * (ithr_k); + k_to = KB * (ithr_k + 1); + if (k_to > k) + k_to = k; + myK = k_to - k_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + ibase = (ithr_m + nthr_m * ithr_n) * nthr_k; + + if ((myM > 0) && (myN > 0)) { + + if (*transa == 'N' || *transa == 'n') { + myA = &(A[m_from + k_from * lda]); + } else { + myA = &(A[k_from + m_from * lda]); + } + if (*transb == 'N' || *transb == 'n') { + myB = &(B[k_from + n_from * ldb]); + } else { + myB = &(B[n_from + k_from * ldb]); + } + if (ithr_k == 0) { + myC = &(C[m_from + n_from * ldc]); + myBeta = beta; + ld = ldc; + if (bias) + myBias = &(bias[m_from]); + } else { + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); + myBeta = 0.0; + ld = MB; + myBias = nullptr; + } + + sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA, + lda, myB, ldb, &myBeta, myC, ld, myBias, ws); + + if (nthr_k > 1 && !sum_later) + ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1; + } + + if (nthr_k > 1 && !sum_later) { + + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + /* need to wait until main thread finishes */ + while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) { + }; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) { + }; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + + + // handle C summation later + if (nthr_k > 1 && ompstatus[0] == 0) { + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int cbase; + float *myC = C; + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + if (nthr_k > 1) { + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + } + + free(c_buffers); + free(ompstatus_); + free(ws_buffers); + + return mkldnn_success; +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp new file mode 100644 index 000000000000..d581b7fd7164 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_GEMM_F32_HPP +#define JIT_AVX512_COMMON_GEMM_F32_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t jit_avx512_common_gemm_f32( + const char *transa, const char *transb, const int *M, + const int *N, const int *K, const float *alpha, const float *A, + const int *lda, const float *B, const int *ldb, const float *beta, + float *C, const int *ldc, const float *bias = nullptr); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp new file mode 100644 index 000000000000..60d4220837af --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp @@ -0,0 +1,2705 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "ref_gemm_f32.hpp" +#include "gemm_utils_f32.hpp" +#include "jit_avx_gemm_f32.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define CACHE_LINE_SIZE 64 + +#define STACKSIZE get_size_of_abi_save_regs() +#if _WIN32 +#define STACK_K_CAPACITY 128 +#else +#define STACK_K_CAPACITY 8192 +#endif +#define SIZE 4 +#define OFFSET 32 +#define BASE_SHIFT 2 +#define SECOND_FETCH 14 + +namespace avx_gemm_f32 { +using namespace gemm_utils; + +struct xbyak_gemm : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm) + + xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false, + void *code_ptr = nullptr, + size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + { + using namespace Xbyak; + + const bool is_avx2 = mayiuse(avx2); + assert(IMPLICATION(!is_avx2, mayiuse(avx))); + + const int UNROLL_M = is_avx2 ? 16 : 8; + const int UNROLL_N = 6; + + bool isBeta0 = (beta == 0.0); + bool isBetaN = (!isBeta0 && beta != 1.0); + + // various definitions for convenience + auto ARG_M = abi_param1; + auto ARG_N = abi_param2; + auto K = abi_param3; + auto ARG_ALPHA = abi_param4; +#ifdef _WIN32 + auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE]; + auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE]; + const auto stackOffset = OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE; + auto A = rsi; + auto LDA = rdi; +#else + auto ARG_A = r8; + auto ARG_LDA = r9; + const auto stackOffset = STACKSIZE; + auto A = ARG_A; + auto LDA = ARG_LDA; +#endif + auto ARG_B = ptr[rsp + 8 + stackOffset]; + auto ARG_LDB = ptr[rsp + 16 + stackOffset]; + auto ARG_BETA = ptr[rsp + 24 + stackOffset]; + auto ARG_C = ptr[rsp + 32 + stackOffset]; + auto ARG_LDC = ptr[rsp + 40 + stackOffset]; + auto ARG_BIAS = ptr[rsp + 48 + stackOffset]; + auto ARG_WS = ptr[rsp + 56 + stackOffset]; + + auto B = r11; + auto LDB = rbx; + auto LDC = r13; + auto LL = rax; + auto AO1 = abi_param2; + auto BO1 = abi_param4; + auto BO2 = rbp; + auto CO1 = r14; + auto CO2 = r15; + auto LDB3 = r10; + auto LDA4 = abi_param1; + auto AA = r12; + auto BIAS1 = abi_param1; + + auto M = qword[rsp + 0]; + auto N = qword[rsp + 8]; + auto FLAG = qword[rsp + 16]; + auto I = qword[rsp + 24]; + auto C = qword[rsp + 32]; + auto BIAS = qword[rsp + 40]; + auto ALPHA = qword[rsp + 48]; + auto BETA = qword[rsp + 64]; + auto ORIG_A = qword[rsp + 80]; + auto MASK = dword[rsp + 88]; + auto STRIDE = qword[rsp + 120]; + auto ORIG_SP = qword[rsp + 152]; + + auto VALPHA = ymm1; + auto VBETA = ymm2; + auto VMASK = ymm3; + auto VBIAS1 = ymm2; + auto VBIAS2 = ymm4; + + auto PREFETCHSIZEA = 128; + auto PREFETCHSIZEB = (!isTransB) ? -16 : 0; + + // Function for packing if needed + auto do_pack = [&]( + int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) { + Label pack2, pack3, pack4, pack10; + + int regIdx; + Reg64 reg; + + mov(BO1, A); + lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]); + + if (isTransA) { + lea(BO2, ptr[BO1 + LDA * 4]); + lea(CO1, ptr[LDA + LDA * 2]); + vmovupd(ymm7, STRIDE); + } + + mov(LL, K); + sar(LL, 2); + jle(pack3, T_NEAR); + align(16); + + L(pack2); + if (!isTransA) { + for (int i = 0; i < 4; i++) { + regIdx = (i % 2 == 0) ? 4 : 6; + if (isLoad1Unmasked) { + vmovups(Ymm(regIdx), + ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(Ymm(regIdx), VMASK, + ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m > 8) { + if (isLoad2Unmasked) { + vmovups(Ymm(regIdx + 1), + ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(Ymm(regIdx + 1), VMASK, + ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(BO1, LDA); + + vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], + Ymm(regIdx)); + if (unroll_m > 8) { + vmovups(ptr[AO1 + + (unroll_m * i + 1 * 8 - OFFSET) + * SIZE], + Ymm(regIdx + 1)); + } + } + + } else { + if (isLoad1Unmasked) { + for (int i = 0; i < 2; i++) { + reg = (i % 2 == 0) ? BO1 : BO2; + vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, + ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[reg + LDA * 2]); + vunpcklps(xmm4, xmm0, xmm1); + vunpckhps(xmm5, xmm0, xmm1); + vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, + ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(xmm6, xmm0, xmm1); + vunpckhps(xmm2, xmm0, xmm1); + + vunpcklpd(xmm0, xmm4, xmm6); + vunpckhpd(xmm1, xmm4, xmm6); + vmovups(ptr[AO1 + + (unroll_m * 0 + i * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 1 + i * 4 - OFFSET) + * SIZE], + xmm1); + vunpcklpd(xmm0, xmm5, xmm2); + vunpckhpd(xmm1, xmm5, xmm2); + vmovups(ptr[AO1 + + (unroll_m * 2 + i * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 3 + i * 4 - OFFSET) + * SIZE], + xmm1); + } + } else if (is_avx2) { + for (int i = 0; i < 2; i++) { + vmovaps(xmm4, xmm3); + vgatherqps(xmm0, + ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, + ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 0 * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 0 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO1 + LDA * 4]); + + for (int i = 0; i < 2; i++) { + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm0, + ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 1 * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 1 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO2 + LDA * 4]); + } else { + vxorps(xmm4, xmm4, xmm4); + lea(BO2, ptr[BO1 + LDA * 4]); + + auto el_cp = [&](int section, int ld_step) { + RegExp src_addr = section == 0 ? BO1 : BO2; + if (ld_step == 1 || ld_step == 2) + src_addr = src_addr + LDA * ld_step; + else if (ld_step == 3) + src_addr = src_addr + CO1; + src_addr = src_addr - OFFSET * SIZE; + + vmovups(Xmm(ld_step % 2), ptr[src_addr]); + RegExp dst_addr = AO1 + + (ld_step + section * 4 - OFFSET) * SIZE; + for (int off = 0; off < 4; ++off) + pextrd(ptr[dst_addr + unroll_m * off * SIZE], + Xmm(ld_step % 2), off); + }; + + Label l_end; + el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR); + el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR); + el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR); + el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR); + el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR); + el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR); + el_cp(1, 2); + L(l_end); + + lea(BO2, ptr[BO2 + LDA * 4]); + } + + if (unroll_m >= 16) { + assert(is_avx2); + if (isLoad2Unmasked) { + for (int i = 0; i < 2; i++) { + vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(xmm4, xmm0, xmm1); + vunpckhps(xmm5, xmm0, xmm1); + vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + if (i == 0) + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(xmm6, xmm0, xmm1); + vunpckhps(xmm2, xmm0, xmm1); + + vunpcklpd(xmm0, xmm4, xmm6); + vunpckhpd(xmm1, xmm4, xmm6); + vmovups(ptr[AO1 + + (unroll_m * 0 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 1 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm1); + vunpcklpd(xmm0, xmm5, xmm2); + vunpckhpd(xmm1, xmm5, xmm2); + vmovups(ptr[AO1 + + (unroll_m * 2 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 3 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm1); + } + } else { + for (int i = 0; i < 2; i++) { + vmovaps(xmm4, xmm3); + vgatherqps(xmm0, + ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 2 * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 2 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO2 + LDA * 4]); + + for (int i = 0; i < 2; i++) { + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm0, + ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 3 * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 3 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO2 + LDA * 4]); + } + } + add(BO1, (4 * SIZE)); + } + + add(AO1, unroll_m * 4 * SIZE); + sub(LL, 1); + jg(pack2, T_NEAR); + align(16); + + L(pack3); + mov(LL, K); + and_(LL, 3); + jle(pack10, T_NEAR); + align(16); + + L(pack4); + if (!isTransA) { + if (isLoad1Unmasked) { + vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m > 8) { + if (isLoad2Unmasked) { + vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm5, VMASK, + ptr[BO1 + (1 + 8 - OFFSET) * SIZE]); + } + } + add(BO1, LDA); + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], + ymm4); + if (unroll_m > 8) { + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE], + ymm5); + } + } else { + if (isLoad1Unmasked) { + for (int i = 0; i < 2; i++) { + reg = (i % 2 == 0) ? BO1 : BO2; + vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, + ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[reg + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE], + xmm1); + + for (int i = 0; i < 2; i++) { + vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, + ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE], + xmm1); + } else if (is_avx2) { + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + lea(BO2, ptr[BO1 + LDA * 4]); + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE], + xmm1); + + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + lea(BO2, ptr[BO2 + LDA * 4]); + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE], + xmm1); + } else { + vxorps(xmm4, xmm4, xmm4); + lea(BO2, ptr[BO1 + LDA * 4]); + + auto el_cp = [&](int section, int ld_step) { + RegExp src_addr = section == 0 ? BO1 : BO2; + if (ld_step == 1 || ld_step == 2) + src_addr = src_addr + LDA * ld_step; + else if (ld_step == 3) + src_addr = src_addr + CO1; + src_addr = src_addr - OFFSET * SIZE; + + vmovss(xmm1, ptr[src_addr]); + RegExp dst_addr = AO1 + + (ld_step + section * 4 - OFFSET) * SIZE; + movss(ptr[dst_addr], xmm1); + }; + + Label l_end; + el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR); + el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR); + el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR); + el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR); + el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR); + el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR); + el_cp(1, 2); + L(l_end); + + lea(BO2, ptr[BO2 + LDA * 4]); + } + + if (unroll_m >= 16) { + assert(is_avx2); + if (isLoad2Unmasked) { + for (int i = 0; i < 2; i++) { + vmovss(Xmm(i + 1), + ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + } else { + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + lea(BO2, ptr[BO2 + LDA * 4]); + } + vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE], + xmm1); + + if (isLoad2Unmasked) { + for (int i = 0; i < 2; i++) { + vmovss(Xmm(i + 1), + ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + } else { + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + } + vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE], + xmm1); + } + add(BO1, SIZE); + } + + add(AO1, unroll_m * SIZE); + sub(LL, 1); + jg(pack4, T_NEAR); + align(16); + + L(pack10); + }; + + // Fused multiply add; may become one or two instructions + auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2, + bool overWrite = false) { + if (useFma) { + if (is_avx2) { + vfmadd231ps(reg2, reg1, reg0); + } else { + assert(UNROLL_M == 8); + auto tent_vreg = overWrite ? reg1 : ymm1; + vmulps(tent_vreg, reg1, reg0); + vaddps(reg2, reg2, tent_vreg); + } + } else { + if (!overWrite) { + vmulps(ymm15, reg1, reg0); + vaddps(reg2, reg2, ymm15); + } else { + vmulps(reg1, reg1, reg0); + vaddps(reg2, reg2, reg1); + } + } + }; + + // Inner kernel with k=8 + auto innerkernel8 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, + Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, + Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, + Ymm reg23) { + + Ymm fmareg; + + if (!isDirect) { + prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + + for (int i = 0; i < 8; i++) { + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg00 : reg12; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg06 : reg18; + fma(useFma, ymm1, ymm2, fmareg); + } + if (i == 0) { + if (!isTransB) { + prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]); + } + } + if (unroll_n >= 2) { + if (!isTransB) { + if (i == 1) { + prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg01 : reg13; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg07 : reg19; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 8 - OFFSET) + * SIZE], + ymm1); + } + if (i == 7) { + sub(LDA4, -unroll_m * 8 * SIZE); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + if (i == 2) { + prefetcht0( + ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg02 : reg14; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg08 : reg20; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (i == 7) { + if (!isTransB) { + sub(BO1, -8 * SIZE); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + if (i == 3) { + prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg03 : reg15; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg09 : reg21; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + if (i == 4) { + prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg04 : reg16; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg10 : reg22; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + if (i == 5) { + prefetcht0( + ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg05 : reg17; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg11 : reg23; + fma(useFma, ymm1, ymm2, fmareg); + } + } + if (isTransB) { + prefetcht0(ptr[BO1 + BO2]); + add(BO1, LDB); + } + + if (i == 0) { + if (unroll_m >= 4) { + if (!isDirect) { + prefetcht0( + ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 1 || i == 2) { + if (unroll_m >= 8) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + (2 + 2 * i) * 8) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 3 || i == 4 || i == 5 || i == 6) { + if (unroll_m >= 16) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + (2 + 2 * i) * 8) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 7) { + if (!isTransB) { + if (unroll_n >= 4) { + sub(BO2, -8 * SIZE); + } + } + if (!isTransA) { + prefetcht2(ptr[AA]); + lea(AA, ptr[AA + LDA]); + } + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps( + ymm0, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } + } + } + } + + if (!isDirect) { + sub(AO1, -unroll_m * 8 * SIZE); + } + sub(LL, 1); + + }; + + // Inner kernel with k=4 + auto innerkernel4 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, + Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, + Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, + Ymm reg23) { + + Ymm fmareg; + + if (!isDirect) { + prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + + for (int i = 0; i < 4; i++) { + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg00 : reg12; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg06 : reg18; + fma(useFma, ymm1, ymm2, fmareg); + } + if (i == 0) { + if (!isTransB) { + prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]); + } + } + if (unroll_n >= 2) { + if (!isTransB) { + if (i == 1) { + prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg01 : reg13; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg07 : reg19; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 8 - OFFSET) + * SIZE], + ymm1); + } + if (i == 3) { + sub(LDA4, -unroll_m * 4 * SIZE); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + if (i == 2) { + prefetcht0( + ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg02 : reg14; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg08 : reg20; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (i == 7) { + if (!isTransB) { + sub(BO1, -8 * SIZE); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + if (i == 3) { + prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg03 : reg15; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg09 : reg21; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + if (i == 4) { + prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg04 : reg16; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg10 : reg22; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + if (i == 5) { + prefetcht0( + ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg05 : reg17; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg11 : reg23; + fma(useFma, ymm1, ymm2, fmareg); + } + } + if (isTransB) { + prefetcht0(ptr[BO1 + BO2]); + add(BO1, LDB); + } + + if (i == 0) { + if (unroll_m >= 4) { + if (!isDirect) { + prefetcht0( + ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 1 || i == 2) { + if (unroll_m >= 8) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + (2 + 2 * i) * 8) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 3) { + if (!isTransB) { + sub(BO1, -4 * SIZE); + if (unroll_n >= 4) { + sub(BO2, -4 * SIZE); + } + } + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps( + ymm0, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } + } + } + } + + if (!isDirect) { + sub(AO1, -unroll_m * 4 * SIZE); + } + + }; + + // Inner kernel with k=2 + auto innerkernel2 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, + Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, + Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, + Ymm reg23) { + + Ymm fmareg; + + for (int i = 0; i < 2; i++) { + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg00 : reg12; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg06 : reg18; + fma(useFma, ymm1, ymm2, fmareg); + } + if (unroll_n >= 2) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg01 : reg13; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg07 : reg19; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + if (i == 2) { + prefetcht0( + ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg02 : reg14; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg08 : reg20; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg03 : reg15; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg09 : reg21; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg04 : reg16; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg10 : reg22; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg05 : reg17; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg11 : reg23; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 1 * 8 - OFFSET) + * SIZE], + ymm1); + } + sub(LDA4, -unroll_m * SIZE); + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + + (unroll_m * 1 + 0 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + + (unroll_m * 1 + 0 * 8 - OFFSET) + * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, + ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * SIZE); + } + + if (!isTransB) { + sub(BO1, -SIZE); + if (unroll_n >= 4) { + sub(BO2, -SIZE); + } + } else { + add(BO1, LDB); + } + } + + }; + + // Inner kernel with k=1 + auto innerkernel1 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) { + + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg00); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg06); + } + + if (unroll_n >= 2) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg01); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg07); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg02); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg08); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg03); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg09); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg04); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg10); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg05); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg11); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE], + ymm1); + } + sub(LDA4, -unroll_m * SIZE); + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * SIZE); + } + + if (!isTransB) { + sub(BO1, -SIZE); + if (unroll_n >= 4) { + sub(BO2, -SIZE); + } + } else { + add(BO1, LDB); + } + + }; + + // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as + // appropriate + // After calculating results in registers, writes back to C matrix + auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma, + Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6), + Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9), + Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12), + Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15), + Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6), + Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9), + Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12), + Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) { + if (!isDirect) { + lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]); + } else { + mov(AO1, A); + } + + if (isCopy) { + lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]); + } else { + lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]); + } + + if (isTransB) { + lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDB * 2]); + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * 0 + 1 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * 0 + 1 * 8 - OFFSET) + * SIZE]); + } + } + } + + for (int i = 4; i < 10; i++) { + vxorps(Ymm(i), Ymm(i), Ymm(i)); + vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6)); + } + + mov(LL, K); + sar(LL, 3); + + Label kernel12, kernel13, kernel14, kernel15; + Label kernel16, kernel17, kernel18; + + sub(LL, SECOND_FETCH); + jle(kernel13, T_NEAR); + align(16); + + L(kernel12); + innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + jg(kernel12, T_NEAR); + align(16); + + L(kernel13); + prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]); + if (unroll_n >= 2) + prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]); + if (unroll_n >= 3) + prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]); + if (unroll_n >= 4) + prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]); + if (unroll_n >= 5) + prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]); + if (unroll_n >= 6) + prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]); + + add(LL, SECOND_FETCH); + jle(kernel15, T_NEAR); + align(16); + + L(kernel14); + innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + jg(kernel14, T_NEAR); + align(16); + + L(kernel15); + test(K, 4); + jle(kernel16, T_NEAR); + innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + + L(kernel16); + test(K, 2); + jle(kernel17, T_NEAR); + innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + align(16); + + L(kernel17); + if (unroll_m == 16) { + if (unroll_n <= 3) { + vaddps(reg00, reg00, reg12); + vaddps(reg01, reg01, reg13); + vaddps(reg02, reg02, reg14); + vaddps(reg06, reg06, reg18); + vaddps(reg07, reg07, reg19); + vaddps(reg08, reg08, reg20); + } + } + + if (unroll_m <= 8) { + vaddps(reg00, reg00, reg12); + vaddps(reg01, reg01, reg13); + vaddps(reg02, reg02, reg14); + vaddps(reg03, reg03, reg15); + vaddps(reg04, reg04, reg16); + vaddps(reg05, reg05, reg17); + } + + test(K, 1); + jle(kernel18, T_NEAR); + innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11); + align(16); + + L(kernel18); + vbroadcastss(VALPHA, ALPHA); + + if (isBetaN) { + vbroadcastss(VBETA, BETA); + } + + // Write back the results; all beta and bias cases need to be + // handled + switch (unroll_n) { + case 1: mov(rax, LDC); break; + case 2: lea(rax, ptr[LDC * 2]); break; + case 3: lea(rax, ptr[LDC + LDC * 2]); break; + case 4: lea(rax, ptr[LDC + LDC * 4]); break; + case 5: + lea(rax, ptr[LDC * 4]); + add(rax, LDC); + break; + case 6: + lea(rax, ptr[LDC + LDC * 2]); + add(rax, rax); + break; + } + + if (hasBias) { + mov(BIAS1, BIAS); + if (isLoad1Unmasked) { + vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]); + } else { + vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]); + } + } + + for (int i = 0; i < unroll_n; i++) { + vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA); + if (!isBeta0) { + if (isLoad1Unmasked) { + switch (i) { + case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break; + case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break; + case 2: + vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]); + break; + case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break; + case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break; + case 5: + vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]); + break; + case 1: + vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]); + break; + case 2: + vmaskmovps( + ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]); + break; + case 3: + vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]); + break; + case 4: + vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]); + break; + case 5: + vmaskmovps( + ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]); + break; + } + } + + if (!isBetaN) { + vaddps(Ymm(i + 4), ymm0, Ymm(i + 4)); + } else { + fma(useFma, VBETA, ymm0, Ymm(i + 4), true); + } + } + if (hasBias) { + vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4)); + } + if (isLoad1Unmasked) { + switch (i) { + case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break; + case 1: + vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4)); + break; + case 2: + vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4)); + break; + case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break; + case 4: + vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4)); + break; + case 5: + vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4)); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 1: + vmaskmovps( + ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 2: + vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK, + Ymm(i + 4)); + break; + case 3: + vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 4: + vmaskmovps( + ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 5: + vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK, + Ymm(i + 4)); + break; + } + } + + if (unroll_m >= 16) { + // Re-use ymm4 (VBIAS2) + if (i == 0) { + if (hasBias) { + if (isLoad1Unmasked) { + vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]); + } else { + vmaskmovps( + VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]); + } + } + } + vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA); + if (!isBeta0) { + if (isLoad2Unmasked) { + switch (i) { + case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break; + case 1: + vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]); + break; + case 2: + vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]); + break; + case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break; + case 4: + vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]); + break; + case 5: + vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]); + break; + case 1: + vmaskmovps( + ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]); + break; + case 2: + vmaskmovps(ymm0, VMASK, + ptr[CO1 + LDC * 2 + 8 * SIZE]); + break; + case 3: + vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]); + break; + case 4: + vmaskmovps( + ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]); + break; + case 5: + vmaskmovps(ymm0, VMASK, + ptr[CO2 + LDC * 2 + 8 * SIZE]); + break; + } + } + if (!isBetaN) { + vaddps(Ymm(i + 10), ymm0, Ymm(i + 10)); + } else { + fma(useFma, VBETA, ymm0, Ymm(i + 10), true); + } + } + if (hasBias) { + vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10)); + } + if (isLoad2Unmasked) { + switch (i) { + case 0: + vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10)); + break; + case 1: + vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10)); + break; + case 2: + vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10)); + break; + case 3: + vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10)); + break; + case 4: + vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10)); + break; + case 5: + vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10)); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10)); + break; + case 1: + vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + case 2: + vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + case 3: + vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10)); + break; + case 4: + vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + case 5: + vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + } + } + } + if (i == 2) + add(CO1, rax); + } + if (unroll_n >= 4) { + add(CO2, rax); + } + + // Compute next address of B + if (!isTransB) { + lea(rax, ptr[K * SIZE]); + switch (unroll_n) { + case 1: + add(BO1, LDB); + add(BO2, LDB); + break; + case 2: + lea(BO1, ptr[BO1 + LDB * 2]); + lea(BO2, ptr[BO2 + LDB * 2]); + break; + case 3: + lea(BO1, ptr[BO1 + LDB3]); + lea(BO2, ptr[BO2 + LDB3]); + break; + case 4: + lea(BO1, ptr[BO1 + LDB * 4]); + lea(BO2, ptr[BO2 + LDB * 4]); + break; + case 5: + lea(BO1, ptr[BO1 + LDB * 4]); + add(BO1, LDB); + lea(BO2, ptr[BO2 + LDB * 4]); + add(BO2, LDB); + break; + case 6: + lea(BO1, ptr[BO1 + LDB3 * 2]); + lea(BO2, ptr[BO2 + LDB3 * 2]); + break; + } + sub(BO1, rax); + sub(BO2, rax); + } else { + mov(rax, LDB); + imul(rax, K); + sub(BO1, rax); + add(BO1, unroll_n * SIZE); + } + }; + + auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, true); + }; + + auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, true); + }; + + auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, true); + }; + + auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, + bool useFma = true) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), + Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9), + Ymm(13), Ymm(14), Ymm(15)); + }; + + auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, + bool useFma = true) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), + Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15)); + }; + + auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy); + }; + + auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy); + }; + + auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, + bool useFma = true) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), + Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9), + Ymm(13), Ymm(14), Ymm(15)); + }; + + auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + // High-level subroutine; does packing if needed, then splits C matrix. + // Operates on chunks of 16 rows, 6 columns at a time (handling tail + // cases appropriately). + // Masking is used for tail cases where M is not divisible by 8. + auto subloop = [&]( + int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) { + if (isTransA) { + do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked); + } + + Label subloop11, subloop11mask; + Label subloop20, subloop21, subloop22, subloop23; + Label subloop24, subloop25; + Label subloop30, subloop31, subloop32, subloop33; + Label subloop34, subloop35; + Label subloop98, subloop98mask; + Label subloop99, subloop99mask; + + mov(CO1, C); + lea(CO2, ptr[CO1 + LDC * 2]); + add(CO2, LDC); + add(C, unroll_m * SIZE); + mov(BO1, B); + if (!isTransB) { + lea(BO2, qword[B + LDB3]); + } + + if (!isTransA) { + lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]); + cmp(M, UNROLL_M); + jg(subloop98, T_NEAR); + + mov(AA, ORIG_A); + lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]); + L(subloop98); + } + + mov(LL, N); + mov(I, LL); + if (!isTransA) { + // If N is too small, skip copy operation + cmp(LL, UNROLL_N * 3); + jle(subloop30, T_NEAR); + + // If A is not aligned to cache line + cmp(FLAG, 0); + je(subloop30, T_NEAR); + } else { + cmp(LL, UNROLL_N); + jl(subloop20, T_NEAR); + } + align(16); + + if (!isTransA) { + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, true); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, true); + } + } else { + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, false, false); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, false, false); + } + } + + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jl(subloop20, T_NEAR); + align(16); + + L(subloop11); + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, false, false); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop11, T_NEAR); + align(16); + + L(subloop20); + cmp(I, 1); + jne(subloop21, T_NEAR); + if (unroll_m == 16) { + kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop21); + cmp(I, 2); + jne(subloop22, T_NEAR); + if (unroll_m == 16) { + kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop22); + cmp(I, 3); + jne(subloop23, T_NEAR); + if (unroll_m == 16) { + kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop23); + cmp(I, 4); + jne(subloop24, T_NEAR); + if (unroll_m == 16) { + kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop24); + cmp(I, 5); + jne(subloop99, T_NEAR); + if (unroll_m == 16) { + kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + if (!isTransA) { + L(subloop30); + cmp(I, UNROLL_N); + jl(subloop25, T_NEAR); + align(16); + + L(subloop31); + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, false); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, false); + } + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop31, T_NEAR); + align(16); + + L(subloop25); + cmp(I, 1); + jne(subloop32, T_NEAR); + if (unroll_m == 16) { + kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop32); + cmp(I, 2); + jne(subloop33, T_NEAR); + if (unroll_m == 16) { + kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop33); + cmp(I, 3); + jne(subloop34, T_NEAR); + if (unroll_m == 16) { + kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop34); + cmp(I, 4); + jne(subloop35, T_NEAR); + if (unroll_m == 16) { + kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop35); + cmp(I, 5); + jne(subloop99, T_NEAR); + if (unroll_m == 16) { + kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + align(16); + } + + L(subloop99); + // Compute address for A + if (!isTransA) { + add(A, unroll_m * SIZE); + } else { + mov(rax, LDA); + imul(rax, rax, unroll_m); + add(A, rax); + } + + // Compute next address of BIAS + if (hasBias) { + add(BIAS, unroll_m * SIZE); + } + }; + + preamble(); + + Label buffer_in_ws, buffer_allocated; + + // Get the registers + mov(B, ARG_B); + mov(LDB, ARG_LDB); + mov(r15, ARG_BETA); + mov(r12, ARG_C); + if (hasBias) + mov(r10, ARG_BIAS); + mov(LDC, ARG_LDC); + mov(rbp, rsp); + + vmovss(xmm0, ptr[ARG_ALPHA]); + vmovss(xmm1, ptr[r15]); + +#if _WIN32 + mov(A, ARG_A); + mov(LDA, ARG_LDA); +#endif + + cmp(K, STACK_K_CAPACITY); + jg(buffer_in_ws, T_NEAR); + + // Create buffer and align to 4kB page + lea(rax, ptr[K * SIZE]); + sal(rax, 4); + add(rax, 256); + sub(rsp, rax); + and_(rsp, -PAGE_4K); + jmp(buffer_allocated, T_NEAR); + + L(buffer_in_ws); + mov(rsp, ARG_WS); + + L(buffer_allocated); + + mov(ORIG_SP, rbp); + mov(M, ARG_M); + mov(N, ARG_N); + mov(C, r12); + if (hasBias) + mov(BIAS, r10); + vmovss(ALPHA, xmm0); + vmovss(BETA, xmm1); + sub(A, -OFFSET * SIZE); + sub(B, -OFFSET * SIZE); + mov(ORIG_A, A); + sal(LDA, BASE_SHIFT); + sal(LDB, BASE_SHIFT); + sal(LDC, BASE_SHIFT); + lea(LDB3, ptr[LDB + LDB * 2]); + + for (int i = 0; i < 8; i++) { + mov(dword[rsp + 88 + i * 4], i); + } + + if (isTransA && is_avx2) { + movq(xmm0, LDA); + vpbroadcastq(ymm1, xmm0); + vinsertf128(ymm0, ymm0, xmm0, 1); + vpermilpd(ymm0, ymm0, 5); + vpaddq(ymm1, ymm1, ymm1); + vperm2f128(ymm1, ymm1, ymm1, 8); + vpaddq(ymm0, ymm0, ymm1); + vmovups(STRIDE, ymm0); + } + + // Check A alignment and leading dimension; take copy-based path as + // needed + mov(rax, LDA); + or_(rax, A); + and_(rax, 0x1f); + mov(FLAG, rax); + + Label main0, main1, main2, main3, main999; + + cmp(M, UNROLL_M); + jl(main0, T_NEAR); + align(16); + + L(main1); + subloop(UNROLL_M, true, true); + sub(M, UNROLL_M); + cmp(M, UNROLL_M); + jge(main1, T_NEAR); + align(16); + + L(main0); + cmp(M, 0); + jle(main999, T_NEAR); + + if (UNROLL_M > 8) { + cmp(M, 8); + jle(main2, T_NEAR); + + sub(M, 8); + vbroadcastss(VMASK, M); + vpcmpgtd(VMASK, VMASK, MASK); + + subloop(16, true, false); + jmp(main999, T_NEAR); + align(16); + + L(main2); + cmp(M, 8); + jne(main3, T_NEAR); + subloop(8, true, true); + jmp(main999, T_NEAR); + } + + align(16); + + L(main3); + vbroadcastss(VMASK, M); + if (is_avx2) { + vpcmpgtd(VMASK, VMASK, MASK); + } else { + auto xmask = Xmm(VMASK.getIdx()); + auto xmm_tmp = xmm4; + + vextractf128(xmm_tmp, VMASK, 1); + vpcmpgtd(xmask, xmask, MASK); + vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4 + vinsertf128(VMASK, VMASK, xmm_tmp, 1); + } + subloop(8, false, false); + align(16); + + L(main999); + // Restore original stack + mov(rsp, ORIG_SP); + + vzeroupper(); + postamble(); + + ker_ = this->getCode(); + } + + typedef void (*ker_t)(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws); + + void operator()(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws) const + { + ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws); + } + +private: + ker_t ker_; +}; + +const xbyak_gemm *get_xbyak_gemm( + bool isTransA, bool isTransB, float beta, bool hasBias) { + auto beta_idx = [](float beta) { + return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2); + }; + + // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)] + static xbyak_gemm *kernel_table[2][2][2][3]; + static std::once_flag initialized; + std::call_once(initialized, [=]{ + for (bool isTransA: {false, true}) + for (bool isTransB: {false, true}) + for (bool hasBias: {false, true}) + for (float beta: {0.0f, 1.0f, 2.0f}) { + // nocopy sgemm with bias for beta != 0.0 is not supported + if (hasBias && beta != 0.0) + continue; + kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] = + new xbyak_gemm(isTransA, isTransB, beta, hasBias); + } + }); + + return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)]; +} + +void sgemm_nocopy_driver(const char *transa, + const char *transb, int m, int n, int k, const float *alpha, + const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta, + float *c, dim_t ldc, const float *bias, float *ws) +{ + bool isTransA = (*transa == 'T' || *transa == 't'); + bool isTransB = (*transb == 'T' || *transb == 't'); + + int Bm, sizeM, Bn, sizeN, Bk, sizeK; + + int i, j; + + if ((m <= 0) || (n <= 0)) + return; + + if ((k <= 0) || (alpha[0] == 0.)) { + + if (beta[0] == 0.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] = 0.0; + } else if (beta[0] != 1.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] *= beta[0]; + } + + return; + } + + assert(IMPLICATION(bias != nullptr, *beta == 0.0)); + + // XXX: this happens on every thread... + bool hasBias = (bias != nullptr); + auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias); + auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false); + auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false); + assert(ker_bn && ker_b1 && ker_b0); + + int BM = 4032; + int BN = isTransA ? 96 : 48; + int BK = isTransB ? 96 : 256; + const float *curA, *curB, *curBias = nullptr; + float *curC; + + for (Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK >= BK * 2) + sizeK = BK; + else { + if (sizeK > BK) + sizeK = (sizeK + 1) / 2; + } + + for (Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM >= BM * 2) + sizeM = BM; + else { + if (sizeM > BM + BM / 2) + sizeM = (sizeM + 1) / 2; + } + + for (Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN >= BN * 2) + sizeN = BN; + else { + if (sizeN > BN + BN / 2) + sizeN = (sizeN + 1) / 2; + } + + if (!isTransA) { + curA = a + Bm + Bk * lda; + } else { + curA = a + Bk + Bm * lda; + } + if (!isTransB) { + curB = b + Bk + Bn * ldb; + } else { + curB = b + Bn + Bk * ldb; + } + curC = c + Bm + (size_t)Bn * ldc; + if (bias != nullptr) { + if (Bk == 0) { + curBias = bias + Bm; + } else { + curBias = nullptr; + } + } + if (Bk == 0) { + if (*beta == 0.0 && bias == nullptr) + (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + else + (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } else { + (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } + } + } + } +} + +} + +mkldnn_status_t jit_avx_gemm_f32( + const char *transa, const char *transb, + const int *p_m, const int *p_n, const int *p_k, const float *p_alpha, + const float *A, const int *p_lda, const float *B, const int *p_ldb, + const float *p_beta, float *C, const int *p_ldc, const float *bias) +{ + using namespace mkldnn::impl::utils; + using namespace avx_gemm_f32; + using namespace gemm_utils; + + if (*p_beta != 0 && bias) + return ref_gemm(transa, transb, p_m, p_n, p_k, + p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias); + + int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + + int m = *p_m; + int n = *p_n; + int k = *p_k; + dim_t lda = *p_lda; + dim_t ldb = *p_ldb; + dim_t ldc = *p_ldc; + float beta = *p_beta; + int MB, NB, KB; + + int nthr_m, nthr_n, nthr_k, nthr_mn; + + // Determine threading partitioning + calc_nthr_nocopy_avx( + m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); + + // May not happen, but just in case + if (nthr < nthr_m * nthr_n * nthr_k) + nthr = nthr_m * nthr_n * nthr_k; + + nthr_mn = nthr_m * nthr_n; + + unsigned char * ompstatus_ = nullptr; + unsigned char volatile *ompstatus = nullptr; + + float *c_buffers = nullptr; + float *ws_buffers = nullptr; + + if (nthr_k > 1) { + ompstatus_ = (unsigned char *) malloc( + nthr * CACHE_LINE_SIZE, + CACHE_LINE_SIZE); + ompstatus = (unsigned char volatile *) ompstatus_; + assert(ompstatus); + + for (int i = 0; i < nthr; i++) + ompstatus[i * CACHE_LINE_SIZE] = 0; + + c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB + * sizeof(float), PAGE_4K); + } + + const size_t ws_elems_per_thr = (size_t)k * 16 + 64; + const size_t ws_size_per_thr + = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); + if (k > STACK_K_CAPACITY) { + ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K); + } + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int k_from, k_to, myK; + int cbase, ibase; + const float *myA, *myB, *myBias = nullptr; + float *myC = C, myBeta; + float *ws = ws_buffers ? + ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0; + dim_t ld = ldc; + + int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k); + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + k_from = KB * (ithr_k); + k_to = KB * (ithr_k + 1); + if (k_to > k) + k_to = k; + myK = k_to - k_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + ibase = (ithr_m + nthr_m * ithr_n) * nthr_k; + + if ((myM > 0) && (myN > 0)) { + + if (*transa == 'N' || *transa == 'n') { + myA = &(A[m_from + k_from * lda]); + } else { + myA = &(A[k_from + m_from * lda]); + } + if (*transb == 'N' || *transb == 'n') { + myB = &(B[k_from + n_from * ldb]); + } else { + myB = &(B[n_from + k_from * ldb]); + } + if (ithr_k == 0) { + myC = &(C[m_from + n_from * ldc]); + myBeta = beta; + ld = ldc; + if (bias) + myBias = &(bias[m_from]); + } else { + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); + myBeta = 0.0; + ld = MB; + myBias = nullptr; + } + + sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA, + lda, myB, ldb, &myBeta, myC, ld, myBias, ws); + + if (nthr_k > 1 && !sum_later) + ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1; + } + + if (nthr_k > 1 && !sum_later) { + + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + /* need to wait until main thread finishes */ + while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) { + }; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) { + }; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + + // handle C summation later + if (nthr_k > 1 && ompstatus[0] == 0) { + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int cbase; + float *myC = C; + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + if (nthr_k > 1) { + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + } + + + free(c_buffers); + free(ompstatus_); + free(ws_buffers); + + return mkldnn_success; +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp new file mode 100644 index 000000000000..aabf520a3c10 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp @@ -0,0 +1,37 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX_GEMM_F32_HPP +#define JIT_AVX_GEMM_F32_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t jit_avx_gemm_f32( + const char *transa, const char *transb, const int *M, + const int *N, const int *K, const float *alpha, const float *A, + const int *lda, const float *B, const int *ldb, const float *beta, + float *C, const int *ldc, const float *bias = nullptr); + + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp new file mode 100644 index 000000000000..5147885a89e4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp @@ -0,0 +1,346 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "gemm_utils_f32.hpp" +#include "ref_gemm_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace gemm_utils; + +namespace { + +template +void copy_A( + bool isTransA, int K, const data_t *A, const dim_t lda, data_t *ws) { + for (int k = 0; k < K; k++) { + PRAGMA_OMP_SIMD() + for (int i = 0; i < unroll_factor::m; i++) { + ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda]; + } + ws += unroll_factor::m; + } +} + +template +void kernel_mxn(int K, const data_t *A, const dim_t lda, + const data_t *B, const dim_t ldb, data_t *C, const dim_t ldc, + const data_t alpha, const data_t beta) { + data_t c[unroll_factor::m * unroll_factor::n] = + { static_cast(0.) }; + for (int k = 0; k < K; k++) { + for (int j = 0; j < unroll_factor::n; j++) { + data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb]; + PRAGMA_OMP_SIMD() + for (int i = 0; i < unroll_factor::m; i++) { + data_t a = isTransA ? A[i * lda + k] : A[i + lda * k]; + c[i + unroll_factor::m * j] += a * b; + } + } + } + for (int j = 0; j < unroll_factor::n; j++) { + PRAGMA_OMP_SIMD() + for (int i = 0; i < unroll_factor::m; i++) { + C[i + j * ldc] = (beta == static_cast(0.)) + ? alpha * c[i + unroll_factor::m * j] + : alpha * c[i + unroll_factor::m * j] + + beta * C[i + j * ldc]; + } + } +} + +template +void block_ker(const int M, const int N, const int K, + const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb, + data_t *C, const dim_t ldc, const data_t alpha, const data_t beta, + data_t *ws, bool do_copy) { + int Nu = rnd_dn(N, unroll_factor::n); + int Mu = rnd_dn(M, unroll_factor::m); + for (int i = 0; i < Mu; i += unroll_factor::m) { + for (int j = 0; j < Nu; j += unroll_factor::n) { + const data_t *b = isTransB ? &B[j] : &B[j * ldb]; + const data_t *a = isTransA ? &A[i * lda] : &A[i]; + if (do_copy) { + if (j == 0) { + copy_A(isTransA, K, a, lda, ws); + } + kernel_mxn( + K, ws, unroll_factor::m, b, ldb, + &C[i + j * ldc], ldc, alpha, beta); + } else { + kernel_mxn( + K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta); + } + } + } + // tail processing + for (int i = 0; i < M; i++) { + for (int j = Nu; j < N; j++) { + data_t c = beta == static_cast(0.) + ? static_cast(0.) + : beta * C[i + j * ldc]; + for (int p = 0; p < K; p++) { + data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; + data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; + c += alpha * a * b; + } + C[i + j * ldc] = c; + } + } + for (int i = Mu; i < M; i++) { + for (int j = 0; j < Nu; j++) { + data_t c = beta == static_cast(0.) + ? static_cast(0.) + : beta * C[i + j * ldc]; + for (int p = 0; p < K; p++) { + data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; + data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; + c += alpha * a * b; + } + C[i + j * ldc] = c; + } + } +} + +template +void gemm_ithr(const int M, const int N, const int K, const data_t alpha, + const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb, + const data_t beta, data_t *C, const dim_t ldc, bool do_copy, + data_t *ws) { + constexpr int BM = gemm_traits::BM; + constexpr int BN = gemm_traits::BN; + constexpr int BK = gemm_traits::BK; + + const data_t *curA; + const data_t *curB; + data_t *curC; + + if ((M <= 0) || (N <= 0)) + return; + + if ((K <= 0) || (alpha == static_cast(0))) { + dim_t MN = N * M; + if (beta == static_cast(0.)) { + for (dim_t j = 0; j < MN; j++) + C[j] = static_cast(0.); + } else if (beta != static_cast(1.)) { + for (dim_t j = 0; j < MN; j++) + C[j] *= beta; + } + return; + } + + for (int Bk = 0; Bk < K; Bk += BK) { + int kb = nstl::min(K - Bk, BK); + for (int Bm = 0; Bm < M; Bm += BM) { + int mb = nstl::min(M - Bm, BM); + for (int Bn = 0; Bn < N; Bn += BN) { + int nb = nstl::min(N - Bn, BN); + curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda; + curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb; + curC = C + Bm + Bn * ldc; + if (Bk == 0) { + block_ker(mb, nb, kb, curA, lda, + curB, ldb, curC, ldc, alpha, beta, ws, do_copy); + } else { + block_ker(mb, nb, kb, curA, lda, + curB, ldb, curC, ldc, alpha, static_cast(1.0), + ws, do_copy); + } + } + } + } +} + +} + +template +mkldnn_status_t ref_gemm( + const char *transa_, const char *transb_, const int *M_, + const int *N_, const int *K_, const data_t *alpha_, const data_t *A, + const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_, + data_t *C, const int *ldc_, const data_t *bias) { + + bool isTransA = (*transa_ == 'T' || *transa_ == 't'); + bool isTransB = (*transb_ == 'T' || *transb_ == 't'); + const int M = *M_, N = *N_, K = *K_; + const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_; + const data_t alpha = *alpha_, beta = *beta_; + + int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads(); + int nthr_m, nthr_n, nthr_k; + int MB, NB, KB; + // thread balancing over M, N, K & size of blocking dimensions + calc_nthr_nocopy_avx( + M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); + + data_t *c_buffers = nullptr; + data_t *ws_buffers = nullptr; + if (nthr_k > 1) { + c_buffers = (data_t *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB + * sizeof(data_t), PAGE_4K); + if (!c_buffers) { + nthr_k = 1; + KB = K; + } + } + + bool do_copy = (NB / unroll_factor::n > 3); + const int nthr_mn = nthr_m * nthr_n; + const int nthr = nthr_mn * nthr_k; + const size_t ws_elems_per_thr = K * unroll_factor::m; + const size_t ws_size_per_thr + = rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K); + if (do_copy) { + ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K); + if (!ws_buffers) + do_copy = false; + } + + auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N, + int ithr) { + from = NB * (ithr); + to = NB * (ithr + 1); + if (to > N) + to = N; + myN = to - from; + }; + + parallel_nd(nthr, [&](const int ithr) { + int ithr_mn = ithr % nthr_mn; + int ithr_m = ithr_mn % nthr_m; + int ithr_n = ithr_mn / nthr_m; + int ithr_k = ithr / nthr_mn; + + int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + data_t *ws = do_copy + ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t) + : nullptr; + + int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0, + k_from = 0, k_to = 0, myK = 0; + + get_thr_block(m_from, m_to, myM, MB, M, ithr_m); + get_thr_block(n_from, n_to, myN, NB, N, ithr_n); + get_thr_block(k_from, k_to, myK, KB, K, ithr_k); + + if (myM > 0 && myN > 0) { + data_t myBeta, *myC; + dim_t ld; + if (ithr_k == 0) { + myC = &(C[m_from + n_from * ldc]); + myBeta = beta; + ld = ldc; + } else { + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); + myBeta = 0.0f; + ld = MB; + } + const data_t *myA = isTransA + ? &(A[k_from + m_from * lda]) + : &(A[m_from + k_from * lda]); + const data_t *myB = isTransB + ? &(B[n_from + k_from * ldb]) + : &(B[k_from + n_from * ldb]); + + if (!isTransA) { + if (!isTransB) { + gemm_ithr(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } else { + gemm_ithr(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } + } else { + if (!isTransB) { + gemm_ithr(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } else { + gemm_ithr(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } + } + } + }); + + if (nthr_k > 1) { + parallel_nd(nthr, [&](const int ithr) { + int ithr_mn = ithr % nthr_mn; + int ithr_m = ithr_mn % nthr_m; + int ithr_k = ithr / nthr_mn; + int ithr_n = ithr_mn / nthr_m; + + int n_from = 0, n_to = 0, myN = 0; + int m_from = 0, m_to = 0, myM = 0; + + int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + get_thr_block(n_from, n_to, myN, NB, N, ithr_n); + get_thr_block(m_from, m_to, myM, MB, M, ithr_m); + + // sum matrices partitioned along K dimension + int offset = 0, block = 0; + gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset, + &block); + for (int ik = 1; ik < nthr_k; ++ik) { + data_t *myC = c_buffers + + MB * ((dim_t)NB * (cbase + ik - 1) + offset); + + gemm_utils::sum_two_matrices(myM, block, myC, MB, + &C[m_from + (n_from + offset) * ldc], ldc); + } + }); + } + + if (bias) { + parallel_nd(N, M, [&](int i, int j) { + C[i*ldc + j] += bias[j]; + }); + } + + free(ws_buffers); + free(c_buffers); + + return mkldnn_success; +} + +template mkldnn_status_t ref_gemm( + const char *transa_, const char *transb_, + const int *M_, const int *N_, const int *K_, const float *alpha_, + const float *A, const int *lda_, const float *B, const int *ldb_, + const float *beta_, float *C, const int *ldc_, const float *bias); + +template mkldnn_status_t ref_gemm( + const char *transa_, const char *transb_, + const int *M_, const int *N_, const int *K_, const double *alpha_, + const double *A, const int *lda_, const double *B, const int *ldb_, + const double *beta_, double *C, const int *ldc_, const double *bias); +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp new file mode 100644 index 000000000000..7c90ba62775a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_GEMM_F32_HPP +#define REF_GEMM_F32_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +mkldnn_status_t ref_gemm(const char *transa, const char *transb, const int *M, + const int *N, const int *K, const data_t *alpha, const data_t *A, + const int *lda, const data_t *B, const int *ldb, const data_t *beta, + data_t *C, const int *ldc, const data_t *bias); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp new file mode 100644 index 000000000000..3dbe07d74375 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "mkldnn_traits.hpp" +#include "nstl.hpp" + +#include "jit_generator.hpp" + +#include "gemm.hpp" + +#include "f32/jit_avx512_common_gemm_f32.hpp" +#include "f32/jit_avx_gemm_f32.hpp" +#include "f32/ref_gemm_f32.hpp" + +#include "s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp" +#include "s8x8s32/simple_gemm_s8s8s32.hpp" +#include "s8x8s32/ref_gemm_s8x8s32.hpp" + +#include "os_blas.hpp" + +/* USE_MKL USE_CBLAS effect + * ------- --------- ------ + * yes yes use Intel(R) MKL CBLAS + * yes no use jit + * no yes system-dependent CBLAS + * no no use jit + */ + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t check_gemm_input(const char *transa, const char *transb, + const int *M, const int *N, const int *K, const int *lda, + const int *ldb, const int *ldc, const float *alpha, const float *beta, + const bool with_bias) { + if (utils::any_null(transa, transb, M, N, K, lda, ldb, ldc, alpha, beta)) + return mkldnn_invalid_arguments; + if (with_bias && *beta != 0) + return mkldnn_unimplemented; + bool consistency = true + && utils::one_of(*transa, 'T', 't', 'N', 'n') + && utils::one_of(*transb, 'T', 't', 'N', 'n') + && *M >= 0 + && *N >= 0 + && *K >= 0; + + if (!consistency) + return mkldnn_invalid_arguments; + bool isTransA = utils::one_of(*transa, 'T', 't'); + bool isTransB = utils::one_of(*transb, 'T', 't'); + int nrowA = isTransA ? *K : *M; + int nrowB = isTransB ? *N : *K; + consistency = true + && *lda >= nstl::max(1, nrowA) + && *ldb >= nstl::max(1, nrowB) + && *ldc >= nstl::max(1, *M); + if (!consistency) + return mkldnn_invalid_arguments; + + return mkldnn_success; +} + +mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc, + const char *transa, const char *transb, const int *M, const int *N, + const int *K, const int *lda, const int *ldb, const int *ldc, + const float *alpha, const float *beta, const bool with_bias) { + if (offsetc == nullptr) + return mkldnn_invalid_arguments; + if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r')) + return mkldnn_invalid_arguments; + + return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha, + beta, with_bias); +} + +mkldnn_status_t extended_sgemm(const char *transa, const char *transb, + const int *M, const int *N, const int *K, const float *alpha, + const float *A, const int *lda, const float *B, const int *ldb, + const float *beta, float *C, const int *ldc, + const float *bias, const bool force_jit_gemm) { + mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K, + lda, ldb, ldc, alpha, beta, bias != nullptr); + if (status != mkldnn_success) + return status; + +#ifdef USE_CBLAS + if (!force_jit_gemm) { + bool trA = *transa == 't' || *transa == 'T'; + bool trB = *transb == 't' || *transb == 'T'; + CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans; + cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB, + *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc); + + if (bias) { + // Add bias if necessary (bias is applied to columns of C) + cblas_int incx = 1, incy = 1; + parallel_nd(*N, [&](int n) { + ptrdiff_t offset = (ptrdiff_t)n * (*ldc); + cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy); + }); + } + return mkldnn_success; + } +#endif + + if (mayiuse(avx512_common)) + return jit_avx512_common_gemm_f32(transa, transb, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); + else if (mayiuse(avx)) + return jit_avx_gemm_f32(transa, transb, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); + else + return ref_gemm(transa, transb, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); +} + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co) { + mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb, + M, N, K, LDA, LDB, LDC, alpha, beta, false); + if (status != mkldnn_success) + return status; + + if (*M == 0 || *N == 0 || *K == 0) + return mkldnn_success; + +#if USE_MKL_IGEMM + bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); + bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); + bool AisN = (*transa == 'N' || *transa == 'n'); + bool BisN = (*transb == 'N' || *transb == 'n'); + + if (data_traits::data_type == data_type::u8) { + CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans; + CBLAS_OFFSET Cblas_offsetc = + OCisR + ? CblasRowOffset + : OCisC + ? CblasColOffset + : CblasFixOffset; + cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc, + *M, *N, *K, *alpha, A, *LDA, *ao, (uint8_t *)B, *LDB, *bo, + *beta, C, *LDC, co); + return mkldnn_success; + } else { + assert(data_traits::data_type == data_type::s8); + // TODO CBLAS implementation of gemm_s8s8s32 goes here. + // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo + if (utils::everyone_is(0, *ao, *bo)) { + return simple_gemm_s8s8s32(transa, transb, offsetc, M, + N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta, + C, LDC, co); + } else { + return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + } + } +#else + cpu_isa_t isa = isa_any; + if (mayiuse(avx512_core_vnni)) { + isa = avx512_core_vnni; + } else if (mayiuse(avx512_core)) { + isa = avx512_core; + } + + if (data_traits::data_type == data_type::u8) { + switch (isa) { + case avx512_core: + case avx512_core_vnni: + return jit_avx512_core_gemm_s8u8s32(transa, transb, offsetc, M, + N, K, alpha, A, LDA, ao, (uint8_t *)B, LDB, bo, beta, + C, LDC, co); + default: + return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + } + } else { + assert(data_traits::data_type == data_type::s8); + // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo + if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni)) + && *ao == 0 && *bo == 0) { + return simple_gemm_s8s8s32(transa, transb, offsetc, M, + N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta, + C, LDC, co); + } else { + return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + } + } +#endif +} + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const int8_t *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co); + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const uint8_t *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co); + +} +} +} + +using namespace mkldnn::impl; +using namespace mkldnn::impl::cpu; + +mkldnn_status_t mkldnn_sgemm(const char *transa, const char *transb, + const int64_t *M, const int64_t *N, const int64_t *K, const float *alpha, + const float *A, const int64_t *lda, const float *B, const int64_t *ldb, + const float *beta, float *C, const int64_t *ldc) { + int M_s32 = (int)*M; + int N_s32 = (int)*N; + int K_s32 = (int)*K; + int lda_s32 = (int)*lda; + int ldb_s32 = (int)*ldb; + int ldc_s32 = (int)*ldc; + + return extended_sgemm(transa, transb, &M_s32, &N_s32, &K_s32, + alpha, A, &lda_s32, B, &ldb_s32, beta, C, &ldc_s32); +} + +mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb, + const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K, + const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao, + const uint8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta, + int32_t *C, const int64_t *ldc, const int32_t *co) { + int M_s32 = (int)*M; + int N_s32 = (int)*N; + int K_s32 = (int)*K; + int lda_s32 = (int)*lda; + int ldb_s32 = (int)*ldb; + int ldc_s32 = (int)*ldc; + return gemm_s8x8s32(transa, transb, offsetc, &M_s32, &N_s32, &K_s32, + alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co); +} + +mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb, + const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K, + const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao, + const int8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta, + int32_t *C, const int64_t *ldc, const int32_t *co) { + int M_s32 = (int)*M; + int N_s32 = (int)*N; + int K_s32 = (int)*K; + int lda_s32 = (int)*lda; + int ldb_s32 = (int)*ldb; + int ldc_s32 = (int)*ldc; + + return gemm_s8x8s32(transa, transb, offsetc, &M_s32, &N_s32, &K_s32, + alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co); +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp new file mode 100644 index 000000000000..dc15ff713040 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp @@ -0,0 +1,58 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_HPP +#define GEMM_HPP + +#include "mkldnn_types.h" +#include "os_blas.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t extended_sgemm(const char *transa, const char *transb, + const int *M, const int *N, const int *K, const float *alpha, + const float *A, const int *lda, const float *B, const int *ldb, + const float *beta, float *C, const int *ldc, + const float *bias = nullptr, bool force_jit_gemm = false); + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *lda, const int8_t *ao, + const b_dt *B, const int *ldb, const int8_t *bo, const float *beta, + int32_t *c, const int *ldc, const int32_t *co); + +#ifdef USE_CBLAS +#define GEMM_IMPL_STR "gemm:blas" +#else +#define GEMM_IMPL_STR "gemm:jit" +#endif + +#if USE_MKL_IGEMM +#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:blas" +#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:blas" +#else +#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:jit" +#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:jit" +#endif + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp new file mode 100644 index 000000000000..4d34ede0bd6d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp @@ -0,0 +1,86 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef OS_BLAS_HPP +#define OS_BLAS_HPP + +/** \file + * Common stuff respecting USE_MKL and USE_CBLAS compile flags + * + * USE_MKL USE_CBLAS effect + * ------- --------- ------ + * yes yes normal compile: jit *may* be preferred over Intel(R) MKL CBLAS + * yes no jit calls OK; assert if cblas is ever called + * no yes system-dependent CBLAS + * no no gemm convolution (or other blas) N/A; create stubs + */ + +#if defined(USE_MKL) + +#include "mkl_version.h" + +#define USE_MKL_PACKED_GEMM (INTEL_MKL_VERSION >= 20190001) +#define USE_MKL_IGEMM \ + (INTEL_MKL_VERSION >= 20180000 && __INTEL_MKL_BUILD_DATE >= 20170628) + +#include "mkl_cblas.h" +#if !defined(USE_CBLAS) +#define cblas_sgemm(...) assert(!"CBLAS is unavailable") +#endif + +#else /* defined(USE_MKL) */ + +#define USE_MKL_PACKED_GEMM 0 +#define USE_MKL_IGEMM 0 + +#if defined(_SX) +/* TODO: _SX should also define USE_CBLAS in case the later is available */ +extern "C" { +#include "cblas.h" // CHECK: does SX also have a fortran API sgemm? +} + +#elif defined(USE_CBLAS) +#include "cblas.h" // Maybe a system/cmake cblas works for you? +#else +/* put the stubs to make a code compilable but not workable */ +#define cblas_sgemm(...) assert(!"CBLAS is unavailable") +#endif /* defined(_SX) */ + +#endif /* defined(USE_MKL) */ + +namespace mkldnn { +namespace impl { +namespace cpu { + +#if defined(USE_MKL) && defined(USE_CBLAS) +typedef MKL_INT cblas_int; + +#elif defined(USE_CBLAS) +typedef int cblas_int; + +#if defined(_SX) +/* this cblas.h is peculiar... */ +typedef CBLAS_ORDER CBLAS_LAYOUT; +#endif +#endif + +} +} +} + +#endif /* OS_BLAS_HPP */ + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp new file mode 100644 index 000000000000..dde72f4a1794 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp @@ -0,0 +1,206 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_H +#define COMMON_H + +#define GEMM_CODE_SIZE (4096L * 32) + +#define AVX512_UNROLL_M 48 +#define AVX512_UNROLL_N 8 +#define AVX512_UNROLL_K 1 +#define AVX512_BM 9984 +#define AVX512_BN 384 +#define AVX512_BK 768 +#define AVX512_BK_VNNI 1536 +#define AVX512_BK_TRADITIONAL 384 +#define AVX512_BLOCKING_SMALL_K 48 +#define AVX512_BN_SMALL_K 24 + + +#define PAGESIZE 4096 + +#define PADD_BYTESIZE_ONPAGE(x, size) (((x) * (size) + PAGESIZE - 1) / PAGESIZE) * PAGESIZE +#define NEXT_THR_STRIDE(x, size) (PADD_BYTESIZE_ONPAGE(x, size)) / size + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +enum { + PARTITION_1D_ROW, + PARTITION_1D_COL, + PARTITION_2D_COL_MAJOR, + PARTITION_2D = PARTITION_2D_COL_MAJOR, +}; + +enum { + COPY_NONE, + COPY_A, +}; + +enum { + NO_OFFSET, + FIX_OFFSET, + COL_OFFSET, + ROW_OFFSET, +}; + +// Alias for any dimension related variable. +typedef long long int dim_t; + +typedef struct { + // Interface arguments. + int transa, transb, offsetc; + dim_t m, n, k; + dim_t lda, ldb, ldc; + const int8_t *a; + const uint8_t *b; + int32_t *c; + const float *alpha, *beta; + + int8_t ao, bo; + const int32_t *co; + + // Kernel parameters. + dim_t um, un, uk, bm, bn, bk; + dim_t bn_small_k, bk_traditional, blocking_small_k; + + int (*copyA)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + int (*copyB)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + int (*kernel)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + // Gemv kernels + void (*gemv_s8u8s32_kernel)(const dim_t, const dim_t, const float, + const int8_t*, const dim_t, const uint8_t*, + const float, int32_t*); + + void (*gemv_u8s8s32_kernel)(const dim_t, const dim_t, const float, + const uint8_t*, const dim_t, const int8_t*, + const float, int32_t*); + + // Gemv parameters + int swap; + +} blas_t; + + +class jit_avx512_core_u8_copy_an_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_an_kern); + + public: + jit_avx512_core_u8_copy_an_kern(); +}; + +class jit_avx512_core_u8_copy_at_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_at_kern); + + public: + jit_avx512_core_u8_copy_at_kern(); +}; + +class jit_avx512_core_u8_copy_bn_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bn_kern); + + public: + jit_avx512_core_u8_copy_bn_kern(); +}; + +class jit_avx512_core_u8_copy_bt_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bt_kern); + + public: + jit_avx512_core_u8_copy_bt_kern(); +}; + +class jit_avx512_core_u8_copy_sum_an_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_an_kern); + + public: + jit_avx512_core_u8_copy_sum_an_kern(); +}; + +class jit_avx512_core_u8_copy_sum_at_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_at_kern); + + public: + jit_avx512_core_u8_copy_sum_at_kern(); +}; + +class jit_avx512_core_u8_copy_sum_bn_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bn_kern); + + public: + jit_avx512_core_u8_copy_sum_bn_kern(); +}; + +class jit_avx512_core_u8_copy_sum_bt_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bt_kern); + + public: + jit_avx512_core_u8_copy_sum_bt_kern(); +}; + +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp new file mode 100644 index 000000000000..db9dd9ef97a4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp @@ -0,0 +1,28 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg); +int gemv_threading_driver(blas_t *arg); + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp new file mode 100644 index 000000000000..e4b8e1cde231 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp @@ -0,0 +1,1409 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "common.hpp" +#include "mkldnn_types.h" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_gemm_s8u8s32.hpp" +#include "jit_avx512_core_gemm_s8u8s32_kern.hpp" +#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp" +#include "gemv.hpp" + +#if defined(_MSC_VER) +#include +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +typedef struct { + int nthrs_m, nthrs_n; + int partition; + int copy_type; +} blas_thread_t; + +static inline void round_to_nearest(int32_t *rounded_val, double fp_val) { + if (fp_val >= 0.) { + fp_val += 0.5; + if (fp_val > INT32_MAX) { + fp_val = INT32_MAX; + } + } else { + fp_val -= 0.5; + if (fp_val < INT32_MIN) { + fp_val = INT32_MIN; + } + } + *rounded_val = (int32_t) fp_val; +} + +static inline void add_results(const dim_t m, const dim_t n, const dim_t k, + const float alpha, const float beta, const int32_t *c_partial_sum, + const dim_t ldcp, int32_t *c_data, const dim_t ldc, + const int32_t *a_row_sum, const int32_t *b_col_sum, const int8_t ao, + const int8_t bo, const int32_t *co, const int offsetc) +{ + for (dim_t j = 0; j < n; ++j) { + for (dim_t i = 0; i < m; ++i) { + int32_t ctemp = c_partial_sum[i + j * ldcp]; + + if (alpha == 1.0f) { + if (beta == 0.0f) { + c_data[i + j * ldc] = ctemp; + } else { + double c_float = (double) beta + * (double) c_data[i + j * ldc]; + c_float += (double) ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } + } else if (alpha == -1.0f) { + if (beta == 0.0f) { + c_data[i + j * ldc] = -ctemp; + } else { + double c_float = (double) beta + * (double) c_data[i + j * ldc]; + c_float -= (double) ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } + } else { + if (beta == 0.0f) { + double c_float = alpha * (double) ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } else { + double c_float = alpha * (double) ctemp + + beta * (double) c_data[i + j * ldc]; + round_to_nearest(&c_data[i + j * ldc], c_float); + } + } + + if (offsetc == FIX_OFFSET) { + c_data[i + j * ldc] += co[0]; + } else if (offsetc == ROW_OFFSET) { + c_data[i + j * ldc] += co[j]; + } else if (offsetc == COL_OFFSET) { + c_data[i + j * ldc] += co[i]; + } + } + } +} + +// TODO Find a better place for those functions. +static inline dim_t ld_padd(const dim_t x) +{ + return ((x + ((2048 / sizeof(int32_t)) - 1)) / (2048 / sizeof(int32_t))) + * (2048 / sizeof(int32_t)) + (64 / sizeof(int32_t)); +} + +void igemm_inner_kernel(const dim_t m, const dim_t n, const dim_t k, + const int8_t *a, const uint8_t *b, float beta, int32_t *c, + const dim_t ldc, const int32_t *a_row_sum, const int32_t *b_col_sum, + const int32_t *co, const int offsetc, const blas_t *arg) +{ + int8_t ao = arg->ao; + int8_t bo = arg->bo; + int32_t co_0 = (offsetc == NO_OFFSET)? 0 : co[0]; + + // Since m and n are limited by blocking, stack overflow may not happen; + // it's up to 32kB +#if !defined(_MSC_VER) + int32_t col_offset[m]; + int32_t row_offset[n]; +#else + int32_t *col_offset = (int32_t *) _alloca(sizeof(*col_offset) * m); + int32_t *row_offset = (int32_t *) _alloca(sizeof(*row_offset) * n); +#endif + + int col_req = 0; + int row_req = 0; + + if ((bo != 0) || (offsetc == COL_OFFSET)) + col_req = 1; + if ((ao != 0) || (offsetc == ROW_OFFSET)) + row_req = 1; + + // It needs one of colum or row offsets, but it doesn't need both + if (((ao != 0) && (bo != 0)) || ((offsetc == FIX_OFFSET) && (co_0 != 0))) { + if ((col_req == 0) && (row_req == 0)) { + if (m <= n) { + col_req = 1; + } else { + row_req = 1; + } + } + } + + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] = 0; + + if (offsetc == COL_OFFSET) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += co[i]; + } + + if (bo != 0) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += bo * a_row_sum[i]; + } + } + + if (row_req) { + for (dim_t i = 0; i < n; i++) + row_offset[i] = 0; + + if (offsetc == ROW_OFFSET) { + for (dim_t i = 0; i < n; i++) + row_offset[i] += co[i]; + } + + if (ao != 0) { + for (dim_t i = 0; i < n; i++) + row_offset[i] += ao * b_col_sum[i]; + } + } + + if ((offsetc == FIX_OFFSET) && (co_0 != 0)) { + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += co_0; + } else { + for (dim_t i = 0; i < n; i++) + row_offset[i] += co_0; + } + } + + if ((ao != 0) && (bo != 0)) { + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += (int32_t) k * ao * bo; + } else { + for (dim_t i = 0; i < n; i++) + row_offset[i] += (int32_t) k * ao * bo; + } + } + + if (col_req == 0) { + if (row_req == 0) { + if (beta == 0.0) { + arg->kernel_b0(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } else { + if (beta == 0.0) { + arg->kernel_b0_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } + } else { + if (row_req == 0) { + if (beta == 0.0) { + arg->kernel_b0_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } else { + if (beta == 0.0) { + arg->kernel_b0_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } + } +} + +static inline void *align(void *ptr, size_t alignment) +{ + return (void *) utils::rnd_up((uintptr_t) ptr, alignment); +} + +static int gemm_kernel_driver(const dim_t m, const dim_t n, const dim_t k, + const int8_t *a, const uint8_t *b, int32_t *c, const int32_t *co, + const blas_t *arg) +{ + dim_t lda = arg->lda; + dim_t ldb = arg->ldb; + dim_t ldc = arg->ldc; + int8_t ao = arg->ao; + int8_t bo = arg->bo; + float alpha = *arg->alpha; + float beta = *arg->beta; + + if (m <= 0 || n <= 0) { + return 0; + } + + // Padding along K dimension. + dim_t k_padd = 0; + if (k <= arg->bk_traditional) { + k_padd = utils::rnd_up(k, arg->uk); + k_padd = nstl::max(128LL, k_padd); + } else if (k < 2 * arg->bk) { + k_padd = utils::rnd_up(k / 2, arg->uk); + } else { + k_padd = arg->bk; + } + + // Padding along M dimension. + dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm), + arg->um); + + // Padding along N dimension. + dim_t n_padd = 0; + if (k < arg->blocking_small_k) { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), + arg->bn_small_k), arg->un); + } else { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn), + arg->un); + } + + // Padding for temporary buffer for C + dim_t ldc_buf = ld_padd(m_padd); + + dim_t strideAm = (arg->transa == 0)? 1 : lda; + dim_t strideAn = (arg->transa != 0)? 1 : lda; + dim_t strideBm = (arg->transb == 0)? 1 : ldb; + dim_t strideBn = (arg->transb != 0)? 1 : ldb; + + size_t a_buf_nelems = m_padd * k_padd; + size_t b_buf_nelems = k_padd * n_padd; + size_t a_row_sum_nelems = m_padd; + size_t b_col_sum_nelems = n_padd; + + size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K + + b_buf_nelems * sizeof(*b) + PAGE_4K + + a_row_sum_nelems * sizeof(*c) + PAGE_4K + + b_col_sum_nelems * sizeof(*c) + PAGE_4K; + + bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0); + if (need_c_buffer) { + size_t c_buf_nelems = ldc_buf * n_padd; + mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; + } + + char *mem = (char *) malloc(mem_size, 128); + + if (!mem) { + return -1; + } + + int8_t *bufferA = (int8_t *) align(mem, PAGE_4K); + uint8_t *bufferB = (uint8_t *) align(bufferA + a_buf_nelems, PAGE_4K); + int32_t *a_row_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K); + int32_t *b_col_sum = (int32_t *) align(a_row_sum + a_row_sum_nelems, + PAGE_4K); + + int32_t *bufferC = NULL; + if (need_c_buffer) { + bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K); + } + + float beta_saved = beta; + + int a_block_copied = 0; + dim_t sizeM = 0; + for (dim_t Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM > m_padd) + sizeM = m_padd; + + dim_t sizeK = 0; + for (dim_t Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK > k_padd) + sizeK = k_padd; + + // Scale C blocks by beta only for the first time + if (Bk == 0) + beta = beta_saved; + else + beta = 1.0f; + + // Apply C offset when to the last k-block of the partial sum. + int offsetc = NO_OFFSET; + if (Bk + sizeK == k) + offsetc = arg->offsetc; + + dim_t sizeN = 0; + for (dim_t Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN > n_padd) + sizeN = n_padd; + + const uint8_t *b_block = b + Bk * strideBm + Bn * strideBn; + arg->copyB(&sizeK, &sizeN, b_block, &ldb, NULL, bufferB, NULL, + NULL, b_col_sum); + + dim_t sizeUM = 0; + for (dim_t Um = 0; Um < sizeM; Um += sizeUM) { + sizeUM = sizeM - Um; + if (sizeUM > arg->um) + sizeUM = arg->um; + + /* + * Use the whole A buffer only if we have multiple B blocks + * for k-dimension, otherwise we are wasting cache to store + * B and C blocks. + */ + dim_t Um_forA = 0; + if (sizeN < n) + Um_forA = Um; + + const int8_t *a_block = a + (Bm + Um) * strideAm + + Bk * strideAn; + if (!a_block_copied) { + arg->copyA(&sizeK, &sizeUM, a_block, &lda, NULL, + bufferA + Um_forA * sizeK, NULL, NULL, + a_row_sum + Um_forA); + } + + int32_t *c_block = c + (Bm + Um) + Bn * ldc; + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = Bn; + } else if (offsetc == COL_OFFSET) { + co_stride = Bm + Um; + } + if (need_c_buffer) { + igemm_inner_kernel(sizeUM, sizeN, sizeK, + bufferA + Um_forA * sizeK, bufferB, 0.0f, + bufferC + Um, ldc_buf, a_row_sum + Um_forA, + b_col_sum, NULL, NO_OFFSET, arg); + + // Finish the block adding the necessary alpha, beta + // and offsets. + add_results(sizeUM, sizeN, sizeK, alpha, beta, + bufferC + Um, ldc_buf, c_block, ldc, + a_row_sum + Um_forA, b_col_sum, ao, bo, + co + co_stride, offsetc); + } else { + igemm_inner_kernel(sizeUM, sizeN, sizeK, + bufferA + Um_forA * sizeK, bufferB, beta, + c_block, ldc, a_row_sum + Um_forA, b_col_sum, + co + co_stride, offsetc, arg); + } + } + a_block_copied = 1; + } + a_block_copied = 0; + } + } + + free(mem); + + return 0; +} + +static int kernel_driver_parallel_acopiedbcopy(const dim_t m, const dim_t n, + const dim_t k, const int8_t *bufferA, const uint8_t *b, + const float beta, int32_t *c, const int offsetc, const int32_t *co, + const int32_t *a_row_sum, const blas_t *arg) +{ + dim_t ldb = arg->ldb; + dim_t ldc = arg->ldc; + int8_t ao = arg->ao; + int8_t bo = arg->bo; + float alpha = *arg->alpha; + + if (m <= 0 || n <= 0) { + return 0; + } + + // Padding along N dimension. + dim_t n_padd = 0; + if (k < arg->blocking_small_k) { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), + arg->bn_small_k), arg->un); + } else { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn), + arg->un); + } + + // Padding for temporary buffer for C + dim_t ldc_buf = ld_padd(m); + + dim_t strideBn = (arg->transb != 0)? 1 : ldb; + + size_t b_buf_nelems = k * n_padd; + size_t b_col_sum_nelems = n_padd; + + size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K + + b_col_sum_nelems * sizeof(*c) + PAGE_4K; + + bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0); + if (need_c_buffer) { + size_t c_buf_nelems = ldc_buf * n_padd; + mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; + } + + char *mem = (char *) malloc(mem_size, 128); + + if (!mem) { + return -1; + } + + uint8_t *bufferB = (uint8_t *) align(mem, PAGE_4K); + int32_t *b_col_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K); + + int32_t *bufferC = NULL; + if (need_c_buffer) { + bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K); + } + + dim_t sizeN = 0; + for (dim_t Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN > n_padd) + sizeN = n_padd; + + // Implement the kernel here. + const uint8_t *b_block = b + Bn * strideBn; + arg->copyB(&k, &sizeN, b_block, &ldb, NULL, bufferB, NULL, NULL, + b_col_sum); + + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = Bn; + } else if (offsetc == COL_OFFSET) { + co_stride = 0; + } + int32_t *c_block = c + Bn * ldc; + if (need_c_buffer) { + igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, 0.0f, bufferC, + ldc_buf, a_row_sum, b_col_sum, NULL, NO_OFFSET, arg); + + // Finish the block adding the necessary alpha, beta and offsets. + add_results(m, sizeN, k, alpha, beta, bufferC, ldc_buf, c_block, + ldc, a_row_sum, b_col_sum, ao, bo, co + co_stride, + offsetc); + } else { + igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, beta, c_block, + ldc, a_row_sum, b_col_sum, co + co_stride, offsetc, arg); + } + } + + free(mem); + + return 0; + +} + +#define N2D_MAX_AVX512 384 +#define M2D_MIN_AVX512 384 +#define VECLEN 16 +#define NCONS 1 +static inline void set_thread_opts_avx512(int *p_nthrs, + blas_thread_t *thread_info, const blas_t *arg) +{ + int nthrs = *p_nthrs; + dim_t m = arg->m; + dim_t n = arg->n; + + thread_info->nthrs_m = 0; + thread_info->nthrs_n = 0; + thread_info->copy_type = COPY_NONE; // By default don't do parallel copy. + + int condition_2D_bsrc = -1; + if ((256 * m > nthrs * n) && (nthrs * m < 256 * n)) { + condition_2D_bsrc = 1; + } else { + condition_2D_bsrc = 0; + } + + int condition_1D_copya = 0; + if ((m >= 1000) && (n >= nthrs * N2D_MAX_AVX512 / 4)) { + condition_2D_bsrc = 0; + condition_1D_copya = 1; + } + + // If offset is non-zero, we need to keep 1D_copya to reduce update overhead + if (arg->ao != 0 || arg->bo != 0 || arg->co[0] != 0 + || arg->offsetc != FIX_OFFSET) { + condition_2D_bsrc = 0; + condition_1D_copya = 1; + } + + if (condition_2D_bsrc == 1) { + int nthrs_m = 1; + int nthrs_n = nthrs; + + while ((nthrs_n % 2 == 0) && + (n / nthrs > N2D_MAX_AVX512 || + n / nthrs_n <= N2D_MAX_AVX512 / 2) && + (m / nthrs_m >= 2 * M2D_MIN_AVX512) && + (nthrs_m < 4)) { + nthrs_m *= 2; + nthrs_n /= 2; + } + + thread_info->nthrs_m = nthrs_m; + thread_info->nthrs_n = nthrs_n; + thread_info->partition = PARTITION_2D; + + // Reset the total number of threads that will be used. + *p_nthrs = nthrs_m * nthrs_n; + + } else if (condition_1D_copya && mkldnn_thr_syncable()) { + // Use parallel copy A algorithm + thread_info->copy_type = COPY_A; + thread_info->partition = PARTITION_1D_COL; + } else { + if ((m > n) && (m / nthrs >= VECLEN || n < NCONS * nthrs)) { + thread_info->partition = PARTITION_1D_ROW; + } else { + thread_info->partition = PARTITION_1D_COL; + } + } +} +#undef N2D_MAX_AVX512 +#undef M2D_MIN_AVX512 +#undef VECLEN +#undef NCONS + +static inline void partition_1d(const int ithr, const int nthrs, const dim_t n, + dim_t *t_offset, dim_t *t_block) +{ + dim_t band = n / nthrs; + + dim_t tail = n - (nthrs - 1) * band; + if (tail > (band + 1)) + band++; + tail = n - (nthrs - 1) * band; + + if (ithr < (nthrs - 1)) + *t_block = band; + else + *t_block = tail; + + *t_offset = ithr * band; + + if (*t_offset >= n) { + *t_block = 0; + *t_offset = 0; + } else if ((*t_offset + *t_block) > n) { + *t_block = n - *t_offset; + } +} + +static inline void partition_2d(const int ithr, int *nthrs, const int ithr_i, + const int ithr_j, const int nthrs_m, const int nthrs_n, const dim_t m, + const dim_t n, dim_t *p_m_disp, dim_t *p_m_band, dim_t *p_n_disp, + dim_t *p_n_band) +{ + dim_t m_disp = 0, n_disp = 0; + dim_t m_band = 0, n_band = 0; + + int mdiv = nthrs_m; + int ndiv = nthrs_n; + + dim_t m_bandt = m / mdiv; /* size per thread */ + dim_t n_bandt = n / ndiv; /* size per thread */ + int firstmgroup = mdiv - 1; + int firstngroup = ndiv - 1; + dim_t firstmval = m_bandt; + dim_t firstnval = n_bandt; + + int mthr_used = mdiv; + if (m - (mdiv - 1) * m_bandt > m_bandt + 1) { + if (m - (mdiv - 1) * m_bandt > mdiv) + ++m_bandt; + + firstmval = m_bandt + 1; + mthr_used = (int) (m / firstmval); + + if (mthr_used * firstmval < m) + ++mthr_used; + + firstmgroup = mthr_used - 1; + } + + int nthr_used = ndiv; + if (n - (ndiv - 1) * n_bandt > n_bandt + 1) { + firstnval = n_bandt + 1; + nthr_used = (int) (n / firstnval); + + if (nthr_used * firstnval < n) + ++nthr_used; + + firstngroup = nthr_used - 1; + } + + *nthrs = mthr_used * nthr_used; + + if (ithr < *nthrs) { + if (ithr_i < firstmgroup) { + m_band = firstmval; + m_disp = ithr_i * firstmval; + } else if (ithr_i <= mthr_used - 2) { + m_band = m_bandt; + m_disp = firstmgroup * firstmval + (ithr_i - firstmgroup) * m_bandt; + } else { + m_disp = firstmgroup * firstmval + + (mthr_used - 1 - firstmgroup) * m_bandt; + m_band = nstl::max(0LL, m - m_disp); + } + + if (ithr_j < firstngroup) { + n_band = firstnval; + n_disp = ithr_j * firstnval; + } else if (ithr_j <= nthr_used - 2) { + n_band = n_bandt; + n_disp = firstngroup * firstnval + (ithr_j - firstngroup) * n_bandt; + } else { + n_disp = firstngroup * firstnval + + (nthr_used - 1 - firstngroup) * n_bandt; + n_band = nstl::max(0LL, n - n_disp); + } + m_disp = nstl::max(nstl::min(m_disp, m - 1), 0LL); + n_disp = nstl::max(nstl::min(n_disp, n - 1), 0LL); + } + + if (ithr < *nthrs) { + *p_m_disp = m_disp; + *p_n_disp = n_disp; + *p_m_band = m_band; + *p_n_band = n_band; + } else { + *p_m_disp = 0; + *p_n_disp = 0; + *p_m_band = 0; + *p_n_band = 0; + } + + return; +} + +static inline void decompose_matrices(const int ithr, int *nthrs, dim_t *m, + dim_t *n, dim_t *k, const int8_t **a, const uint8_t **b, int32_t **c, + const int32_t **co, const blas_thread_t *thread_info, const blas_t *arg) +{ + dim_t strideAm = (arg->transa == 0)? 1 : arg->lda; + dim_t strideBn = (arg->transb != 0)? 1 : arg->ldb; + int offsetc = arg->offsetc; + + switch (thread_info->partition) { + case PARTITION_1D_ROW: + { + dim_t offset = 0; + dim_t block = 0; + partition_1d(ithr, *nthrs, arg->m, &offset, &block); + + *m = block; + *n = arg->n; + *k = arg->k; + + // Set matrix A. + *a = arg->a + offset * strideAm; + + // Set matrix B. + *b = arg->b; + + // Set matrix C. + *c = arg->c + offset; + + // Set offset vector for C matrix + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = 0; + } else if (offsetc == COL_OFFSET) { + co_stride = offset; + } + *co = arg->co + co_stride; + break; + } + + case PARTITION_1D_COL: + { + dim_t offset = 0; + dim_t block = 0; + partition_1d(ithr, *nthrs, arg->n, &offset, &block); + + *m = arg->m; + *n = block; + *k = arg->k; + + // Set matrix A. + *a = arg->a; + + // Set matrix B. + *b = arg->b + offset * strideBn; + + // Set matrix C. + *c = arg->c + offset * arg->ldc; + + // Set offset vector for C matrix + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = offset; + } else if (offsetc == COL_OFFSET) { + co_stride = 0; + } + *co = arg->co + co_stride; + break; + } + + case PARTITION_2D_COL_MAJOR: + { + int nthrs_m = thread_info->nthrs_m; + int nthrs_n = thread_info->nthrs_n; + int ithr_i = ithr % nthrs_m; + int ithr_j = ithr / nthrs_m; + + dim_t m_disp = 0; + dim_t m_band = 0; + dim_t n_disp = 0; + dim_t n_band = 0; + + partition_2d(ithr, nthrs, ithr_i, ithr_j, nthrs_m, nthrs_n, + arg->m, arg->n, &m_disp, &m_band, &n_disp, &n_band); + + *m = m_band; + *n = n_band; + *k = arg->k; + + // Set matrix A. + *a = arg->a + m_disp * strideAm; + + // Set matrix B. + *b = arg->b + n_disp * strideBn; + + // Set matrix C. + *c = arg->c + m_disp + n_disp * arg->ldc; + + // Set offset vector for C matrix + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = n_disp; + } else if (offsetc == COL_OFFSET) { + co_stride = m_disp; + } + *co = arg->co + co_stride; + break; + } + } +} + +#define MULTIPLIER 10 +static int parallel_a_copy(const int ithr, const int nthrs, const dim_t m, + const dim_t n, const dim_t k, const int8_t *a, const uint8_t *b, + int32_t *c, const int32_t *co, const blas_t *arg, + char **p_shared_mem) +{ + const dim_t lda = arg->lda; + const dim_t ldb = arg->ldb; + const dim_t strideAm = (arg->transa == 0)? 1 : lda; + const dim_t strideAn = (arg->transa != 0)? 1 : lda; + const dim_t strideBm = (arg->transb == 0)? 1 : ldb; + + // Padding along M dimension. + dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm), + arg->um); + + // Padding along K dimension. + dim_t k_padd = 0; + if (k <= arg->bk_traditional) { + k_padd = utils::rnd_up(k, arg->uk); + k_padd = nstl::max(128LL, k_padd); + } else if (k < 2 * arg->bk) { + k_padd = utils::rnd_up(k / 2, arg->uk); + } else { + k_padd = arg->bk; + } + + m_padd *= nthrs > MULTIPLIER ? MULTIPLIER : nthrs; + if (m_padd > m) { + m_padd = utils::rnd_up(m, arg->um); + } + + size_t a_buf_nelems = m_padd * k_padd; + + // Allocate shared memory for A and its row sum buffers in master thread. + if (ithr == 0) { // If thread master + size_t a_row_sum_nelems = m_padd; + + size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K) + + a_row_sum_nelems * sizeof(*c) + PAGE_4K; + + *p_shared_mem = (char *) malloc(mem_size, 128); + + } + mkldnn_thr_barrier(); + + char *mem = *p_shared_mem; + int8_t *bufferA = (int8_t *) align(mem, PAGE_4K); + int32_t *a_row_sum = (int32_t *) align(bufferA + a_buf_nelems, PAGE_4K); + + if (!mem) { + return -1; + } + + int result = 0; // Return status + + dim_t sizeK = 0; + for (dim_t Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK > k_padd) + sizeK = k_padd; + + // Scale C blocks by beta only for the first term of partial sum. + float beta = 1.0f; + if (Bk == 0) + beta = *(arg->beta); + + // Apply C offset for the last k-block of the partial sum. + int offsetc = NO_OFFSET; + if (Bk + sizeK == k) + offsetc = arg->offsetc; + + dim_t sizeM = 0; + for (dim_t Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM > m_padd) + sizeM = m_padd; + + if (ithr < nthrs) { + dim_t band = (sizeM + nthrs - 1) / nthrs; + band = utils::rnd_up(band, arg->um); + + dim_t offset = band * ithr; + + // If offset is too large don't use that thread for copying. + if (offset >= sizeM) { + offset = 0; + band = 0; + } + + // Handle the tail of the copy. + if (offset + band > sizeM) { + band = sizeM - offset; + } + + if (band > 0) { + const int8_t *a_block = a + (Bm + offset) * strideAm + + Bk * strideAn; + arg->copyA(&sizeK, &band, a_block, &lda, NULL, + bufferA + offset * sizeK, NULL, NULL, + a_row_sum + offset); + } + } + mkldnn_thr_barrier(); // Wait for finishing parallel copy. + + const uint8_t *b_block = b + Bk * strideBm; + int32_t *c_block = c + Bm; + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = 0; + } else if (offsetc == COL_OFFSET) { + co_stride = Bm; + } + + result = kernel_driver_parallel_acopiedbcopy(sizeM, n, sizeK, + bufferA, b_block, beta, c_block, offsetc, co + co_stride, + a_row_sum, arg); + + mkldnn_thr_barrier(); // Wait for kernel computations to finish. + } + } + + // Free memory allocated in master thread + if (ithr == 0) { + free(mem); + } + + return result; +} +#undef MULTIPLIER + +static inline void get_omp_thread_count(dim_t m, dim_t n, dim_t k, + double fp_per_cycle, int *nthrs) +{ + double omp_overhead_small_core = 3.0e+3; + double omp_intercept_big_core = 4.0e+3; + double omp_slope_big_core = 5.0e+2; + + double gemm_cycles = 8.0 * m * n * k / fp_per_cycle; + + int i = *nthrs; + + // Use a different model for omp overheads if nthrs is <= 4 + if (*nthrs <= 4 && omp_overhead_small_core > 0) { + double omp_cycles = omp_overhead_small_core; + if (gemm_cycles < omp_cycles) { + *nthrs = 1; + return; + } else { + while (i > 1) { + if (omp_cycles * i < gemm_cycles * (i - 1)) break; + --i; + } + } + } else { + if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) { + *nthrs = 1; + return; + } + + // adaptive decrement to march faster· + while (i > 1) { + double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core; + if (omp_cycles * i < gemm_cycles * (i - 1)) + break; + + if (i < 10) + i -= 2; + else if (i < 30) + i -= 4; + else + i -= 8; + } + } + + if (i < 1) + i = 1; + + *nthrs = i; +} + +#define CACHE_LINE_SIZE 64 +static int gemm_threading_driver(blas_t *arg) +{ + if ((arg->m <= 0) || (arg->n <= 0)) + return mkldnn_success; + + if (gemm_s8u8s32_jump_to_gemv_s8u8s32(arg)) { + return mkldnn_success; + } + + int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + get_omp_thread_count(arg->m, arg->n, arg->k, 64.0, &nthr); + + if (nthr == 1) { + return gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, arg->b, + arg->c, arg->co, arg); + } + + int *results = (int *) malloc(sizeof(*results) * nthr * CACHE_LINE_SIZE, + PAGE_4K); + + if (!results) { + return -1; + } + + for (int i = 0; i < nthr; i++) { + results[i * CACHE_LINE_SIZE] = 0; // Initialize to success + } + + char *shared_mem = NULL; + + parallel(nthr, [&](const int ithr, const int nthr) { + int nthrs = nthr; + if (nthrs == 1) { + results[0] = gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, + arg->b, arg->c, arg->co, arg); + } else { + blas_thread_t thread_info; + set_thread_opts_avx512(&nthrs, &thread_info, arg); + + const int8_t *a = NULL; + const uint8_t *b = NULL; + int32_t *c = NULL; + const int32_t *co = NULL; + dim_t m = -1; + dim_t n = -1; + dim_t k = -1; + decompose_matrices(ithr, &nthrs, &m, &n, &k, &a, &b, &c, &co, + &thread_info, arg); + + if (ithr < nthrs) { + switch (thread_info.copy_type) { + case COPY_A: + results[ithr * CACHE_LINE_SIZE] = + parallel_a_copy(ithr, nthrs, m, n, k, a, b, c, co, arg, + &shared_mem); + break; + + default: + case COPY_NONE: + results[ithr * CACHE_LINE_SIZE] = + gemm_kernel_driver(m, n, k, a, b, c, co, arg); + break; + } + } + } + }); + + int result = 0; // Initialize to success + for (int i = 0; i < nthr; i++) { + if (results[i] != 0) { + result = results[i * CACHE_LINE_SIZE]; + break; + } + } + + free(results); + + return result; +} +#undef CACHE_LINE_SIZE + +static jit_avx512_core_u8_copy_an_kern *copy_an; +static jit_avx512_core_u8_copy_at_kern *copy_at; +static jit_avx512_core_u8_copy_bn_kern *copy_bn; +static jit_avx512_core_u8_copy_bt_kern *copy_bt; +static jit_avx512_core_u8_copy_sum_an_kern *copy_sum_an; +static jit_avx512_core_u8_copy_sum_at_kern *copy_sum_at; +static jit_avx512_core_u8_copy_sum_bn_kern *copy_sum_bn; +static jit_avx512_core_u8_copy_sum_bt_kern *copy_sum_bt; +static jit_avx512_core_gemm_s8u8s32_kern *kernel; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_r; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_c; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_b; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_r; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_c; +static jit_avx512_core_gemv_s8u8s32_kern *gemv_s8u8s32_kernel; +static jit_avx512_core_gemv_s8u8s32_kern *gemv_u8s8s32_kernel; + +static void jit_init(blas_t *arg) +{ + static int (*copyAn)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copyAt)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copyBn)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copyBt)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumAn)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumAt)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumBn)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumBt)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*kern)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static void (*gemv_s8u8s32_kern)(const dim_t, const dim_t, const float, + const int8_t*, const dim_t, const uint8_t*, + const float, int32_t*); + + static void (*gemv_u8s8s32_kern)(const dim_t, const dim_t, const float, + const uint8_t*, const dim_t, const int8_t*, + const float, int32_t*); + + if (mayiuse(avx512_core_vnni)) { + arg->um = AVX512_UNROLL_M; + arg->un = AVX512_UNROLL_N; + arg->uk = AVX512_UNROLL_K; + arg->bm = AVX512_BM; + arg->bn = AVX512_BN; + arg->bk = AVX512_BK_VNNI; + + arg->bk_traditional = AVX512_BK_TRADITIONAL; + arg->bn_small_k = AVX512_BN_SMALL_K; + arg->blocking_small_k = AVX512_BLOCKING_SMALL_K; + } else { + arg->um = AVX512_UNROLL_M; + arg->un = AVX512_UNROLL_N; + arg->uk = AVX512_UNROLL_K; + arg->bm = AVX512_BM; + arg->bn = AVX512_BN; + arg->bk = AVX512_BK; + + arg->bk_traditional = AVX512_BK_TRADITIONAL; + arg->bn_small_k = AVX512_BN_SMALL_K; + arg->blocking_small_k = AVX512_BLOCKING_SMALL_K; + } + + static std::once_flag initialized; + std::call_once(initialized, []{ + + copy_an = new jit_avx512_core_u8_copy_an_kern(); + copy_at = new jit_avx512_core_u8_copy_at_kern(); + copy_bn = new jit_avx512_core_u8_copy_bn_kern(); + copy_bt = new jit_avx512_core_u8_copy_bt_kern(); + + copy_sum_an = new jit_avx512_core_u8_copy_sum_an_kern(); + copy_sum_at = new jit_avx512_core_u8_copy_sum_at_kern(); + copy_sum_bn = new jit_avx512_core_u8_copy_sum_bn_kern(); + copy_sum_bt = new jit_avx512_core_u8_copy_sum_bt_kern(); + + kernel = new jit_avx512_core_gemm_s8u8s32_kern(false, false, false); + kernel_b = new jit_avx512_core_gemm_s8u8s32_kern(false, true, true); + kernel_r = new jit_avx512_core_gemm_s8u8s32_kern(false, false, true); + kernel_c = new jit_avx512_core_gemm_s8u8s32_kern(false, true, false); + kernel_b0 = new jit_avx512_core_gemm_s8u8s32_kern(true, false, false); + kernel_b0_b = new jit_avx512_core_gemm_s8u8s32_kern(true, true, true); + kernel_b0_r = new jit_avx512_core_gemm_s8u8s32_kern(true, false, true); + kernel_b0_c = new jit_avx512_core_gemm_s8u8s32_kern(true, true, false); + + gemv_s8u8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern(); + gemv_u8s8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern(); + + + copyAn = copy_an->getCode(); + + copyAt = copy_at->getCode(); + + copyBn = copy_bn->getCode(); + + copyBt = copy_bt->getCode(); + + copySumAn = copy_sum_an->getCode(); + + copySumAt = copy_sum_at->getCode(); + + copySumBn = copy_sum_bn->getCode(); + + copySumBt = copy_sum_bt->getCode(); + + kern = kernel->getCode(); + + kern_b = kernel_b->getCode(); + + kern_r = kernel_r->getCode(); + + kern_c = kernel_c->getCode(); + + kern_b0 = kernel_b0->getCode(); + + kern_b0_b = kernel_b0_b->getCode(); + + kern_b0_r = kernel_b0_r->getCode(); + + kern_b0_c = kernel_b0_c->getCode(); + + gemv_s8u8s32_kern = + gemv_s8u8s32_kernel -> generate + (mayiuse(avx512_core_vnni)); + gemv_u8s8s32_kern = + gemv_u8s8s32_kernel -> generate + (mayiuse(avx512_core_vnni)); + }); + + if (arg->bo == 0) { // No need to compute A row sum if bo is zero + if (arg->transa == 0) { + arg->copyA = copyAn; + } else { + arg->copyA = copyAt; + } + } else { + if (arg->transa == 0) { + arg->copyA = copySumAn; + } else { + arg->copyA = copySumAt; + } + } + + if (arg->ao == 0) { // No need to compute B column sum if ao is zero + if (arg->transb == 0) { + arg->copyB = copyBn; + } else { + arg->copyB = copyBt; + } + } else { + if (arg->transb == 0) { + arg->copyB = copySumBn; + } else { + arg->copyB = copySumBt; + } + } + + arg->kernel = kern; + arg->kernel_b = kern_b; + arg->kernel_r = kern_r; + arg->kernel_c = kern_c; + arg->kernel_b0 = kern_b0; + arg->kernel_b0_b = kern_b0_b; + arg->kernel_b0_r = kern_b0_r; + arg->kernel_b0_c = kern_b0_c; + arg -> gemv_s8u8s32_kernel = gemv_s8u8s32_kern; + arg -> gemv_u8s8s32_kernel = gemv_u8s8s32_kern; +} + +mkldnn_status_t jit_avx512_core_gemm_s8u8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const uint8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc) +{ + char transa = *transA; + char transb = *transB; + char offsetc = *offsetC; + + blas_t args; + + // Initialize blas structure + args.m = *m; + args.n = *n; + args.k = *k; + args.alpha = alpha; + args.a = a; + args.lda = *lda; + args.b = b; + args.ldb = *ldb; + args.beta = beta; + args.c = c; + args.ldc = *ldc; + args.transa = (transa == 'N' || transa == 'n') ? 0 : 1; + args.transb = (transb == 'N' || transb == 'n') ? 0 : 1; + args.um = 0; + args.un = 0; + args.bm = 0; + args.bn = 0; + args.bk = 0; + args.copyA = NULL; + args.copyB = NULL; + args.kernel = NULL; + args.kernel_b0 = NULL; + args.ao = *oa; + args.bo = *ob; + args.co = oc; + + if (offsetc == 'F' || offsetc == 'f') { + args.offsetc = FIX_OFFSET; + } else if (offsetc == 'R' || offsetc == 'r') { + args.offsetc = ROW_OFFSET; + } else { // offsetc == 'C' || offsetc == 'c' + args.offsetc = COL_OFFSET; + } + + jit_init(&args); + int result = gemm_threading_driver(&args); + + return (result < 0) ? mkldnn_out_of_memory : mkldnn_success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp new file mode 100644 index 000000000000..b2e2902a1208 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp @@ -0,0 +1,38 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_CORE_GEMM_S8U8S32_HPP +#define JIT_AVX512_CORE_GEMM_S8U8S32_HPP + +#include +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t jit_avx512_core_gemm_s8u8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const uint8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc); + +} +} +} + +#endif // JIT_AVX512_CORE_GEMM_S8U8S32_HPP diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp new file mode 100644 index 000000000000..57554a185248 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp @@ -0,0 +1,539 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_avx512_core_gemm_s8u8s32_kern.hpp" + + +#ifdef _WIN32 +static const bool is_windows = 1; +#else +static const bool is_windows = 0; +#endif + + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + + + + +// Convert between vector register lengths. +static inline Xmm make_xmm(const Xmm &v) { return Xmm(v.getIdx()); } +static inline Ymm make_ymm(const Xmm &v) { return Ymm(v.getIdx()); } + +// Load from or store to C. +void jit_avx512_core_gemm_s8u8s32_kern::c_load(const Xbyak::Xmm &dst, + const Xbyak::Address &src, int nelems) +{ + switch (nelems) { + default: vmovups(dst, src); break; + case 8: vmovups(make_ymm(dst), src); break; + case 4: vmovups(make_xmm(dst), src); break; + case 2: vmovlps(make_xmm(dst), src); break; + case 1: vmovss(make_xmm(dst), src); break; + } +} +void jit_avx512_core_gemm_s8u8s32_kern::c_store(const Xbyak::Address &dst, + const Xbyak::Xmm &src, int nelems) +{ + switch (nelems) { + default: vmovups(dst, src); break; + case 8: vmovups(dst, make_ymm(src)); break; + case 4: vmovups(dst, make_xmm(src)); break; + case 2: vmovsd(dst, make_xmm(src)); break; + case 1: vmovss(dst, make_xmm(src)); break; + } +} + +// Perform length-4 dot product accumulations of unsigned and signed bytes +// in parallel. +// Use vpdpbusd if VNNI available, otherwise emulate. +void jit_avx512_core_gemm_s8u8s32_kern::dot_product(const Xmm &dst, + const Xmm &src1, const Xmm &src2) +{ + if (vnni) + vpdpbusd(dst, src1, src2); + else { + vpmaddubsw(dp_scratch, src1, src2); + vpmaddwd(dp_scratch, ones, dp_scratch); + vpaddd(dst, dst, dp_scratch); + } +} + +// Inner kernel. +void jit_avx512_core_gemm_s8u8s32_kern::kernel_loop(int unroll_m, int unroll_n, + bool cfetch) +{ + int um_vecs = (unroll_m + 15) >> 4; + Label label_kernel_loop; + + L_aligned(label_kernel_loop); { + for (int h = 0; h < 4; h++) { + for (int j = 0; j < unroll_n; j++) { + const Zmm b = b_regs[j & 1]; + + vpbroadcastd(b, ptr[BO + isize * + (2 * j + 2 * h * unroll_n - offset_b)]); + dot_product(c_regs[0][j], b, a_regs[0]); + + if (j == 1 && !(h & 1)) + prefetch_b(ptr[BO + isize * (prefetch_size_b + + 2 * h * unroll_n - offset_b)]); + else if (j % 3 == 0) + prefetch_a(ptr[AO + isize * (prefetch_size_a + + 32 * (j / 3) + 2 * h * unroll_m - offset_a)]); + + for (int i = 1; i < um_vecs; i++) + dot_product(c_regs[i][j], b, a_regs[i]); + + if (cfetch && (j == std::min(1, unroll_n - 1))) { + if (h == 3) + lea(CO2, ptr[CO2 + LDC]); + else if (h < um_vecs) + prefetch_c(ptr[CO2 + (16 * h * size)]); + } + + if (h == 3 && j == std::min(3, unroll_n - 1)) + lea(AA, ptr[AA + (32 * isize)]); + } + + for (int i = 0; i < um_vecs; i++) + vmovups(a_regs[i], ptr[AO + isize * + (32 * i + 2 * (h + 1) * unroll_m - offset_a)]); + + if (h == 2) + prefetch_x(ptr[AA - (offset_a * isize)]); + } + + add(AO, 8 * isize * unroll_m); + add(BO, 8 * isize * unroll_n); + sub(LoopCount, 1); + jg(label_kernel_loop, T_NEAR); + } +} + +// k remainder loop for kernel. +void jit_avx512_core_gemm_s8u8s32_kern::remainder_kernel(int unroll_m, + int unroll_n, int unroll_k, int bwidth) +{ + if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N) + || (unroll_m < 0) || (unroll_n < 0)) + return; + + int um_vecs = (unroll_m + 15) >> 4; + + for (int h = 0; h < unroll_k; h++) { + for (int j = 0; j < unroll_n; j++) { + Zmm b = b_regs[j & 1]; + auto b_src = ptr[BO + (-isize * offset_b + + bwidth * (j + h * unroll_n))]; + + switch (bwidth) { + case 4: + vpbroadcastd(b, b_src); + break; + case 2: + vpbroadcastw(b, b_src); + break; + case 1: + vpbroadcastb(b, b_src); + break; + } + for (int i = 0; i < um_vecs; i++) + dot_product(c_regs[i][j], b, a_regs[i]); + } + + if (unroll_k > 1) { + for (int i = 0; i < um_vecs; i++) + vmovups(a_regs[i], ptr[AO + isize * (32 * i + + (h + 1) * 2 * unroll_m - offset_a)]); + } + } + + add(AO, unroll_k * unroll_m * bwidth); + add(BO, unroll_k * unroll_n * bwidth); +} + +// Inner loop. +void jit_avx512_core_gemm_s8u8s32_kern::innerloop(int unroll_m, int unroll_n) +{ + if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N) + || (unroll_m < 0) || (unroll_n < 0)) + return; + + int um_vecs = (unroll_m + 15) >> 4; + int stage1 = unroll_n, stage2 = unroll_n; + + Label label_kernel_loop_1, label_k_main_loop_2, label_kernel_loop_2; + Label label_k_main_loop_3, label_kernel_loop_3; + Label label_k_remainder_loop_begin, label_k_rem_4, label_k_rem_2; + Label label_k_rem_1, label_update_begin; + + mov(AO, A); + for (int i = 0; i < um_vecs; i++) + vmovups(a_regs[i], ptr[AO + isize * (32 * i - offset_a)]); + + mov(LoopCount, K); + sar(LoopCount, 4); + jle(label_k_remainder_loop_begin, T_NEAR); + + // Main k loops, broken into three parts to time C prefetching. + sub(LoopCount, stage1 + stage2); + jle(label_k_main_loop_2, T_NEAR); + + kernel_loop(unroll_m, unroll_n, false); + + L_aligned(label_k_main_loop_2); + lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]); + add(LoopCount, stage1); + jle(label_k_main_loop_3, T_NEAR); + + kernel_loop(unroll_m, unroll_n, true); + + L_aligned(label_k_main_loop_3); + lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]); + add(LoopCount, stage2); + jle(label_k_remainder_loop_begin, T_NEAR); + + kernel_loop(unroll_m, unroll_n, true); + + // k remainder handling + L_aligned(label_k_remainder_loop_begin); + mov(LoopCount, K); + test(LoopCount, 8); + je(label_k_rem_4, T_NEAR); + + remainder_kernel(unroll_m, unroll_n, 2, 4); + + L_aligned(label_k_rem_4); + mov(LoopCount, K); + test(LoopCount, 4); + je(label_k_rem_2, T_NEAR); + + remainder_kernel(unroll_m, unroll_n, 1, 4); + + L_aligned(label_k_rem_2); + mov(LoopCount, K); + test(LoopCount, 2); + je(label_k_rem_1, T_NEAR); + + Zmm zero = zmm6; + Zmm tmp = zmm5; + + vpxorq(zero, zero, zero); + for (int i = 0; i < um_vecs; i++) { + Zmm a = a_regs[i]; + vbroadcasti64x4(a, ptr[AO + isize * (16 * i - offset_a)]); + vpunpcklwd(tmp, a, zero); + vpunpckhwd(a, a, zero); + vshufi32x4(a, tmp, a, 0x44); + vshufi32x4(a, a, a, 0xD8); + } + + remainder_kernel(unroll_m, unroll_n, 1, 2); + + L_aligned(label_k_rem_1); + mov(LoopCount, K); + test(LoopCount, 1); + je(label_update_begin, T_NEAR); + + vpxorq(zero, zero, zero); + for (int i = 0; i < um_vecs; i++) { + Zmm a = a_regs[i]; + vbroadcasti32x4(a, ptr[AO + isize * (8 * i - offset_a)]); + vpunpcklbw(tmp, a, zero); + vpunpckhbw(a, a, zero); + vinsertf128(make_ymm(a), make_ymm(tmp), make_xmm(a), 1); + vpunpcklwd(tmp, a, zero); + vpunpckhwd(a, a, zero); + vshufi32x4(a, tmp, a, 0x44); + vshufi32x4(a, a, a, 0xD8); + } + + remainder_kernel(unroll_m, unroll_n, 1, 1); + + // Add offsets and update C. + L_aligned(label_update_begin); + + if (enable_offset_r) { + // Add row offsets. + mov(rax, coffset_ry); + for (int j = 0; j < unroll_n; j++) { + Zmm row_offset = zmm0; + + vbroadcastss(row_offset, ptr[rax + size * j]); + + for (int i = 0; i < um_vecs; i++) + vpaddd(c_regs[i][j], c_regs[i][j], row_offset); + } + add(coffset_ry, size * unroll_n); + } + + if (enable_offset_c) { + // Add column offsets. + mov(rax, coffset_cy); + for (int i = 0; i < um_vecs; i++) { + Zmm col_offset = zmm0; + + c_load(col_offset, ptr[rax + size * 16 * i], unroll_m); + + for (int j = 0; j < unroll_n; j++) + vpaddd(c_regs[i][j], c_regs[i][j], col_offset); + } + } + + Reg64 LDC3 = rax; + lea(LDC3, ptr[LDC + LDC * 2]); + + // C updates. + int c_off_j = 0; + for (int j = 0; j < unroll_n; j++) { + if (j > 0 && (j & 3) == 0) { + lea(CO1, ptr[CO1 + LDC * 4]); + c_off_j += 4; + } + + int jj = j - c_off_j; + + for (int i = 0; i < um_vecs; i++) { + Zmm c = c_regs[i][j]; + Zmm c_old = zmm0; + decltype(LDC * jj) ldc_mult = (jj == 3) ? LDC3 : LDC * jj; + + auto c_mem = ptr[CO1 + ldc_mult + size * 16 * i]; + + if (beta_zero) + c_store(c_mem, c, unroll_m); + else { + c_load(c_old, c_mem, unroll_m); + vpaddd(c_old, c, c_old); + c_store(c_mem, c_old, unroll_m); + } + + vpxorq(c, c, c); + } + } + + lea(CO1, ptr[CO1 + LDC * (unroll_n - c_off_j)]); +} + +// Outer loop. +void jit_avx512_core_gemm_s8u8s32_kern::outerloop(int unroll_x, int unroll_y, + Label *&cur_outerloop_label) +{ + Label label_m_loop, label_n_loop, label_n_remainder_loops[6]; + + L(*cur_outerloop_label); + cur_outerloop_label++; + if (unroll_x >= IGEMM_UNROLL_M) { + mov(J, M); + cmp(J, unroll_x); + jl(*cur_outerloop_label, T_NEAR); // Jump to next outerloop label. + } else { + test(J, unroll_x); + jle(*cur_outerloop_label, T_NEAR); + } + + L_aligned(label_m_loop); { + mov(CO1, C); + add(C, unroll_x * size); + + mov(BO, B); + + mov(AA, K); + imul(AA, AA, unroll_x * isize); + lea(AA, ptr[A + AA + isize * prefetch_size_a]); + + if (enable_offset_c) { + mov(rax, coffset_cx); + mov(coffset_cy, rax); + add(rax, unroll_x * size); + mov(coffset_cx, rax); + } + + if (enable_offset_r) { + mov(rax, coffset_rx); + mov(coffset_ry, rax); + } + + mov(I, N); + cmp(I, unroll_y); + jl(label_n_remainder_loops[0], T_NEAR); + + L_aligned(label_n_loop); { + innerloop(unroll_x, unroll_y); + sub(I, unroll_y); + cmp(I, unroll_y); + jge(label_n_loop, T_NEAR); + } + + align(16); + + int label_idx = 0; + for (int uy = 16; uy > 0; uy >>= 1) { + L(label_n_remainder_loops[label_idx++]); + if (unroll_y > uy) { + test(I, uy); + jle(label_n_remainder_loops[label_idx], T_NEAR); + + innerloop(unroll_x, uy); + align(16); + } + } + L(label_n_remainder_loops[label_idx]); + + mov(A, AO); + if (unroll_x >= IGEMM_UNROLL_M) { + sub(J, unroll_x); + cmp(J, unroll_x); + jge(label_m_loop); + } + } + + align(16); +} + +void jit_avx512_core_gemm_s8u8s32_kern::generate() +{ + // Prologue + preamble(); + sub(rsp, stack_alloc_size); + + if (is_windows) { + mov(A, arg_a); + mov(B, arg_b); + } + + mov(C, arg_c); + mov(LDC, arg_ldc); + + sub(A, -offset_a * isize); + sub(B, -offset_b * isize); + + mov(M, qword[M]); + mov(N, qword[N]); + mov(K, qword[K]); + + lea(LDC, ptr[LDC * size]); + + if (enable_offset_c) { + mov(rax, arg_coffset_c); + mov(coffset_cx, rax); + } + if (enable_offset_r) { + mov(rax, arg_coffset_r); + mov(coffset_rx, rax); + } + + for (int i = 0; i < (max_unroll_m >> 4); i++) { + for (int j = 0; j < max_unroll_n; j++) { + auto &c = c_regs[i][j]; + vpxorq(c, c, c); + } + } + + if (!vnni) { + mov(rax, 1); + movq(make_xmm(ones), rax); + vpbroadcastw(ones, make_xmm(ones)); + } + + Label outerloop_labels[8]; + Label *cur_outerloop_label = &outerloop_labels[0]; + + // Main m loop. + outerloop(IGEMM_UNROLL_M, IGEMM_UNROLL_N, cur_outerloop_label); + + // m remainder loops. + for (int um = 32; um > 0; um >>= 1) + if (IGEMM_UNROLL_M > um) + outerloop(um, IGEMM_UNROLL_N, cur_outerloop_label); + + L(*cur_outerloop_label); + + // Epilogue. + add(rsp, stack_alloc_size); + postamble(); +} + + +jit_avx512_core_gemm_s8u8s32_kern::jit_avx512_core_gemm_s8u8s32_kern(bool + beta_zero_, bool enable_offset_c_, bool enable_offset_r_) : + jit_generator(nullptr, 100000), arg_a(0), arg_b(0), arg_c(0), arg_ldc(0), + arg_coffset_c(0), arg_coffset_r(0), coffset_cx(0), coffset_cy(0), + coffset_rx(0), coffset_ry(0) +{ + beta_zero = beta_zero_; + enable_offset_c = enable_offset_c_; + enable_offset_r = enable_offset_r_; + vnni = mayiuse(avx512_core_vnni); + + // Assign integer registers + M = is_windows ? rcx : rdi; + N = is_windows ? rdx : rsi; + K = is_windows ? r8 : rdx; + A = is_windows ? rsi : r8; + B = r9; + C = r10; + LDC = r11; + I = r12; + J = r13; + LoopCount = rax; + AO = r14; + BO = r15; + CO1 = rbx; + CO2 = rbp; + AA = is_windows ? rdi : rcx; + + // Assign vector registers + dp_scratch = zmm6; + ones = zmm7; + for (int i = 0; i < (max_unroll_m >> 4); i++) + a_regs[i] = Zmm(i); + b_regs[0] = zmm4; + b_regs[1] = zmm5; + + int rn = 0; + for (int i = 0; i < (max_unroll_m >> 4); i++) + for (int j = 0; j < max_unroll_n; j++) + c_regs[i][j] = Zmm(8 + rn++); + + // Assign stack variables. + stack_alloc_size = 32; + auto args_offset = stack_alloc_size + get_size_of_abi_save_regs() + + 8 + (is_windows ? 48 : 0); + + arg_a = ptr[rsp + (args_offset - 16)]; + arg_b = ptr[rsp + (args_offset - 8)]; + arg_c = ptr[rsp + (args_offset + 0)]; + arg_ldc = ptr[rsp + (args_offset + 8)]; + arg_coffset_c = ptr[rsp + (args_offset + 16)]; + arg_coffset_r = ptr[rsp + (args_offset + 24)]; + + coffset_cx = qword[rsp + 0]; + coffset_cy = qword[rsp + 8]; + coffset_rx = qword[rsp + 16]; + coffset_ry = qword[rsp + 24]; + + generate(); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp new file mode 100644 index 000000000000..e8efcc1cc87b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp @@ -0,0 +1,101 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef IGEMM_KERNEL_GENERATOR_HPP +#define IGEMM_KERNEL_GENERATOR_HPP + +#include "jit_generator.hpp" + + +namespace mkldnn { +namespace impl { +namespace cpu { + +class jit_avx512_core_gemm_s8u8s32_kern : public jit_generator { +public: + jit_avx512_core_gemm_s8u8s32_kern(bool beta_zero_, bool enable_offset_c_, + bool enable_offset_r_); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemm_s8u8s32_kern); + +protected: + bool beta_zero; + bool enable_offset_c, enable_offset_r; + bool vnni; + + void prefetch_a(const Xbyak::Address &src) { + prefetcht0(src); + } + void prefetch_b(const Xbyak::Address &src) { + prefetcht0(src); + } + void prefetch_c(const Xbyak::Address &src) { + prefetchw(src); + } + void prefetch_x(const Xbyak::Address &src) { + prefetcht0(src); + } + + void c_load(const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems); + void c_store(const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems); + + void dot_product(const Xbyak::Xmm &dst, const Xbyak::Xmm &src1, + const Xbyak::Xmm &src2); + void kernel_loop(int unroll_m, int unroll_n, bool cfetch); + void remainder_kernel(int unroll_m, int unroll_n, int unroll_k, int bwidth); + void innerloop(int unroll_m, int unroll_n); + void outerloop(int unroll_x, int unroll_y, Xbyak::Label *&outerloop_label); + + void generate(); + + +private: + static const int IGEMM_UNROLL_M = 48; + static const int IGEMM_UNROLL_N = 8; + + static const int isize = 2; + static const int size = 4; + + // Prefetch configuration + static const int prefetch_size_a = 32 * 5; + static const int prefetch_size_b = 32 * 4; + + static const int offset_a = 256, offset_b = 256; + static const int max_unroll_m = 48, max_unroll_n = 8; + + // Integer register assignments + Xbyak::Reg64 M, N, K, A, B, C, LDC, I, J, LoopCount; + Xbyak::Reg64 AO, BO, CO1, CO2, AA; + + // Vector register assignments + Xbyak::Zmm dp_scratch, ones, a_regs[max_unroll_m >> 4], b_regs[2]; + Xbyak::Zmm c_regs[max_unroll_m >> 4][max_unroll_n]; + + // Stack variable assignments + int stack_alloc_size; + Xbyak::Address arg_a, arg_b, arg_c, arg_ldc, arg_coffset_c, arg_coffset_r; + Xbyak::Address coffset_cx, coffset_cy, coffset_rx, coffset_ry; + + void L_aligned(Xbyak::Label &label, int alignment = 16) { + align(alignment); + L(label); + } +}; + +} +} +} + +#endif /* header guard */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp new file mode 100644 index 000000000000..4f0b10dadd44 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp @@ -0,0 +1,290 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "gemv.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg) { + + blas_t arg_gemv = *arg; + + if ((arg -> offsetc == FIX_OFFSET) && // Fix offset + (arg -> ao == 0) && + (arg -> bo == 0) && + (arg -> co[0] == 0) && + (*(arg -> alpha) == 1.0f) && + ((*(arg -> beta) == 1.0f) || *(arg -> beta) == 0.0f)) { + + if (arg -> n == 1) { + + if (arg -> transa == 1) { // A transpose + arg_gemv.n = arg -> k; + arg_gemv.ldc = 1; + arg_gemv.swap = 0; + if (arg -> transb == 0) { // B non transpose + arg_gemv.ldb = 1; + } + // B transpose arg_gemv.ldb = arg -> ldb + gemv_threading_driver(&arg_gemv); + return 1; + } + } + + if (arg -> m == 1) { + + if (arg -> transb == 0) { // B non transpose + arg_gemv.transa = 1; + arg_gemv.m = arg -> n; + arg_gemv.n = arg -> k; + arg_gemv.a = (int8_t *) arg -> b; + arg_gemv.lda = arg -> ldb; + arg_gemv.b = (uint8_t *) arg -> a; + arg_gemv.swap = 1; + if (arg -> transa == 0) { // A non transpose + arg_gemv.ldb = arg -> lda; + } + else { // A transpose + arg_gemv.ldb = 1; + } + gemv_threading_driver(&arg_gemv); + return 1; + } + } + } + + return 0; +} + + +int gemv_kernel_driver(blas_t *arg) { + + dim_t m = arg -> m; + dim_t n = arg -> n; + uint8_t *a = (uint8_t *) arg -> a; + dim_t lda = arg -> lda; + int8_t *b = (int8_t *) arg -> b; + float beta = *(arg -> beta); + + if (arg -> swap) { + arg -> gemv_u8s8s32_kernel(m, n, 1.0f, a, lda, b, beta, arg -> c); + } + else { + arg -> gemv_s8u8s32_kernel(arg -> m, arg -> n, 1.0f, arg -> a, + arg -> lda, arg -> b, *(arg -> beta), arg -> c); + } + + return 0; +} + +int gemv_threading_driver(blas_t *arg) { + + dim_t nthr_m, nthr_n = 1; + dim_t MB, NB, UM = 16, UN = 64; + dim_t BLOCKM = 192, BLOCKN = 3072; + int status; + dim_t i; + + dim_t nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + + uint8_t *new_x = NULL; + int32_t *tmp_y = NULL, *new_y = NULL; + + dim_t m = arg -> m, n = arg -> n; + + blas_t arg_seq = *arg; + float zero = 0.0f; + + nthr_m = std::min(std::max(m / BLOCKM, (dim_t) 1), nthr); + MB = m / nthr_m; + MB = (((MB / UM) * UM) == MB) ? MB : (MB / UM) * UM + UM; + nthr_m = (((m / MB) * MB) == m) ? m / MB : m / MB + 1; + nthr_m = std::min(std::max(nthr_m, (dim_t) 1), nthr); + + while ((nthr_m * (nthr_n + 1) <= nthr) && ((n / (nthr_n + 1)) >= BLOCKN)) { + nthr_n++; + } + + NB = n / nthr_n; + NB = (((NB / UN) * UN) == NB) ? NB : (NB / UN) * UN + UN; + nthr_n = (((n / NB) * NB) == n) ? n / NB : n / NB + 1; + nthr_n = std::min(std::max(nthr_n, (dim_t) 1), nthr / nthr_m); + + nthr = nthr_m * nthr_n; + + if (arg -> ldb != 1) { + new_x = (uint8_t *)malloc(n, 64); + if (new_x == NULL) + return 1; + for (i = 0; i < n; i++) { + new_x[i] = (arg -> b)[i * arg -> ldb]; + } + arg_seq.b = new_x; + arg_seq.ldb = 1; + } + else new_x = (uint8_t *) arg -> b; + + if (arg -> ldc != 1) { + new_y = (int32_t *) malloc(nthr_m * PADD_BYTESIZE_ONPAGE(MB, sizeof(int32_t)), 64); + if (new_y == NULL) { + if (arg -> ldb != 1) { + free(new_x); + } + return 1; + } + } + + // GEMV computation + if (nthr == 1) { + + if (arg -> ldc != 1) { + if (*(arg -> beta) != 0.0f) { + for (i = 0; i < m; i++) { + new_y[i] = arg -> c[i * arg -> ldc]; + } + } + } + + status = gemv_kernel_driver(&arg_seq); + + if (arg -> ldc != 1) { + for (i = 0; i < m; i++) { + arg -> c[i * arg -> ldc] = new_y[i]; + } + } + + if (arg -> ldb != 1) { + free(new_x); + } + if (arg -> ldc != 1) { + free(new_y); + } + return status; + } + + if (nthr_n > 1) { + tmp_y = (int32_t *) malloc((nthr_n - 1) * PADD_BYTESIZE_ONPAGE(m, sizeof(int32_t)), PAGESIZE); + if (tmp_y == NULL) { + if (arg -> ldb != 1) { + free(new_x); + } + return 1; + } + } + + parallel_nd((int) nthr, [&](const dim_t ithr) { + + dim_t m_from, m_to, myM; + dim_t n_from, n_to, myN; + + dim_t n_id, m_id; + dim_t loc_incy = 1; + int32_t *loc_y; + + blas_t arg_loc = arg_seq; + int j; + + m_id = ithr / nthr_n; + n_id = ithr % nthr_n; + + m_from = MB * m_id; + m_to = MB * (m_id + 1); + if ((m_to > m) || (m_id == nthr_m - 1)) + m_to = m; + + myM = m_to - m_from; + + n_from = NB * n_id; + n_to = NB * (n_id + 1); + if ((n_to > n) || (n_id == nthr_n - 1)) + n_to = n; + + myN = n_to - n_from; + + if (n_id != 0) { + arg_loc.beta = &zero; + loc_y = tmp_y + (NEXT_THR_STRIDE(m, sizeof(int32_t))) * (n_id - 1) + m_from; + } + else { + if (arg -> ldc == 1) { + loc_y = arg_seq.c + m_from; + } + else { + // need to copy the block of c in new_y + loc_y = new_y + m_id * NEXT_THR_STRIDE(MB, sizeof(int32_t)); + if (*(arg -> beta) != 0.0f) { + for (j = 0; j < myM; j++) { + loc_y[j] = arg -> c[(m_from + j) * arg -> ldc]; + } + } + } + } + + arg_loc.m = myM; + arg_loc.n = myN; + arg_loc.a = arg_seq.a + m_from * arg_seq.lda + n_from; + arg_loc.b = arg_seq.b + n_from; + arg_loc.c = loc_y; + arg_loc.ldc = loc_incy; + + gemv_kernel_driver(&arg_loc); + + if ((n_id == 0) && (arg -> ldc != 1)) { + for (j = 0; j < myM; j++) { + arg -> c[(m_from + j) * arg -> ldc] = loc_y[j]; + } + } + + }); + + if (nthr_n > 1) { + parallel_nd((int) nthr_m, [&](const dim_t ithr) { + + dim_t j, j_from, j_to, ii; + int32_t acc; + + j_from = MB * ithr; + j_to = MB * (ithr + 1); + if ((j_to > m) || (ithr == nthr - 1)) + j_to = m; + + for (j = j_from; j < j_to; j++) { + acc = 0; + for (ii = 0; ii < nthr_n - 1; ii++) { + acc += tmp_y[ii * NEXT_THR_STRIDE(m, sizeof(int32_t)) + j]; + } + (arg -> c)[j * arg -> ldc] += acc; + } + }); + free(tmp_y); + } + + if (arg -> ldb != 1) { + free(new_x); + } + + if (arg -> ldc != 1) { + free(new_y); + } + + return 0; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp new file mode 100644 index 000000000000..c57a8c1d1212 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp @@ -0,0 +1,411 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp" + +#ifdef _WIN32 +#define is_windows 1 +#else +#define is_windows 0 +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +void jit_avx512_core_gemv_s8u8s32_kern::vnni(Xbyak::Zmm acc, Xbyak::Zmm b, + Xbyak::Zmm a, Xbyak::Zmm tmp, + Xbyak::Zmm one, bool swap, + int use_vnni) { + + if (use_vnni) { + if (swap) + vpdpbusd(acc, a, b); + else + vpdpbusd(acc, b, a); + } + + else { + if (swap) + vpmaddubsw(tmp, a, b); + else + vpmaddubsw(tmp, b, a); + vpmaddwd(tmp, tmp, one); + vpaddd(acc, tmp, acc); + } + +} + +void jit_avx512_core_gemv_s8u8s32_kern::n_loop_body(int start_a_idx, int start_acc_idx, + int b_idx, int nreg_acc, + Xbyak::Reg64 A, Xbyak::Reg64 lda, + Xbyak::Reg64 X, Xbyak::Zmm tmp, + Xbyak::Zmm one, bool swap, int use_vnni, + int use_mask, Xbyak::Opmask mask_n) { + + int i; + int nreg_A = nreg_acc / 2 + (nreg_acc % 2); + + // load X + j + if (use_mask) + vmovdqu8(Xbyak::Zmm(b_idx) | mask_n | T_z, ptr[X]); + else + vmovdqu8(Xbyak::Zmm(b_idx), ptr[X]); + + xor_(r14, r14); + // load values of A + for (i = 0; i < nreg_A; i++) { + if (use_mask) + vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]); + else + vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]); + add(r14, lda); + } + + for (i = 0; i < nreg_A; i++) { + // vnni (acc, b, a, tmp, one, swap, use_vnni) + vnni(Xbyak::Zmm(start_acc_idx + i), Xbyak::Zmm(b_idx), + Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni); + } + + for (i = 0; i < nreg_A - (nreg_acc % 2); i++) { + if (use_mask) + vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]); + else + vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]); + add(r14, lda); + } + + for (i = 0; i < nreg_A - (nreg_acc % 2); i++) { + vnni(Xbyak::Zmm(start_acc_idx + i + nreg_A), Xbyak::Zmm(b_idx), + Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni); + } + +} + +void jit_avx512_core_gemv_s8u8s32_kern::shuffle_and_add(Xbyak::Zmm dest, Xbyak::Zmm A, + Xbyak::Zmm B, Xbyak::Zmm C, + Xbyak::Zmm D) { + + vshufi32x4(dest, A, C, 0x44); + vshufi32x4(A, A, C, 0xEE); + vpaddd(C, dest, A); // C = A0 + A2|A1 + A3|C0 + C2|C1 + C3 + + vshufi32x4(dest, B, D, 0x44); + vshufi32x4(B, B, D, 0xEE); + vpaddd(D, dest, B); // D = B0 + B2|B1 + B3|D0 + D2|D1 + D3 + + vshufi32x4(A, C, D, 0x88); + vshufi32x4(B, C, D, 0xDD); + vpaddd(dest, A, B); // dest = SAi|SBi|SCi|SDi + +} + +void jit_avx512_core_gemv_s8u8s32_kern::update_c(int nreg_acc, Xbyak::Reg64 Y, + int start_a_idx, int start_acc_idx, + Xbyak::Xmm beta, int use_mask, + Xbyak::Opmask mask_m) { + + int l, i, k, j, last_it; + Xbyak::Label store_label; + + l = 0; + for (k = 0; k < nreg_acc; k += 8) { + for (i = 0, j = k; i < 8; i += 4, j += 2) { + if (j < nreg_acc) { + // shuffle per block of 4 registers + shuffle_and_add(Xbyak::Zmm(start_a_idx + l), // dest + Xbyak::Zmm(start_acc_idx + j), // A = acc0 + Xbyak::Zmm(start_acc_idx + 1 + j), // B = acc1 + Xbyak::Zmm(start_acc_idx + 4 + j), // C = acc4 + Xbyak::Zmm(start_acc_idx + 5 + j)); // D = acc5 + + // extract low and high from dest and hadd + vextracti32x8(Xbyak::Ymm(start_a_idx + l + 1), Xbyak::Zmm(start_a_idx + l), 0); + vextracti32x8(Xbyak::Ymm(start_a_idx + l + 2), Xbyak::Zmm(start_a_idx + l), 1); + vphaddd(Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + l + 1), + Xbyak::Ymm(start_a_idx + l + 2)); + } + l++; + } + + vphaddd(Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + l - 2), + Xbyak::Ymm(start_a_idx + l - 1)); + + l++; + } + + // eventually add with C and store new value + vxorps(Xbyak::Ymm(start_a_idx), + Xbyak::Ymm(start_a_idx), + Xbyak::Ymm(start_a_idx)); + vucomiss(beta, Xbyak::Ymm(start_a_idx)); + je(store_label, T_NEAR); + + // beta = 1 + for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) { + // load Y and add + last_it = (k + 8) > nreg_acc; + if (use_mask && last_it) + vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8) | mask_m | T_z, ptr[Y + (k / 8) * 32]); + else + vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8), ptr[Y + (k / 8) * 32]); + + vpaddd(Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + k / 8)); + } + + // store + aligned_label(store_label); + for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) { + last_it = (k + 8) > nreg_acc; + if (use_mask && last_it) + vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l) | mask_m); + else + vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l)); + } + +} + +template +T jit_avx512_core_gemv_s8u8s32_kern::generate(int use_vnni) { + + Xbyak::Opmask mask_n = k1, mask_m = k2; + Xbyak::Label one_label, m_tail_label, m_loop_label, n_loop_label; + Xbyak::Label n_tail_label, update_c_label, end_label; + constexpr unsigned int n_labels = (1 << unroll_m) - 1; + Xbyak::Label m_tail_label_case[n_labels]; + Xbyak::Label n_loop_label_case[n_labels]; + Xbyak::Label n_tail_label_case[n_labels]; + Xbyak::Label update_c_label_case[n_labels]; + + int i, ii; + + Xbyak::Zmm one, tmp; + Xbyak::Reg64 n = abi_param2, m = abi_param1; + Xbyak::Reg64 A = is_windows ? abi_param4 : abi_param3; + Xbyak::Reg64 lda = is_windows ? abi_param3 : abi_param4; + Xbyak::Reg64 X = is_windows ? rdi : r8; + Xbyak::Xmm beta = xmm1; + Xbyak::Reg64 Y = is_windows ? rsi : r9; + + bool swap = !std::is_same::value; + + // Windows: read on the stack lda, X, beta, Y + + int zmm_idx = 1; + int nreg_acc = 1 << unroll_m; + int nreg_A = 1 << (unroll_m - 1); + int nreg_A_acc = nreg_acc + nreg_A; + + if (!use_vnni) { + // set a zmm register to one + tmp = Xbyak::Zmm(0); + one = Xbyak::Zmm(zmm_idx + 1); + zmm_idx += 2; // one + tmp + } + else { + beta = xmm0; + } + + preamble(); + + if (is_windows) { + mov(lda, ptr[rsp + get_size_of_abi_save_regs() + 40]); + mov(X, ptr[rsp + get_size_of_abi_save_regs() + 48]); + movss(beta, ptr[rsp + get_size_of_abi_save_regs() + 56]); + mov(Y, ptr[rsp + get_size_of_abi_save_regs() + 64]); + } + + if (use_vnni && !is_windows) { + movaps(beta, xmm1); + } + + mov(rax, (1 << unroll_n) - 1); + kmovq(k3, rax); + + and_(rax, n); // rax contains n & ((1 << unroll_n) - 1) + mov(rbx, 1); + shlx(rbx, rbx, rax); + sub(rbx, 1); + kmovq(mask_n, rbx); + // mask_n set (AVX512 only), can use rax and rbx again + + // set mask_m for update of the C matrix + // load/store on the C matrix use Ymm so tail according to Ymm size + mov(rax, 7); // 8 * 32 = 256 Ymm size + and_(rax, m); // rax contains m & 7 + mov(rbx, 1); + shlx(rbx, rbx, rax); + sub(rbx, 1); + kmovq(mask_m, rbx); + // mask_m set (AVX512 only), can use rax and rbx again + + // setup register of ones when VNNI instructions not available + if (!use_vnni) { + vmovdqu16(one, ptr[rip + one_label]); + } + + // M loop + // base pointer for A rax contains a + i * lda + // Loop stop when rax >= a + (m & mask_um) * lda = rbx + // loop increment r10 = um * lda + // rbp = Y + i + mov(rax, A); // i = 0 + mov(rbx, m); + and_(rbx, mask_um); + imul(rbx, lda); + add(rbx, A); + mov(r10, lda); + sal(r10, unroll_m); + mov(rbp, Y); + + // N loop + // base pointer for X r11 contains x + j + // Loop stop when r11 >= x + n & mask_un = r12 + // loop increment un + // r13 = rax + j = A + i * lda + j + mov(r12, n); + and_(r12, mask_un); + add(r12, X); + + // M loop + aligned_label(m_loop_label); + cmp(rax, rbx); + jge(m_tail_label, T_NEAR); + + // enter M loop + for(i = 0; i < nreg_acc; i++) { + vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A)); + } + + // N loop + mov(r11, X); // j = 0 + mov(r13, rax); + aligned_label(n_loop_label); + cmp(r11, r12); + jge(n_tail_label, T_NEAR); + + // enter N loop + + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc, + r13, lda, r11, tmp, one, swap, use_vnni, 0, mask_n); + + // increment rax with un + add(r11, 1 << unroll_n); + add(r13, 1 << unroll_n); + jmp(n_loop_label, T_NEAR); + // end N loop + + // N tail + aligned_label(n_tail_label); + + ktestq(mask_n, k3); + je(update_c_label, T_NEAR); + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc, + r13, lda, r11, tmp, one, swap, use_vnni, 1, mask_n); + + // update C matrix + aligned_label(update_c_label); + + update_c(nreg_acc, rbp, zmm_idx, zmm_idx + nreg_A, beta, 0, mask_m); + + // increment rax with um * lda + add(rax, r10); + add(rbp, 1 << (unroll_m + 2)); + jmp(m_loop_label, T_NEAR); + // end M loop + + // M tail + aligned_label(m_tail_label); + + // r10 will contain m_tail = m % unroll_m = m & (1 << unroll_m) - 1 + mov(r10, m); + and_(r10, (1 << unroll_m) - 1); + for (ii = 1; ii < 1 << unroll_m; ii++) { + aligned_label(m_tail_label_case[ii-1]); + cmp(r10, ii); + if (ii == (1 << unroll_m) - 1) + jne(end_label, T_NEAR); + else + jne(m_tail_label_case[ii], T_NEAR); + + // m_tail = i, use i accumulators + + for(i = 0; i < ii; i++) { + vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A)); + } + + // N loop + mov(r11, X); // j = 0 + mov(r13, rax); + aligned_label(n_loop_label_case[ii - 1]); + cmp(r11, r12); + jge(n_tail_label_case[ii - 1], T_NEAR); + + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13, + lda, r11, tmp, one, swap, use_vnni, 0, mask_n); + + // increment rax with un + add(r11, 1 << unroll_n); + add(r13, 1 << unroll_n); + jmp(n_loop_label_case[ii - 1], T_NEAR); + // end N loop + + // N tail + aligned_label(n_tail_label_case[ii - 1]); + ktestq(mask_n, k3); + je(update_c_label_case[ii - 1], T_NEAR); + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13, + lda, r11, tmp, one, swap, use_vnni, 1, mask_n); + + // update C matrix + aligned_label(update_c_label_case[ii - 1]); + update_c(ii, rbp, zmm_idx, zmm_idx + nreg_A, beta, 1, mask_m); + + if (ii < ((1 << unroll_m) - 1)) + jmp(end_label, T_NEAR); + } + + aligned_label(end_label); + + postamble(); + + if (!use_vnni) { + aligned_label(one_label); + for (i = 0; i < size_vec_reg/8; i++) + dq(0x0001000100010001); + } + + return (T) getCode(); +} + +template jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t +jit_avx512_core_gemv_s8u8s32_kern::generate(int); + +template jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t +jit_avx512_core_gemv_s8u8s32_kern::generate(int); + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp new file mode 100644 index 000000000000..9ea23a5f56af --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp @@ -0,0 +1,64 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +class jit_avx512_core_gemv_s8u8s32_kern : jit_generator { + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemv_s8u8s32_kern); + + // assumes untoll_{m,n} are a power of 2 + static constexpr unsigned int unroll_m = 4; // real unrolling factor is 2^unroll_m + const int mask_um = 0xFFFFFFF0; + static constexpr unsigned int unroll_n = 6; // real unrolling factor is 2^unroll_n + const int mask_un = 0xFFFFFFC0; + const int size_vec_reg = 64; // bytes + + void aligned_label(Xbyak::Label &label, int alignment = 16) { + align(alignment); + L(label); + } + + void vnni(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, bool, int); + void n_loop_body(int, int, int, int, Xbyak::Reg64, Xbyak::Reg64, + Xbyak::Reg64, Xbyak::Zmm, Xbyak::Zmm, bool, int, int, Xbyak::Opmask); + void shuffle_and_add(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm); + void update_c(int, Xbyak::Reg64, int, int, Xbyak::Xmm, int, Xbyak::Opmask); + +public: + jit_avx512_core_gemv_s8u8s32_kern() : jit_generator(nullptr, GEMM_CODE_SIZE) {}; + + // m, n, alpha, a, lda, x, beta, y + typedef void (*gemv_s8u8s32_kernel_t)(const dim_t, const dim_t, const float, + const int8_t*, const dim_t, const uint8_t*, + const float, int32_t*); + typedef void (*gemv_u8s8s32_kernel_t)(const dim_t, const dim_t, const float, + const uint8_t*, const dim_t, const int8_t*, + const float, int32_t*); + + template + T generate(int use_vnni); + +}; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp new file mode 100644 index 000000000000..544cd2ff25cd --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp @@ -0,0 +1,819 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_an_kern::jit_avx512_core_u8_copy_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l170; +Xbyak::Label l1f0; +Xbyak::Label l20; +Xbyak::Label l224; +Xbyak::Label l234; +Xbyak::Label l240; +Xbyak::Label l254; +Xbyak::Label l32c; +Xbyak::Label l34; +Xbyak::Label l388; +Xbyak::Label l3b0; +Xbyak::Label l3c0; +Xbyak::Label l3cc; +Xbyak::Label l3dc; +Xbyak::Label l454; +Xbyak::Label l48c; +Xbyak::Label l4a8; +Xbyak::Label l4b8; +Xbyak::Label l4c4; +Xbyak::Label l4d8; +Xbyak::Label l570; +Xbyak::Label l5c4; +Xbyak::Label l5f0; +Xbyak::Label l60c; +Xbyak::Label l61c; +Xbyak::Label l628; +Xbyak::Label l638; +Xbyak::Label l6b0; +Xbyak::Label l6f4; +Xbyak::Label l720; +Xbyak::Label l73c; +Xbyak::Label l74c; +Xbyak::Label l758; +Xbyak::Label l76c; +Xbyak::Label l804; +Xbyak::Label l858; +Xbyak::Label l88c; +Xbyak::Label l8a4; +Xbyak::Label l8b2; +Xbyak::Label l8bc; +Xbyak::Label l8cc; +Xbyak::Label l944; +Xbyak::Label l98c; +Xbyak::Label l9b0; +Xbyak::Label l9c8; +Xbyak::Label l9d8; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x30); + jl(l234, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x30); + mov(I, M); + sar(I, 0x2); + jle(l170, T_NEAR); + align(4); + +L(l34); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + movdqu(xmm0, xword[A1-0x70]); + movdqu(xmm1, xword[A1+LDA*1-0x70]); + movdqu(xmm2, xword[A1+LDA*2-0x70]); + movdqu(xmm3, xword[A1+LDA3*1-0x70]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B-0x30], xmm1); + movdqu(xword[B-0x20], xmm4); + movdqu(xword[B-0x10], xmm2); + movdqu(xmm0, xword[A1-0x60]); + movdqu(xmm1, xword[A1+LDA*1-0x60]); + movdqu(xmm2, xword[A1+LDA*2-0x60]); + movdqu(xmm3, xword[A1+LDA3*1-0x60]); + lea(A1, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B], xmm0); + movdqu(xword[B+0x10], xmm1); + movdqu(xword[B+0x20], xmm4); + movdqu(xword[B+0x30], xmm2); + sub(B, -192); + dec(I); + jg(l34, T_NEAR); + align(4); + +L(l170); + test(M, 0x2); + jle(l1f0, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + movdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + movdqu(xmm3, xword[A1-0x80]); + movdqu(xmm4, xword[A1-0x70]); + movdqu(xmm5, xword[A1-0x60]); + add(A1, LDA); + movdqa(xmm6, xmm0); + punpcklbw(xmm0, xmm3); + punpckhbw(xmm6, xmm3); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm6); + movdqa(xmm6, xmm1); + punpcklbw(xmm1, xmm4); + punpckhbw(xmm6, xmm4); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x50], xmm6); + movdqa(xmm6, xmm2); + punpcklbw(xmm2, xmm5); + punpckhbw(xmm6, xmm5); + movdqu(xword[B-0x40], xmm2); + movdqu(xword[B-0x30], xmm6); + sub(B, -96); + align(4); + +L(l1f0); + test(M, 0x1); + jle(l224, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + movdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm2); + sub(B, -48); + align(4); + +L(l224); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + align(4); + +L(l234); + cmp(N, 0x20); + jl(l3c0, T_NEAR); + align(4); + +L(l240); + mov(A1, A); + add(A, 0x20); + mov(I, M); + sar(I, 0x2); + jle(l32c, T_NEAR); + align(4); + +L(l254); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + movdqu(xmm0, xword[A1-0x70]); + movdqu(xmm1, xword[A1+LDA*1-0x70]); + movdqu(xmm2, xword[A1+LDA*2-0x70]); + movdqu(xmm3, xword[A1+LDA3*1-0x70]); + lea(A1, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B-0x30], xmm1); + movdqu(xword[B-0x20], xmm4); + movdqu(xword[B-0x10], xmm2); + sub(B, -128); + dec(I); + jg(l254, T_NEAR); + align(4); + +L(l32c); + test(M, 0x2); + jle(l388, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + movdqu(xmm3, xword[A1-0x70]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm2); + punpckhbw(xmm4, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm4); + movdqa(xmm4, xmm1); + punpcklbw(xmm1, xmm3); + punpckhbw(xmm4, xmm3); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x50], xmm4); + sub(B, -64); + align(4); + +L(l388); + test(M, 0x1); + jle(l3b0, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l3b0); + sub(N, 0x20); + cmp(N, 0x20); + jge(l240, T_NEAR); + align(4); + +L(l3c0); + cmp(N, 0x10); + jl(l4b8, T_NEAR); + align(4); + +L(l3cc); + mov(A1, A); + add(A, 0x10); + mov(I, M); + sar(I, 0x2); + jle(l454, T_NEAR); + align(4); + +L(l3dc); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm3, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm1, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm1, xmm3); + movdqa(xmm3, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm3, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm1); + punpckhwd(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm3); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + sub(B, -64); + dec(I); + jg(l3dc, T_NEAR); + align(4); + +L(l454); + test(M, 0x2); + jle(l48c, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm2, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + align(4); + +L(l48c); + test(M, 0x1); + jle(l4a8, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l4a8); + sub(N, 0x10); + cmp(N, 0x10); + jge(l3cc, T_NEAR); + align(4); + +L(l4b8); + cmp(N, 0x8); + jl(l61c, T_NEAR); + align(4); + +L(l4c4); + mov(A1, A); + add(A, 0x8); + mov(I, M); + sar(I, 0x3); + jle(l570, T_NEAR); + align(4); + +L(l4d8); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(l4d8, T_NEAR); + align(4); + +L(l570); + test(M, 0x4); + jle(l5c4, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l5c4); + test(M, 0x2); + jle(l5f0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l5f0); + test(M, 0x1); + jle(l60c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l60c); + sub(N, 0x8); + cmp(N, 0x8); + jge(l4c4, T_NEAR); + align(4); + +L(l61c); + cmp(N, 0x4); + jl(l74c, T_NEAR); + align(4); + +L(l628); + mov(A1, A); + add(A, 0x4); + mov(I, M); + sar(I, 0x3); + jle(l6b0, T_NEAR); + align(4); + +L(l638); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(l638, T_NEAR); + align(4); + +L(l6b0); + test(M, 0x4); + jle(l6f4, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l6f4); + test(M, 0x2); + jle(l720, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l720); + test(M, 0x1); + jle(l73c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l73c); + sub(N, 0x4); + cmp(N, 0x4); + jge(l628, T_NEAR); + align(4); + +L(l74c); + cmp(N, 0x2); + jl(l8b2, T_NEAR); + align(4); + +L(l758); + mov(A1, A); + add(A, 0x2); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l804, T_NEAR); + align(4); + +L(l76c); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(l76c, T_NEAR); + align(4); + +L(l804); + test(M, 0x4); + jle(l858, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l858); + test(M, 0x2); + jle(l88c, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l88c); + test(M, 0x1); + jle(l8a4, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l8a4); + sub(N, 0x2); + cmp(N, 0x2); + jge(l758, T_NEAR); + align(4); + +L(l8b2); + cmp(N, 0x1); + jl(l9d8, T_NEAR); + align(4); + +L(l8bc); + mov(A1, A); + add(A, 0x1); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l944, T_NEAR); + align(4); + +L(l8cc); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l8cc, T_NEAR); + align(4); + +L(l944); + test(M, 0x4); + jle(l98c, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l98c); + test(M, 0x2); + jle(l9b0, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l9b0); + test(M, 0x1); + jle(l9c8, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l9c8); + sub(N, 0x1); + cmp(N, 0x1); + jge(l8bc, T_NEAR); + align(4); + +L(l9d8); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp new file mode 100644 index 000000000000..1c11fc6cef13 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp @@ -0,0 +1,2209 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_at_kern::jit_avx512_core_u8_copy_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l1014; +Xbyak::Label l1390; +Xbyak::Label l159c; +Xbyak::Label l173c; +Xbyak::Label l18e4; +Xbyak::Label l1a7c; +Xbyak::Label l1a8c; +Xbyak::Label l1a98; +Xbyak::Label l1ab4; +Xbyak::Label l1c64; +Xbyak::Label l1d74; +Xbyak::Label l1e50; +Xbyak::Label l1f2c; +Xbyak::Label l1ffc; +Xbyak::Label l20; +Xbyak::Label l200c; +Xbyak::Label l2018; +Xbyak::Label l2034; +Xbyak::Label l2110; +Xbyak::Label l21a0; +Xbyak::Label l2210; +Xbyak::Label l2284; +Xbyak::Label l22f0; +Xbyak::Label l2300; +Xbyak::Label l230c; +Xbyak::Label l2324; +Xbyak::Label l2398; +Xbyak::Label l23e8; +Xbyak::Label l242c; +Xbyak::Label l2474; +Xbyak::Label l24b4; +Xbyak::Label l24c4; +Xbyak::Label l24d0; +Xbyak::Label l24e8; +Xbyak::Label l2520; +Xbyak::Label l254c; +Xbyak::Label l2578; +Xbyak::Label l25a8; +Xbyak::Label l25c8; +Xbyak::Label l25d6; +Xbyak::Label l25e0; +Xbyak::Label l25f0; +Xbyak::Label l260c; +Xbyak::Label l262c; +Xbyak::Label l264c; +Xbyak::Label l2668; +Xbyak::Label l2680; +Xbyak::Label l2690; +Xbyak::Label l44; +Xbyak::Label l58c; +Xbyak::Label l8b0; +Xbyak::Label lb14; +Xbyak::Label ld84; +Xbyak::Label lfdc; +Xbyak::Label lfec; +Xbyak::Label lff8; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x30); + jl(lfec, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + lea(I, ptr[I+LDA*8]); + lea(I, ptr[I+LDA*8]); + add(A, I); + mov(I, M); + sar(I, 0x4); + jle(l58c, T_NEAR); + align(4); + +L(l44); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B+0x40], xmm1); + movdqu(xword[B+0x100], xmm4); + movdqu(xword[B+0x1c0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x50], xmm1); + movdqu(xword[B+0x110], xmm4); + movdqu(xword[B+0x1d0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x60], xmm1); + movdqu(xword[B+0x120], xmm4); + movdqu(xword[B+0x1e0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x70], xmm1); + movdqu(xword[B+0x130], xmm4); + movdqu(xword[B+0x1f0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x80], xmm1); + movdqu(xword[B+0x140], xmm4); + movdqu(xword[B+0x200], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x90], xmm1); + movdqu(xword[B+0x150], xmm4); + movdqu(xword[B+0x210], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0xa0], xmm1); + movdqu(xword[B+0x160], xmm4); + movdqu(xword[B+0x220], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0xb0], xmm1); + movdqu(xword[B+0x170], xmm4); + movdqu(xword[B+0x230], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B], xmm0); + movdqu(xword[B+0xc0], xmm1); + movdqu(xword[B+0x180], xmm4); + movdqu(xword[B+0x240], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B+0x10], xmm0); + movdqu(xword[B+0xd0], xmm1); + movdqu(xword[B+0x190], xmm4); + movdqu(xword[B+0x250], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B+0x20], xmm0); + movdqu(xword[B+0xe0], xmm1); + movdqu(xword[B+0x1a0], xmm4); + movdqu(xword[B+0x260], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B+0x30], xmm0); + movdqu(xword[B+0xf0], xmm1); + movdqu(xword[B+0x1b0], xmm4); + movdqu(xword[B+0x270], xmm3); + sub(A1, -16); + sub(B, -768); + dec(I); + jg(l44, T_NEAR); + align(4); + +L(l58c); + test(M, 0x8); + jle(l8b0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B+0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x50], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x70], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x80], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x90], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0xa0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0xb0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B], xmm0); + movdqu(xword[B+0xc0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B+0x10], xmm0); + movdqu(xword[B+0xd0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B+0x20], xmm0); + movdqu(xword[B+0xe0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B+0x30], xmm0); + movdqu(xword[B+0xf0], xmm1); + sub(A1, -8); + sub(B, -384); + align(4); + +L(l8b0); + test(M, 0x4); + jle(lb14, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x50], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x40], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x30], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x10], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B+0x10], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B+0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B+0x30], xmm0); + sub(A1, -4); + sub(B, -192); + align(4); + +L(lb14); + test(M, 0x2); + jle(ld84, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x50], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x40], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x30], xmm0); + sub(A1, -2); + sub(B, -96); + align(4); + +L(ld84); + test(M, 0x1); + jle(lfdc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x70], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x60], xmm0); + sub(B, -48); + align(4); + +L(lfdc); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + align(4); + +L(lfec); + cmp(N, 0x20); + jl(l1a8c, T_NEAR); + align(4); + +L(lff8); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + add(A, I); + mov(I, M); + sar(I, 0x4); + jle(l1390, T_NEAR); + align(4); + +L(l1014); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B], xmm1); + movdqu(xword[B+0x80], xmm4); + movdqu(xword[B+0x100], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x10], xmm1); + movdqu(xword[B+0x90], xmm4); + movdqu(xword[B+0x110], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x20], xmm1); + movdqu(xword[B+0xa0], xmm4); + movdqu(xword[B+0x120], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x30], xmm1); + movdqu(xword[B+0xb0], xmm4); + movdqu(xword[B+0x130], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x40], xmm1); + movdqu(xword[B+0xc0], xmm4); + movdqu(xword[B+0x140], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x50], xmm1); + movdqu(xword[B+0xd0], xmm4); + movdqu(xword[B+0x150], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0x60], xmm1); + movdqu(xword[B+0xe0], xmm4); + movdqu(xword[B+0x160], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0x70], xmm1); + movdqu(xword[B+0xf0], xmm4); + movdqu(xword[B+0x170], xmm3); + sub(A1, -16); + sub(B, -512); + dec(I); + jg(l1014, T_NEAR); + align(4); + +L(l1390); + test(M, 0x8); + jle(l159c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x10], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x50], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0x70], xmm1); + sub(A1, -8); + sub(B, -256); + align(4); + +L(l159c); + test(M, 0x4); + jle(l173c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x50], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x40], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x30], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x10], xmm0); + sub(A1, -4); + sub(B, -128); + align(4); + +L(l173c); + test(M, 0x2); + jle(l18e4, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x50], xmm0); + sub(A1, -2); + sub(B, -64); + align(4); + +L(l18e4); + test(M, 0x1); + jle(l1a7c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l1a7c); + sub(N, 0x20); + cmp(N, 0x20); + jge(lff8, T_NEAR); + align(4); + +L(l1a8c); + cmp(N, 0x10); + jl(l200c, T_NEAR); + align(4); + +L(l1a98); + mov(A1, A); + mov(I, LDA); + shl(I, 0x4); + add(A, I); + mov(I, M); + sar(I, 0x4); + jle(l1c64, T_NEAR); + align(4); + +L(l1ab4); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x40], xmm1); + movdqu(xword[B], xmm4); + movdqu(xword[B+0x40], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x30], xmm1); + movdqu(xword[B+0x10], xmm4); + movdqu(xword[B+0x50], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x20], xmm1); + movdqu(xword[B+0x20], xmm4); + movdqu(xword[B+0x60], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B-0x10], xmm1); + movdqu(xword[B+0x30], xmm4); + movdqu(xword[B+0x70], xmm3); + sub(A1, -16); + sub(B, -256); + dec(I); + jg(l1ab4, T_NEAR); + align(4); + +L(l1c64); + test(M, 0x8); + jle(l1d74, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B-0x10], xmm1); + sub(A1, -8); + sub(B, -128); + align(4); + +L(l1d74); + test(M, 0x4); + jle(l1e50, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x50], xmm0); + sub(A1, -4); + sub(B, -64); + align(4); + +L(l1e50); + test(M, 0x2); + jle(l1f2c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x70], xmm0); + sub(A1, -2); + sub(B, -32); + align(4); + +L(l1f2c); + test(M, 0x1); + jle(l1ffc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l1ffc); + sub(N, 0x10); + cmp(N, 0x10); + jge(l1a98, T_NEAR); + align(4); + +L(l200c); + cmp(N, 0x8); + jl(l2300, T_NEAR); + align(4); + +L(l2018); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l2110, T_NEAR); + align(4); + +L(l2034); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x40], xmm4); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + movdqu(xword[B-0x30], xmm4); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l2034, T_NEAR); + align(4); + +L(l2110); + test(M, 0x8); + jle(l21a0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l21a0); + test(M, 0x4); + jle(l2210, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l2210); + test(M, 0x2); + jle(l2284, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l2284); + test(M, 0x1); + jle(l22f0, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l22f0); + sub(N, 0x8); + cmp(N, 0x8); + jge(l2018, T_NEAR); + align(4); + +L(l2300); + cmp(N, 0x4); + jl(l24c4, T_NEAR); + align(4); + +L(l230c); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l2398, T_NEAR); + align(4); + +L(l2324); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l2324, T_NEAR); + align(4); + +L(l2398); + test(M, 0x8); + jle(l23e8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l23e8); + test(M, 0x4); + jle(l242c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l242c); + test(M, 0x2); + jle(l2474, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l2474); + test(M, 0x1); + jle(l24b4, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l24b4); + sub(N, 0x4); + cmp(N, 0x4); + jge(l230c, T_NEAR); + align(4); + +L(l24c4); + cmp(N, 0x2); + jl(l25d6, T_NEAR); + align(4); + +L(l24d0); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l2520, T_NEAR); + align(4); + +L(l24e8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l24e8, T_NEAR); + align(4); + +L(l2520); + test(M, 0x8); + jle(l254c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l254c); + test(M, 0x4); + jle(l2578, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l2578); + test(M, 0x2); + jle(l25a8, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l25a8); + test(M, 0x1); + jle(l25c8, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l25c8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l24d0, T_NEAR); + align(4); + +L(l25d6); + cmp(N, 0x1); + jl(l2690, T_NEAR); + align(4); + +L(l25e0); + mov(A1, A); + add(A, LDA); + mov(I, M); + sar(I, 0x4); + jle(l260c, T_NEAR); + align(4); + +L(l25f0); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(l25f0, T_NEAR); + align(4); + +L(l260c); + test(M, 0x8); + jle(l262c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l262c); + test(M, 0x4); + jle(l264c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l264c); + test(M, 0x2); + jle(l2668, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(l2668); + test(M, 0x1); + jle(l2680, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l2680); + sub(N, 0x1); + cmp(N, 0x1); + jge(l25e0, T_NEAR); + align(4); + +L(l2690); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp new file mode 100644 index 000000000000..56c36ee14a6e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp @@ -0,0 +1,564 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_bn_kern::jit_avx512_core_u8_copy_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l118; +Xbyak::Label l1a8; +Xbyak::Label l20; +Xbyak::Label l218; +Xbyak::Label l28c; +Xbyak::Label l2f8; +Xbyak::Label l308; +Xbyak::Label l314; +Xbyak::Label l32c; +Xbyak::Label l3a0; +Xbyak::Label l3c; +Xbyak::Label l3f0; +Xbyak::Label l434; +Xbyak::Label l47c; +Xbyak::Label l4bc; +Xbyak::Label l4cc; +Xbyak::Label l4d8; +Xbyak::Label l4f0; +Xbyak::Label l528; +Xbyak::Label l554; +Xbyak::Label l580; +Xbyak::Label l5b0; +Xbyak::Label l5d0; +Xbyak::Label l5de; +Xbyak::Label l5e8; +Xbyak::Label l5f8; +Xbyak::Label l614; +Xbyak::Label l634; +Xbyak::Label l654; +Xbyak::Label l670; +Xbyak::Label l688; +Xbyak::Label l698; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x8); + jl(l308, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l118, T_NEAR); + align(4); + +L(l3c); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x40], xmm4); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + movdqu(xword[B-0x30], xmm4); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l3c, T_NEAR); + align(4); + +L(l118); + test(M, 0x8); + jle(l1a8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l1a8); + test(M, 0x4); + jle(l218, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l218); + test(M, 0x2); + jle(l28c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l28c); + test(M, 0x1); + jle(l2f8, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l2f8); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l308); + cmp(N, 0x4); + jl(l4cc, T_NEAR); + align(4); + +L(l314); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l3a0, T_NEAR); + align(4); + +L(l32c); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l32c, T_NEAR); + align(4); + +L(l3a0); + test(M, 0x8); + jle(l3f0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l3f0); + test(M, 0x4); + jle(l434, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l434); + test(M, 0x2); + jle(l47c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l47c); + test(M, 0x1); + jle(l4bc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l4bc); + sub(N, 0x4); + cmp(N, 0x4); + jge(l314, T_NEAR); + align(4); + +L(l4cc); + cmp(N, 0x2); + jl(l5de, T_NEAR); + align(4); + +L(l4d8); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l528, T_NEAR); + align(4); + +L(l4f0); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l4f0, T_NEAR); + align(4); + +L(l528); + test(M, 0x8); + jle(l554, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l554); + test(M, 0x4); + jle(l580, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l580); + test(M, 0x2); + jle(l5b0, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l5b0); + test(M, 0x1); + jle(l5d0, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l5d0); + sub(N, 0x2); + cmp(N, 0x2); + jge(l4d8, T_NEAR); + align(4); + +L(l5de); + cmp(N, 0x1); + jl(l698, T_NEAR); + align(4); + +L(l5e8); + mov(A1, A); + add(A, LDA); + mov(I, M); + sar(I, 0x4); + jle(l614, T_NEAR); + align(4); + +L(l5f8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(l5f8, T_NEAR); + align(4); + +L(l614); + test(M, 0x8); + jle(l634, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l634); + test(M, 0x4); + jle(l654, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l654); + test(M, 0x2); + jle(l670, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(l670); + test(M, 0x1); + jle(l688, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l688); + sub(N, 0x1); + cmp(N, 0x1); + jge(l5e8, T_NEAR); + align(4); + +L(l698); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp new file mode 100644 index 000000000000..53e99d94de18 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp @@ -0,0 +1,501 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_bt_kern::jit_avx512_core_u8_copy_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l120; +Xbyak::Label l14c; +Xbyak::Label l168; +Xbyak::Label l178; +Xbyak::Label l184; +Xbyak::Label l194; +Xbyak::Label l20; +Xbyak::Label l20c; +Xbyak::Label l250; +Xbyak::Label l27c; +Xbyak::Label l298; +Xbyak::Label l2a8; +Xbyak::Label l2b4; +Xbyak::Label l2c8; +Xbyak::Label l34; +Xbyak::Label l360; +Xbyak::Label l3b4; +Xbyak::Label l3e8; +Xbyak::Label l400; +Xbyak::Label l40e; +Xbyak::Label l418; +Xbyak::Label l428; +Xbyak::Label l4a0; +Xbyak::Label l4e8; +Xbyak::Label l50c; +Xbyak::Label l524; +Xbyak::Label l534; +Xbyak::Label lcc; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x8); + jl(l178, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x8); + mov(I, M); + sar(I, 0x3); + jle(lcc, T_NEAR); + align(4); + +L(l34); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(l34, T_NEAR); + align(4); + +L(lcc); + test(M, 0x4); + jle(l120, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l120); + test(M, 0x2); + jle(l14c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l14c); + test(M, 0x1); + jle(l168, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l168); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l178); + cmp(N, 0x4); + jl(l2a8, T_NEAR); + align(4); + +L(l184); + mov(A1, A); + add(A, 0x4); + mov(I, M); + sar(I, 0x3); + jle(l20c, T_NEAR); + align(4); + +L(l194); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(l194, T_NEAR); + align(4); + +L(l20c); + test(M, 0x4); + jle(l250, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l250); + test(M, 0x2); + jle(l27c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l27c); + test(M, 0x1); + jle(l298, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l298); + sub(N, 0x4); + cmp(N, 0x4); + jge(l184, T_NEAR); + align(4); + +L(l2a8); + cmp(N, 0x2); + jl(l40e, T_NEAR); + align(4); + +L(l2b4); + mov(A1, A); + add(A, 0x2); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l360, T_NEAR); + align(4); + +L(l2c8); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(l2c8, T_NEAR); + align(4); + +L(l360); + test(M, 0x4); + jle(l3b4, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l3b4); + test(M, 0x2); + jle(l3e8, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l3e8); + test(M, 0x1); + jle(l400, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l400); + sub(N, 0x2); + cmp(N, 0x2); + jge(l2b4, T_NEAR); + align(4); + +L(l40e); + cmp(N, 0x1); + jl(l534, T_NEAR); + align(4); + +L(l418); + mov(A1, A); + add(A, 0x1); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l4a0, T_NEAR); + align(4); + +L(l428); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l428, T_NEAR); + align(4); + +L(l4a0); + test(M, 0x4); + jle(l4e8, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l4e8); + test(M, 0x2); + jle(l50c, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l50c); + test(M, 0x1); + jle(l524, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l524); + sub(N, 0x1); + cmp(N, 0x1); + jge(l418, T_NEAR); + align(4); + +L(l534); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp new file mode 100644 index 000000000000..49a312fc8826 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp @@ -0,0 +1,1283 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_an_kern::jit_avx512_core_u8_copy_sum_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l1024; +Xbyak::Label l1090; +Xbyak::Label l10d4; +Xbyak::Label l10fc; +Xbyak::Label l111a; +Xbyak::Label l1124; +Xbyak::Label l113c; +Xbyak::Label l11d4; +Xbyak::Label l1234; +Xbyak::Label l1278; +Xbyak::Label l129c; +Xbyak::Label l12bc; +Xbyak::Label l20; +Xbyak::Label l2a0; +Xbyak::Label l3c0; +Xbyak::Label l438; +Xbyak::Label l480; +Xbyak::Label l48c; +Xbyak::Label l4c8; +Xbyak::Label l5c; +Xbyak::Label l6a8; +Xbyak::Label l7b4; +Xbyak::Label l850; +Xbyak::Label l89c; +Xbyak::Label l8a8; +Xbyak::Label l8d0; +Xbyak::Label l9d0; +Xbyak::Label la64; +Xbyak::Label lab8; +Xbyak::Label lae8; +Xbyak::Label laf4; +Xbyak::Label lb14; +Xbyak::Label lc30; +Xbyak::Label lcc8; +Xbyak::Label ld1c; +Xbyak::Label ld54; +Xbyak::Label ld78; +Xbyak::Label ld84; +Xbyak::Label ld9c; +Xbyak::Label le58; +Xbyak::Label lebc; +Xbyak::Label lef8; +Xbyak::Label lf1c; +Xbyak::Label lf3c; +Xbyak::Label lf48; +Xbyak::Label lf60; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x30); + jl(l480, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x30); + vxorps(ymm8, ymm8, ymm8); + vxorps(ymm9, ymm9, ymm9); + vxorps(ymm10, ymm10, ymm10); + vxorps(ymm11, ymm11, ymm11); + vxorps(ymm12, ymm12, ymm12); + vxorps(ymm13, ymm13, ymm13); + vxorps(ymm14, ymm14, ymm14); + vxorps(ymm15, ymm15, ymm15); + mov(I, M); + sar(I, 0x2); + jle(l2a0, T_NEAR); + align(4); + +L(l5c); + vmovdqu(xmm0, xword[A1-0x80]); + vmovdqu(xmm1, xword[A1+LDA*1-0x80]); + vmovdqu(xmm2, xword[A1+LDA*2-0x80]); + vmovdqu(xmm3, xword[A1+LDA3*1-0x80]); + vpunpcklbw(xmm4, xmm0, xmm1); + vpunpckhbw(xmm5, xmm0, xmm1); + vpunpcklbw(xmm6, xmm2, xmm3); + vpunpckhbw(xmm7, xmm2, xmm3); + vpunpcklwd(xmm0, xmm4, xmm6); + vpunpckhwd(xmm1, xmm4, xmm6); + vpunpcklwd(xmm2, xmm5, xmm7); + vpunpckhwd(xmm3, xmm5, xmm7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovdqu(xword[B-0x80], xmm0); + vmovdqu(xword[B-0x70], xmm1); + vpmovsxbw(ymm5, xmm2); + vmovhlps(xmm6, xmm2, xmm2); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovdqu(xword[B-0x60], xmm2); + vmovdqu(xword[B-0x50], xmm3); + vmovdqu(xmm0, xword[A1-0x70]); + vmovdqu(xmm1, xword[A1+LDA*1-0x70]); + vmovdqu(xmm2, xword[A1+LDA*2-0x70]); + vmovdqu(xmm3, xword[A1+LDA3*1-0x70]); + vpunpcklbw(xmm4, xmm0, xmm1); + vpunpckhbw(xmm5, xmm0, xmm1); + vpunpcklbw(xmm6, xmm2, xmm3); + vpunpckhbw(xmm7, xmm2, xmm3); + vpunpcklwd(xmm0, xmm4, xmm6); + vpunpckhwd(xmm1, xmm4, xmm6); + vpunpcklwd(xmm2, xmm5, xmm7); + vpunpckhwd(xmm3, xmm5, xmm7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovdqu(xword[B-0x40], xmm0); + vmovdqu(xword[B-0x30], xmm1); + vpmovsxbw(ymm5, xmm2); + vmovhlps(xmm6, xmm2, xmm2); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovdqu(xword[B-0x20], xmm2); + vmovdqu(xword[B-0x10], xmm3); + vmovdqu(xmm0, xword[A1-0x60]); + vmovdqu(xmm1, xword[A1+LDA*1-0x60]); + vmovdqu(xmm2, xword[A1+LDA*2-0x60]); + vmovdqu(xmm3, xword[A1+LDA3*1-0x60]); + lea(A1, ptr[A1+LDA*4]); + vpunpcklbw(xmm4, xmm0, xmm1); + vpunpckhbw(xmm5, xmm0, xmm1); + vpunpcklbw(xmm6, xmm2, xmm3); + vpunpckhbw(xmm7, xmm2, xmm3); + vpunpcklwd(xmm0, xmm4, xmm6); + vpunpckhwd(xmm1, xmm4, xmm6); + vpunpcklwd(xmm2, xmm5, xmm7); + vpunpckhwd(xmm3, xmm5, xmm7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovdqu(xword[B], xmm0); + vmovdqu(xword[B+0x10], xmm1); + vpmovsxbw(ymm5, xmm2); + vmovhlps(xmm6, xmm2, xmm2); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + vmovdqu(xword[B+0x20], xmm2); + vmovdqu(xword[B+0x30], xmm3); + sub(B, -192); + dec(I); + jg(l5c, T_NEAR); + align(4); + +L(l2a0); + test(M, 0x2); + jle(l3c0, T_NEAR); + vmovdqu(xmm0, xword[A1-0x80]); + vmovdqu(xmm1, xword[A1-0x70]); + vmovdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + vmovdqu(xmm6, xword[A1-0x80]); + vmovdqu(xmm4, xword[A1-0x70]); + vmovdqu(xmm5, xword[A1-0x60]); + add(A1, LDA); + vpunpcklbw(xmm3, xmm0, xmm6); + vpunpckhbw(xmm0, xmm0, xmm6); + vpmovsxbw(ymm7, xmm3); + vmovhlps(xmm6, xmm3, xmm3); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm8, ymm8, ymm7); + vmovdqu(xword[B-0x80], xmm3); + vpmovsxbw(ymm7, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm9, ymm9, ymm7); + vmovdqu(xword[B-0x70], xmm0); + vpunpcklbw(xmm3, xmm1, xmm4); + vpunpckhbw(xmm0, xmm1, xmm4); + vpmovsxbw(ymm7, xmm3); + vmovhlps(xmm6, xmm3, xmm3); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm10, ymm10, ymm7); + vmovdqu(xword[B-0x60], xmm3); + vpmovsxbw(ymm7, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm11, ymm11, ymm7); + vmovdqu(xword[B-0x50], xmm0); + vpunpcklbw(xmm3, xmm2, xmm5); + vpunpckhbw(xmm0, xmm2, xmm5); + vpmovsxbw(ymm7, xmm3); + vmovhlps(xmm6, xmm3, xmm3); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm12, ymm12, ymm7); + vmovdqu(xword[B-0x40], xmm3); + vpmovsxbw(ymm7, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm13, ymm13, ymm7); + vmovdqu(xword[B-0x30], xmm0); + sub(B, -96); + align(4); + +L(l3c0); + test(M, 0x1); + jle(l438, T_NEAR); + vmovdqu(xmm0, xword[A1-0x80]); + vmovdqu(xmm1, xword[A1-0x70]); + vmovdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm8, ymm8, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm9, ymm9, ymm7); + vmovdqu(xword[B-0x80], xmm0); + vpmovsxbd(ymm7, xmm1); + vpaddd(ymm10, ymm10, ymm7); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm11, ymm11, ymm7); + vmovdqu(xword[B-0x70], xmm1); + vpmovsxbd(ymm7, xmm2); + vpaddd(ymm12, ymm12, ymm7); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm13, ymm13, ymm7); + vmovdqu(xword[B-0x60], xmm2); + sub(B, -48); + align(4); + +L(l438); + mov(A1, qword[ARG_BIAS]); + vmovdqu(yword[A1], ymm8); + vmovdqu(yword[A1+0x20], ymm9); + vmovdqu(yword[A1+0x40], ymm10); + vmovdqu(yword[A1+0x60], ymm11); + vmovdqu(yword[A1+0x80], ymm12); + vmovdqu(yword[A1+0xa0], ymm13); + add(qword[ARG_BIAS], 0xc0); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + vzeroupper(); + align(4); + +L(l480); + cmp(N, 0x20); + jl(l89c, T_NEAR); + align(4); + +L(l48c); + mov(A1, A); + add(A, 0x20); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + pxor(xmm12, xmm12); + pxor(xmm13, xmm13); + pxor(xmm14, xmm14); + pxor(xmm15, xmm15); + mov(I, M); + sar(I, 0x2); + jle(l6a8, T_NEAR); + align(4); + +L(l4c8); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm4); + pmovsxbw(xmm5, xmm2); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm2); + movdqu(xmm0, xword[A1-0x70]); + movdqu(xmm1, xword[A1+LDA*1-0x70]); + movdqu(xmm2, xword[A1+LDA*2-0x70]); + movdqu(xmm3, xword[A1+LDA3*1-0x70]); + lea(A1, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm4); + pmovsxbw(xmm5, xmm2); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm2); + sub(B, -128); + dec(I); + jg(l4c8, T_NEAR); + align(4); + +L(l6a8); + test(M, 0x2); + jle(l7b4, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + movdqu(xmm3, xword[A1-0x70]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm2); + punpckhbw(xmm4, xmm2); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm4); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x70], xmm4); + movdqa(xmm4, xmm1); + punpcklbw(xmm1, xmm3); + punpckhbw(xmm4, xmm3); + pmovsxbw(xmm5, xmm1); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm13, xmm6); + movdqu(xword[B-0x60], xmm1); + pmovsxbw(xmm5, xmm4); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x50], xmm4); + sub(B, -64); + align(4); + +L(l7b4); + test(M, 0x1); + jle(l850, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + pmovsxbd(xmm5, xmm1); + paddd(xmm12, xmm5); + pshufd(xmm6, xmm1, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm13, xmm6); + pshufd(xmm5, xmm1, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm14, xmm5); + pshufd(xmm6, xmm1, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l850); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + movdqu(xword[A1+0x40], xmm12); + movdqu(xword[A1+0x50], xmm13); + movdqu(xword[A1+0x60], xmm14); + movdqu(xword[A1+0x70], xmm15); + add(qword[ARG_BIAS], 0x80); + sub(N, 0x20); + cmp(N, 0x20); + jge(l48c, T_NEAR); + align(4); + +L(l89c); + cmp(N, 0x10); + jl(lae8, T_NEAR); + align(4); + +L(l8a8); + mov(A1, A); + add(A, 0x10); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + mov(I, M); + sar(I, 0x2); + jle(l9d0, T_NEAR); + align(4); + +L(l8d0); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm3, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm1, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm1, xmm3); + movdqa(xmm3, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm3, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm1); + punpckhwd(xmm2, xmm1); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm3); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + pmovsxbw(xmm5, xmm2); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + sub(B, -64); + dec(I); + jg(l8d0, T_NEAR); + align(4); + +L(l9d0); + test(M, 0x2); + jle(la64, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm2, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm2, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + pmovsxbw(xmm5, xmm2); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + align(4); + +L(la64); + test(M, 0x1); + jle(lab8, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(lab8); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + add(qword[ARG_BIAS], 0x40); + sub(N, 0x10); + cmp(N, 0x10); + jge(l8a8, T_NEAR); + align(4); + +L(lae8); + cmp(N, 0x8); + jl(ld78, T_NEAR); + align(4); + +L(laf4); + mov(A1, A); + add(A, 0x8); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x3); + jle(lc30, T_NEAR); + align(4); + +L(lb14); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(lb14, T_NEAR); + align(4); + +L(lc30); + test(M, 0x4); + jle(lcc8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(lcc8); + test(M, 0x2); + jle(ld1c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(ld1c); + test(M, 0x1); + jle(ld54, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(ld54); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(laf4, T_NEAR); + align(4); + +L(ld78); + cmp(N, 0x4); + jl(lf3c, T_NEAR); + align(4); + +L(ld84); + mov(A1, A); + add(A, 0x4); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x3); + jle(le58, T_NEAR); + align(4); + +L(ld9c); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(ld9c, T_NEAR); + align(4); + +L(le58); + test(M, 0x4); + jle(lebc, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(lebc); + test(M, 0x2); + jle(lef8, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(lef8); + test(M, 0x1); + jle(lf1c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(lf1c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(ld84, T_NEAR); + align(4); + +L(lf3c); + cmp(N, 0x2); + jl(l111a, T_NEAR); + align(4); + +L(lf48); + mov(A1, A); + add(A, 0x2); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l1024, T_NEAR); + align(4); + +L(lf60); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(lf60, T_NEAR); + align(4); + +L(l1024); + test(M, 0x4); + jle(l1090, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l1090); + test(M, 0x2); + jle(l10d4, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l10d4); + test(M, 0x1); + jle(l10fc, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l10fc); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(lf48, T_NEAR); + align(4); + +L(l111a); + cmp(N, 0x1); + jl(l12bc, T_NEAR); + align(4); + +L(l1124); + mov(A1, A); + add(A, 0x1); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l11d4, T_NEAR); + align(4); + +L(l113c); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l113c, T_NEAR); + align(4); + +L(l11d4); + test(M, 0x4); + jle(l1234, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l1234); + test(M, 0x2); + jle(l1278, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l1278); + test(M, 0x1); + jle(l129c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l129c); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(l1124, T_NEAR); + align(4); + +L(l12bc); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp new file mode 100644 index 000000000000..a4f4ff09c685 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp @@ -0,0 +1,3163 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_at_kern::jit_avx512_core_u8_copy_sum_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l1750; +Xbyak::Label l1b6c; +Xbyak::Label l1e14; +Xbyak::Label l20; +Xbyak::Label l2068; +Xbyak::Label l226c; +Xbyak::Label l22b8; +Xbyak::Label l22c4; +Xbyak::Label l22f4; +Xbyak::Label l26b4; +Xbyak::Label l28cc; +Xbyak::Label l2a2c; +Xbyak::Label l2b5c; +Xbyak::Label l2c64; +Xbyak::Label l2c94; +Xbyak::Label l2ca0; +Xbyak::Label l2cc8; +Xbyak::Label l2eac; +Xbyak::Label l2fc0; +Xbyak::Label l3078; +Xbyak::Label l3118; +Xbyak::Label l319c; +Xbyak::Label l31c0; +Xbyak::Label l31cc; +Xbyak::Label l31ec; +Xbyak::Label l32e4; +Xbyak::Label l3378; +Xbyak::Label l33dc; +Xbyak::Label l3434; +Xbyak::Label l347c; +Xbyak::Label l349c; +Xbyak::Label l34a8; +Xbyak::Label l34c8; +Xbyak::Label l3558; +Xbyak::Label l35b0; +Xbyak::Label l35f4; +Xbyak::Label l3638; +Xbyak::Label l366c; +Xbyak::Label l368a; +Xbyak::Label l3694; +Xbyak::Label l36a8; +Xbyak::Label l36ec; +Xbyak::Label l3728; +Xbyak::Label l3760; +Xbyak::Label l3794; +Xbyak::Label l37b8; +Xbyak::Label l37d8; +Xbyak::Label l5cc; +Xbyak::Label l6c; +Xbyak::Label l968; +Xbyak::Label lc80; +Xbyak::Label lf1c; +Xbyak::Label lf64; +Xbyak::Label lf70; +Xbyak::Label lfb4; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x30); + jl(lf64, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + lea(I, ptr[I+LDA*8]); + lea(I, ptr[I+LDA*8]); + add(A, I); + vxorps(ymm8, ymm8, ymm8); + vxorps(ymm9, ymm9, ymm9); + vxorps(ymm10, ymm10, ymm10); + vxorps(ymm11, ymm11, ymm11); + vxorps(ymm12, ymm12, ymm12); + vxorps(ymm13, ymm13, ymm13); + vxorps(ymm14, ymm14, ymm14); + vxorps(ymm15, ymm15, ymm15); + mov(I, M); + sar(I, 0x3); + jle(l5cc, T_NEAR); + align(4); + +L(l6c); + vmovq(xmm0, qword[A1-0x80]); + vmovq(xmm1, qword[A1+LDA*1-0x80]); + vmovq(xmm2, qword[A1+LDA*2-0x80]); + vmovq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x80], xmm0); + vmovdqu(xword[B+0x40], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x70], xmm2); + vmovdqu(xword[B+0x50], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x60], xmm0); + vmovdqu(xword[B+0x60], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x50], xmm2); + vmovdqu(xword[B+0x70], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x40], xmm0); + vmovdqu(xword[B+0x80], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x30], xmm2); + vmovdqu(xword[B+0x90], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x20], xmm0); + vmovdqu(xword[B+0xa0], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x10], xmm2); + vmovdqu(xword[B+0xb0], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B], xmm0); + vmovdqu(xword[B+0xc0], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B+0x10], xmm2); + vmovdqu(xword[B+0xd0], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B+0x20], xmm0); + vmovdqu(xword[B+0xe0], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B+0x30], xmm2); + vmovdqu(xword[B+0xf0], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + sub(A1, -8); + sub(B, -384); + dec(I); + jg(l6c, T_NEAR); + align(4); + +L(l5cc); + test(M, 0x4); + jle(l968, T_NEAR); + vmovd(xmm0, dword[A1-0x80]); + vmovd(xmm1, dword[A1+LDA*1-0x80]); + vmovd(xmm2, dword[A1+LDA*2-0x80]); + vmovd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x80], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x70], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x60], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x50], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x40], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x30], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x20], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x10], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B+0x10], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B+0x20], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B+0x30], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + sub(A1, -4); + sub(B, -192); + align(4); + +L(l968); + test(M, 0x2); + jle(lc80, T_NEAR); + mov(ax, word[A1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovdqu(xword[B-0x50], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovdqu(xword[B-0x40], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + vmovdqu(xword[B-0x30], xmm0); + sub(A1, -2); + sub(B, -96); + align(4); + +L(lc80); + test(M, 0x1); + jle(lf1c, T_NEAR); + mov(al, byte[A1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xf); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm8, ymm8, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm9, ymm9, ymm7); + vmovdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xf); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm10, ymm10, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm11, ymm11, ymm7); + vmovdqu(xword[B-0x70], xmm0); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xf); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm12, ymm12, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm13, ymm13, ymm7); + vmovdqu(xword[B-0x60], xmm0); + sub(B, -48); + align(4); + +L(lf1c); + mov(A1, qword[ARG_BIAS]); + vmovdqu(yword[A1], ymm8); + vmovdqu(yword[A1+0x20], ymm9); + vmovdqu(yword[A1+0x40], ymm10); + vmovdqu(yword[A1+0x60], ymm11); + vmovdqu(yword[A1+0x80], ymm12); + vmovdqu(yword[A1+0xa0], ymm13); + add(qword[ARG_BIAS], 0xc0); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + vzeroupper(); + align(4); + +L(lf64); + cmp(N, 0x20); + jl(l22b8, T_NEAR); + align(4); + +L(lf70); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + add(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + pxor(xmm12, xmm12); + pxor(xmm13, xmm13); + pxor(xmm14, xmm14); + pxor(xmm15, xmm15); + mov(I, M); + sar(I, 0x4); + jle(l1750, T_NEAR); + align(4); + +L(lfb4); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B+0x80], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B+0x100], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x10], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x90], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x110], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x20], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0xa0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x120], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x30], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0xb0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x130], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0x40], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0xc0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0x140], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0x50], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0xd0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0x150], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0x60], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0xe0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0x160], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0xf0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0x170], xmm3); + sub(A1, -16); + sub(B, -512); + dec(I); + jg(lfb4, T_NEAR); + align(4); + +L(l1750); + test(M, 0x8); + jle(l1b6c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x10], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0x50], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0x70], xmm1); + sub(A1, -8); + sub(B, -256); + align(4); + +L(l1b6c); + test(M, 0x4); + jle(l1e14, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm0); + sub(A1, -4); + sub(B, -128); + align(4); + +L(l1e14); + test(M, 0x2); + jle(l2068, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm13, xmm6); + movdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x50], xmm0); + sub(A1, -2); + sub(B, -64); + align(4); + +L(l2068); + test(M, 0x1); + jle(l226c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + pmovsxbd(xmm5, xmm0); + paddd(xmm12, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm13, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm14, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l226c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + movdqu(xword[A1+0x40], xmm12); + movdqu(xword[A1+0x50], xmm13); + movdqu(xword[A1+0x60], xmm14); + movdqu(xword[A1+0x70], xmm15); + add(qword[ARG_BIAS], 0x80); + sub(N, 0x20); + cmp(N, 0x20); + jge(lf70, T_NEAR); + align(4); + +L(l22b8); + cmp(N, 0x10); + jl(l2c94, T_NEAR); + align(4); + +L(l22c4); + mov(A1, A); + mov(I, LDA); + shl(I, 0x4); + add(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + mov(I, M); + sar(I, 0x4); + jle(l26b4, T_NEAR); + align(4); + +L(l22f4); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B+0x40], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x10], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x50], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x20], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x20], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x60], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x10], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x30], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x70], xmm3); + sub(A1, -16); + sub(B, -256); + dec(I); + jg(l22f4, T_NEAR); + align(4); + +L(l26b4); + test(M, 0x8); + jle(l28cc, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x10], xmm1); + sub(A1, -8); + sub(B, -128); + align(4); + +L(l28cc); + test(M, 0x4); + jle(l2a2c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + sub(A1, -4); + sub(B, -64); + align(4); + +L(l2a2c); + test(M, 0x2); + jle(l2b5c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x70], xmm0); + sub(A1, -2); + sub(B, -32); + align(4); + +L(l2b5c); + test(M, 0x1); + jle(l2c64, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0xf); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l2c64); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + add(qword[ARG_BIAS], 0x40); + sub(N, 0x10); + cmp(N, 0x10); + jge(l22c4, T_NEAR); + align(4); + +L(l2c94); + cmp(N, 0x8); + jl(l31c0, T_NEAR); + align(4); + +L(l2ca0); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x4); + jle(l2eac, T_NEAR); + align(4); + +L(l2cc8); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l2cc8, T_NEAR); + align(4); + +L(l2eac); + test(M, 0x8); + jle(l2fc0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l2fc0); + test(M, 0x4); + jle(l3078, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l3078); + test(M, 0x2); + jle(l3118, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l3118); + test(M, 0x1); + jle(l319c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l319c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(l2ca0, T_NEAR); + align(4); + +L(l31c0); + cmp(N, 0x4); + jl(l349c, T_NEAR); + align(4); + +L(l31cc); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l32e4, T_NEAR); + align(4); + +L(l31ec); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x60], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l31ec, T_NEAR); + align(4); + +L(l32e4); + test(M, 0x8); + jle(l3378, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l3378); + test(M, 0x4); + jle(l33dc, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l33dc); + test(M, 0x2); + jle(l3434, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l3434); + test(M, 0x1); + jle(l347c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l347c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(l31cc, T_NEAR); + align(4); + +L(l349c); + cmp(N, 0x2); + jl(l368a, T_NEAR); + align(4); + +L(l34a8); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l3558, T_NEAR); + align(4); + +L(l34c8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pshufd(xmm6, xmm2, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l34c8, T_NEAR); + align(4); + +L(l3558); + test(M, 0x8); + jle(l35b0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l35b0); + test(M, 0x4); + jle(l35f4, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l35f4); + test(M, 0x2); + jle(l3638, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l3638); + test(M, 0x1); + jle(l366c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(byte[B-0x7f], al); + sub(B, -2); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + align(4); + +L(l366c); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l34a8, T_NEAR); + align(4); + +L(l368a); + cmp(N, 0x1); + jl(l37d8, T_NEAR); + align(4); + +L(l3694); + mov(A1, A); + add(A, LDA); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l36ec, T_NEAR); + align(4); + +L(l36a8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(l36a8, T_NEAR); + align(4); + +L(l36ec); + test(M, 0x8); + jle(l3728, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l3728); + test(M, 0x4); + jle(l3760, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l3760); + test(M, 0x2); + jle(l3794, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(l3794); + test(M, 0x1); + jle(l37b8, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l37b8); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(l3694, T_NEAR); + align(4); + +L(l37d8); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp new file mode 100644 index 000000000000..c7f1393c9d85 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp @@ -0,0 +1,821 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_bn_kern::jit_avx512_core_u8_copy_sum_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l20; +Xbyak::Label l22c; +Xbyak::Label l340; +Xbyak::Label l3f8; +Xbyak::Label l48; +Xbyak::Label l498; +Xbyak::Label l51c; +Xbyak::Label l540; +Xbyak::Label l54c; +Xbyak::Label l56c; +Xbyak::Label l664; +Xbyak::Label l6f8; +Xbyak::Label l75c; +Xbyak::Label l7b4; +Xbyak::Label l7fc; +Xbyak::Label l81c; +Xbyak::Label l828; +Xbyak::Label l848; +Xbyak::Label l8d8; +Xbyak::Label l930; +Xbyak::Label l974; +Xbyak::Label l9b8; +Xbyak::Label l9ec; +Xbyak::Label la0a; +Xbyak::Label la14; +Xbyak::Label la28; +Xbyak::Label la6c; +Xbyak::Label laa8; +Xbyak::Label lae0; +Xbyak::Label lb14; +Xbyak::Label lb38; +Xbyak::Label lb58; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x8); + jl(l540, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x4); + jle(l22c, T_NEAR); + align(4); + +L(l48); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l48, T_NEAR); + align(4); + +L(l22c); + test(M, 0x8); + jle(l340, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l340); + test(M, 0x4); + jle(l3f8, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l3f8); + test(M, 0x2); + jle(l498, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l498); + test(M, 0x1); + jle(l51c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l51c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l540); + cmp(N, 0x4); + jl(l81c, T_NEAR); + align(4); + +L(l54c); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l664, T_NEAR); + align(4); + +L(l56c); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x60], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l56c, T_NEAR); + align(4); + +L(l664); + test(M, 0x8); + jle(l6f8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l6f8); + test(M, 0x4); + jle(l75c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l75c); + test(M, 0x2); + jle(l7b4, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l7b4); + test(M, 0x1); + jle(l7fc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l7fc); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(l54c, T_NEAR); + align(4); + +L(l81c); + cmp(N, 0x2); + jl(la0a, T_NEAR); + align(4); + +L(l828); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l8d8, T_NEAR); + align(4); + +L(l848); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pshufd(xmm6, xmm2, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l848, T_NEAR); + align(4); + +L(l8d8); + test(M, 0x8); + jle(l930, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l930); + test(M, 0x4); + jle(l974, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l974); + test(M, 0x2); + jle(l9b8, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l9b8); + test(M, 0x1); + jle(l9ec, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(byte[B-0x7f], al); + sub(B, -2); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + align(4); + +L(l9ec); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l828, T_NEAR); + align(4); + +L(la0a); + cmp(N, 0x1); + jl(lb58, T_NEAR); + align(4); + +L(la14); + mov(A1, A); + add(A, LDA); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(la6c, T_NEAR); + align(4); + +L(la28); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(la28, T_NEAR); + align(4); + +L(la6c); + test(M, 0x8); + jle(laa8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(laa8); + test(M, 0x4); + jle(lae0, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(lae0); + test(M, 0x2); + jle(lb14, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(lb14); + test(M, 0x1); + jle(lb38, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(lb38); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(la14, T_NEAR); + align(4); + +L(lb58); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp new file mode 100644 index 000000000000..afe4f1713ef4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp @@ -0,0 +1,647 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_bt_kern::jit_avx512_core_u8_copy_sum_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l15c; +Xbyak::Label l1f4; +Xbyak::Label l20; +Xbyak::Label l248; +Xbyak::Label l280; +Xbyak::Label l2a4; +Xbyak::Label l2b0; +Xbyak::Label l2c8; +Xbyak::Label l384; +Xbyak::Label l3e8; +Xbyak::Label l40; +Xbyak::Label l424; +Xbyak::Label l448; +Xbyak::Label l468; +Xbyak::Label l474; +Xbyak::Label l48c; +Xbyak::Label l550; +Xbyak::Label l5bc; +Xbyak::Label l600; +Xbyak::Label l628; +Xbyak::Label l646; +Xbyak::Label l650; +Xbyak::Label l668; +Xbyak::Label l700; +Xbyak::Label l760; +Xbyak::Label l7a4; +Xbyak::Label l7c8; +Xbyak::Label l7e8; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x8); + jl(l2a4, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x8); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x3); + jle(l15c, T_NEAR); + align(4); + +L(l40); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(l40, T_NEAR); + align(4); + +L(l15c); + test(M, 0x4); + jle(l1f4, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l1f4); + test(M, 0x2); + jle(l248, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l248); + test(M, 0x1); + jle(l280, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l280); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l2a4); + cmp(N, 0x4); + jl(l468, T_NEAR); + align(4); + +L(l2b0); + mov(A1, A); + add(A, 0x4); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x3); + jle(l384, T_NEAR); + align(4); + +L(l2c8); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(l2c8, T_NEAR); + align(4); + +L(l384); + test(M, 0x4); + jle(l3e8, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l3e8); + test(M, 0x2); + jle(l424, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l424); + test(M, 0x1); + jle(l448, T_NEAR); + movd(xmm0, dword[A1-0x80]); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l448); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(l2b0, T_NEAR); + align(4); + +L(l468); + cmp(N, 0x2); + jl(l646, T_NEAR); + align(4); + +L(l474); + mov(A1, A); + add(A, 0x2); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l550, T_NEAR); + align(4); + +L(l48c); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(l48c, T_NEAR); + align(4); + +L(l550); + test(M, 0x4); + jle(l5bc, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l5bc); + test(M, 0x2); + jle(l600, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l600); + test(M, 0x1); + jle(l628, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l628); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l474, T_NEAR); + align(4); + +L(l646); + cmp(N, 0x1); + jl(l7e8, T_NEAR); + align(4); + +L(l650); + mov(A1, A); + add(A, 0x1); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l700, T_NEAR); + align(4); + +L(l668); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l668, T_NEAR); + align(4); + +L(l700); + test(M, 0x4); + jle(l760, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l760); + test(M, 0x2); + jle(l7a4, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l7a4); + test(M, 0x1); + jle(l7c8, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l7c8); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(l650, T_NEAR); + align(4); + +L(l7e8); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp new file mode 100644 index 000000000000..4fc11afcbc05 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp @@ -0,0 +1,116 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "../f32/ref_gemm_f32.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co) { + + if (*M == 0 || *N == 0 || *K == 0) + return mkldnn_success; + + bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); + bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); + bool AisN = (*transa == 'N' || *transa == 'n'); + bool BisN = (*transb == 'N' || *transb == 'n'); + + int m = *M, n = *N, k = *K, lda = *LDA, ldb = *LDB, ldc = *LDC; + size_t sizeA = AisN ? lda * k : lda * m; + size_t sizeB = BisN ? ldb * n : ldb * k; + size_t sizeC = ldc * n; + + double *dA = (double *)malloc(sizeA * sizeof(double), PAGE_4K); + double *dB = (double *)malloc(sizeB * sizeof(double), PAGE_4K); + double *dC = (double *)malloc(sizeC * sizeof(double), PAGE_4K); + + if (utils::any_null(dA, dB, dC)) { + free(dA); + free(dB); + free(dC); + return mkldnn_out_of_memory; + } + + auto da_setter = [=] (int i, int j, double v) { dA[j * lda + i] = v; }; + auto db_setter = [=] (int i, int j, double v) { dB[j * ldb + i] = v; }; + + auto ia_accessor = [=] (int i, int j) { return A[j * lda + i]; }; + auto ib_accessor = [=] (int i, int j) { return B[j * ldb + i]; }; + + const int a_rows = AisN ? m : k; + const int a_cols = AisN ? k : m; + mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) { + da_setter(i, j, + static_cast(ia_accessor(i, j)) + static_cast(ao[0])); + }); + + const int b_rows = BisN ? k : n; + const int b_cols = BisN ? n : k; + mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) { + db_setter(i, j, + static_cast(ib_accessor(i, j)) + static_cast(bo[0])); + }); + double one = 1.0, zero = 0.0; + ref_gemm(transa, transb, M, N, K, &one, dA, LDA, dB, LDB, &zero, + dC, LDC, nullptr); + + auto i2d = [=] (int32_t v) { return static_cast(v); }; + auto f2d = [=] (float v) { return static_cast(v); }; + + mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) { + double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]); + double val = ((*beta == 0.0f) ? 0.0 : f2d(*beta) * i2d(C[i + j * ldc])) + + f2d(*alpha) * dC[i + j * ldc] + coffset; + C[i + j * ldc] = math::out_round(math::saturate(val)); + }); + + free(dA); + free(dB); + free(dC); + return mkldnn_success; +} + +template mkldnn_status_t ref_gemm_s8x8s32( + const char *transa, const char *transb, const char *offsetc, + const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const uint8_t *B, const int *LDB, const int8_t *bo, + const float *beta, int32_t *C, const int *LDC, const int32_t *co); + +template mkldnn_status_t ref_gemm_s8x8s32( + const char *transa, const char *transb, const char *offsetc, + const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const int8_t *B, const int *LDB, const int8_t *bo, + const float *beta, int32_t *C, const int *LDC, const int32_t *co); + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp new file mode 100644 index 000000000000..6c0370ae99da --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp @@ -0,0 +1,38 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_GEMM_S8X8S32_HPP +#define REF_GEMM_S8X8S32_HPP + +#include + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co); + +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp new file mode 100644 index 000000000000..de1035f3b20b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp @@ -0,0 +1,180 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common.hpp" +#include "nstl.hpp" +#include "math_utils.hpp" + +#include "../gemm.hpp" +#include "jit_avx512_core_gemm_s8u8s32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +void compensation_init(const char *offsetC, int32_t *compensation, int len, + const int32_t *oc) { + bool OCisC = (*offsetC == 'C' || *offsetC == 'c'); + bool OCisF = (*offsetC == 'F' || *offsetC == 'f'); + + if (OCisF && (*oc) != 0) { + for (int i = 0; i < len; i++) + compensation[i] = *oc; + } else if (OCisC) { + for (int i = 0; i < len; i++) + compensation[i] = oc[i]; + } else { + parallel_nd(len, [=](int i) { compensation[i] = 0; }); + } +} + +void compensation_compute(bool transa, int m, int k, float alpha, + const int8_t *a, int lda, int32_t *compensation) { + if (!transa) { + const int L2_cache_size = get_cache_size(2, true); + const int blocking_factor = nstl::min(k, L2_cache_size / lda + 1); + const int npanels = k / blocking_factor; + const bool has_tile = k % blocking_factor > 0; + + parallel_nd(npanels, m, [&](int j, int i) { + int32_t val = 0; + for (int jb = 0; jb < blocking_factor; jb++) { + val += a[(i + (ptrdiff_t)j * blocking_factor * lda) + + (ptrdiff_t)jb * lda]; + } + if (alpha != 1.0f) { + val = math::out_round(math::saturate( + (double)val * alpha * -128.0)); + } else { + val *= -128; + } + fetch_and_add(&compensation[i], val); + }); + + if (has_tile) { + parallel_nd(m, [=](int i) { + int32_t val = 0; + for (int j = npanels * blocking_factor; j < k; j++) { + val += a[i + (ptrdiff_t)j * lda]; + } + if (alpha != 1.0f) { + val = math::out_round(math::saturate( + (double)val * alpha * -128.0)); + } else { + val *= -128; + } + fetch_and_add(&compensation[i], val); + }); + } + } else { + parallel_nd(m, [=](int i) { + int32_t val = 0; + for (int j = 0; j < k; j++) { + val += a[j + (ptrdiff_t)i * lda]; + } + if (alpha != 1.0f) { + val = math::out_round(math::saturate( + (double)val * alpha * -128.0)); + } else { + val *= -128; + } + compensation[i] += val; + }); + } +} + +void copy_and_shift_b(bool transb, int k, int n, uint8_t *b_u8, int ldb_u8, + const int8_t *b_s8, int ldb_s8) { + const int b_cols = transb ? k : n; + + parallel_nd(b_cols, [=](int j) { + const int b_rows = transb ? n : k; + + uint8_t *pb_u8 = b_u8 + j * ldb_u8; + const int8_t *pb_s8 = b_s8 + j * ldb_s8; + + for (int i = 0; i < b_rows; i++) { + (*pb_u8) = (*pb_s8) + 128; + pb_u8++; + pb_s8++; + } + }); +} + +/** + * gemm_s8s8s32 operation is defined as follows: + * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset + compensation + * + * where + * - compensation is a vector of length m that contains computed compensation + * that may contain C_offset if applicable. The compensation is applied inside + * gemm_s8u8s32 as a C_offset + * - B_shift is a k-by-n matrix, every element of B_shift is equal to 128 + * + * What is the compensation: + * In order to prepare the matrix B for gemm_s8u8s32 call the B_shift is applied: + * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset = + * alpha * op(A) * op(B) + alpha * op(A) * B_shift + beta * C + C_offset + * compensation = -alpha * op(A) * B_shift + * Since B_shift is a matrix, every element of which is equal to 128 then + * - if op(A) = A: compensation contains sum of the elements in each row + * scaled by -128 * alpha + * - if op(A) = A**T: compensation contains sum of the elements in each column + * scaled by -128 * alpha + * + * The rest of parameters is described in mkldnn.h + */ +mkldnn_status_t simple_gemm_s8s8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const int8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc) { + if (*oa != 0 || *ob != 0) return mkldnn_unimplemented; + + int M = *m, N = *n, K = *k; + bool transa = (*transA == 'T' || *transA == 't'); + bool transb = (*transB == 'T' || *transB == 't'); + int ld = transb ? N : K; + + uint8_t *b_u8 = (uint8_t *)malloc(sizeof(uint8_t) * K * N, 64); + int32_t *compensation = (int32_t *)malloc(sizeof(int32_t) * M, 64); + + if (utils::any_null(b_u8, compensation)) { + free(b_u8); + free(compensation); + return mkldnn_out_of_memory; + } + + compensation_init(offsetC, compensation, M, oc); + compensation_compute(transa, M, K, *alpha, a, *lda, compensation); + copy_and_shift_b(transb, K, N, b_u8, ld, b, *ldb); + + gemm_s8x8s32(transA, transB, "C", m, n, k, alpha, a, lda, oa, b_u8, + &ld, ob, beta, c, ldc, compensation); + + if ((*offsetC == 'R' || *offsetC == 'r')) + parallel_nd(M, N, + [=](int i, int j) { c[i + (ptrdiff_t)j * *ldc] += oc[j]; }); + + free(b_u8); + free(compensation); + + return mkldnn_success; +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp new file mode 100644 index 000000000000..03a3d2f7e0e6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp @@ -0,0 +1,37 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SIMPLE_GEMM_S8S8S32_HPP +#define SIMPLE_GEMM_S8S8S32_HPP + +#include +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t simple_gemm_s8s8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const int8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc); +} +} +} + +#endif // SIMPLE_GEMM_S8S8S32_HPP diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp new file mode 100644 index 000000000000..604a728b4765 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp @@ -0,0 +1,307 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "gemm_convolution.hpp" +#include "utils.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "ref_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +void gemm_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + auto col = scratchpad(ctx).get(key_conv_gemm_col); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const int M = jcp.os * jcp.od; + const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const size_t dst_step = jcp.oc * M; + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + assert(IMPLICATION( + jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow)); + assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); + + const int K = jcp.ic * jcp.ks; + const int N = jcp.oc; + + if (jcp.im2col_sz && jcp.id != 1) + parallel_nd(jcp.im2col_sz * jcp.nthr, + [&](ptrdiff_t i) { col[i] = (data_t)0; }); + + const int nb_oh = div_up(jcp.oh, jcp.oh_block); + const int nb_ow = div_up(jcp.ow, jcp.ow_block); + const size_t work_amount = jcp.ngroups * jcp.mb * jcp.od * nb_oh * nb_ow; + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; + + int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 }; + size_t start = 0, end = 0; + + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, + nb_oh, owb, nb_ow); + for (size_t iwork = start; iwork < end; ++iwork) { + int oh = ohb * jcp.oh_block; + int ow = owb * jcp.ow_block; + const data_t *_src = src + (n * jcp.ngroups + g) * src_step; + const data_t *_weights = weights + g * weights_g_size; + data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step; + const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh); + const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow); + if (jcp.im2col_sz) { + if (jcp.id == 1) + jit_gemm_convolution_utils::im2col( + jcp, _src, _col, oh, h_step, ow, w_step); + else + jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od); + } + + const data_t one = 1.0; + + const int m = h_step * w_step; + const int LDA = jcp.im2col_sz ? m : M; + data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow; + + extended_sgemm("N", "N", &m, &N, &K, &one, + jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K, + &this->beta_, _dst, &M); + + data_t *d = _dst; + if (eltwise_) { + // fast branch for ReLU case + if (eltwise_->alg_ == alg_kind::eltwise_relu) { + parallel_nd(jcp.oc, [&](const int oc) { + data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0; + data_t *d_ = d + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + if (d_[oS] < 0) d_[oS] *= eltwise_->alpha_; + } + }); + } else { + parallel_nd(jcp.oc, [&](const int oc) { + data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0; + data_t *d_ = d + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + d_[oS] = eltwise_->compute_scalar(d_[oS]); + } + }); + } + } else if (jcp.with_bias) { + parallel_nd(jcp.oc, [&](const int oc) { + data_t b = bias[g * jcp.oc + oc]; + data_t *d_ = d + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + } + }); + } + nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, nb_oh, + owb, nb_ow); + } + }); +} + +void gemm_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + auto col = scratchpad(ctx).get(key_conv_gemm_col); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const int M = jcp.os * jcp.od; + const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const size_t dst_step = jcp.oc * M; + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + const int m = jcp.os; + const int K = jcp.oc; + const int N = jcp.ic * jcp.ks; + const int LDC = jcp.im2col_sz ? m : M; + + const size_t work_amount = (size_t)jcp.ngroups * jcp.mb; + + if (jcp.id > 1) { + const ptrdiff_t diff_src_sz = (ptrdiff_t)(work_amount * src_step); + parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[i] = (data_t)0; }); + } + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; + + int g{0}, n{0}; + size_t start = 0, end = 0; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb); + for (size_t iwork = start; iwork < end; ++iwork) { + + data_t *_diff_src = diff_src + (n * jcp.ngroups + g)*src_step; + const data_t *_weights = weights + g * weights_g_size; + for (int od = 0; od < jcp.od; ++od) { + const data_t *_diff_dst = diff_dst + (n * jcp.ngroups + g) + *dst_step + od * m; + + const data_t zero = 0.0, one = 1.0; + extended_sgemm("N", "T", &m, &N, &K, &one, _diff_dst, &M, + _weights, &N, &zero, + jcp.im2col_sz ? _col:_diff_src + od * m, &LDC); + + if (jcp.im2col_sz) { + if (jcp.id == 1) + jit_gemm_convolution_utils::col2im(jcp, _col, + _diff_src); + else + jit_gemm_convolution_utils::col2im_3d(jcp, _col, + _diff_src, od); + } + } + nd_iterator_step(g, jcp.ngroups, n, jcp.mb); + } + }); +} + +void gemm_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto col = scratchpad(ctx).get(key_conv_gemm_col); + auto wei_reduction = scratchpad(ctx).get(key_conv_wei_reduction); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const int K = jcp.os * jcp.od; + const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const size_t dst_step = jcp.oc * K; + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + const int k = jcp.os; + const int N = jcp.oc; + const int M = jcp.ic * jcp.ks; + const int LDA = jcp.im2col_sz ? k : K; + + parallel_nd(jcp.im2col_sz * jcp.nthr, + [&](ptrdiff_t i) { col[i] = (data_t)0; }); + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + int ithr_g, nthr_g, ithr_mb, nthr_mb; + size_t g_start{0}, g_end{0}, mb_start{0}, mb_end{0}; + + const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1; + jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups, + mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb); + + assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1)); + const int need_reduction = nthr_mb != 1; + + if (ithr_g != -1 && ithr_mb != -1) { + balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end); + balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end); + + assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0)); + + data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; + data_t *weights_reduce_base = wei_reduction + + ithr_g * nthr_mb * weights_g_size; + data_t *weights_reduce = weights_reduce_base + + ithr_mb * weights_g_size; + + for (size_t g = g_start; g < g_end; ++g) { + data_t *_diff_weights = need_reduction + ? weights_reduce : (diff_weights + g * weights_g_size); + for (size_t mb = mb_start; mb < mb_end; ++mb) { + const data_t *_src = src + (mb*jcp.ngroups+g)*src_step; + for (int od = 0; od < jcp.od; ++od) { + const data_t *_diff_dst = diff_dst + + (mb*jcp.ngroups+g)*dst_step + od * k; + + if (jcp.im2col_sz) { + if (jcp.id == 1) + jit_gemm_convolution_utils::im2col( + jcp, _src, _col, 0, jcp.oh, 0, jcp.ow); + else + jit_gemm_convolution_utils::im2col_3d(jcp, _src, + _col, od); + } + + const data_t zero = 0.0, one = 1.0; + extended_sgemm( + "T", "N", &M, &N, &k, &one, + jcp.im2col_sz ? _col : _src + od * k, + &LDA, _diff_dst, &K, + mb == mb_start && od == 0 ? &zero : &one, + _diff_weights, &M); + } + } + } + if (need_reduction) { + mkldnn_thr_barrier(); + data_t *weights_base = diff_weights + g_start * weights_g_size; + jit_gemm_convolution_utils::bwd_weights_reduction_par( + ithr_mb, nthr_mb, jcp, weights_reduce_base, weights_base); + } + } else + if (need_reduction) { mkldnn_thr_barrier(); } + }); + + if (jcp.with_bias) { + parallel_nd(jcp.ngroups, jcp.oc, [&](int g, int oc) { + data_t db = 0; + size_t offset_ = (size_t)g * dst_step + (size_t)oc * K; + for (int mb = 0; mb < jcp.mb; ++mb) + { + size_t offset = offset_ + (size_t)mb * jcp.ngroups * dst_step; + for (int od = 0; od < jcp.od; ++od) + for (int oh = 0; oh < jcp.oh; ++oh) + PRAGMA_OMP_SIMD(reduction(+:db)) + for (int ow = 0; ow < jcp.ow; ++ow) { + db += diff_dst[offset]; + offset++; + } + } + diff_bias[g*jcp.oc+oc] = db; + }); + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp new file mode 100644 index 000000000000..302e46369a3f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp @@ -0,0 +1,250 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_GEMM_CONVOLUTION_HPP +#define CPU_JIT_GEMM_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "gemm_convolution_utils.hpp" +#include "gemm/gemm.hpp" +#include "ref_eltwise.hpp" + +#include "cpu_convolution_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct gemm_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && post_ops_ok() + && memory_desc_matches_tag(*src_md(), dat_tag()) + && memory_desc_matches_tag(*dst_md(), dat_tag()) + && memory_desc_matches_tag(*weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), src_md(), weights_md(0), dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { + using namespace format_tag; + return utils::pick(ndims() - 3, ncw, nchw, ncdhw); + } + + format_tag_t wei_tag() const { + using namespace format_tag; + return with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + } + + bool post_ops_ok() const { + auto const &po = attr()->post_ops_; + auto is_eltwise = [&](int idx) + { return po.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return po.entry_[idx].is_sum(); }; + + switch (po.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + return false; + } + }; + + gemm_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd, true) + , eltwise_(nullptr) + { + const auto &post_ops = pd()->attr()->post_ops_; + const data_t one = 1.0, zero = 0.0; + beta_ = post_ops.find(primitive_kind::sum) >= 0 ? one : zero; + + const int entry_idx = post_ops.find(primitive_kind::eltwise); + if (entry_idx != -1) eltwise_ = new ref_eltwise_scalar_fwd_t( + post_ops.entry_[entry_idx].eltwise); + } + + ~gemm_convolution_fwd_t() { delete eltwise_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + data_t beta_; + + ref_eltwise_scalar_fwd_t* eltwise_; +}; + +struct gemm_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && memory_desc_matches_tag(*diff_src_md(), dat_tag()) + && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) + && memory_desc_matches_tag(*weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), diff_src_md(), weights_md(0), diff_dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { + using namespace format_tag; + return utils::pick(ndims() - 3, ncw, nchw, ncdhw); + } + + format_tag_t wei_tag() const { + using namespace format_tag; + return with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + } + }; + + gemm_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd, true) {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct gemm_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && memory_desc_matches_tag(*src_md(), dat_tag()) + && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) + && memory_desc_matches_tag(*diff_weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), src_md(), diff_weights_md(0), diff_dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { + using namespace format_tag; + return utils::pick(ndims() - 3, ncw, nchw, ncdhw); + } + + format_tag_t wei_tag() const { + using namespace format_tag; + return with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + } + }; + + gemm_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd, true) {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp new file mode 100644 index 000000000000..f133b1e62b36 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp @@ -0,0 +1,771 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "cpu_isa_traits.hpp" + +#include "gemm_convolution_utils.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; +using namespace prop_kind; +using namespace data_type; + +namespace jit_gemm_convolution_utils { + +void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col, + int od) +{ + const size_t OHW = jcp.oh * jcp.ow; + const size_t im_step = jcp.ih * jcp.iw * jcp.id; + const size_t col_step = jcp.ks * OHW; + + parallel_nd(jcp.ic, [&](int ic) { + const float *__restrict im_loc = im + ic * im_step; + float *__restrict col_loc = col + ic * col_step; + int id = od * jcp.stride_d - jcp.f_pad; + for (int kd = 0; kd < jcp.kd; ++kd) { + float *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW; + if (id < 0 || id >= jcp.id) { + int ih_ = -jcp.t_pad; + for (int kh = 0; kh < jcp.kh; ++kh) { + int ih = ih_; + for (int oh = 0; oh < jcp.oh; ++oh) { + if (ih < 0 || ih >= jcp.ih) { + ih += jcp.stride_h; + continue; + } + int iw_ = -jcp.l_pad; + for (int kw = 0; kw < jcp.kw; ++kw) { + int iw = iw_; + for (int ow = 0; ow < jcp.ow; ++ow) { + if (iw < 0 || iw >= jcp.iw) { + iw += jcp.stride_w; + continue; + } + + const size_t col_idx = kw * OHW + oh * jcp.ow + + ow; + + col_[col_idx] = 0; + iw += jcp.stride_w; + } + iw_ += (1 + jcp.dilate_w); + } + ih += jcp.stride_h; + } + ih_ += (1 + jcp.dilate_h); + col_ += jcp.kw * OHW; + } + } else { + const float *__restrict im_ = im_loc + id * jcp.ih * jcp.iw; + int ih_ = -jcp.t_pad; + for (int kh = 0; kh < jcp.kh; ++kh) { + int ih = ih_; + for (int oh = 0; oh < jcp.oh; ++oh) { + if (ih < 0 || ih >= jcp.ih) { + ih += jcp.stride_h; + continue; + } + int iw_ = -jcp.l_pad; + for (int kw = 0; kw < jcp.kw; ++kw) { + int iw = iw_; + for (int ow = 0; ow < jcp.ow; ++ow) { + if (iw < 0 || iw >= jcp.iw) { + iw += jcp.stride_w; + continue; + } + + const size_t col_idx = kw * OHW + oh * jcp.ow + + ow; + const size_t im_idx = ih * jcp.iw + iw; + + col_[col_idx] = im_[im_idx]; + iw += jcp.stride_w; + } + iw_ += (1 + jcp.dilate_w); + } + ih += jcp.stride_h; + } + ih_ += (1 + jcp.dilate_h); + col_ += jcp.kw * OHW; + } + } + id += (1 + jcp.dilate_d); + } + }); +} + +/* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */ +void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im, + float *__restrict col, int hs, int hb, int ws, int wb) { + const size_t im_step = jcp.is; + const size_t col_step = jcp.ks * hb * wb; + if (jcp.stride_w == 1) { + // Generated code is more optimized for stride_w == 1 + // because innermost loop is by width + auto ker = [&](int ic, int kh, int kw, int oh) { + const float *__restrict im_ = im + ic * im_step; + float *__restrict col_ + = col + ic * col_step + ((kh * jcp.kw + kw) * hb + oh) * wb; + + const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) { + for (int ow = 0; ow < wb; ++ow) + col_[ow] = 0.f; + } else { + for (int ow = 0; ow < wb; ++ow) { + const int iw = ow + ws - jcp.l_pad + kw * (1 + jcp.dilate_w); + if (iw < 0 || iw >= jcp.iw) + col_[ow] = 0.f; + else { + const size_t im_idx = ih * jcp.iw + iw; + col_[ow] = im_[im_idx]; + } + } + } + }; + + if (jcp.outer_threading) { + for (int ic = 0; ic < jcp.ic; ic++) + for (int kh = 0; kh < jcp.kh; kh++) + for (int kw = 0; kw < jcp.kw; kw++) + for (int oh = 0; oh < hb; oh++) + ker(ic, kh, kw, oh); + } + else { + parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, ker); + } + } else if (jcp.ic == 1) { + parallel_nd(jcp.kh, hb, [&](int kh, int oh) { + const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) + for (int kw = 0; kw < jcp.kw; ++kw) { + for (int ow = 0; ow < wb; ++ow) { + const size_t col_idx + = ((kh * jcp.kw + kw) * hb + oh) * wb + ow; + col[col_idx] = 0; + } + } + else + for (int kw = 0; kw < jcp.kw; ++kw) { + for (int ow = 0; ow < wb; ++ow) { + const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad + + kw * (1 + jcp.dilate_w); + const size_t col_idx + = ((kh * jcp.kw + kw) * hb + oh) * wb + ow; + const size_t im_idx = ih * jcp.iw + iw; + if (iw < 0 || iw >= jcp.iw) + col[col_idx] = 0; + else + col[col_idx] = im[im_idx]; + } + } + }); + } else { + + parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, + [&](int ic, int kh, int kw, int oh) { + const float *__restrict im_ = im + ic * im_step; + float *__restrict col_ = col + ic * col_step + + ((kh * jcp.kw + kw) * hb + oh) * wb; + + const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) { + for (int ow = 0; ow < wb; ++ow) + col_[ow] = 0.f; + } else { + for (int ow = 0; ow < wb; ++ow) { + const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad + + kw * (1 + jcp.dilate_w); + const size_t im_idx = ih * jcp.iw + iw; + if (iw < 0 || iw >= jcp.iw) + col_[ow] = 0.f; + else + col_[ow] = im_[im_idx]; + } + } + }); + } +} + +inline int limit(int low, int upper, int value) { + return nstl::max(low, nstl::min(upper, value)); +} + +/* col[kh][kw][ic][oh][ow] <-- im2col_u8(im[ih][iw][ic]) */ +template +void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, + T *__restrict imtr, uint8_t *__restrict col, int hs, int hb, int ws, + int wb) { + uint8_t shift = jcp.signed_input ? 128 : 0; + const int dh = 1 + jcp.dilate_h; + const int dw = 1 + jcp.dilate_w; + const int sh = jcp.stride_h; + const int sw = jcp.stride_w; + const int im_iw_stride = jcp.ic * jcp.ngroups; + const int im_ih_stride = jcp.iw * im_iw_stride; + const int tp = jcp.t_pad; + const int lp = jcp.l_pad; + + if (jcp.outer_threading && sh == 1 && sw == 1 && dh == 1 && dw == 1) { + /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */ + const int hp = hs - tp; + const int wp = ws - lp; + const int ih_start = limit(0, jcp.ih, hp); + const int ih_end = limit(0, jcp.ih, hp + hb + jcp.kh); + const int iw_start = limit(0, jcp.iw, wp); + const int iw_end = limit(0, jcp.iw, wp + wb + jcp.kw); + + const int ihb = ih_end - ih_start; + const int iwb = iw_end - iw_start; + + const int imtr_ic_stride = ihb * iwb; + const ptrdiff_t imtr_idx_shift = ih_start * iwb + iw_start; + for (int ic = 0; ic < jcp.ic; ic++) { + const ptrdiff_t imtr_idx_ic = ic * imtr_ic_stride - imtr_idx_shift; + for (int ih = ih_start; ih < ih_end; ih++) { + const ptrdiff_t im_idx_ih = ic + ih * im_ih_stride; + const ptrdiff_t imtr_idx_ih = imtr_idx_ic + ih * iwb; + for (int iw = iw_start; iw < iw_end; iw++) + imtr[imtr_idx_ih + iw] = im[im_idx_ih + iw * im_iw_stride]; + } + } + + const int col_ic_str = hb * wb; + const int col_kw_stride = jcp.ic * col_ic_str; + const int col_kh_stride = jcp.kw * col_kw_stride; + + const int oh_init = ih_start - hp; + const int ow_init = iw_start - wp; + for (int kh = 0; kh < jcp.kh; kh++) { + const ptrdiff_t col_idx_kh = kh * col_kh_stride; + const int oh_kh = oh_init - kh; + const int oh_start = limit(0, hb, oh_kh); + const int oh_end = limit(0, hb, oh_kh + ihb); + for (int kw = 0; kw < jcp.kw; kw++) { + const ptrdiff_t col_idx_kw + = col_idx_kh + kw * jcp.ic * col_ic_str; + const int ow_kw = ow_init - kw; + const int imtr_shift = oh_kh * iwb + ow_kw; + const int ow_start = limit(0, wb, ow_kw); + const int ow_end = limit(0, wb, ow_kw + iwb); + for (int ic = 0; ic < jcp.ic; ic++) { + const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str; + const int imtr_idx_ic = ic * imtr_ic_stride - imtr_shift; + for (int oh = 0; oh < oh_start; oh++) { + const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; + for (int ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } + for (int oh = oh_start; oh < oh_end; oh++) { + const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; + const ptrdiff_t imtr_idx_oh = imtr_idx_ic + oh * iwb; + for (int ow = 0; ow < ow_start; ++ow) + col[col_idx_oh + ow] = shift; + for (int ow = ow_start; ow < ow_end; ++ow) + col[col_idx_oh + ow] + = imtr[imtr_idx_oh + ow] + shift; + for (int ow = ow_end; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } + for (int oh = oh_end; oh < hb; oh++) { + const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; + for (int ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } + } + } + } + } else { + parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb, + [&](int kh, int kw, int ic, int oh) { + const int hp = tp - kh * dh; + const int ih = (oh + hs) * sh - hp; + const ptrdiff_t col_idx_base + = (((kh * jcp.kw + kw) * jcp.ic + ic) * hb + oh) * wb; + if (ih < 0 || ih >= jcp.ih) + for (int ow = 0; ow < wb; ow++) + col[col_idx_base + ow] = shift; + else { + const int wp = lp - kw * dw; + const int ow_start = limit(0, wb, div_up(wp, sw) - ws); + const int ow_end + = limit(0, wb, div_up(jcp.iw + wp, sw) - ws); + for (int ow = 0; ow < ow_start; ow++) + col[col_idx_base + ow] = shift; + const int iw_base = ws * sw - wp; + const ptrdiff_t im_idx_base = ih * im_ih_stride + ic; + for (int ow = ow_start; ow < ow_end; ow++) { + const int iw = iw_base + ow * sw; + const ptrdiff_t im_idx + = im_idx_base + iw * im_iw_stride; + col[col_idx_base + ow] = im[im_idx] + shift; + } + for (int ow = ow_end; ow < wb; ow++) + col[col_idx_base + ow] = shift; + } + }); + } +} + +template void im2col_u8(const jit_gemm_conv_conf_t &jcp, + const int8_t *__restrict im, int8_t *__restrict imtr, + uint8_t *__restrict col, int hs, int hb, int ws, int wb); +template void im2col_u8(const jit_gemm_conv_conf_t &jcp, + const uint8_t *__restrict im, uint8_t *__restrict imtr, + uint8_t *__restrict col, int hs, int hb, int ws, int wb); + +/* im[ih][iw][ic] <-- col2im_s32(col[oh][ow][kh][kw][ic]) */ +void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col, + int32_t *__restrict im) +{ + parallel(0, [&](const int ithr, const int nthr) { + int h_nthr = nstl::min(jcp.ih, nthr); + int w_nthr = nstl::min(jcp.iw, nthr / h_nthr); + int h_ithr = 1, h_s = 0, h_e = 0, w_ithr = 1, w_s = 0, w_e = 0; + if (ithr < h_nthr * w_nthr) { + h_ithr = ithr / w_nthr; + w_ithr = ithr % w_nthr; + balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e); + balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e); + } else { + h_ithr = w_ithr = -ithr; + h_s = h_e = w_s = w_e = -1; + } + + for (int ih = h_s; ih < h_e; ++ih) { + for (int iw = w_s; iw < w_e; ++iw) { + PRAGMA_OMP_SIMD() + for (int ic = 0; ic < jcp.ic; ++ic) { + im[(ih * jcp.iw + iw) * jcp.ic + ic] = 0; + } + } + } + + // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh] + for (int oh = 0; oh < jcp.oh; ++oh) { + for (int ow = 0; ow < jcp.ow; ++ow) { + for (int kh = 0; kh < jcp.kh; ++kh) { + const int ih = oh * jcp.stride_h + - jcp.t_pad + kh * (1 + jcp.dilate_h); + if (ih < h_s || ih >= h_e) continue; + + for (int kw = 0; kw < jcp.kw; ++kw) { + const int iw = ow * jcp.stride_w + - jcp.l_pad + kw * (1 + jcp.dilate_w); + if (iw < w_s || iw >= w_e) continue; + + const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh + + kh) * jcp.kw + kw) * jcp.ic; + const size_t im_idx + = (ih * jcp.iw + iw) * jcp.ic; + PRAGMA_OMP_SIMD() + for (int ic = 0; ic < jcp.ic; ++ic) { + im[im_idx + ic] += col[col_idx + ic]; + } + } + } + } + } + }); +} + +void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im, + int od) +{ + parallel_nd(jcp.ic, [&](int ic) { + const float *__restrict col_ = col + (size_t)ic * jcp.ks * jcp.os; + float *__restrict im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id; + + int id = od * jcp.stride_d - jcp.f_pad; + for (int kd = 0; kd < jcp.kd; ++kd) { + if (id < 0 || id >= jcp.id) { + col_ += jcp.kh * jcp.kw * jcp.os; + id += (1 + jcp.dilate_d); + continue; + } + + float *__restrict im_ = im_ic + id * jcp.ih * jcp.iw; + + for (int oh = 0; oh < jcp.oh; ++oh) { + for (int kh = 0; kh < jcp.kh; ++kh) { + const int ih = oh * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) continue; + + for (int ow = 0; ow < jcp.ow; ++ow) { + for (int kw = 0; kw < jcp.kw; ++kw) { + const int iw = ow * jcp.stride_w - jcp.l_pad + + kw * (1 + jcp.dilate_w); + if (iw < 0 || iw >= jcp.iw) continue; + + const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow; + const size_t im_idx = ih*jcp.iw + iw; + im_[im_idx] += col_[col_idx]; + }} + }} + + col_ += jcp.kh * jcp.kw * jcp.os; + id += (1 + jcp.dilate_d); + } + }); +} + +void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im) { + const size_t col_step = jcp.ks * jcp.os; + const size_t im_step = jcp.ih * jcp.iw; + const int iS = jcp.ih * jcp.iw; + + parallel_nd(jcp.ic, [&](int ic) { + float *__restrict im_ = im + ic * im_step; + const float *__restrict col_ = col + ic * col_step; + PRAGMA_OMP_SIMD() + for (int is = 0; is < iS; ++is) im_[is] = 0.; + + for (int kh = 0; kh < jcp.kh; ++kh) { + for (int oh = 0; oh < jcp.oh; ++oh) { + const int ih = + oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) continue; + + for (int kw = 0; kw < jcp.kw; ++kw) { + for (int ow = 0; ow < jcp.ow; ++ow) { + const int iw = + ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w); + if (iw < 0 || iw >= jcp.iw) continue; + + const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow; + const size_t im_idx = ih*jcp.iw + iw; + im_[im_idx] += col_[col_idx]; + } + } + } + } + }); +} + +status_t init_conf(jit_gemm_conv_conf_t &jcp, + memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, int max_threads) { + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + const int is_1d = ndims == 3; + const int is_3d = ndims == 5; + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.id = is_3d ? src_d.dims()[2] : 1; + jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.od = is_3d ? dst_d.dims()[2] : 1; + jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.f_pad = is_3d ? cd.padding[0][0] : 0; + jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_d = is_3d ? cd.strides[0] : 1; + jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.dilate_d = is_3d ? cd.dilates[0] : 0; + jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef + || cd.diff_bias_desc.format_kind != format_kind::undef; + + jcp.is = jcp.ih * jcp.iw; + jcp.os = jcp.oh * jcp.ow; + jcp.ks = jcp.kh * jcp.kw * jcp.kd; + + jcp.signed_input = src_d.data_type() == data_type::s8; + + jcp.im2col_sz = !everyone_is(true, + jcp.ow == jcp.iw, jcp.oh == jcp.ih, jcp.od == jcp.id, + jcp.stride_w == 1, jcp.stride_h == 1, jcp.stride_d == 1, + jcp.ks == 1, !jcp.signed_input) + ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os : 0; + + jcp.outer_threading = false; + + bool is_int8_conv = utils::one_of(src_d.data_type(), s32, s8, u8) + && weights_d.data_type() == s8; + + const int vlen = mayiuse(avx512_common) + ? cpu_isa_traits::vlen + : mayiuse(avx) + ? cpu_isa_traits::vlen + : mayiuse(sse42) ? cpu_isa_traits::vlen : 4; + const int simd_w = vlen / (is_int8_conv ? 1 : 4); + + const bool is_bwd_d = jcp.prop_kind == backward_data; + const bool is_bwd_w = jcp.prop_kind == backward_weights; + const bool is_fwd = !is_bwd_d && !is_bwd_w; + jcp.oh_block = is_fwd ? jcp.oh : jcp.ih; + jcp.ow_block = is_fwd ? jcp.ow : jcp.iw; + + using namespace memory_tracking::names; + bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; + + // TODO: maybe mitigate blocking restriction + const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw; + const int L2 = get_cache_size(2, true) + / (is_int8_conv ? sizeof(int8_t) : sizeof(float)); + bool is_blocking_applicable = true + && is_fwd && jcp.im2col_sz + && jcp.id == 1 && jcp.od == 1 + && jcp.dilate_h == 0 && jcp.dilate_w == 0 + && !is_depthwise + && wei_size < L2/2; + if (is_blocking_applicable) { + // looking for oh and ow blocking + int h_block{ jcp.oh_block }, w_block{ jcp.ow_block }; + const int ic = jcp.ic; + const int oc = jcp.oc; + const int iw = jcp.iw; + const int ow = jcp.ow; + const int oh = jcp.oh; + const int os = oh * ow; + + // 1. cache requirement + int row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow); + if (is_int8_conv) { + // Heuristic rule: gemm needed a lot of memory for internal usage + row_size *= 5; + // memory for accumulators + row_size += oc * ow * sizeof(uint32_t); + // memory for transposition + row_size += ic * iw; + } + + h_block = nstl::max(1, nstl::min(oh, div_up(L2, row_size))); + if (h_block == 1) { + int col_size = ic * jcp.ks + 2 * (ic + oc); + if (is_int8_conv) { + col_size *= 5; + col_size += oc * sizeof(uint32_t); + col_size += ic; + } + w_block = nstl::max(1, nstl::min(ow, div_up(L2, col_size))); + } + + // 2. threading requirement + if (h_block != oh) + h_block = nstl::max(1, rnd_dn(h_block, 4)); + if (w_block != ow) + w_block = nstl::max(1, rnd_dn(w_block, simd_w)); + + float thr_eff = 0.f; + float thr_eff_treshold = 0.9f; + if (w_block == ow) { + do { + int nb_h = div_up(oh, h_block); + size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h; + float disb = (float)oh / rnd_up(oh, h_block); + thr_eff = (float)work / rnd_up(work, max_threads); + thr_eff = (thr_eff + disb) / 2.f; + if (thr_eff >= thr_eff_treshold) + break; + h_block = rnd_dn(h_block - 4, 4); + } while (h_block > 0); + } + if (thr_eff < thr_eff_treshold) // we didn't find suitable h_block + { + h_block = 1; + int nb_h = oh; + do { + int nb_w = div_up(ow, w_block); + size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w; + float disb = (float)ow / rnd_up(ow, w_block); + thr_eff = (float)work_amount / rnd_up(work_amount, max_threads); + thr_eff = (thr_eff + disb) / 2.f; + if (thr_eff > thr_eff_treshold) + break; + w_block = rnd_dn(w_block - simd_w, simd_w); + } while (w_block > 0); + } + h_block = nstl::max(1, h_block); + w_block = nstl::max(1, w_block); + const size_t inner_work = div_up(os, simd_w) * div_up(oc, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); + if (thr_eff >= inner_thr_eff / 2 && h_block > 0 && w_block > 0) { + jcp.oh_block = h_block; + jcp.ow_block = w_block; + jcp.outer_threading = true; + } + // updating jcp.im2col_sz + if (jcp.oh_block != 1) + jcp.ow_block = ow; + jcp.im2col_sz = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block; + } + // For threading selection in bwd_d we do: + // 1. Rough estimation of efficiency for inner and outer threading. + // 2. Gemm size estimation in assumption that it does not work + // so effectively for small sizes. + // 64K - this is heuristic gemm size per thread threshold. + const int gemm_thrld = 64 * 1024; + + if (is_int8_conv) { + if (is_fwd) { + if (!jcp.outer_threading) { + bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; + const size_t outer_work = jcp.ngroups * jcp.mb; + const float outer_thr_eff + = (float)outer_work / rnd_up(outer_work, max_threads); + const size_t inner_work + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); + jcp.outer_threading = (is_depthwise + || (jcp.is / max_threads < 64 && jcp.mb != 1)) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + } + jcp.nthr = jcp.outer_threading ? max_threads : 1; + scratchpad.book(key_conv_gemm_col, + sizeof(int8_t) * jcp.nthr * jcp.im2col_sz); + scratchpad.book(key_conv_int_dat_in_acc_dt, + sizeof(int32_t) * jcp.nthr * jcp.oh_block * jcp.ow_block * jcp.oc); + scratchpad.book(key_conv_gemm_imtr, + sizeof(int8_t) * jcp.nthr * jcp.is * jcp.ic); + } else if (is_bwd_d) { + bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; + const size_t outer_work = jcp.ngroups * jcp.mb; + const float outer_thr_eff + = (float)outer_work / rnd_up(outer_work, max_threads); + const size_t inner_work + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); + jcp.outer_threading = (is_depthwise + || (jcp.is / max_threads < 64 && jcp.mb != 1)) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + + jcp.nthr = jcp.outer_threading ? max_threads : 1; + scratchpad.book(key_conv_gemm_col, + sizeof(int32_t) * jcp.nthr * jcp.im2col_sz); + scratchpad.book(key_conv_int_dat_in_acc_dt, + sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic); + } else if (is_bwd_w) { + assert(!"unimplemented prop_kind"); + return status::unimplemented; + } + } else { + if (is_fwd) { + if (!jcp.outer_threading) { + const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od; + const float outer_thr_eff = (float)outer_work_amount + / rnd_up(outer_work_amount, max_threads); + const size_t inner_work_amount + = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w); + const float inner_thr_eff = (float)inner_work_amount + / rnd_up(inner_work_amount, max_threads); + jcp.outer_threading = jcp.os / max_threads < 512 + && IMPLICATION(jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + } + } else if (is_bwd_d) { + const size_t outer_work_amount = jcp.ngroups * jcp.mb; + const float outer_thr_eff = (float)outer_work_amount + / rnd_up(outer_work_amount, max_threads); + const size_t inner_work + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + const float inner_thr_eff = (float)inner_work + / rnd_up(inner_work, max_threads); + jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64) + && (jcp.mb != 1 || jcp.ngroups > 2) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + } else if (is_bwd_w) + jcp.outer_threading = jcp.os / max_threads < 256 + && (jcp.mb != 1 || jcp.ngroups > 2); + + jcp.nthr = jcp.outer_threading ? max_threads : 1; + scratchpad.book(key_conv_gemm_col, + sizeof(float) * jcp.nthr * jcp.im2col_sz); + + if (is_bwd_w) { + jcp.need_wei_reduction = mkldnn_thr_syncable() + ? jcp.mb != 1 && jcp.nthr != 1 : false; + scratchpad.book(key_conv_wei_reduction, + sizeof(float) * jcp.nthr * jcp.ngroups * weights_d.size()); + } + } + + return status::success; +} + +void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g, + int &nthr_g, int &ithr_mb, int &nthr_mb) { + nthr_g = nstl::min(ngroups, nthr); + nthr_mb = nstl::min(mb, nthr / nthr_g); + if (ithr / nthr_mb >= ngroups) { + ithr_g = ithr_mb = -1; + } else { + ithr_g = ithr / nthr_mb; + ithr_mb = ithr % nthr_mb; + } +} + +void bwd_weights_reduction_par(int ithr, int nthr, + const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws, + float *weights) { + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + size_t weights_start{0}, weights_end{0}; + balance211(weights_g_size, nthr, ithr, weights_start, weights_end); + + for (int i = 0; i < nthr; ++i) { + const float *ws_i = weights_reduce_ws + i * weights_g_size; + for (size_t s = weights_start; s < weights_end; ++s) + weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s]; + } +} + +}; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp new file mode 100644 index 000000000000..e006789344ea --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp @@ -0,0 +1,66 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP +#define CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_engine.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace jit_gemm_convolution_utils { + +void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col, + int od); +void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im, + float *__restrict col, int hs, int hb, int ws, int wb); +template +void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, + T* __restrict imtr, uint8_t *__restrict col, + int hs, int hb, int ws, int wb); + +void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col, + int32_t *__restrict im); +void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im, + int od); +void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im); + +status_t init_conf(jit_gemm_conv_conf_t &jcp, + memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, int max_threads); + +void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, + int &ithr_g, int &nthr_g, int &ithr_mb, int &nthr_mb); +void bwd_weights_reduction_par(int ithr, int nthr, + const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws, + float *weights); + +} + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp new file mode 100644 index 000000000000..2872122f0d9a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp @@ -0,0 +1,156 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" + +#include "gemm_inner_product.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::data_type; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::primitive_kind; + +template +void gemm_inner_product_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC_total_padded(); + + bool wei_tr = !memory_desc_matches_one_of_tag( + *pd()->weights_md(), hwio, dhwio, io); + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_relu = post_ops.len_ == 1; + + float alpha = 1.0, beta = 0.0; + extended_sgemm(wei_tr ? "T" : "N", "N", &OC, &MB, &IC, &alpha, weights, + wei_tr ? &IC : &OC, src, &IC, &beta, dst, &OC, bias); + + if (do_relu) { + float nslope = post_ops.entry_[0].eltwise.alpha; + parallel_nd(MB, OC, [&](int mb, int oc) { + size_t dst_off = mb * OC + oc; + if (dst[dst_off] < 0) + dst[dst_off] *= nslope; + }); + } +} + +template +void gemm_inner_product_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC_total_padded(); + + bool wei_tr = memory_desc_matches_one_of_tag( + *pd()->weights_md(), hwio, dhwio, io); + + float alpha = 1.0, beta = 0.0; + extended_sgemm(wei_tr ? "T" : "N", "N", &IC, &MB, &OC, &alpha, weights, + wei_tr ? &OC : &IC, diff_dst, &OC, &beta, diff_src, &IC); +} + +template +void gemm_inner_product_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + diff_dst += diff_dst_d.offset0(); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC_total_padded(); + + bool wei_tr = memory_desc_matches_one_of_tag( + *pd()->diff_weights_md(), hwio, dhwio, io); + + float alpha = 1.0, beta = 0.0; + if (wei_tr) + extended_sgemm("N", "T", &OC, &IC, &MB, &alpha, diff_dst, &OC, src, &IC, + &beta, diff_weights, &OC); + else + extended_sgemm("N", "T", &IC, &OC, &MB, &alpha, src, &IC, diff_dst, &OC, + &beta, diff_weights, &IC); + + if (diff_bias) { + diff_bias += diff_bias_d.offset0(); + constexpr int blksize = 8; + const int OC_blocks = OC / blksize; + const int rem_OC = OC % blksize; + parallel(0, [&](const int ithr, const int nthr) { + int oc_st{0}, oc_e{0}; + balance211(OC_blocks, nthr, ithr, oc_st, oc_e); + oc_st = oc_st * blksize; + oc_e = oc_e * blksize; + + PRAGMA_OMP_SIMD() + for (int oc = oc_st; oc < oc_e; ++oc) { + diff_bias[oc] = diff_dst[oc]; + } + + for (int mb = 1; mb < MB; ++mb) { + PRAGMA_OMP_SIMD() + for (int oc = oc_st; oc < oc_e; ++oc) { + diff_bias[oc] += diff_dst[mb * OC + oc]; + } + } + + if (rem_OC != 0 && ithr == nthr-1) { + for (int oc = OC_blocks * blksize; oc < OC; oc++) + diff_bias[oc] = diff_dst[oc]; + for (int mb = 1; mb < MB; ++mb) { + for (int oc = OC_blocks * blksize; oc < OC; oc++) { + diff_bias[oc] += diff_dst[mb * OC + oc]; + } + } + } + }); + } +} + +template struct gemm_inner_product_fwd_t; +template struct gemm_inner_product_bwd_data_t; +template struct gemm_inner_product_bwd_weights_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp new file mode 100644 index 000000000000..acf0a49b9a8c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp @@ -0,0 +1,157 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_GEMM_INNER_PRODUCT_HPP +#define CPU_GEMM_INNER_PRODUCT_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "gemm/gemm.hpp" + +#include "cpu_inner_product_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct gemm_inner_product_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_fwd_pd_t { + using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_fwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type, + src_md()->data_type, + weights_md()->data_type, + dst_md()->data_type, + with_bias() ? weights_md(1)->data_type : data_type) + && attr()->output_scales_.has_default_values() + && attr()->post_ops_.len_ <= 1 + && IMPLICATION(attr()->post_ops_.len_ == 1, + attr()->post_ops_.entry_[0].is_relu(true, false)) + && dense_gemm_consitency_check(src_md(), weights_md(), + dst_md()); + return ok ? status::success : status::unimplemented; + } + }; + + gemm_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct gemm_inner_product_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_data_pd_t { + using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t; + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_data_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_data + && !has_zero_dim_memory() + && utils::everyone_is(data_type, + diff_src_md()->data_type, + weights_md()->data_type, + diff_dst_md()->data_type) + && attr()->has_default_values() + && dense_gemm_consitency_check(diff_src_md(), weights_md(), + diff_dst_md()); + return ok ? status::success : status::unimplemented; + } + }; + + gemm_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct gemm_inner_product_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_weights_pd_t { + using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t; + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_weights_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_weights + && !has_zero_dim_memory() + && utils::everyone_is(data_type, + src_md()->data_type, + diff_weights_md()->data_type, + diff_dst_md()->data_type, + with_bias() ? diff_weights_md(1)->data_type : data_type) + && attr()->has_default_values() + && dense_gemm_consitency_check(src_md(), diff_weights_md(), + diff_dst_md()); + + return ok ? status::success : status::unimplemented; + } + }; + + gemm_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp new file mode 100644 index 000000000000..fed7e4d69335 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp @@ -0,0 +1,740 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "utils.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "math_utils.hpp" + +#include "simple_q10n.hpp" + +#include "gemm/gemm.hpp" +#include "gemm_x8s8s32x_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace mkldnn::impl::memory_tracking::names; + +template +void _gemm_x8s8s32x_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src_base = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst_base = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + assert(IMPLICATION( + jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow)); + assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base, dst_base, + scratchpad); + }); +} + +template +_gemm_x8s8s32x_convolution_fwd_t::pp_ker_t::pp_ker_t( + const pd_t *pd) + : ker_(nullptr) + , jcp_(pd->jcp_) + , OC_(pd->jcp_.oc) + , OS_(pd->jcp_.os) + , bias_data_type_(data_type::undef) + , bias_data_type_size_(0) + , scale_idx_mult_(0) + , do_bias_(false) + , do_relu_(false) + , do_sum_(false) +{ + using namespace types; + + const auto dst_md = memory_desc_wrapper(pd->dst_md()); + dst_os_stride_ = dst_md.blk_off(0, 0, 0, 1); + + scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1)); + + auto &post_ops = pd->attr()->post_ops_; + + int entry_idx = -1; + for (int idx = 0; idx < post_ops.len_; ++idx) { + const auto &e = post_ops.entry_[idx]; + if (e.is_relu(true, false)) { + entry_idx = idx; + break; + } + } + do_relu_ = entry_idx >= 0; + + do_signed_scaling_ = jcp_.signed_input; + + do_sum_ = post_ops.contain(primitive_kind::sum, 0); + do_bias_ = pd->with_bias(); + bias_data_type_ = pd->desc()->bias_desc.data_type; + if (do_bias_) { + assert(bias_data_type_ != data_type::undef); + bias_data_type_size_ = data_type_size(bias_data_type_); + } + const size_t vlen_start + = cpu_isa_traits::vlen / sizeof(float); + + for (size_t i = vlen_start; i > 0; i--) { + if (OC_ % i == 0) { + vlen_ = i; + break; + } + } + + if (!mayiuse(avx512_core)) + // use fallback code for older CPUs + return; + else + generate(); +} + +template +void _gemm_x8s8s32x_convolution_fwd_t::pp_ker_t::generate() +{ + using namespace Xbyak; + using namespace utils; + + // TODO: clean-up + Reg64 reg_param = abi_param1; + Reg64 reg_dst = rdx; + Reg64 reg_acc = rax; + Reg64 reg_bias = rbx; + Reg64 reg_scales = rsi; + + Reg64 reg_len = r8; + Reg64 reg_tmp = rcx; // intentional for shifting purposes + Reg64 reg_oc_offset = r9; + Reg64 reg_rem_mask_short = r10; + Reg64 reg_rem_mask_vlen = r11; + Opmask kreg_rem_mask_short = k1; + Opmask kreg_rem_mask_vlen = k3; + Opmask kreg_relu_cmp = k2; + + const size_t vlen = vlen_; + + Zmm vreg_zero = Zmm(0); + Zmm vreg_scale = Zmm(1); + Zmm vreg_nslope = Zmm(2); + Zmm vreg_sum_scale = Zmm(3); + Zmm vreg_signed_scale = Zmm(4); + + size_t def_unroll = 4; + size_t max_unroll = 12; + size_t zmm_step = 2; + if (do_sum_) { + max_unroll = 8; + zmm_step = 3; + } + + auto vreg_dst = [&](int idx) { + return Zmm(5 + idx * zmm_step + 0); + }; + auto vreg_bias = [&](int idx) { + return Zmm(5 + idx * zmm_step + 1); + }; + auto vreg_prev_dst = [&](int idx) { + return Zmm(5 + idx * zmm_step + 2); + }; + + preamble(); + +#define PARAM_OFF(x) offsetof(ker_args, x) + mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); + mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]); + mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); + mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]); + mov(reg_len, ptr[reg_param + PARAM_OFF(len)]); + mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); + vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]); + vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]); + vbroadcastss(vreg_signed_scale, ptr[reg_param + PARAM_OFF(signed_scale)]); + if (scale_idx_mult_ == 0) + vbroadcastss(vreg_scale, dword[reg_scales]); + +#undef PARAM_OFF + + mov(reg_rem_mask_vlen, 1); + shl(reg_rem_mask_vlen, vlen); + sub(reg_rem_mask_vlen, 1); + kmovq(kreg_rem_mask_vlen, reg_rem_mask_vlen); + + if (do_relu_ || dst_type == data_type::u8) + vxorps(vreg_zero, vreg_zero, vreg_zero); + + // Load accumulated value, convert to float, apply sum (if any), + // bias (if any), scaling, and relu (if any); + // then convert to destination type and store + auto compute = [&](size_t offset, int idx, bool apply_mask) { + auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)]; + + if (scale_idx_mult_ > 0) { + assert(scale_idx_mult_ == 1); + auto scale_addr = ptr[reg_scales + offset * sizeof(float)]; + auto vreg_scale_ = vreg_scale; + if (apply_mask) + vreg_scale_ = vreg_scale_ | kreg_rem_mask_short; + else + vreg_scale_ = vreg_scale_ | kreg_rem_mask_vlen; + vmovups(vreg_scale_, scale_addr); + } + + auto vreg_dst_ = vreg_dst(idx); + if (apply_mask) + vreg_dst_ = vreg_dst_ | kreg_rem_mask_short; + else + vreg_dst_ = vreg_dst_ | kreg_rem_mask_vlen; + vcvtdq2ps(vreg_dst_, acc_addr); + + if (do_signed_scaling_) + vmulps(vreg_dst(idx), vreg_dst(idx), vreg_signed_scale); + + if (do_bias_) { + auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_]; + auto vreg_bias_ = vreg_bias(idx); + if (apply_mask) + vreg_bias_ = vreg_bias_ | kreg_rem_mask_short; + else + vreg_bias_ = vreg_bias_ | kreg_rem_mask_vlen; + + switch (bias_data_type_) { + case data_type::s8: + vpmovsxbd(vreg_bias_, bias_addr); + break; + case data_type::u8: + vpmovzxbd(vreg_bias_, bias_addr); + break; + case data_type::s32: + case data_type::f32: + vmovups(vreg_bias_, bias_addr); + break; + default: assert(!"unimplemented"); + } + if (bias_data_type_ != data_type::f32) + vcvtdq2ps(vreg_bias(idx), vreg_bias(idx)); + vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx)); + } + + vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale); + + auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)]; + + if (do_sum_) + { + auto vreg_prev_dst_ = vreg_prev_dst(idx); + if (apply_mask) + vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_short; + else + vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_vlen; + + switch (dst_type) { + case data_type::f32: + case data_type::s32: vmovups(vreg_prev_dst_, dst_addr); break; + case data_type::s8: vpmovsxbd(vreg_prev_dst_, dst_addr); break; + case data_type::u8: vpmovzxbd(vreg_prev_dst_, dst_addr); break; + default: assert(!"unsupported data type"); + } + if (dst_type != data_type::f32) + vcvtdq2ps(vreg_prev_dst(idx), vreg_prev_dst(idx)); + + vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale); + } + + if (do_relu_) { + vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os); + vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope); + } + + if (dst_type != data_type::f32) { + vcvtps2dq(vreg_dst(idx), vreg_dst(idx)); + } + + if (dst_type == data_type::u8) + vpmaxsd(vreg_dst(idx), vreg_dst(idx), vreg_zero); + + switch (dst_type) { + case data_type::s8: + vpmovsdb(dst_addr, vreg_dst_); + break; + case data_type::u8: + vpmovusdb(dst_addr, vreg_dst_); + break; + case data_type::f32: + case data_type::s32: + vmovups(dst_addr, vreg_dst_); + break; + default: assert(!"unimplemented"); + } + }; + + // Advance all pointers by an immediate + auto advance_ptrs_imm = [&](size_t offset) { + add(reg_dst, offset * sizeof(dst_data_t)); + add(reg_acc, offset * sizeof(acc_data_t)); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + add(reg_scales, offset * sizeof(float)); + } + if (do_bias_) + add(reg_bias, offset * bias_data_type_size_); + }; + + // Advance all pointers by a value stored in a register + auto advance_ptrs_reg = [&](Reg64 offset) { + lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]); + lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]); + } + if (do_bias_) + lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]); + }; + + // Rewind pointers that point to data that is indexed by output channel + // (bias or per-oc scaling factors) + auto rewind_ptrs = [&]() { + if (do_bias_) + sub(reg_bias, OC_ * bias_data_type_size_); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + sub(reg_scales, OC_ * sizeof(float)); + } + add(reg_dst, (dst_os_stride_ - OC_) * sizeof(dst_data_t)); + }; + + // <--------- OC ---------------> + // + // ^ ................+..............+-------------+....................... + // | . : not accessed |Prologue loop| . + // | . +--------------+-------------+ . + // . | | . + // O . | Main loop (unrolled) | . + // S . | | . + // . +--------------+-------------+ . + // | . | Epilogue loop|not accessed : . + // v ................+--------------+.............+....................... + + Label prologue_end; + cmp(reg_oc_offset, 0); + je(prologue_end, T_NEAR); + + // Prologue loop + { + mov(reg_tmp, OC_); + sub(reg_tmp, reg_oc_offset); + cmp(reg_tmp, reg_len); + cmovg(reg_tmp, reg_len); + sub(reg_len, reg_tmp); + + Label prologue_loop, prologue_loop_tail, prologue_loop_end; + cmp(reg_tmp, vlen); + jle(prologue_loop_tail, T_NEAR); + L(prologue_loop); { + compute(0, 0, false); + advance_ptrs_imm(vlen); + sub(reg_tmp, vlen); + cmp(reg_tmp, vlen); + jge(prologue_loop, T_NEAR); + } + + L(prologue_loop_tail); + mov(reg_rem_mask_short, 1); + // cl == reg_tmp because reg_tmp <= vlen here + shl(reg_rem_mask_short, cl); + sub(reg_rem_mask_short, 1); + jz(prologue_loop_end, T_NEAR); + + kmovq(kreg_rem_mask_short, reg_rem_mask_short); + compute(0, 0, true); + advance_ptrs_reg(reg_tmp); + + L(prologue_loop_end); + rewind_ptrs(); + } + L(prologue_end); + + // Main loop + Label main_loop_end; + { + cmp(reg_len, OC_); + jle(main_loop_end, T_NEAR); + + Label main_loop; + L(main_loop); { + size_t OC_loop, OC_tail; + if (OC_ < max_unroll * vlen) { + // Fully unroll small loops + OC_loop = 0; + OC_tail = OC_; + } + else { + OC_loop = vlen * def_unroll; + OC_tail = OC_ % OC_loop; + } + + assert(!!OC_loop || !!OC_tail); + + if (OC_tail % vlen) { + int vlen_tail = OC_tail % vlen; + unsigned tail_mask = (1 << vlen_tail) - 1; + mov(reg_tmp, tail_mask); + kmovq(kreg_rem_mask_short, reg_tmp); + } + + if (OC_loop) { + mov(reg_tmp, rnd_dn(OC_, OC_loop)); + Label oc_loop; + L(oc_loop); { + for (size_t offset = 0; offset < OC_loop; offset += vlen) + compute(offset, offset / vlen, false); + advance_ptrs_imm(OC_loop); + sub(reg_tmp, OC_loop); + jnz(oc_loop); + } + } + + if (OC_tail) { + for (size_t offset = 0; offset < OC_tail; offset += vlen) { + bool use_mask = (offset + vlen) > OC_tail; + compute(offset, offset / vlen, use_mask); + } + advance_ptrs_imm(OC_tail); + } + + rewind_ptrs(); + sub(reg_len, OC_); + cmp(reg_len, OC_); + jge(main_loop, T_NEAR); + } + } + L(main_loop_end); + + // Epilogue loop + Label epilogue_end; + { + cmp(reg_len, 0); + je(epilogue_end, T_NEAR); + + Label epilogue_loop, epilogue_loop_tail; + cmp(reg_len, vlen); + jle(epilogue_loop_tail, T_NEAR); + L(epilogue_loop); { + compute(0, 0, false); + sub(reg_len, vlen); + advance_ptrs_imm(vlen); + cmp(reg_len, vlen); + jge(epilogue_loop, T_NEAR); + } + + L(epilogue_loop_tail); + mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift + mov(reg_rem_mask_short, 1); + shl(reg_rem_mask_short, cl); // reg_tmp == rcx and reg_tail < vlen + sub(reg_rem_mask_short, 1); + jz(epilogue_end, T_NEAR); + kmovq(kreg_rem_mask_short, reg_rem_mask_short); + compute(0, 0, true); + } + + L(epilogue_end); + + postamble(); + + ker_ = getCode(); +} + +template +void _gemm_x8s8s32x_convolution_fwd_t::pp_ker_t::operator () + (dst_data_t *dst, const acc_data_t *acc, const char *bias, + const float *scales, float nslope, float sum_scale, float signed_scale, + int g, size_t start, size_t end) +{ + using math::get_bias; + + if (end <= start) + return; + + if (ker_) { + // JIT + ker_args args; + size_t oc_offset = start % OC_; + size_t os_offset = start / OC_; + args.acc = acc + start; + args.dst = dst + os_offset * dst_os_stride_ + oc_offset; + args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_; + args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset); + args.nslope = nslope; + args.sum_scale = sum_scale; + args.signed_scale = signed_scale; + args.len = end - start; + args.oc_offset = oc_offset; + ker_(&args); + } + else { + // Fallback + const size_t first_oc = start % OC_; + const size_t last_oc = (end - 1) % OC_; + const size_t first_os = start / OC_; + const size_t last_os = (end - 1) / OC_; + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * dst_os_stride_ + oc; + + float d = (float)(acc[acc_off]); + if (jcp_.signed_input) + d *= signed_scale; + + if (do_bias_) + d += get_bias(bias, g * jcp_.oc + oc, + bias_data_type_); + + d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_]; + if (do_sum_) + d += sum_scale * dst[dst_off]; + if (do_relu_ && d < 0) + d *= nslope; + dst[dst_off] = qz_a1b0()(d); + } + } + } +}; + +template +void _gemm_x8s8s32x_convolution_fwd_t:: +execute_forward_thr(const int ithr, const int nthr, const src_data_t *src_base, + const wei_data_t *wei_base, const char *bia_base, dst_data_t *dst_base, + const memory_tracking::grantor_t &scratchpad) const { + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const auto src_md = memory_desc_wrapper(pd()->src_md()); + const size_t src_mb_stride = src_md.blk_off(1); + const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic; + + const auto wei_md = memory_desc_wrapper(pd()->weights_md(0)); + const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0; + + const auto dst_md = memory_desc_wrapper(pd()->dst_md()); + const size_t dst_mb_stride = dst_md.blk_off(1); + const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc; + + const float *scales = pd()->attr()->output_scales_.scales_; + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_sum = post_ops.contain(primitive_kind::sum, 0); + const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0; + + float nslope = 0; + for (int idx = 0; idx < post_ops.len_; ++idx) { + const auto &e = post_ops.entry_[idx]; + if (e.is_relu(true, false)) { + nslope = e.eltwise.alpha; + break; + } + } + + auto col = scratchpad.get(key_conv_gemm_col) + + (ptrdiff_t)ithr * jcp.im2col_sz; + src_data_t *__restrict imtr = scratchpad.get(key_conv_gemm_imtr) + + (ptrdiff_t)ithr * jcp.is * jcp.ic; + auto acc = scratchpad.get(key_conv_int_dat_in_acc_dt) + + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc; + + const ptrdiff_t offset = (ptrdiff_t)jcp.ngroups * jcp.ks * jcp.ic * jcp.oc; + const int32_t *_wei_comp = (const int32_t *)(wei_base + offset); + + int g{ 0 }, n{ 0 }, ohb{ 0 }, owb{ 0 }; + size_t start = 0, end = 0; + + const int nb_oh = div_up(jcp.oh, jcp.oh_block); + const int nb_ow = div_up(jcp.ow, jcp.ow_block); + const size_t work_amount = jcp.ngroups * jcp.mb * nb_oh * nb_ow; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, + nb_oh, owb, nb_ow); + + for (size_t iwork = start; iwork < end; ++iwork) { + int oh = ohb * jcp.oh_block; + int ow = owb * jcp.ow_block; + const src_data_t *__restrict src = src_base + n * src_mb_stride + + g * src_g_stride; + const wei_data_t *__restrict wei = wei_base + g * wei_g_stride; + dst_data_t *__restrict dst = + dst_base + n * dst_mb_stride + g * dst_g_stride; + const int32_t *wei_comp = _wei_comp + g * jcp.oc; + const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh); + const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow); + + if (jcp.im2col_sz) + jit_gemm_convolution_utils::im2col_u8( + jcp, src, imtr, col, oh, h_step, ow, w_step); + + const int M = jcp.oc; + const int K = jcp.ks * jcp.ic; + const int N = h_step * w_step; + const int LDA = M * jcp.ngroups; + const int LDB = jcp.im2col_sz ? N : K; + const char *BT = jcp.im2col_sz ? "T" : "N"; + const int8_t off_a = 0, off_b = 0; + const int32_t off_c = 0; + const float onef = 1.0, zerof = 0.0; + gemm_s8x8s32("N", BT, jcp.signed_input ? "C" : "F", + &M, &N, &K, &onef, wei, &LDA, &off_a, + jcp.im2col_sz ? col : (uint8_t *)src, &LDB, &off_b, + &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c); + + auto wei_adj_scale = + (wei_md.extra().flags | memory_extra_flags::scale_adjust) + ? wei_md.extra().scale_adjust : 1.f; + + parallel(0, [&](int ithr, int nthr) { + size_t start, end; + balance211((size_t)N * jcp.oc, nthr, ithr, start, end); + (*pp_ker_)(dst + (oh * jcp.ow + ow) * pp_ker_->dst_os_stride_, + acc, bia_base, scales, nslope, sum_scale, + 1.f / wei_adj_scale, g, start, end); + }); + + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, + owb, nb_ow); + } +} + +template +void _gemm_u8s8s32x_convolution_bwd_data_t:: +execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst_base = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + auto scratchpad = this->scratchpad(ctx); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + execute_backward_data_thr(ithr, nthr, diff_dst_base, wei_base, + bia_base, diff_src_base, scratchpad); + }); +} + +template +void _gemm_u8s8s32x_convolution_bwd_data_t:: +execute_backward_data_thr(const int ithr, const int nthr, + const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base, + const char *bia_base, diff_src_data_t *diff_src_base, + const memory_tracking::grantor_t &scratchpad) const +{ + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_md()); + const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1); + const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc; + + const auto wei_md = memory_desc_wrapper(pd()->weights_md(0)); + const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0; + + const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_md()); + const size_t diff_src_mb_stride = diff_src_md.blk_off(1); + const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic; + const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1); + + /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */ + const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1); + const float *scales = pd()->attr()->output_scales_.scales_; + const size_t work_amount = jcp.ngroups * jcp.mb; + + auto col = scratchpad.get(key_conv_gemm_col) + + (ptrdiff_t)ithr * jcp.im2col_sz; + auto acc = scratchpad.get(key_conv_int_dat_in_acc_dt) + + (ptrdiff_t)ithr * jcp.is * jcp.ic; + + int n{0}, g{0}; + size_t start = 0, end = 0; + + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups); + + for (size_t iwork = start; iwork < end; ++iwork) { + const diff_dst_data_t *diff_dst = diff_dst_base + + n * diff_dst_mb_stride + g * diff_dst_g_stride; + const wei_data_t *wei = wei_base + g * wei_g_stride; + diff_src_data_t *diff_src = diff_src_base + n * diff_src_mb_stride + + g * diff_src_g_stride; + + const int M = jcp.ks * jcp.ic; + const int N = jcp.os; + const int K = jcp.oc; + const int8_t off_a = 0, off_b = 0; + const int32_t off_c = 0; + const float onef = 1.0, zerof = 0.0; + const int LD = K * jcp.ngroups; + + gemm_s8x8s32("T", "N", "F", &M, &N, &K, &onef, + wei, &LD, &off_a, diff_dst, &LD, &off_b, + &zerof, jcp.im2col_sz ? col : acc, &M, &off_c); + + if (jcp.im2col_sz) + jit_gemm_convolution_utils::col2im_s32(jcp, col, acc); + + parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) { + float d = (float)acc[is * jcp.ic + ic]; + if (jcp.with_bias) + d += get_bias(bia_base, g * jcp.ic + ic, + pd()->desc()->bias_desc.data_type); + d *= scales[(g * jcp.ic + ic) * scale_idx_mult]; + const size_t diff_src_off = is * diff_src_os_stride + ic; + diff_src[diff_src_off] = + qz_a1b0()(d); + }); + nd_iterator_step(n, jcp.mb, g, jcp.ngroups); + } +} + +using namespace data_type; + +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; + +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; + +template struct _gemm_u8s8s32x_convolution_bwd_data_t; +template struct _gemm_u8s8s32x_convolution_bwd_data_t; +template struct _gemm_u8s8s32x_convolution_bwd_data_t; +template struct _gemm_u8s8s32x_convolution_bwd_data_t; +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp new file mode 100644 index 000000000000..9e77b890d55a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp @@ -0,0 +1,266 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_X8S8S32X_CONVOLUTION_HPP +#define GEMM_X8S8S32X_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_generator.hpp" +#include "gemm_convolution_utils.hpp" + +#include "gemm/gemm.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct _gemm_x8s8s32x_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR, + _gemm_x8s8s32x_convolution_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, s8, data_type::undef, dst_type, + s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, f32, s32, s8, u8)) + && !has_zero_dim_memory() + && set_default_formats_common( + dat_tag(), format_tag::any, dat_tag()) + && post_ops_ok() + && memory_desc_matches_tag(*src_md(), dat_tag()) + && memory_desc_matches_tag(*dst_md(), dat_tag()) + && set_or_check_wei_format(); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), src_md(), weights_md(0), dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { return format_tag::nhwc; } + + bool set_or_check_wei_format() { + using namespace format_tag; + + const bool is_src_s8 = src_md_.data_type == data_type::s8; + + memory_desc_t want_wei_md = weights_md_; + memory_desc_init_by_tag(want_wei_md, with_groups() ? hwigo : hwio); + + if (is_src_s8) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups() ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md_.format_kind == format_kind::any) { + weights_md_ = want_wei_md; + return true; + } + + return weights_md_ == want_wei_md; + } + + bool post_ops_ok() const { + using namespace mkldnn::impl::primitive_kind; + auto const &po = attr()->post_ops_; + auto is_relu = [&](int idx) { + return po.entry_[idx].is_relu(true, false); }; + + switch (po.len_) { + case 0: return true; + case 1: return is_relu(0) || po.contain(sum, 0); + case 2: return po.contain(sum, 0) && is_relu(1); + default: return false; + } + return false; + } + }; + + _gemm_x8s8s32x_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd, true), pp_ker_(nullptr) + { pp_ker_ = new pp_ker_t(pd()); } + ~_gemm_x8s8s32x_convolution_fwd_t() { delete pp_ker_; } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + // XXX: this is throwaway code that will become unnecessary when we have a + // sufficiently advanced igemm jit generator that supports quantization, + // relu, and whatnot + class pp_ker_t : jit_generator { + public: + DECLARE_CPU_JIT_AUX_FUNCTIONS( + _gemm_x8s8s32x_convolution_fwd_t::pp_kernel); + pp_ker_t(const pd_t *pd); + + void operator()(dst_data_t *dst, const acc_data_t *acc, + const char *bias, const float *scales, + float nslope, float sum_scale, float signed_scale, + int g, size_t start, size_t end); + + size_t dst_os_stride_; + + private: + void generate(); + + struct ker_args { + dst_data_t *dst; + const acc_data_t *acc; + const char *bias; + const float *scales; + float nslope; + float sum_scale; + float signed_scale; + size_t len; + size_t oc_offset; + }; + void(*ker_)(const ker_args *args); + + const jit_gemm_conv_conf_t &jcp_; + size_t OC_; + size_t OS_; + data_type_t bias_data_type_; + size_t bias_data_type_size_; + size_t scale_idx_mult_; + bool do_bias_; + bool do_relu_; + bool do_sum_; + bool do_signed_scaling_; + size_t vlen_; + }; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src_base, const wei_data_t *wei_base, + const char *bia_base, dst_data_t *dst_base, + const memory_tracking::grantor_t &scratchpad) const; + + int nthr_; + pp_ker_t *pp_ker_; + +}; + +template +struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t{ + pd_t(engine_t *engine, + const convolution_desc_t *adesc, const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR, + _gemm_u8s8s32x_convolution_bwd_data_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(dst_type, s8, data_type::undef, u8, s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, f32, s32, s8, u8)) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && attr()->post_ops_.has_default_values() + && memory_desc_matches_tag(*diff_src_md(), dat_tag()) + && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) + && memory_desc_matches_tag(*weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), diff_src_md(), weights_md(), diff_dst_md(), + mkldnn_get_max_threads()); + } + + virtual bool support_bias() const override { return true; } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { return format_tag::nhwc; } + + format_tag_t wei_tag() const { + return with_groups() ? format_tag::hwigo : format_tag::hwio; + } + }; + + _gemm_u8s8s32x_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd, true) {} + + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_src_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + void execute_backward_data_thr(const int ithr, const int nthr, + const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base, + const char *bia_base, diff_src_data_t *diff_src_base, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp new file mode 100644 index 000000000000..1e435a233afe --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp @@ -0,0 +1,453 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "simple_q10n.hpp" + +#include "gemm/gemm.hpp" +#include "gemm_x8s8s32x_inner_product.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace math; +using namespace format_tag; +using namespace memory_tracking::names; + +template +gemm_x8s8s32x_inner_product_fwd_t::pp_kernel_t::pp_kernel_t( + const pd_t *pd, bool dst_is_acc) + : ker_(nullptr), OC_(pd->OC()) + , bias_data_type_(data_type::undef), bias_data_type_size_(0) + , scale_idx_mult_(0), do_bias_(false), do_relu_(false) +{ + using namespace types; + + scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1)); + + auto &post_ops = pd->attr()->post_ops_; + do_relu_ = post_ops.len_ == 1; + do_bias_ = pd->with_bias(); + bias_data_type_ = pd->desc()->bias_desc.data_type; + if (do_bias_) { + assert(bias_data_type_ != data_type::undef); + bias_data_type_size_ = data_type_size(bias_data_type_); + } + + if (!mayiuse(avx512_core)) + // use fallback code for older CPUs since they do not have optimized + // x8s8s32 GEMM anyways. The configuration variables above are used by + // the fallback code. + return; + else + generate(); +} + +template +void gemm_x8s8s32x_inner_product_fwd_t::pp_kernel_t::generate() +{ + using namespace Xbyak; + using namespace utils; + + // TODO: clean-up + Reg64 reg_param = abi_param1; + Reg64 reg_dst = rdx; + Reg64 reg_acc = rax; + Reg64 reg_bias = rbx; + Reg64 reg_scales = rsi; + + Reg64 reg_len = r8; + Reg64 reg_tmp = rcx; // intentional for shifting purposes + Reg64 reg_oc_offset = r9; + Reg64 reg_rem_mask = r10; + Opmask kreg_rem_mask = k1; + Opmask kreg_relu_cmp = k2; + + const size_t vlen = cpu_isa_traits::vlen / sizeof(float); + + Zmm vreg_zero = Zmm(0); + Zmm vreg_scale = Zmm(1); + Zmm vreg_nslope = Zmm(2); + + auto vreg_dst = [&](int idx) { return Zmm(3 + idx * 2 + 0); }; + auto vreg_bias = [&](int idx) { return Zmm(3 + idx * 2 + 1); }; + + preamble(); + +#define PARAM_OFF(x) offsetof(ker_args, x) + mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); + mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]); + mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); + mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]); + mov(reg_len, ptr[reg_param + PARAM_OFF(len)]); + mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); + vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]); + if (scale_idx_mult_ == 0) + vbroadcastss(vreg_scale, dword[reg_scales]); +#undef PARAM_OFF + + if (do_relu_ || dst_type == data_type::u8) + vxorps(vreg_zero, vreg_zero, vreg_zero); + + // Load accumulated value, convert to float, apply bias (if any), scaling, + // and relu (if any); then convert to destination type and store + auto compute = [&](size_t offset, int idx, bool apply_mask) { + auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)]; + + if (scale_idx_mult_ > 0) { + assert(scale_idx_mult_ == 1); + auto scale_addr = ptr[reg_scales + offset * sizeof(float)]; + auto vreg_scale_ = vreg_scale; + if (apply_mask) + vreg_scale_ = vreg_scale_ | kreg_rem_mask; + vmovups(vreg_scale, scale_addr); + } + + auto vreg_dst_ = vreg_dst(idx); + if (apply_mask) + vreg_dst_ = vreg_dst_ | kreg_rem_mask; + vcvtdq2ps(vreg_dst_, acc_addr); + + if (do_bias_) { + auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_]; + auto vreg_bias_ = vreg_bias(idx); + if (apply_mask) + vreg_bias_ = vreg_bias_ | kreg_rem_mask; + + switch (bias_data_type_) { + case data_type::s8: + vpmovsxbd(vreg_bias_, bias_addr); + break; + case data_type::u8: + vpmovzxbd(vreg_bias_, bias_addr); + break; + case data_type::s32: + case data_type::f32: + vmovups(vreg_bias_, bias_addr); + break; + default: assert(!"unimplemented"); + } + if (bias_data_type_ != data_type::f32) + vcvtdq2ps(vreg_bias(idx), vreg_bias(idx)); + vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx)); + } + + vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale); + if (do_relu_) { + vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os); + vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope); + } + + if (dst_type == data_type::u8) + vmaxps(vreg_dst(idx), vreg_dst(idx), vreg_zero); + + if (dst_type != data_type::f32) { + vcvtps2dq(vreg_dst(idx), vreg_dst(idx)); + } + + auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)]; + switch (dst_type) { + case data_type::s8: + vpmovsdb(dst_addr, vreg_dst_); + break; + case data_type::u8: + vpmovusdb(dst_addr, vreg_dst_); + break; + case data_type::f32: + case data_type::s32: + vmovups(dst_addr, vreg_dst_); + break; + default: assert(!"unimplemented"); + } + }; + + // Advance all pointers by an immediate + auto advance_ptrs_imm = [&](size_t offset) { + add(reg_dst, offset * sizeof(dst_data_t)); + add(reg_acc, offset * sizeof(acc_data_t)); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + add(reg_scales, offset * sizeof(float)); + } + if (do_bias_) + add(reg_bias, offset * bias_data_type_size_); + }; + + // Advance all pointers by a value stored in a register + auto advance_ptrs_reg = [&](Reg64 offset) { + lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]); + lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]); + } + if (do_bias_) + lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]); + }; + + // Rewind pointers that point to data that is indixed by output channel + // (bias or per-oc scaling factors) + auto rewind_ptrs = [&]() { + if (do_bias_) + sub(reg_bias, OC_ * bias_data_type_size_); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + sub(reg_scales, OC_ * sizeof(float)); + } + }; + + // <-------------------- OC -------------------------------> + // + // ^ +....................+----------------------------------+ + // | : not accessed | Prologue loop | + // | +--------------------+----------------------------------+ + // | | + // M | Main loop (unrolled) | + // B | | + // +--------------------------------+----------------------+ + // | | Epilogue loop | not accessed : + // v +--------------------------------+......................+ + + Label prologue_end; + cmp(reg_oc_offset, 0); + je(prologue_end, T_NEAR); + + // Prologue loop + { + mov(reg_tmp, OC_); + sub(reg_tmp, reg_oc_offset); + cmp(reg_tmp, reg_len); + cmovg(reg_tmp, reg_len); + sub(reg_len, reg_tmp); + + Label prologue_loop, prologue_loop_tail, prologue_loop_end; + cmp(reg_tmp, vlen); + jle(prologue_loop_tail, T_NEAR); // Skips for reg_tmp == 16 too (?) + L(prologue_loop); { + compute(0, 0, false); + advance_ptrs_imm(vlen); + sub(reg_tmp, vlen); + cmp(reg_tmp, vlen); + jge(prologue_loop, T_NEAR); + } + + L(prologue_loop_tail); + mov(reg_rem_mask, 1); + shl(reg_rem_mask, cl); // cl == reg_tmp because reg_tmp <= vlen here + sub(reg_rem_mask, 1); + jz(prologue_loop_end, T_NEAR); + + kmovq(kreg_rem_mask, reg_rem_mask); + compute(0, 0, true); + advance_ptrs_reg(reg_tmp); + + L(prologue_loop_end); + rewind_ptrs(); + } + L(prologue_end); + + // Main loop + Label main_loop_end; + { + cmp(reg_len, OC_); + jle(main_loop_end, T_NEAR); + + Label main_loop; + L(main_loop); { + size_t def_unroll = 4; + size_t max_unroll = 13; + + size_t OC_loop, OC_tail; + if (OC_ < max_unroll * vlen) { + // Fully unroll small loops + OC_loop = 0; + OC_tail = OC_; + } else { + OC_loop = vlen * def_unroll; + OC_tail = OC_ % OC_loop; + } + + assert(!!OC_loop || !!OC_tail); + + if (OC_tail % vlen) { + int vlen_tail = OC_tail % vlen; + unsigned tail_mask = (1 << vlen_tail) - 1; + mov(reg_tmp, tail_mask); + kmovq(kreg_rem_mask, reg_tmp); + } + + if (OC_loop) { + mov(reg_tmp, rnd_dn(OC_, OC_loop)); + Label oc_loop; + L(oc_loop); { + for (size_t offset = 0; offset < OC_loop; offset += vlen) + compute(offset, offset / vlen, false); + advance_ptrs_imm(OC_loop); + sub(reg_tmp, OC_loop); + jnz(oc_loop); + } + } + + if (OC_tail) { + for (size_t offset = 0; offset < OC_tail; offset += vlen) { + bool use_mask = (offset + vlen) > OC_tail; + compute(offset, offset / vlen, use_mask); + } + advance_ptrs_imm(OC_tail); + } + + rewind_ptrs(); + sub(reg_len, OC_); + cmp(reg_len, OC_); + jge(main_loop, T_NEAR); + } + } + L(main_loop_end); + + // Epilogue loop + Label epilogue_end; + { + cmp(reg_len, 0); + je(epilogue_end, T_NEAR); + + Label epilogue_loop, epilogue_loop_tail; + cmp(reg_len, vlen); + jle(epilogue_loop_tail, T_NEAR); // Skips for reg_len == 16 (?) + L(epilogue_loop); { + compute(0, 0, false); + sub(reg_len, vlen); + advance_ptrs_imm(vlen); + cmp(reg_len, vlen); + jge(epilogue_loop, T_NEAR); + } + + L(epilogue_loop_tail); + mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift + mov(reg_rem_mask, 1); + shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen == 16 + sub(reg_rem_mask, 1); + jz(epilogue_end, T_NEAR); + kmovq(kreg_rem_mask, reg_rem_mask); + compute(0, 0, true); + } + + L(epilogue_end); + + postamble(); + + ker_ = getCode(); +} + +template +void gemm_x8s8s32x_inner_product_fwd_t::pp_kernel_t::operator ()( + dst_data_t *dst, const acc_data_t *acc, + const char *bias, const float *scales, float nslope, + size_t start, size_t end) +{ + using math::get_bias; + + if (end <= start) + return; + + if (ker_) { + // JIT + ker_args args; + size_t oc_offset = start % OC_; + args.dst = dst + start; + args.acc = acc + start; + args.bias = bias + oc_offset * bias_data_type_size_; + args.scales = scales + scale_idx_mult_ * oc_offset; + args.nslope = nslope; + args.len = end - start; + args.oc_offset = oc_offset; + ker_(&args); + } else { + // Fallback + size_t oc = start % OC_; + for (size_t i = start; i < end; i++) { + float d = (float)acc[i]; + float b = get_bias(bias, oc, bias_data_type_); + d = d + b; + d *= scales[oc * scale_idx_mult_]; + if (do_relu_ && d < 0) + d *= nslope; + dst[i] = qz_a1b0()(d); + oc = (oc == OC_ - 1) ? 0 : oc + 1; + } + } +}; + +template +void gemm_x8s8s32x_inner_product_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + + bool wei_tr = memory_desc_matches_one_of_tag( + *pd()->weights_md(), oiw, oihw, oidhw, oi); + + const int M = OC; + const int N = MB; + const int K = pd()->IC_total_padded(); + const int8_t off_a = 0, off_b = 0; + const int32_t off_c = 0; + + const float *scales = pd()->attr()->output_scales_.scales_; + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_relu = post_ops.len_ == 1; + const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f; + + acc_data_t *acc = pd()->dst_is_acc_ + ? (acc_data_t *)dst + : scratchpad(ctx).template get(key_iprod_int_dat_in_acc_dt); + + const float onef = 1.0, zerof = 0.0; + gemm_s8x8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef, weights, + wei_tr ? &K : &M, &off_a, src, &K, &off_b, &zerof, acc, &M, &off_c); + + if (!pd()->attr()->has_default_values() || !pd()->dst_is_acc_ + || pd()->with_bias()) { + const bool force_sequential = MB * OC < 2000; + parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) { + size_t start, end; + balance211((size_t)OC * MB, nthr, ithr, start, end); + (*pp_kernel_)(dst, acc, bias, scales, nslope, start, end); + }); + } +} + +using namespace data_type; + +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp new file mode 100644 index 000000000000..ac6a5c8f85c7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp @@ -0,0 +1,166 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_X8S8S32X_INNER_PRODUCT_HPP +#define GEMM_X8S8S32X_INNER_PRODUCT_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "gemm/gemm.hpp" +#include "jit_generator.hpp" + +#include "cpu_inner_product_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct gemm_x8s8s32x_inner_product_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_fwd_pd_t { + using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; + + DECLARE_COMMON_PD_T(src_type == data_type::u8 + ? IGEMM_S8U8S32_IMPL_STR + : IGEMM_S8S8S32_IMPL_STR, + gemm_x8s8s32x_inner_product_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && !has_zero_dim_memory() + && src_md()->data_type == src_type + && dst_md()->data_type == dst_type + && weights_md()->data_type == s8 + && IMPLICATION(with_bias(), utils::one_of( + weights_md(1)->data_type, f32, s32, s8, u8)) + && attr()->post_ops_.len_ <= 1 + && IMPLICATION(attr()->post_ops_.len_, + attr()->post_ops_.entry_[0].is_relu(true, false)) + && dense_gemm_consitency_check(src_md(), weights_md(), + dst_md()); + if (!ok) return status::unimplemented; + + dst_is_acc_ = utils::one_of(dst_type, s32, f32); + + init_scratchpad(); + + return status::success; + } + + bool dst_is_acc_; + + protected: + status_t set_default_params() { + using namespace format_tag; + if (src_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md_, + utils::pick(ndims() - 2, nc, nwc, nhwc, ndhwc))); + } + if (dst_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_md_, nc)); + if (weights_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md_, + utils::pick(ndims() - 2, io, wio, hwio, dhwio))); + } + return inner_product_fwd_pd_t::set_default_params(); + } + + private: + void init_scratchpad() { + if (!dst_is_acc_) { + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book( + memory_tracking::names::key_iprod_int_dat_in_acc_dt, + sizeof(acc_data_t) * MB() * OC()); + } + } + }; + + gemm_x8s8s32x_inner_product_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd, true) + { pp_kernel_ = new pp_kernel_t(apd, pd()->dst_is_acc_); } + ~gemm_x8s8s32x_inner_product_fwd_t() { delete pp_kernel_; } + + typedef typename prec_traits::type data_t; + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + // XXX: this is throwaway code that will become unnecessary when we have a + // sufficiently advanced igemm jit generator that supports quantization, + // relu, and whatnot + class pp_kernel_t: jit_generator { + public: + DECLARE_CPU_JIT_AUX_FUNCTIONS( + gemm_x8s8s32x_inner_product_fwd_t::pp_kernel); + pp_kernel_t(const pd_t *pd, bool dst_is_acc); + + void operator()(dst_data_t *dst, const acc_data_t *acc, + const char *bias, const float *scales, float nslope, + size_t start, size_t end); + private: + void generate(); + + struct ker_args { + dst_data_t *dst; + const acc_data_t *acc; + const char *bias; + const float *scales; + float nslope; + size_t len; + size_t oc_offset; + }; + void (*ker_)(const ker_args *args); + + size_t OC_; + data_type_t bias_data_type_; + size_t bias_data_type_size_; + size_t scale_idx_mult_; + bool do_bias_; + bool do_relu_; + }; + + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + pp_kernel_t *pp_kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp new file mode 100644 index 000000000000..6fa251d465f5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp @@ -0,0 +1,674 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +#include "jit_avx2_1x1_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_avx2_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, reg_bcast_loop_work); + + Label bcast_loop, bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + generate_reduce_loop(load_loop_blk, jcp.ur); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + generate_reduce_loop(load_loop_blk, jcp.ur_tail); + L(bcast_loop_tail_out); + } +} + +void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop( + int load_loop_blk, int ur) +{ + auto vreg_load = [=](int i) { + return Ymm(ur * load_loop_blk + i); + }; + + auto vreg_accum = [=](int i, int j) { + return Ymm(j * load_loop_blk + i); + }; + + auto bias_ptr = [=](int i) { + return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i]; + }; + + auto bcast_ptr = [=](int u, int j) { + assert(j < jcp.ur); + assert(u <= jcp.reduce_loop_unroll); + size_t offt; + if (one_of(jcp.prop_kind, + forward_training, forward_inference, backward_data)) + { + assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data) + ? jcp.oc_block : jcp.ic_block); + auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is; + offt = (u == jcp.reduce_loop_unroll) + ? (height + j) * jcp.reduce_loop_unroll + : j * jcp.reduce_loop_unroll + u; + } else + offt = u * jcp.ic_block + j; + return ptr[aux_reg_bcast_data + sizeof(float) * offt]; + }; + + auto load_ptr = [=](int u, int i) { + size_t offt; + size_t u0 = u % jcp.reduce_loop_unroll; + size_t u1 = u / jcp.reduce_loop_unroll; + switch (jcp.prop_kind) { + case backward_data: + offt = (i * jcp.oc_block + u0) * jcp.ic_block; + break; + case backward_weights: + offt = (i * jcp.os + u0) * jcp.oc_block; + break; + default: + offt = (i * jcp.ic + u0) * jcp.oc_block; + } + return ptr[aux_reg_load_data + + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt]; + }; + + auto output_ptr = [=](int i, int j) { + switch (jcp.prop_kind) { + case backward_data: + return ptr[aux_reg_output_data + + (i * jcp.is + j) * jcp.ic_block * sizeof(float)]; + case backward_weights: + return ptr[aux_reg_output_data + + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale + + sizeof(float) * jcp.oc_block * j]; + default: + return ptr[aux_reg_output_data + + (i * jcp.os + j) * jcp.oc_block * sizeof(float)]; + } + }; + + auto init = [=]() { + Label init_done, init_zero; + + if (jcp.with_bias && one_of(jcp.prop_kind, forward_training, + forward_inference)) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(init_zero); + + for (int i = 0; i < load_loop_blk; i++) + for (int j = 0; j < ur; ++j) + vmovups(vreg_accum(i, j), bias_ptr(i)); + jmp(init_done); + } + + L(init_zero); + for (int i = 0; i < load_loop_blk; ++i) + for (int j = 0; j < ur; ++j) { + auto r = vreg_accum(i, j); + vxorps(r, r, r); + } + + L(init_done); + for (int i = 0; i < load_loop_blk; ++i) + vmovups(vreg_load(i), load_ptr(0, i)); + vbroadcastss(vreg_bcast, bcast_ptr(0, 0)); + }; + + auto store = [=]() { + Label store_noadd; + + if (!jcp.with_sum) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jnz(store_noadd, T_NEAR); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + auto r = vreg_accum(i, j); + vaddps(r, r, output_ptr(i, j)); + } + + L(store_noadd); + + if (jcp.with_eltwise) { + assert(ur * load_loop_blk < 14); + + Label store_norelu; + test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + jz(store_norelu, T_NEAR); + + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + L(store_norelu); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + vmovups(output_ptr(i, j), vreg_accum(i, j)); + } + }; + + auto fma_block = [=](bool last_block) { + for (int u = 0; u < jcp.reduce_loop_unroll; ++u) { + for (int j = 0; j < ur; ++j) { + for (int i = 0; i < load_loop_blk; ++i) { + if (mayiuse(avx2)) + vfmadd231ps(vreg_accum(i, j), vreg_load(i), vreg_bcast); + else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support + vmulps(vtmp, vreg_bcast, vreg_load(i)); + vaddps(vreg_accum(i, j), vreg_accum(i, j), vtmp); + } + if (j == ur - 1 && !(last_block + && u == jcp.reduce_loop_unroll - 1)) + vmovups(vreg_load(i), load_ptr(u + 1, i)); + } + if (j < ur - 1) + vbroadcastss(vreg_bcast, bcast_ptr(u, j + 1)); + } + if (!last_block || u < jcp.reduce_loop_unroll - 1) + vbroadcastss(vreg_bcast, bcast_ptr(u + 1, 0)); + } + }; + + Label reduce_loop, reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} + +void jit_avx2_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) +{ + if (!jcp.with_bias || jcp.prop_kind != backward_weights) + return; + + Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out; + Label diff_bias_load; + + auto diff_bias_ptr = [=](int i) { + return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)]; + }; + + auto load_ptr = [=](int u, int i) { + return ptr[aux_reg_load_data + + (i * jcp.os + u) * jcp.oc_block * sizeof(float)]; + }; + + auto diff_bias_reg = [=](int i) { return Ymm(i); }; + + mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]); + cmp(reg_diff_bias_data, 0); + je(diff_bias_loop_out, T_NEAR); + + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(diff_bias_load, T_NEAR); + + for (int i = 0; i < load_loop_blk; ++i) { + auto r = diff_bias_reg(i); + vxorps(r, r, r); + } + jmp(diff_bias_init_out, T_NEAR); + + L(diff_bias_load); + for (int i = 0; i < load_loop_blk; ++i) + vmovups(diff_bias_reg(i), diff_bias_ptr(i)); + + L(diff_bias_init_out); + mov(aux_reg_load_data, reg_load_data); + mov(reduce_loop_iter, reg_reduce_loop_work); + L(diff_bias_loop); { + for(int u = 0; u < jcp.reduce_loop_unroll; ++u) + for (int i = 0; i < load_loop_blk; ++i) + vaddps(diff_bias_reg(i), diff_bias_reg(i), load_ptr(u, i)); + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jnz(diff_bias_loop, T_NEAR); + } + + for (int i = 0; i < load_loop_blk; i++) + vmovups(diff_bias_ptr(i), diff_bias_reg(i)); + add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + + L(diff_bias_loop_out); +} + +void jit_avx2_1x1_conv_kernel_f32::generate() +{ + preamble(); + + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + if (jcp.with_bias) { + if (jcp.prop_kind == backward_weights) { + sub(rsp, stack_space_needed); + mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + } else + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + } + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + if (jcp.prop_kind == backward_weights) + mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + + auto generate_load_loop_body = [=] (int load_loop_blk) { + generate_bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + add(reg_output_data, + load_loop_blk * jcp.os * jcp.oc_block * sizeof(float)); + break; + case backward_data: + add(reg_output_data, + load_loop_blk * jcp.is * jcp.ic_block * sizeof(float)); + break; + case backward_weights: + for (int i = 0; i < load_loop_blk; i++) + add(reg_output_data, reg_output_stride); + break; + default: + assert(!"invalid prop_kind"); + } + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + Label load_loop_blk_8; + Label load_loop_blk_16; + Label load_loop_blk_24; + Label load_loop_blk_end; + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16, T_NEAR); + + cmp(reg_load_loop_work, 16); + jle(load_loop_blk_16, T_NEAR); + + L(load_loop_blk_24); { + generate_diff_bias_loop(3); + generate_load_loop_body(3); + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16); + cmp(reg_load_loop_work, 24); + jge(load_loop_blk_24); + } + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + L(load_loop_blk_16); { + generate_diff_bias_loop(2); + generate_load_loop_body(2); + cmp(reg_load_loop_work, 16); + jge(load_loop_blk_16); + } + + L(load_loop_blk_8); { + cmp(reg_load_loop_work, 0); + je(load_loop_blk_end, T_NEAR); + generate_diff_bias_loop(1); + generate_load_loop_body(1); + } + + L(load_loop_blk_end); + + if (jcp.with_bias && jcp.prop_kind == backward_weights) + add(rsp, 8); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx2_1x1_conv_kernel_f32::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(avx)) return status::unimplemented; + + // TODO (Roma): this code is duplicated from the generic kernel; maybe the + // configuration struct could do some stuff below + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu) + return status::unimplemented; + } + + const int is_bwd_d = jcp.prop_kind == backward_data; + + format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c; + format_tag_t wei_tag = with_groups + ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o, + gOIhw8o8i) + : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o, + OIhw8o8i); + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + + const int simd_w = 8; + + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + args_ok = true + && jcp.ih == jcp.oh && jcp.iw == jcp.ow + && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + // TODO: remove this restriction + // optimized 1x1 bwd_w does not support Intel AVX + if (jcp.prop_kind == backward_weights && !mayiuse(avx2)) + return status::unimplemented; + + jcp.ic_block = jcp.oc_block = simd_w; + + jcp.ur = mayiuse(avx2) ? 4 : 3; // Intel AVX support + + int load_blocking{ 0 }; + int load_blocking_max{ 0 }; + int bcast_blocking{ 0 }; + int bcast_blocking_max{ 0 }; + int reduce_blocking{ 0 }; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.is * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + load_blocking = 120; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 192; + reduce_blocking = 128; // affects L1$ utilization + } else if (jcp.prop_kind == backward_data) { + jcp.reduce_dim = jcp.oc; + jcp.reduce_block = jcp.oc_block; + + jcp.load_dim = jcp.ic; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.os; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.os * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.ic * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.load_loop_iter_step = jcp.ic_block; + + load_blocking = 96; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 196; + reduce_blocking = 64; // affects L1$ utilization + } else if (jcp.prop_kind == backward_weights) { + jcp.reduce_dim = jcp.os; + jcp.reduce_block = 1; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float); + jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float); + jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float); + + jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + while (true) { + if (load_blocking <= 32) break; + else if (load_blocking % 2 == 0) load_blocking /= 2; + else if (load_blocking % 3 == 0) load_blocking /= 3; + else break; + } + load_blocking *= jcp.load_block; + load_blocking_max = load_blocking; + assert(jcp.load_dim % load_blocking == 0); + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + while (true) { + if (bcast_blocking <= 9) break; + else if (bcast_blocking % 2 == 0) bcast_blocking /= 2; + else if (bcast_blocking % 3 == 0) bcast_blocking /= 3; + else break; + } + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(jcp.bcast_dim % bcast_blocking == 0); + + reduce_blocking = 128; // affects L1$ utilization + } else + return status::unimplemented; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + + assert(jcp.bcast_block % jcp.ur == 0); + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + return status::success; +} + +void jit_avx2_1x1_conv_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp) { + using namespace mkldnn::impl::memory_tracking::names; + + if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp new file mode 100644 index 000000000000..bfdeb2b18d68 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp @@ -0,0 +1,110 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX2_1x1_CONV_KERNEL_F32_HPP +#define JIT_AVX2_1x1_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_1x1_conv_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_1x1_conv_kernel_f32) + + jit_avx2_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode(); + } + + ~jit_avx2_1x1_conv_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp); + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + using ymm_t = const Xbyak::Ymm; + + reg64_t reg_bcast_data = rax; + reg64_t reg_load_data = rsi; + reg64_t reg_output_data = rbx; + reg64_t aux_reg_bcast_data = rdx; + reg64_t aux1_reg_bcast_data = abi_not_param1; + reg64_t aux_reg_load_data = abi_param1; + reg64_t aux_reg_output_data = rbp; + reg64_t reg_load_loop_work = r9; + reg64_t reg_bcast_loop_work = r10; + reg64_t reg_reduce_loop_work = r11; + reg64_t load_loop_iter = r13; + reg64_t bcast_loop_iter = r14; + reg64_t reduce_loop_iter = r15; + reg64_t imm_addr64 = reduce_loop_iter; + reg64_t reg_reduce_pos_flag = r8; + reg64_t reg_output_stride = r12; + reg64_t reg_bias_data = r12; + reg64_t reg_diff_bias_data = bcast_loop_iter; + + int reg_diff_bias_data_stack_offt = 0; + int stack_space_needed = 8; + + ymm_t vreg_bcast = ymm_t(15); + ymm_t vtmp = ymm_t(14); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + void generate_bcast_loop(int load_loop_blk); + void generate_reduce_loop(int load_loop_blk, int ur); + void generate_diff_bias_loop(int load_loop_blk); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp new file mode 100644 index 000000000000..f116ac90564b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp @@ -0,0 +1,545 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "jit_avx2_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +#define data_blk_off(f, n, c, h, w) \ + ((ndims == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w)) + +/* convolution forward */ + +void jit_avx2_1x1_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad(ctx).get(key_conv_rtus_space); + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + const int ndims = dst_d.ndims(); + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto ker = [&](const int ithr, const int nthr) { + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + const int os_block = jcp.bcast_block; + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int iwork = start; + while (iwork < end) { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + int bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, end - iwork); + + const int os = osb * os_block; + const int oh = os / jcp.ow; + const int ow = os % jcp.ow; + + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); + rp.os = p.bcast_dim; + + int ocb = 0; + while (ocb < jcp.nb_load) { + const int load_step = step(jcp.nb_load_blocking, + jcp.nb_load - ocb, jcp.nb_load_blocking_max); + + const int _ocb = g * nb_oc + ocb; + p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, + load_step * jcp.oc_block); + const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); + + p.output_data = &dst[dst_off]; + + p.bias_data = &bias[_ocb * jcp.oc_block]; + + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + p.first_last_flag = 0 + | (icb == 0 ? FLAG_REDUCE_FIRST : 0) + | (icb + nb_ic_blocking >= nb_ic + ? FLAG_REDUCE_LAST : 0); + + p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic, + nb_ic_blocking * jcp.ic_block); + rp.icb = p.reduce_dim / jcp.reduce_block; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + const int _icb = g * nb_ic + icb; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + _icb * jcp.is * jcp.ic_block; + + if (ocb == 0) { + rp.src = src + data_blk_off(src_d, n, _icb, ih, iw); + rtus_driver_->ker_(&rp); + } + + p.bcast_data = rp.ws; + } else + p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw); + + kernel_->jit_ker(&p); + } + + ocb += load_step; + } + + iwork += bcast_step; + } + }; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad(ctx).get(key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + parallel(0, ker); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +/* convolution backward wtr data */ + +void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad(ctx).get(key_conv_rtus_space); + + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + const int ndims = diff_dst_d.ndims(); + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + const int nb_ic = jcp.nb_load; + const int nb_oc = jcp.nb_reduce; + const int os_block = jcp.bcast_block; + const int nb_oc_blocking = jcp.nb_reduce_blocking; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto ker = [&](const int ithr, const int nthr) { + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int load_step = 0; + for (int icb = 0; icb < jcp.nb_load; icb += load_step) { + load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, + jcp.nb_load_blocking_max); + + p.load_dim = this_block_size(icb * jcp.ic_block, jcp.ic, + load_step * jcp.ic_block); + rp.icb = p.load_dim / jcp.ic_block; + + int bcast_step; + for (int iwork = start; iwork < end; iwork += bcast_step) { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, end - iwork); + + const int os = osb * os_block; + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + + const int oh = os / jcp.ow; + const int ow = os % jcp.ow; + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + const int _icb = g * nb_ic + icb; + rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw); + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_; + p.output_data = rp.ws; + } else + p.output_data = rp.src; + + for (int ocb = 0; ocb < jcp.nb_reduce; + ocb += jcp.nb_reduce_blocking) { + const int _ocb = g * nb_oc + ocb; + size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, + ow); + p.bcast_data = &diff_dst[diff_dst_off]; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; + + p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, + nb_oc_blocking * jcp.oc_block); + + kernel_->jit_ker(&p); + } + + if (pd()->rtus_.reduce_src_) + rtus_driver_->ker_(&rp); + } + } + }; + + parallel(0, ker); +} + +/* convolution backward wtr weights */ + +jit_avx2_1x1_convolution_bwd_weights_t::jit_avx2_1x1_convolution_bwd_weights_t( + const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr) + , rtus_driver_(nullptr) +{ + kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + reducer_weights_ = + new cpu_reducer_2d_t(pd()->reducer_wei_conf_); + reducer_bias_ = new cpu_reducer_t(pd()->reducer_bia_conf_); + init_rtus_driver(this); +} + +void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto scratchpad = this->scratchpad(ctx); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad.get(key_conv_rtus_space); + + data_t *diff_bias = pd()->wants_padded_bias() + ? scratchpad.get(key_conv_padded_bias) : diff_bias_in; + + auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_wei); + auto rw = this->reducer_weights_; + rw->init(reducer_wei_scratchpad); + + const int ndims = diff_dst_d.ndims(); + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + const int nb_ic = jcp.nb_bcast; + const int nb_ic_blocking = jcp.nb_bcast_blocking; + const int bcast_work = div_up(nb_ic, nb_ic_blocking); + + const int nb_oc = jcp.nb_load; + const int nb_oc_blocking = jcp.nb_load_blocking; + const int load_work = div_up(nb_oc, nb_oc_blocking); + + const int sp_dim = jcp.reduce_dim; + const int mb_sp_work = jcp.mb * sp_dim; + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto oc_ic_sp_loop = [=](int sp_start, int sp_end, bool first_image, + data_t *store_to, size_t store_to_ld, const data_t *diff_dst, + const data_t *src, int ithr) { + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + p.output_stride = store_to_ld * sizeof(float); + const int sp_step_def = jcp.nb_reduce_blocking * jcp.reduce_block; + + int oc_b_step = 0; + for (int oc_b = 0; oc_b < nb_oc_blocking; oc_b += oc_b_step) { + oc_b_step = step(12, nb_oc_blocking - oc_b, 18); + p.load_dim = oc_b_step * jcp.oc_block; + + int ic_b_step = 0; + for (int ic_b = 0; ic_b < nb_ic_blocking; ic_b += ic_b_step) { + ic_b_step = step(12, nb_ic_blocking - ic_b, 18); + p.bcast_dim = ic_b_step * jcp.ic_block; + rp.icb = p.bcast_dim / jcp.ic_block; + + p.output_data = store_to + oc_b * store_to_ld + + ic_b * jcp.ic_block * jcp.oc_block; + + /* spatial reduction */ + int sp_step = 0; + for (int sp = sp_start; sp < sp_end; sp += sp_step) { + sp_step = step(sp_step_def, sp_end - sp, 192); + p.reduce_dim = sp_step; + rp.os = p.reduce_dim; + + p.first_last_flag = sp == sp_start && first_image + ? FLAG_REDUCE_FIRST : 0; + + p.load_data = diff_dst + + (oc_b * jcp.reduce_dim + sp) * jcp.oc_block; + + if (pd()->rtus_.reduce_src_) { + const int oh = sp / jcp.ow; + const int ow = sp % jcp.ow; + + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + (ic_b * jcp.is + sp) * jcp.ic_block; + if (ndims == 3) + rp.src = src + + iw * src_d.blocking_desc().strides[2]; + else + rp.src = src + + ih * src_d.blocking_desc().strides[2] + + iw * src_d.blocking_desc().strides[3]; + + if (oc_b == 0) + rtus_driver_->ker_(&rp); + + p.bcast_data = rp.ws; + } else + p.bcast_data = src + + (ic_b * jcp.reduce_dim + sp) * jcp.ic_block; + + kernel_->jit_ker(&p); + } + } + } + }; + + auto ker = [&](const int ithr, const int nthr) { + assert(nthr == rw->balancer().nthr_); + + const int w_njobs = rw->balancer().ithr_njobs(ithr); + if (w_njobs == 0) return; + + /* setup: independent work (oc, ic) */ + const int w_job_start = rw->balancer().ithr_job_off(ithr); + int g{0}, load_i{0}, bcast_i{0}; + nd_iterator_init(w_job_start, g, jcp.ngroups, load_i, load_work, + bcast_i, bcast_work); + + /* setup: reduction work (mb, sp) */ + int mb_sp_start{0}, mb_sp_end{0}; + balance211(mb_sp_work, rw->balancer().nthr_per_group_, + rw->balancer().id_in_group(ithr), mb_sp_start, mb_sp_end); + int img_start{0}, sp_start{0}; + nd_iterator_init(mb_sp_start, img_start, jcp.mb, sp_start, sp_dim); + + /* independent work */ + for (int iwork = 0; iwork < w_njobs; ++iwork) { + const int oc_b = nb_oc_blocking * load_i; + const int ic_b = nb_ic_blocking * bcast_i; + + const int _ic_b = g * nb_ic + ic_b; + const int _oc_b = g * nb_oc + oc_b; + + data_t *store_to; + size_t store_to_ld; + + if (rw->balancer().nthr_per_group_ == 1) { + const size_t off = pd()->with_groups() + ? diff_weights_d.blk_off(g, oc_b, ic_b) + : diff_weights_d.blk_off(oc_b, ic_b); + store_to = &diff_weights[off]; + store_to_ld = jcp.ic * jcp.oc_block; + } else { + const size_t off = iwork * rw->balancer().job_size_; + store_to = + rw->get_local_ptr(ithr, reducer_wei_scratchpad) + off; + store_to_ld = nb_ic_blocking * jcp.ic_block * jcp.oc_block; + } + + /* reduction work */ + int img = img_start; + int sp = sp_start; + int sp_step = 0; + for (int mb_sp = mb_sp_start; mb_sp < mb_sp_end; mb_sp += sp_step) + { + sp_step = nstl::min(sp_dim - sp, mb_sp_end - mb_sp); + + const bool first_image = img == img_start; + oc_ic_sp_loop(sp, sp + sp_step, first_image, store_to, + store_to_ld, &diff_dst[diff_dst_d.blk_off(img, _oc_b)], + &src[src_d.blk_off(img, _ic_b)], ithr); + + sp = 0; + img += 1; + } + + nd_iterator_step(g, jcp.ngroups, load_i, load_work, bcast_i, + bcast_work); + } + rw->reduce(ithr, diff_weights, reducer_wei_scratchpad); + }; + + auto ker_bias = [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + + const int b_job_start = rb->balancer().ithr_job_off(ithr); + const int b_njobs = rb->balancer().ithr_njobs(ithr); + + if (b_njobs == 0) return; + + /* reduction dimension */ + int img_start{0}, img_end{0}; + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ithr), img_start, img_end); + + /* jobs */ + int g_start{0}, ocb_start{0}; + nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, nb_oc); + + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * nb_oc + ocb; + + const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; + data_t *d_bias = + rb->get_local_ptr(ithr, diff_bias, reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 8; ++o) d_bias[o] = 0.; + + for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 8; ++o) + d_bias[o] += d_dst[o]; + d_dst += 8; + } + + nd_iterator_step(g, jcp.ngroups, ocb, nb_oc); + } + } + rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); + }; + + parallel(0, [&](const int ithr, const int nthr) { + ker(ithr, nthr); + if (pd()->with_bias()) + ker_bias(ithr, nthr); + }); + + /* TODO: put this in ker_bias */ + if (pd()->wants_padded_bias()) { + assert(jcp.ngroups == 1); + for (int oc = 0; oc < jcp.oc_without_padding; ++oc) + diff_bias_in[oc] = diff_bias[oc]; + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp new file mode 100644 index 000000000000..976224217345 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp @@ -0,0 +1,344 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX2_1x1_CONVOLUTION_HPP +#define CPU_JIT_AVX2_1x1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_avx2_1x1_conv_kernel_f32.hpp" +#include "jit_uni_1x1_conv_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_1x1_convolution_fwd_t: public cpu_primitive_t { + // TODO: (Roma) Code duplication duplication! Remove with templates + // (maybe...)! + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), + jit_avx2_1x1_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md()); + + status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, + *conv_d, *src_d, *weights_md(), *dst_md(), *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) + : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx2_1x1_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx2_1x1_convolution_fwd_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_1x1_conv_kernel_f32 *kernel_; + rtus_driver_t *rtus_driver_; +}; + +struct jit_avx2_1x1_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), + jit_avx2_1x1_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *diff_src_d = diff_src_md(); + rtus_prepare(this, conv_d, diff_src_d, diff_dst_md()); + + status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, + *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), + *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i) + : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx2_1x1_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr) + , rtus_driver_(nullptr) + { + kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx2_1x1_convolution_bwd_data_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_1x1_conv_kernel_f32 *kernel_; + rtus_driver_t *rtus_driver_; +}; + +struct jit_avx2_1x1_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), + jit_avx2_1x1_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, diff_dst_md()); + + status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, + *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), + *attr()); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_wei); + reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + cpu_reducer_t::conf_t reducer_bia_conf_; + cpu_reducer_2d_t::conf_t reducer_wei_conf_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) + : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + + private: + void init_balancers() { + const int ic_block = jcp_.bcast_block; + const int nb_ic = jcp_.nb_bcast; + const int nb_ic_blocking = jcp_.nb_bcast_blocking; + const int bcast_work = utils::div_up(nb_ic, nb_ic_blocking); + + const int oc_block = jcp_.load_block; + const int nb_oc = jcp_.nb_load; + const int nb_oc_blocking = jcp_.nb_load_blocking; + const int load_work = utils::div_up(nb_oc, nb_oc_blocking); + + const int job_size + = nb_oc_blocking * nb_ic_blocking * ic_block * oc_block; + const int njobs_x = bcast_work; + const int njobs_y = jcp_.ngroups * load_work; + + const int max_threads = mkldnn_get_max_threads(); + const size_t max_buffer_size = max_threads * job_size * 8; + + if (with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(max_threads, + oc_block, jcp_.ngroups * jcp_.oc / oc_block, + jcp_.mb, max_buffer_size)); + } + + reducer_wei_conf_.init( + reduce_balancer_t(max_threads, job_size, njobs_y * njobs_x, + jcp_.mb * jcp_.nb_reduce, max_buffer_size), + job_size / nb_oc_blocking, nb_oc_blocking, ic_block, + nb_ic * ic_block * oc_block, nb_oc); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx2_1x1_convolution_bwd_weights_t(const pd_t *apd); + + ~jit_avx2_1x1_convolution_bwd_weights_t() { + delete kernel_; + delete rtus_driver_; + delete reducer_weights_; + delete reducer_bias_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_1x1_conv_kernel_f32 *kernel_; + cpu_reducer_2d_t *reducer_weights_; + cpu_reducer_t *reducer_bias_; + rtus_driver_t *rtus_driver_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp new file mode 100644 index 000000000000..e24770a2dac9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp @@ -0,0 +1,1501 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include "jit_avx2_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w)); + int jj_end = ur_w + - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w)); + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + size_t inp_off; + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) + inp_off = sizeof(float)*((size_t)ifm2*id*ih*iw + + (ki*dilate_w + jj*stride_w - pad_l)); + else + inp_off = sizeof(float)*((ki*dilate_w + jj*stride_w + - pad_l)*ic_blk + ifm2); + vbroadcastss(Ymm(oc_blocks * ur_w + jj), + make_safe_addr(aux_reg_input, inp_off, reg_long_offt)); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + int ker_off = ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + + ki * ic_blk * oc_blk + ifm2 * oc_blk; + vmovups(ymm15, ptr[aux_reg_kernel + sizeof(float) * ker_off]); + for (int jj = jj_start; jj < jj_end; jj++) + if (mayiuse(avx2)) + vfmadd231ps(Ymm(ur_w * ii + jj), + Ymm(oc_blocks * ur_w + jj), ymm15); + else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support + vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj)); + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp); + } + } + } + } +} + +void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(int ur_w, + int pad_l, int pad_r, char pad_tag, + int oc_blocks, char oc_blocks_tag) +{ + Label kw_loop; + + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + xor_(ki_iter, ki_iter); + L(kw_loop); + { + int jj_start = 0; + int jj_end = ur_w; + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + size_t inp_off; + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) + inp_off = sizeof(float)*((size_t)ifm2 * id * ih * iw + + (jj * stride_w - pad_l)); + else + inp_off = sizeof(float)*((jj * stride_w - pad_l) * ic_blk + + ifm2); + vbroadcastss(Ymm(oc_blocks * ur_w + jj), + make_safe_addr(aux_reg_input, inp_off, reg_long_offt)); + } + for (int ii = 0; ii < oc_blocks; ii++) { + int aux_kernel_offset = + ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + ifm2 * oc_blk; + vmovups(ymm15, ptr[aux_reg_kernel + + sizeof(float) * aux_kernel_offset]); + for (int jj = jj_start; jj < jj_end; jj++) + if (mayiuse(avx2)) + vfmadd231ps(Ymm(ur_w * ii + jj), + Ymm(oc_blocks * ur_w + jj), ymm15); + else { // Intel AVX support + vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj)); + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp); + } + } + } + add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? dilate_w : ic_blk * dilate_w)); + + inc(ki_iter); + cmp(ki_iter, kw); + jl(kw_loop, T_NEAR); + } +} + +void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w, + int pad_l, int pad_r, char pad_tag, + int oc_blocks, char oc_blocks_tag) +{ + int iw = jcp.iw; + int kw = jcp.kw; + int ow = jcp.ow; + int oh = jcp.oh; + int od = jcp.od; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : ic_blk; + const int inp_off = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? dilate_w : ic_blk * dilate_w; + + Label init_done, init_first; + + if (!jcp.with_sum) { + test(reg_ci_flag, FLAG_IC_FIRST); + jne(init_first, T_NEAR); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + size_t offt = + sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk; + vmovups(Ymm(ur_w * ii + jj), + make_safe_addr(reg_output, offt, reg_long_offt)); + } + } + + if (jcp.with_sum && jcp.with_bias) { + test(reg_ci_flag, FLAG_IC_FIRST); + je(init_done, T_NEAR); + + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), + yword[reg_bias + sizeof(float) * ii * oc_blk]); + } + + jmp(init_done); + + L(init_first); + if (this->jcp.with_bias) { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + vmovups(Ymm(ur_w * ii + jj), + yword[reg_bias + sizeof(float) * ii * oc_blk]); + } else { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj)); + } + + L(init_done); + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + } + + Label skip_kh_loop, skip_kd_loop, kd_loop; + if (jcp.ndims == 5) { + push(reg_output); + push(oi_iter); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_input); + + if ((jcp.dilate_d >= jcp.id) + || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) { + cmp(reg_ki, 0); + je(skip_kd_loop, T_NEAR); + } + L(kd_loop); + mov(kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_input, aux_reg_inp_d); + mov(aux_reg_kernel, aux_reg_ker_d); + } + + if ((jcp.dilate_h >= jcp.ih) + || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { + cmp(kj, 0); + je(skip_kh_loop, T_NEAR); + } + Label kh_loop; + L(kh_loop); + { + if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) { + oh_step_nopad(ur_w, pad_l, pad_r, pad_tag, oc_blocks, + oc_blocks_tag); + sub(aux_reg_input, sizeof(float) * kw * inp_off); + add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult); + } else { + oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks); + add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult); + } + + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + L(skip_kh_loop); + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + sizeof(float) * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mult); + add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_loop, T_NEAR); + L(skip_kd_loop); + + pop(oi_iter); + pop(reg_output); + } + + Label regular_store; + + if (jcp.with_eltwise) { + test(reg_ci_flag, FLAG_IC_LAST); + je(regular_store, T_NEAR); + + eltwise_injector_->compute_vector_range(0, oc_blocks * ur_w); + + L(regular_store); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + const size_t o_off + = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk; + Ymm reg_out = Ymm(ur_w * ii + jj); + vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), reg_out); + } + } +} + +inline void jit_avx2_conv_fwd_kernel_f32::solve_common( + int oc_blocks, char oc_blocks_tag) +{ + int ur_w = jcp.ur_w; + int ur_w_tail = jcp.ur_w_tail; + int n_oi = jcp.ow / ur_w; + int iw = jcp.iw; + int kw = jcp.kw; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + int dilate_w = jcp.dilate_w + 1; + int str_w = jcp.stride_w; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_blk; + + int l_pad = jcp.l_pad; + int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1)); + int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1); + if (r_pad1 > 0) n_oi--; + + if (l_pad > 0) { + n_oi--; + if (n_oi < 0 && r_pad1 > 0) + width_blk_step(ur_w, l_pad, r_pad1, + 'l', oc_blocks, oc_blocks_tag); // "lrpad" + else + width_blk_step(ur_w, l_pad, 0, + 'l', oc_blocks, oc_blocks_tag); // "lpad" + add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + Label ow_loop; + xor_(oi_iter, oi_iter); + + if (n_oi > 0) { + L(ow_loop); + + width_blk_step(ur_w, 0, 0, + 'm', oc_blocks, oc_blocks_tag); // "middle" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + + inc(oi_iter); + cmp(oi_iter, n_oi); + jl(ow_loop, T_NEAR); + } + + if (r_pad1 > 0 && n_oi >=0) { + width_blk_step(ur_w, 0, r_pad1, + 'r', oc_blocks, oc_blocks_tag); // "rpad" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + if (ur_w_tail != 0) + width_blk_step(ur_w_tail, 0, r_pad, + 't', oc_blocks, oc_blocks_tag); // "tail" +} + +void jit_avx2_conv_fwd_kernel_f32::generate() +{ + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); + mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); + + int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; + Label tail, exit; + + if (jcp.nb_oc > jcp.nb_oc_blocking) { + cmp(reg_oc_blocks, jcp.nb_oc_blocking); + jne(nb_oc_tail ? tail : exit, T_NEAR); + + solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking); + jmp(exit, T_NEAR); + + if (nb_oc_tail) { + L(tail); + cmp(reg_oc_blocks, nb_oc_tail); + jne(exit, T_NEAR); + solve_common(nb_oc_tail, '0' + nb_oc_tail); + } + + L(exit); + } else if (jcp.nb_oc == jcp.nb_oc_blocking) { + solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking); + } else { + solve_common(nb_oc_tail, '0' + nb_oc_tail); + } + + this->postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx2_conv_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(avx)) return status::unimplemented; + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 :dst_d.dims()[ndims-2]; + jcp.ow = dst_d.dims()[ndims-1]; + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2]; + jcp.kw = weights_d.dims()[with_groups + ndims-1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 :cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + + if (ndims == 3) { + jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c); + } else if (ndims == 5) { + jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nCdhw8c); + } + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu) + return status::unimplemented; + } + + const int simd_w = 8; + const bool flat = jcp.ic < simd_w; + const bool mimo = !flat; + + + /* Grouped channel offset to support 'non-blocked data' format for + * convolution sizes with '(input_channel / ngroups) < simd' */ + jcp.nonblk_group_off = + one_of(jcp.src_tag, ncw, nchw, ncdhw) && jcp.ngroups > 1 ? jcp.ic : 1; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + if (mimo) + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + bool args_ok = true + && IMPLICATION(flat, true + && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc) + && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o, + gOdhwi8o)) + && IMPLICATION(mimo, true + && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) + && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o, + OIdhw8i8o, gOIdhw8i8o)) + && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c); + if (!args_ok) return status::unimplemented; + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + jcp.ur_w = 3; + + jcp.oc_block = simd_w; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */ + + // Intel AVX and Intel AVX2 kernels need 2 and 1 temporary YMMs, respectively + // Thus, we can only assign 14 or 15 YMMs for data storage + const int num_avail_regs = mayiuse(avx2) ? 15 : 14; + if (!mayiuse(avx2)) { + if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) { + // current register assignment requires more YMMs than available + // adjust one of nb_oc_block, ur_w preserving to ur_w >= l_pad + if (jcp.ur_w > jcp.l_pad && jcp.ur_w > 1) + jcp.ur_w -= 1; + else + for (int b = 3; b > 1; b--) + if (jcp.nb_oc % b == 0) { + jcp.nb_oc_blocking = b; + break; + } + } + } + + if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + args_ok = true + && jcp.oc % simd_w == 0 + && jcp.l_pad <= jcp.ur_w + && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0) + || (jcp.stride_w == 1 && jcp.stride_h == 1)) + && IMPLICATION(mimo, jcp.ic % simd_w == 0); + if (!args_ok) return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + + if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) { + /* recalculate ur_w, nb_oc_blocking and ur_w_tail */ + jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail, + nstl::min(jcp.ow, num_avail_regs / 2)); + jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + /* check again ... */ + r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail)) + return status::unimplemented; + } + assert(jcp.nb_oc_blocking > 0); + assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs); + + jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.nb_ic_blocking = 12; + jcp.nb_ic_blocking_max = 16; + } else { + jcp.nb_ic_blocking = 1; + jcp.nb_ic_blocking_max = jcp.nb_ic_blocking; + } + + return status::success; +} + +void jit_avx2_conv_fwd_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +void jit_avx2_conv_bwd_data_kernel_f32::compute_loop(int ur_w, int l_overflow, + int r_overflow) +{ + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int ow = jcp.ow; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_ic_block = jcp.nb_ic_blocking; + int stride_w = jcp.stride_w; + int stride_h = jcp.stride_h; + + Label kd_loop, skip_kd_loop; + Label oc_loop, skip_oc_loop; + + for (int ii = 0; ii < nb_ic_block; ii++) + for (int jj = 0; jj < ur_w; jj++) { + uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), + Ymm(ur_w * ii + jj)); + } + + if (one_of(jcp.ndims, 3, 4)) { + cmp(reg_channel_work, 0); + jle(skip_oc_loop, T_NEAR); + xor_(reg_channel, reg_channel); + + mov(aux_reg_ddst_oc_loop, reg_ddst); + mov(aux_reg_kernel_oc_loop, reg_kernel); + + L(oc_loop); + mov(aux_reg_ddst, aux_reg_ddst_oc_loop); + mov(aux_reg_kernel, aux_reg_kernel_oc_loop); + } + + if (jcp.ndims == 5) { + assert(jcp.nb_oc_blocking == 1); + push(oi_iter); + + mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_ddst); + mov(aux_reg_ker_d, ptr[this->param1 + GET_OFF(filt)]); + + L(kd_loop); + mov(kj, ptr[this->param1 + GET_OFF(kh_padding)]); + } else { + mov(kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_ddst, aux_reg_dst_d); + mov(aux_reg_kernel, aux_reg_ker_d); + } + + Label kh_loop, skip_kh_loop; + cmp(kj, 0); + jle(skip_kh_loop, T_NEAR); + L(kh_loop); { + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_iw_start(ki, l_overflow); // 0; + int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w; + for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) { + + for (int jj = jj_start ; jj < jj_end; jj += stride_w) { + int aux_output_offset + = (jj + jcp.l_pad - ki) / stride_w * jcp.oc_block + ofm2; + vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w), + ptr[aux_reg_ddst + + sizeof(float) * aux_output_offset]); + } + + for (int ii = 0; ii < nb_ic_block; ii++) { + int aux_kernel_offset + = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block + + ki * jcp.ic_block * jcp.oc_block + + ofm2 * jcp.ic_block; + vmovups(ymm15, + ptr[aux_reg_kernel + + sizeof(float) * aux_kernel_offset]); + for (int jj = jj_start; jj < jj_end; jj += stride_w) + vfmadd231ps(Ymm(ur_w * ii + jj), + Ymm(nb_ic_block * ur_w + jj / stride_w), ymm15); + } + } + } + add(aux_reg_kernel, sizeof(float) * stride_h * kw * oc_block + * ic_block); + sub(aux_reg_ddst, sizeof(float) * ow * oc_block); + + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + L(skip_kh_loop); + + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, + sizeof(float) * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d, + sizeof(float) * jcp.kw * jcp.kh * oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_loop, T_NEAR); + L(skip_kd_loop); + + pop(oi_iter); + } + + if (one_of(jcp.ndims, 3, 4)) { + int ddst_oc_shift = sizeof(float) * jcp.od * jcp.oh * jcp.ow + * jcp.oc_block; + int kernel_oc_shift = sizeof(float) * jcp.kd * jcp.kh * jcp.kw + * jcp.ic * jcp.oc_block; + + add(aux_reg_ddst_oc_loop, ddst_oc_shift); + add(aux_reg_kernel_oc_loop, kernel_oc_shift); + + inc(reg_channel); + cmp(reg_channel, reg_channel_work); + jl(oc_loop, T_NEAR); + + L(skip_oc_loop); + mov(reg_channel, ptr[param1 + GET_OFF(channel)]); + } + + Label no_update_label; + cmp(reg_channel, 0); + je(no_update_label, T_NEAR); + for (int ii = 0; ii < nb_ic_block; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + size_t offt = + sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block; + vmovups(Ymm(15), + make_safe_addr(reg_dsrc, offt, reg_long_offt)); + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), + Ymm(15)); + + } + } + L(no_update_label); + + for (int ii = 0; ii < nb_ic_block; ii++) + for (int jj = 0; jj < ur_w; jj++) { + size_t offt = + sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block; + vmovups(make_safe_addr(reg_dsrc, offt, reg_long_offt), + Ymm(ur_w * ii + jj)); + } +} + +void jit_avx2_conv_bwd_data_kernel_f32::generate() { + preamble(); + + mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); + mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_channel, ptr[param1 + GET_OFF(channel)]); + mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]); + + int ddst_shift = sizeof(float) * (jcp.ur_w / jcp.stride_w) * jcp.ic_block; + int dsrc_shift = sizeof(float) * jcp.ur_w * jcp.oc_block; + + int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w); + int r_overflow = nstl::max(0, (jcp.kw - 1 + - nstl::max(0, jcp.r_pad)) / jcp.stride_w); + int r_overflow1 = nstl::max(0, (jcp.kw - 1 + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); + + int n_oi = jcp.iw / jcp.ur_w; + if (r_overflow1 > 0) + n_oi--; + + if (jcp.ur_w == jcp.iw) { + compute_loop(jcp.ur_w, l_overflow, r_overflow); + } else if (n_oi == 0) { + compute_loop(jcp.ur_w, l_overflow, r_overflow1); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + if (jcp.ur_w_tail != 0) + compute_loop(jcp.ur_w_tail, 0, r_overflow); + } else { + xor_(oi_iter, oi_iter); + if (l_overflow > 0) { + compute_loop(jcp.ur_w, l_overflow, 0); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + inc(oi_iter); + } + + if ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1)) { + Label ow_loop; + L(ow_loop); { + compute_loop(jcp.ur_w, 0, 0); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + inc(oi_iter); + cmp(oi_iter, n_oi); jl(ow_loop, T_NEAR); + } + } + + if (r_overflow1 > 0 ) { + compute_loop(jcp.ur_w, 0, r_overflow1); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + } + + if (jcp.ur_w_tail != 0) + compute_loop(jcp.ur_w_tail, 0, r_overflow); + } + + this->postamble(); +} + +status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + if (!mayiuse(avx2)) return status::unimplemented; + + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; + + int ndims = diff_src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = diff_src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2]; + jcp.iw = diff_src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + const int simd_w = 8; + + /* derivatives */ + jcp.idp = jcp.id + 2 * jcp.f_pad; + jcp.ihp = jcp.ih + 2 * jcp.t_pad; + jcp.iwp = jcp.iw + 2 * jcp.l_pad; + jcp.ohp = jcp.oh; /* do we really need */ + jcp.owp = jcp.ow; /* padded output ??? */ + + bool ok_to_pad_channels = true + && jcp.ngroups == 1; + + /* gemm-based convolution performs better in these cases */ + if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1) + return status::unimplemented; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + jcp.ic_block = (jcp.ic % simd_w) ? 1 : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.oc_block = simd_w; + if (jcp.oc % jcp.oc_block) return status::unimplemented; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + jcp.nb_ic_blocking = 1; + jcp.nb_oc_blocking = 1; + jcp.ur_w = 1; + + if(one_of(ndims, 3, 4) && jcp.ow < 40) + jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2; + + if (ndims == 3) { + jcp.src_tag = diff_src_d.matches_one_of_tag(nCw8c); + jcp.wei_tag = weights_d.matches_one_of_tag(OIw8i8o, gOIw8o8i); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = diff_src_d.matches_one_of_tag(nChw8c); + jcp.wei_tag = weights_d.matches_one_of_tag(OIhw8o8i, gOIhw8o8i); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c); + } else if (ndims == 5) { + jcp.src_tag = diff_src_d.matches_one_of_tag(nCdhw8c); + jcp.wei_tag = weights_d.matches_one_of_tag(OIdhw8o8i, gOIdhw8o8i); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c); + } + + bool args_ok = true + && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) + && one_of(jcp.wei_tag, gOIw8o8i, OIw8i8o, gOIhw8o8i, OIhw8o8i, + gOIdhw8o8i, OIdhw8o8i) + && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c) + && jcp.stride_w == jcp.stride_h + && jcp.stride_d == 1 + && jcp.dilate_d == 0 + && jcp.dilate_h == 0 + && jcp.dilate_w == 0 + && jcp.ic % simd_w == 0 + && jcp.oc % simd_w == 0 + && jcp.od == (jcp.idp - jcp.kd) / jcp.stride_d + 1 + && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 + && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1; + if (!args_ok) return status::unimplemented; + jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad; + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad; + int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w); + + const int max_regs = 15; /* Maximun number of registers available for + result accumulation and delta dst data. + One additional register is reserved for weights + data. */ + + /* Find the best blocking with maximum number of fma instructions + per ur_w * nb_ic_blocking compute loops. Number of required registers + is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs. + ur_w must be divisible by stride_w */ + if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers + distribution exceeds max_regs */ + return status::unimplemented; + + int best_nfmas = 0; + for (int b = 1; b <= 4; b++) + { + if (jcp.nb_ic % b != 0) + continue; + + for (int u = jcp.stride_w; + u * b + u / jcp.stride_w <= max_regs && u < jcp.iw + jcp.stride_w; + u += jcp.stride_w) + { + int ur_w = nstl::min(u, jcp.iw); + /* maximum 1 step with l_overflow so far */ + if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw) + continue; + int nfmas = utils::div_up(ur_w, jcp.stride_w) * b; + if (nfmas > best_nfmas + || (nfmas == best_nfmas && jcp.ur_w < ur_w)) { + jcp.ur_w = ur_w; + jcp.nb_ic_blocking = b; + best_nfmas = nfmas; + } + } + } + if (best_nfmas == 0) /* can't find appropriate blocking */ + return status::unimplemented; + + jcp.ur_w_tail = jcp.iw % jcp.ur_w; + + int r_overflow_no_tail = nstl::max(0, (jcp.kw - 1 - jcp.ur_w_tail + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); + /* maximum 1 ur_w block with r_overflow so far */ + if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w) + return status::unimplemented; + + if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0)) + return status::unimplemented; + + return status::success; +} + +void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + UNUSED(scratchpad); + UNUSED(jcp); +} + +void jit_avx2_conv_bwd_weights_kernel_f32::generate() { + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + compute_oh_loop_common(); + this->postamble(); +} + +status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d) { + if (!mayiuse(avx2)) return status::unimplemented; + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2]; + jcp.kw = diff_weights_d.dims()[with_groups + ndims-1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + if (ndims == 3) { + jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); + jcp.wei_tag = diff_weights_d.matches_one_of_tag( + Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); + jcp.wei_tag = diff_weights_d.matches_one_of_tag( + Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c); + } else if (ndims == 5) { + jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c); + jcp.wei_tag = diff_weights_d.matches_one_of_tag( + Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c); + } + jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; + + const bool flat = jcp.ic == 3; + const bool mimo = !flat; + + const int simd_w = 8; + + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + + int back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - jcp.id + - jcp.f_pad); + if (ndims == 5) + if (jcp.f_pad != 0 || back_pad != 0) + return status::unimplemented; + + const int max_h_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1); + const int max_w_pad = ((jcp.kw - 1) * (jcp.dilate_w + 1) + 1); + const bool boundaries_ok = true + && jcp.t_pad < max_h_pad && jcp.b_pad < max_h_pad + && jcp.l_pad < max_w_pad && jcp.r_pad < max_w_pad; + if (!boundaries_ok) + return status::unimplemented; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + if (mimo) + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + bool args_ok = true + && IMPLICATION(flat, true + && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc) + && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o, + gOdhwi8o)) + && IMPLICATION(mimo, true + && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) + && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o, + OIdhw8i8o, gOIdhw8i8o)) + && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c) + && IMPLICATION(mimo, jcp.ic % simd_w == 0) + && jcp.oc % simd_w == 0 + && jcp.kw < 14 + && jcp.kh <= jcp.t_pad + jcp.ih /* [bwd_w:r1] */ + && jcp.kh <= jcp.ih /* [bwd_w:r2] */ + && jcp.kd <= jcp.f_pad + jcp.id + && jcp.kd <= jcp.id + && jcp.t_pad < jcp.kh /* XXX: must fix the kernel! */ + && jcp.dilate_d == 0 + && jcp.dilate_h == 0 + && jcp.dilate_w == 0; + if (!args_ok) return status::unimplemented; + + jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.oc_block = simd_w; + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + + return status::success; +} + +void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() +{ + Label kd_comeback_loop; + mov(kj, jcp.kd); //FIXME (Anton): this works only if f_pad = back_pad = 0 + L(kd_comeback_loop); { + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : jcp.ic_block; + sub(aux_reg_input, sizeof(float) * jcp.iw * jcp.ih * inp_mult); + sub(aux_reg_kernel, sizeof(float) * jcp.kw * jcp.kh * jcp.ic_block + * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kd_comeback_loop, T_NEAR); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() +{ + mov(kj, reg_kh); + Label kh_comeback_loop; + L(kh_comeback_loop); { + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : jcp.ic_block; + sub(reg_input, sizeof(float) * jcp.iw * inp_mult); + sub(reg_kernel, sizeof(float) * jcp.kw * jcp.ic_block * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_comeback_loop, T_NEAR); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_ic_block_step( + int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset, + int kernel_offset, int output_offset) +{ + const int kw = jcp.kw; + const int ic_block = jcp.ic_block; + const int oc_block = jcp.oc_block; + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + size_t off + = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block + + kernel_offset; + vmovups(Ymm(i_kw * ic_block_step + i_ic), yword[reg_kernel + off]); + } + + for (int i_ur = 0; i_ur < ur_w; i_ur++) { + vmovups(Ymm(kw * ic_block_step + 0), + yword[reg_output + + sizeof(float) * i_ur * oc_block + output_offset]); + + for (int i_kw = 0; i_kw < kw; i_kw++) { + int i_iw = i_ur * jcp.stride_w + i_kw; + if (i_iw - pad_l < 0 + || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1 - pad_r) + continue; + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + size_t i_off = (size_t)input_offset + sizeof(float)*( + one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? (i_iw - pad_l) + i_ic + * ((size_t)jcp.id * jcp.ih * jcp.iw) + : (i_iw - pad_l) * ic_block + i_ic); + vbroadcastss(Ymm(kw * ic_block_step + 1), + make_safe_addr(reg_input, i_off, reg_long_offt)); + vfmadd231ps(Ymm(i_kw * ic_block_step + i_ic), + Ymm(kw * ic_block_step + 0), + Ymm(kw * ic_block_step + 1)); + } + } + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + size_t off + = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block + + kernel_offset; + vmovups(yword[reg_kernel + off], + Ymm(i_kw * ic_block_step + i_ic)); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_disp() +{ + int ic_block_step; + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) { + ic_block_step = jcp.kw >= 5 ? 1 : jcp.ic_block; + } else { + ic_block_step = jcp.kw > 7 ? 1 + : jcp.kw > 3 ? 2 + : jcp.kw > 1 ? 4 : 8; + } + + const int max_ur_w = jcp.ow > 56 ? 14 : 28; + + if (jcp.ow <= max_ur_w) + compute_oh_step_unroll_ow(ic_block_step, max_ur_w); + else + compute_oh_step_common(ic_block_step, max_ur_w); + + if (jcp.ndims == 5) { + od_step_comeback_pointers(); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } else { + oh_step_comeback_pointers(); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_unroll_ow( + int ic_block_step, int max_ur_w) +{ + UNUSED(max_ur_w); + + const int ic_block = jcp.ic_block; + const int oc_block = jcp.oc_block; + int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block; + Label kd_loop; + + const int r_pad + = nstl::max(0, + (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + + if (jcp.ndims == 5) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + mov(ki, jcp.kd); + L(kd_loop); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + Label kh_loop; + L(kh_loop); { + xor_(b_ic, b_ic); + Label ic_block_loop; + L(ic_block_loop); { + compute_ic_block_step(jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0, + 0, 0); + size_t inp_icblk_stride = sizeof(float) * ic_block_step + * (one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? jcp.id*jcp.ih*jcp.iw : 1); + safe_add(reg_input, inp_icblk_stride, reg_long_offt); + add(reg_kernel, sizeof(float) * ic_block_step * oc_block); + add(b_ic, ic_block_step); + cmp(b_ic, ic_block); + jl(ic_block_loop, T_NEAR); + } + if(one_of(jcp.src_tag, ncw, nchw, ncdhw)) { + size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, offt, reg_long_offt); + add(reg_input, sizeof(float) * jcp.iw); + } else { + add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block); + } + add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_loop, T_NEAR); + } + +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_common( + int ic_block_step, int max_ur_w) +{ + const int ic_block = jcp.ic_block; + const int oc_block = jcp.oc_block; + const int stride_w = jcp.stride_w; + int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block; + Label kd_loop; + + const int r_pad = jcp.r_pad; + + int ur_w = nstl::min(jcp.ow, max_ur_w); + int ur_w_trips = jcp.ow / ur_w; + int ur_w_tail = jcp.ow % ur_w; + if ((ur_w_tail == 0 && r_pad != 0) || r_pad >= ur_w_tail) { + if (ur_w_trips > 1) { + ur_w_tail += ur_w; + ur_w_trips--; + } else { + ur_w_tail += (ur_w - ur_w / 2); + ur_w = ur_w / 2; + } + } + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_block; + + int input_comeback = (ur_w_trips * ur_w * stride_w - jcp.l_pad) * inp_mult; + int output_comeback = ur_w_trips * ur_w * oc_block; + + if (jcp.ndims == 5) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + mov(ki, jcp.kd); + L(kd_loop); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + Label kh_loop; + L(kh_loop); { + xor_(b_ic, b_ic); + Label ic_block_loop; + L(ic_block_loop); { + if (jcp.l_pad != 0) { + ur_w_trips--; + compute_ic_block_step(ur_w, + jcp.l_pad, 0, ic_block_step, 0, 0, 0); + add(reg_input, sizeof(float) + * (ur_w * stride_w - jcp.l_pad) * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_block); + } + + if (ur_w_trips > 0) { + xor_(reg_ur_w_trips, reg_ur_w_trips); + Label ow_block_loop; + L(ow_block_loop); { + compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0); + add(reg_input, sizeof(float) * ur_w * stride_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_block); + + inc(reg_ur_w_trips); + cmp(reg_ur_w_trips, ur_w_trips); + jl(ow_block_loop, T_NEAR); + } + } + + if (ur_w_tail > 0) + compute_ic_block_step(ur_w_tail, + 0, r_pad, ic_block_step, 0, 0, 0); + + sub(reg_input, sizeof(float) * input_comeback); + sub(reg_output, sizeof(float) * output_comeback); + + size_t inp_icblk_stride = sizeof(float) * ic_block_step + * (one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? jcp.id*jcp.ih*jcp.iw : 1); + safe_add(reg_input, inp_icblk_stride, reg_long_offt); + add(reg_kernel, sizeof(float) * ic_block_step * oc_block); + + add(b_ic, ic_block_step); + cmp(b_ic, jcp.ic_block); + jl(ic_block_loop, T_NEAR); + } + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) { + size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, offt, reg_long_offt); + add(reg_input, sizeof(float) * jcp.iw); + } else { + add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block); + } + add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_loop, T_NEAR); + } + +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_loop_common() +{ + const int icoc_block = jcp.ic_block * jcp.oc_block; + const int t_pad = jcp.t_pad; + const int stride_h = jcp.stride_h; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : jcp.ic_block; + int b_pad = jcp.b_pad; + + Label oh_tpad_loop, oh_loop, oh_loop_end; + + mov(reg_kh, jcp.kh); + xor_(reg_ih_count, reg_ih_count); + xor_(reg_oj, reg_oj); + if (t_pad > 0) { + assert(jcp.kh <= t_pad + jcp.ih); /* [bwd_w:r1] */ + mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih); + add(reg_kernel, sizeof(float) * t_pad * jcp.kw * icoc_block); + + L(oh_tpad_loop); { + compute_oh_step_disp(); + add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); + sub(reg_kernel, sizeof(float) * stride_h * jcp.kw * icoc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + add(reg_kh, stride_h); + + /* the overlap between input and kernel may not reach kernel size. + * so far we do not support that (until we put constant here) */ + const int final_inp_ker_overlap = jcp.kh; /* [bwd_w:r2] */ + cmp(reg_kh, final_inp_ker_overlap); + jl(oh_tpad_loop, T_NEAR); + } + + if (t_pad % stride_h != 0) { + int inp_corr = stride_h - t_pad % stride_h; + add(reg_kernel, sizeof(float) * inp_corr * jcp.kw * icoc_block); + add(reg_input, sizeof(float) * inp_corr * jcp.iw * inp_mult); + } + } + cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1); + jge(oh_loop_end, T_NEAR); + cmp(reg_oj, jcp.oh); + jge(oh_loop, T_NEAR); + + mov(reg_kh, jcp.kh); + L(oh_loop); { + compute_oh_step_disp(); + add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult); + add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + + cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1); + jge(oh_loop_end, T_NEAR); + + cmp(reg_oj, jcp.oh); + jl(oh_loop, T_NEAR); + } + L(oh_loop_end); + if (b_pad > 0) { + Label oh_bpad_loop, oh_bpad_loop_end; + cmp(reg_oj, jcp.oh); + jge(oh_bpad_loop_end, T_NEAR); + + mov(reg_kh, jcp.ih + t_pad); + sub(reg_kh, reg_ih_count); + L(oh_bpad_loop); { + compute_oh_step_disp(); + add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult); + add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); + + sub(reg_kh, stride_h); + cmp(reg_kh, 0); + jle(oh_bpad_loop_end, T_NEAR); + + inc(reg_oj); + cmp(reg_oj, jcp.oh); + jl(oh_bpad_loop, T_NEAR); + } + L(oh_bpad_loop_end); + } +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp new file mode 100644 index 000000000000..412c50c9ee2d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp @@ -0,0 +1,225 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX2_CONV_KERNEL_F32_HPP +#define JIT_AVX2_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_conv_fwd_kernel_f32: public jit_generator { + jit_avx2_conv_fwd_kernel_f32(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + ~jit_avx2_conv_fwd_kernel_f32() { + delete eltwise_injector_; + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_fwd_kernel_f32) + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + reg64_t reg_input = rax; + reg64_t aux_reg_input = r8; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r9; + reg64_t reg_output = rsi; + reg64_t reg_bias = rbx; + + reg64_t aux_reg_inp_d = r11; + reg64_t aux_reg_ker_d = abi_not_param1; + + reg64_t reg_ki = rsi; + reg64_t kj = r10; + reg64_t oi_iter = r11; + reg64_t ki_iter = r12; + reg64_t reg_kh = abi_not_param1; + reg64_t reg_oc_blocks = r14; + reg64_t imm_addr64 = r15; + reg64_t reg_long_offt = r15; + Xbyak::Reg32 reg_ci_flag = r13d; + + Xbyak::Ymm ytmp = Xbyak::Ymm(14); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, + int oc_blocks); + inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, + char pad_label, int oc_blocks, char oc_blocks_label); + inline void width_blk_step(int ur_w, int pad_l, int pad_r, + char pad_label, int oc_blocks, char oc_blocks_label); + inline void solve_common(int oc_blocks, char oc_blocks_label); + + void generate(); +}; + +struct jit_avx2_conv_bwd_data_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_data_kernel_f32) + + jit_avx2_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) + { + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + + reg64_t reg_ddst = rax; + reg64_t aux_reg_ddst = r8; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r10; + reg64_t reg_dsrc = rsi; + reg64_t aux_reg_ddst_oc_loop = rbx; // used in ndims < 5 case only + reg64_t aux_reg_kernel_oc_loop = abi_not_param1; /* used in ndims < 5 + case only */ + + reg64_t aux_reg_dst_d = r12; // used in ndims == 5 case only + reg64_t aux_reg_ker_d = r14; // used in ndims == 5 case only + + reg64_t reg_ki = abi_not_param1; // used in ndims == 5 case only + reg64_t kj = r11; + reg64_t oi_iter = r12; + reg64_t reg_kh = r14; + reg64_t reg_channel = r13; // used in ndims < 5 case only + reg64_t reg_channel_work = r9; // used in ndims < 5 case only + reg64_t reg_long_offt = r15; + + inline void compute_loop(int ur_w, int l_overflow, int r_overflow); + + void generate(); + + inline int get_iw_start(int ki, int l_overflow) + { + int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w + + l_overflow * jcp.stride_w + - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return res; + } + + inline int get_iw_end(int ur_w, int ki, int r_overflow) + { + if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) + ur_w += nstl::min(0, jcp.r_pad); // remove negative padding + int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w + + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return ur_w - res; + } +}; + +struct jit_avx2_conv_bwd_weights_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_weights_kernel_f32) + + jit_avx2_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) + { + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + reg64_t reg_input = rax; + reg64_t reg_kernel = rdx; + reg64_t reg_output = rsi; + reg64_t b_ic = abi_not_param1; + reg64_t kj = r8; + reg64_t reg_kh = r9; + reg64_t reg_ur_w_trips = r10; + reg64_t reg_tmp = r11; + reg64_t reg_oj = r15; + reg64_t reg_ih_count = rbx; + reg64_t aux_reg_input = r12; + reg64_t aux_reg_kernel = r13; + reg64_t ki = r14; + reg64_t reg_long_offt = r11; + + inline void od_step_comeback_pointers(); + inline void oh_step_comeback_pointers(); + inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset); + inline void compute_oh_step_disp(); + inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); + inline void compute_oh_step_common(int ic_block_step, int max_ur_w); + inline void compute_oh_loop_common(); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp new file mode 100644 index 000000000000..13f61e84fef2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp @@ -0,0 +1,410 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx2_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +#define src_blk_off(f, n, c, d, h, w) \ + (pd()->ndims() == 3) \ + ? (f).blk_off(n, c, w) \ + : (pd()->ndims() == 4) \ + ? (f).blk_off(n, c, h, w) \ + : (f).blk_off(n, c, d, h, w) + +#define wht_blk_off_(f, g, ...) \ + pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__) +#define wht_blk_off(f, g, oc, ic, kd, kh, kw) \ + (pd()->ndims() == 3) \ + ? wht_blk_off_(f, g, oc, ic, kw) \ + : (pd()->ndims() == 4) \ + ? wht_blk_off_(f, g, oc, ic, kh, kw) \ + : wht_blk_off_(f, g, oc, ic, kd, kh, kw) + +void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = kernel_->jcp; + + int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); + const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.od + * jcp.oh; + + auto ker = [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int icbb = 0; + while (icbb < jcp.nb_ic) { + int icb_step = jcp.nb_ic_blocking; + int icb_step_rem = jcp.nb_ic - icbb; + if (icb_step_rem < jcp.nb_ic_blocking_max) + icb_step = icb_step_rem; + + size_t n{0}, g{0}, ocbb{0}, oh{0}, od{0}; + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + od, jcp.od, oh, jcp.oh); + for (size_t iwork = start; iwork < end; ++iwork) { + int ocb = ocbb * jcp.nb_oc_blocking; + int ocb_num = jcp.nb_oc_blocking; + + for (int icb = icbb; icb < icbb + icb_step; ++icb) { + auto par_conv = jit_conv_call_s(); + + const int ij = oh * jcp.stride_h; + const int i_t_overflow = nstl::max(0, jcp.t_pad - ij); + const int i_b_overflow = nstl::max(jcp.ih, ij + + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih; + + const int dj = od * jcp.stride_d; + const int d_t_overflow = nstl::max(0, jcp.f_pad - dj); + const int d_b_overflow = nstl::max(jcp.id, dj + + (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id; + + const size_t _oc = g * jcp.nb_oc + ocb; + const size_t _ic = g * jcp.nb_ic * jcp.nonblk_group_off + icb; + + const int ih = nstl::max(ij - jcp.t_pad + + div_up(i_t_overflow, + (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0); + + const int id = nstl::max(dj - jcp.f_pad + + div_up(d_t_overflow, + (jcp.dilate_d+1)) * (jcp.dilate_d + 1), 0); + + par_conv.src = &src[src_blk_off(src_d, n, + jcp.ic == 3 ? 0 : _ic, id, ih, 0)]; + + par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)]; + + const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); + const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1)); + par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb, + jcp.ic == 3 ? 0 : icb, wd, wh, 0)]; + + if (icb == 0) { + if (bias) + par_conv.bias = + &bias[bias_d.blk_off(_oc * jcp.oc_block)]; + par_conv.flags |= FLAG_IC_FIRST; + } + + if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) { + par_conv.flags |= FLAG_IC_LAST; + } + + par_conv.oc_blocks = + nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb; + + par_conv.kw_padding = 0; + const int kh_padding = jcp.kh + - div_up(i_t_overflow, (jcp.dilate_h + 1)) + - div_up(i_b_overflow, (jcp.dilate_h + 1)); + par_conv.kh_padding = nstl::max(0, kh_padding); + + const int kd_padding = jcp.kd + - div_up(d_t_overflow, (jcp.dilate_d + 1)) + - div_up(d_b_overflow, (jcp.dilate_d + 1)); + par_conv.kd_padding = nstl::max(0, kd_padding); + + kernel_->jit_ker(&par_conv); + } + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + od, jcp.od, oh, jcp.oh); + } + icbb += icb_step; + } + }; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad(ctx).get(key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + parallel(0, ker); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +void jit_avx2_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + int icb_work = jcp.nb_ic / jcp.nb_ic_blocking; + int ih_block_size = jcp.ih; + int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); + size_t work_amount = jcp.mb * jcp.ngroups * icb_work * num_ih_blocks; + if (work_amount < (size_t)2 * mkldnn_get_max_threads()) { + ih_block_size = 1; + num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); + work_amount *= num_ih_blocks; + } + + auto ker = [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + size_t n{0}, g{0}, icbb{0}, ihb{0}; + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, icbb, icb_work, + ihb, num_ih_blocks); + for (size_t iwork = start; iwork < end; ++iwork) { + for (int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking) + for (int id = 0; id < jcp.id; ++id) { + auto par_conv = jit_conv_call_s(); + + const int idp = jcp.id + 2 * jcp.f_pad; + const int d_t_overflow = nstl::max(0, + jcp.kd - 1 - id - jcp.f_pad); + const int back_pad = idp - jcp.id - jcp.f_pad; + const int d_b_overflow = nstl::max(0, + jcp.kd - 1 - (jcp.id - 1 - id) - back_pad); + const int od = id + jcp.f_pad - d_b_overflow; + + int ih_start = ihb * ih_block_size; + int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size); + for (int ih = ih_start; ih < ih_end; ++ih) { + + const int i_t_overflow = nstl::max(0, (jcp.kh - 1 + - ih - jcp.t_pad) / jcp.stride_h); + const int i_b_overflow = nstl::max(0, (jcp.kh - jcp.ih + + ih - jcp.b_pad) / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 + + jcp.b_pad - ih) % jcp.stride_h); + int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h; + + par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow; + par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo) + / jcp.stride_h + 1 - i_t_overflow - i_b_overflow; + par_conv.kw_padding = 0; + + const int k_lo = overflow_kh_lo + + i_b_overflow * jcp.stride_h; + const int oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h; + + par_conv.src = &diff_src[src_blk_off(diff_src_d, n, + /*jcp.ic == 3 ? 0 :*/ + g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)]; + par_conv.dst = &diff_dst[src_blk_off(diff_dst_d, + n, g * jcp.nb_oc + oc, od, oh, 0)]; + par_conv.filt = &weights[wht_blk_off(weights_d, g, oc, + jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb, + d_b_overflow, k_lo, 0)]; + + par_conv.src_prf = nullptr; + par_conv.dst_prf = nullptr; + par_conv.filt_prf = nullptr; + par_conv.channel = oc; + par_conv.ch_blocks = nstl::min(jcp.nb_oc - oc, + jcp.nb_oc_blocking); + + kernel_->jit_ker(&par_conv); + } + } + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb, + num_ih_blocks); + } + }; + + parallel(0, ker); +} + +void jit_avx2_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto scratchpad = this->scratchpad(ctx); + + data_t *diff_bias = pd()->wants_padded_bias() + ? scratchpad.get(key_conv_padded_bias) : diff_bias_in; + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + + auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_wei); + auto rw = this->reducer_weights_; + rw->init(reducer_wei_scratchpad); + + auto ker = [&](int ithr, int nthr) { + assert(nthr == rw->balancer().nthr_); + + const int w_job_start = rw->balancer().ithr_job_off(ithr); + const int w_njobs = rw->balancer().ithr_njobs(ithr); + + if (w_njobs == 0) return; + + /* reduction dimension */ + int img_od_start{0}, img_od_end{0}, img{0}, od_s{0}; + balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_, + rw->balancer().id_in_group(ithr), img_od_start, img_od_end); + + int img_start = img_od_start, img_end = img_od_end; + nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od); + const int img_first = img; + + /* jobs */ + int g_start{0}, ocb_start{0}, icb_start{0}; + nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start, + jcp.nb_oc, icb_start, jcp.nb_ic); + + while (img_start < img_end) { + int g = g_start, ocb = ocb_start, icb = icb_start; + + const int work_rem = img_end - img_start; + const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; + const int id_s = od_s * jcp.stride_d; + const int idp = jcp.id + jcp.f_pad + jcp.back_pad; + + if (id_s < idp - jcp.back_pad - jcp.kd + 1) + for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) { + const size_t _oc = g * jcp.nb_oc + ocb; + const size_t _ic = g * jcp.nb_ic + icb; + + /* TODO: put dw <-- 0 in kernel */ + if (img == img_first) + array_set(rw->get_local_ptr(ithr, diff_weights, + reducer_wei_scratchpad) + + w_job_loc * rw->balancer().job_size_, 0, + rw->balancer().job_size_); + + for (int od = od_s; od < od_e; ++od) { + const int id = od * jcp.stride_d; + if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break; + + auto par_conv = jit_conv_call_s(); + par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)]; + par_conv.dst = + &diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)]; + par_conv.filt = rw->get_local_ptr(ithr, diff_weights, + reducer_wei_scratchpad) + + w_job_loc * rw->balancer().job_size_; + + kernel_->jit_ker(&par_conv); + } + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, icb, + jcp.nb_ic); + } + nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od); + } + rw->reduce(ithr, diff_weights, reducer_wei_scratchpad); + }; + + auto ker_bias = [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + + const int b_job_start = rb->balancer().ithr_job_off(ithr); + const int b_njobs = rb->balancer().ithr_njobs(ithr); + + if (b_njobs == 0) return; + + /* reduction dimension */ + int img_start{0}, img_end{0}; + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ithr), img_start, img_end); + + /* jobs */ + int g_start{0}, ocb_start{0}; + nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, + jcp.nb_oc); + + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * jcp.nb_oc + ocb; + + const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; + data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, + reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 8; ++o) + d_bias[o] = 0.; + + for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 8; ++o) + d_bias[o] += d_dst[o]; + d_dst += 8; + } + + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc); + } + } + rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); + }; + + parallel(0, [&](const int ithr, const int nthr) { + ker(ithr, nthr); + if (pd()->with_bias()) + ker_bias(ithr, nthr); + }); + + /* TODO: put this in ker_bias */ + if (pd()->wants_padded_bias()) { + assert(jcp.ngroups == 1); + for (int oc = 0; oc < jcp.oc_without_padding; ++oc) + diff_bias_in[oc] = diff_bias[oc]; + } +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp new file mode 100644 index 000000000000..bb65bce79cde --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp @@ -0,0 +1,302 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX2_CONVOLUTION_HPP +#define CPU_JIT_AVX2_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_reducer.hpp" + +#include "jit_avx2_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx2, ""), + jit_avx2_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx2_conv_fwd_kernel_f32::init_conf(jcp_, + *desc(), src_md(), weights_md(), dst_md(), *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + const bool flat = IC() < 8; + auto src_tag = flat + ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) + : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto dst_tag = + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, + gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) + : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, + OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); + + return set_default_formats_common(src_tag, wei_tag, dst_tag); + } + }; + + jit_avx2_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_avx2_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); } + ~jit_avx2_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_conv_fwd_kernel_f32 *kernel_; +}; + +struct jit_avx2_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx2, ""), + jit_avx2_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx2_conv_bwd_data_kernel_f32::init_conf( + jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(scratchpad, + jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i) + : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + jit_avx2_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_avx2_conv_bwd_data_kernel_f32(pd()->jcp_); } + ~jit_avx2_convolution_bwd_data_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_conv_bwd_data_kernel_f32 *kernel_; +}; + +struct jit_avx2_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx2, ""), + jit_avx2_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx2_conv_bwd_weights_kernel_f32::init_conf( + jcp_, *desc(), *src_md(), *diff_weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(scratchpad, + jcp_); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_wei); + reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad); + + return status::success; + } + + jit_conv_conf_t jcp_; + cpu_reducer_t::conf_t reducer_bia_conf_; + cpu_reducer_t::conf_t reducer_wei_conf_; + + protected: + bool set_default_formats() { + using namespace format_tag; + const bool flat = IC() == 3; + + auto src_tag = flat + ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) + : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto dst_tag = + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, + gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) + : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, + OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); + + return set_default_formats_common(src_tag, wei_tag, dst_tag); + } + + private: + void init_balancers() { + const int max_threads = mkldnn_get_max_threads(); + const size_t max_buffer_size = 1<<21; /* just a heuristic */ + + if(with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(max_threads, + jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb, + max_buffer_size)); + } + + reducer_wei_conf_.init(reduce_balancer_t(max_threads, + jcp_.kd * jcp_.kh * jcp_.kw + * jcp_.ic_block * jcp_.oc_block, + jcp_.ngroups * jcp_.nb_ic * jcp_.nb_oc, + jcp_.mb * jcp_.od, max_buffer_size)); + } + }; + + jit_avx2_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr) + , reducer_weights_(nullptr) + , reducer_bias_(nullptr) + { + kernel_ = new jit_avx2_conv_bwd_weights_kernel_f32(pd()->jcp_); + reducer_bias_ = + new cpu_reducer_t(pd()->reducer_bia_conf_); + reducer_weights_ = + new cpu_reducer_t(pd()->reducer_wei_conf_); + } + + ~jit_avx2_convolution_bwd_weights_t() { + delete kernel_; + delete reducer_weights_; + delete reducer_bias_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_conv_bwd_weights_kernel_f32 *kernel_; + cpu_reducer_t *reducer_weights_, *reducer_bias_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp new file mode 100644 index 000000000000..635b83b2bf5f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp @@ -0,0 +1,1255 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" +#include "cpu_barrier.hpp" + +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_avx512_common_1x1_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_avx512_common_1x1_conv_kernel::bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_bcast_data, reg_bcast_data); + + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt)); + + if (jcp.ver == ver_4fma) + { + Label bcast_loop; + Label bcast_loop_wraparound; + Label bcast_loop_out; + Label bcast_loop_ur_full; + + cmp(bcast_loop_iter, jcp.ur); + jle(bcast_loop_wraparound, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + reduce_loop(load_loop_blk, jcp.ur, i, false); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } + else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jg(bcast_loop, T_NEAR); + } + + L(bcast_loop_wraparound); + if (jcp.ur_tail) { + je(bcast_loop_ur_full, T_NEAR); + reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); + jmp(bcast_loop_out, T_NEAR); + } + L(bcast_loop_ur_full); + reduce_loop(load_loop_blk, jcp.ur, 0, true); + L(bcast_loop_out); + } + else + { + Label bcast_loop; + Label bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + reduce_loop(load_loop_blk, jcp.ur, i, false); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } + else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); + L(bcast_loop_tail_out); + } + } +} + +void jit_avx512_common_1x1_conv_kernel::reduce_loop(int load_loop_blk, + int ur, int substep, bool wraparound) +{ + auto vreg_load = [=](int i_load, int i_fma) { + return Zmm(utils::rnd_up(ur * load_loop_blk, jcp.fma_step) + + jcp.fma_step * i_load + i_fma); + }; + + auto vreg_accum = [=](int i_load, int i_ur) { + return Zmm(i_ur * load_loop_blk + i_load); + }; + + auto bias_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_bias_data, + jcp.typesize_out * jcp.oc_block * i_load); + }; + + auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) { + assert(i_ur < jcp.ur); + assert(i_reduce <= jcp.reduce_loop_unroll); + int offt; + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) { + assert(jcp.reduce_loop_unroll == jcp.reduce_block); + offt = (i_reduce == jcp.reduce_loop_unroll) + ? (jcp.bcast_dim + i_ur) * jcp.reduce_loop_unroll + : i_ur * jcp.reduce_loop_unroll + i_reduce; + } else { + if (jcp.transpose_src) { + const int reduce_group = i_reduce / 4; + const int reduce_shift = i_reduce % 4; + offt = 4 * (reduce_group * jcp.ic_block + i_ur) + reduce_shift; + } + else + offt = i_reduce * jcp.ic_block + i_ur; + } + return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt, + bcast); + }; + + auto load_ptr = [=](int i_reduce, int i_load) { + int offt; + int u0 = i_reduce % jcp.reduce_loop_unroll; + int u1 = i_reduce / jcp.reduce_loop_unroll; + offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block; + return EVEX_compress_addr(aux_reg_load_data, + u1 * jcp.reduce_loop_load_step + + jcp.typesize_in * offt); + }; + + auto output_ptr = [=](int i_load, int i_ur) { + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) + return EVEX_compress_addr(aux_reg_output_data, + (i_load * jcp.bcast_dim + i_ur) * jcp.load_block + * jcp.typesize_out); + else + return ptr[aux_reg_output_data + + (i_load + ? reg_output_stride * i_load + : 0) // TODO: Xbyak should allow 0 scale + + jcp.typesize_out * jcp.load_block * i_ur]; + }; + + auto init = [=]() { + Label init_done; + Label init_zero; + + if (jcp.with_sum) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + for (int i_ur = 0; i_ur < ur; ++i_ur) { + mic_prefetcht1(output_ptr(i_load, i_ur)); + } + } + } + + if (jcp.with_bias + && one_of(jcp.prop_kind, forward_training, forward_inference)) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(init_zero, T_NEAR); + + for (int i_load = 0; i_load < load_loop_blk; i_load++) + for (int i_ur = 0; i_ur < ur; ++i_ur) + vmovups(vreg_accum(i_load, i_ur), bias_ptr(i_load)); + jmp(init_done, T_NEAR); + } + + L(init_zero); + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + vpxord(r, r, r); + } + L(init_done); + }; + + auto store = [=]() { + Label store_noadd; + if (!jcp.with_sum) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jnz(store_noadd, T_NEAR); + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + auto r = vreg_accum(i_load, i_ur); + vaddps(r, r, output_ptr(i_load, i_ur)); + } + + L(store_noadd); + if (jcp.with_eltwise) { + Label store_noeltwise; + test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + jz(store_noeltwise, T_NEAR); + + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + L(store_noeltwise); + } + + auto store_output = [=](bool output_is_aligned) { + for (int i_ur = 0; i_ur < ur; ++i_ur) + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + if (output_is_aligned && jcp.use_vmovntps) + vmovntps(output_ptr(i_load, i_ur), + vreg_accum(i_load, i_ur)); + else + vmovups(output_ptr(i_load, i_ur), + vreg_accum(i_load, i_ur)); + }; + + Label unaligned_store, end_store; + test(aux_reg_output_data, cpu_isa_traits::vlen - 1); + jnz(unaligned_store, T_NEAR); + store_output(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + store_output(false); + } + L(end_store); + }; + + auto prefetch_callback = [=](int ur, int i_reduce, int i_ur, int i_load, + bool last_block, bool wraparound, int reduce_step) + { + bool pf_ker_l1 = true; + bool pf_ker_l2 = wraparound; + int n_ops = (jcp.reduce_loop_unroll / reduce_step) * ur * load_loop_blk; + int i_op = (i_reduce / reduce_step) * ur * load_loop_blk + + i_ur * load_loop_blk + i_load; + + int n_pf_ker_l1 = pf_ker_l1 ? jcp.reduce_block : 0; + int n_pf_ker_l2 = pf_ker_l2 && wraparound ? jcp.reduce_block : 0; + int n_pf_out_l1 = jcp.use_vmovntps ? 0 : ur; + + int pf_inp_ops = n_ops / 2; // # of operations during which to pf input + int pf_inp_trigger; + if (jcp.prop_kind == backward_weights) + pf_inp_trigger = nstl::max(1, pf_inp_ops / jcp.reduce_block); + else + pf_inp_trigger = nstl::max(1, pf_inp_ops / ur); + + int n_other_pf = + load_loop_blk * (n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1); + int n_other_pf_ops = n_ops - pf_inp_ops; + int other_pf_trigger + = n_other_pf ? nstl::max(1, n_other_pf_ops / n_other_pf) : 0; + + if (i_op < pf_inp_ops && i_op % pf_inp_trigger == 0) { + // input prefetches have the highest priority b/c the + // first iteration of the kernel block touches all the + // cache lines + int i_pf = i_op / pf_inp_trigger; + auto pf_reg = wraparound && last_block + ? reg_bcast_data + : (last_block ? aux1_reg_bcast_data + : aux_reg_bcast_data); + int offt = i_pf; + if (jcp.prop_kind == backward_weights) { + offt += wraparound && last_block + ? 0 + : (last_block ? jcp.is : jcp.reduce_block); + offt *= jcp.bcast_block; + } else { + offt += wraparound && last_block + ? 0 + : (last_block ? jcp.ur : jcp.bcast_dim); + offt *= jcp.reduce_block; + } + mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]); + } else if (i_op >= pf_inp_ops && n_other_pf) { + // remaining prefetches are spread among the rest of the + // operations; prefetches for output take priority + // TODO: spread L2 prefetches among L1 prefetches + i_op -= pf_inp_ops; + if (i_op % other_pf_trigger == 0) { + int i_pf = i_op / (load_loop_blk * other_pf_trigger); + if (i_pf < n_pf_ker_l2) { + int offt = (i_pf + (i_load + 1) * jcp.reduce_dim) + * jcp.load_block; + mic_prefetcht1(ptr[aux_reg_load_data + + offt * jcp.typesize_in]); + } else if (i_pf < n_pf_ker_l2 + n_pf_ker_l1) { + i_pf -= n_pf_ker_l2; + auto pf_reg = last_block ? reg_load_data + : aux_reg_load_data; + int offt = (i_pf + i_load * jcp.reduce_dim + + (last_block + ? (wraparound ? jcp.reduce_dim : 0) + : jcp.reduce_block)) + * jcp.load_block; + mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]); + } else if (i_pf < n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1) { + i_pf -= n_pf_ker_l1 + n_pf_ker_l2; + int offt = i_pf * jcp.load_block; + mic_prefetcht0(ptr[aux_reg_output_data + + offt * jcp.typesize_out]); + } + } + } + }; + + auto fma_block = [=](bool last_block) { + assert(jcp.reduce_loop_unroll % jcp.fma_step == 0); + + int reduce_step = jcp.fma_step; + + for (int i_reduce = 0; i_reduce < jcp.reduce_loop_unroll; + i_reduce += reduce_step) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + // if transposed input data used and if spatial size is + // not divided by transpose step (4) then for last reduce step + // we should load only needed load_registers data + // and clear remaining + if (jcp.transpose_src && jcp.is % jcp.fma_step && last_block + && i_reduce == jcp.reduce_loop_unroll - reduce_step) { + Label load_all; + Label load_finish; + test(reg_reduce_pos_flag, FLAG_SP_LAST); + jz(load_all, T_NEAR); + + const int n_loads = jcp.is % jcp.fma_step; + for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { + if (i_fma < n_loads) + vmovups(vreg_load(i_load, i_fma), + load_ptr(i_reduce + i_fma, i_load)); + else + vpxord(vreg_load(i_load, i_fma), + vreg_load(i_load, i_fma), + vreg_load(i_load, i_fma)); + } + jmp(load_finish); + + L(load_all); + for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { + vmovups(vreg_load(i_load, i_fma), + load_ptr(i_reduce + i_fma, i_load)); + } + L(load_finish); + } else { + for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { + vmovups(vreg_load(i_load, i_fma), + load_ptr(i_reduce + i_fma, i_load)); + } + } + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) { + if (jcp.ver == ver_avx512_core && jcp.expl_bcast + && load_loop_blk > 1) + vbroadcastss(vreg_bcast, bcast_ptr(i_reduce, i_ur, false)); + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + if (jcp.ver == ver_4fma) + v4fmaddps(vreg_accum(i_load, i_ur), + vreg_load(i_load, 0), + bcast_ptr(i_reduce, i_ur, false)); + else if (jcp.ver == ver_avx512_core && jcp.expl_bcast + && load_loop_blk > 1) + vfmadd231ps(vreg_accum(i_load, i_ur), + vreg_load(i_load, 0), vreg_bcast); + else + vfmadd231ps(vreg_accum(i_load, i_ur), + vreg_load(i_load, 0), + bcast_ptr(i_reduce, i_ur, true)); + prefetch_callback(ur, i_reduce, i_ur, i_load, + last_block, wraparound, reduce_step); + } + } + } + }; + Label reduce_loop; + Label reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} + +void jit_avx512_common_1x1_conv_kernel::generate() +{ + preamble(); + + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + + sub(rsp, stack_space_needed); + + if (jcp.with_bias) + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + if (one_of(jcp.prop_kind, forward_training, forward_inference)) + mov(reg_relu_ns, reinterpret_cast(&jcp.eltwise.alpha)); + if (jcp.prop_kind == backward_weights) + mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + + auto load_loop_body = [=](int load_loop_blk) { + bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add(reg_bias_data, + load_loop_blk * jcp.load_block * jcp.typesize_out); + add(reg_output_data, + load_loop_blk * jcp.bcast_dim * jcp.load_block * + jcp.typesize_out); + break; + case backward_data: + add(reg_output_data, + load_loop_blk * jcp.bcast_dim * jcp.load_block * + jcp.typesize_out); + break; + case backward_weights: + for (int i_load = 0; i_load < load_loop_blk; i_load++) + add(reg_output_data, reg_output_stride); + break; + default: + assert(!"invalid prop_kind"); + } + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + const int simd_w = 16; + + Label load_loop_blk[7]; + + static const int ur_cases_fma_embd_bcast[] = { 2, 4, 5, 8, 14, 32 }; + static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 }; + static const int ur_cases_4fma[] = { 2, 4, 6, 12, 32 }; + + const int size_ur_cases_fma + = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ? + sizeof(ur_cases_fma_expl_bcast) : + sizeof(ur_cases_fma_embd_bcast); + const int size_ur_cases_4fma = sizeof(ur_cases_4fma); + + const int *ur_cases_fma = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ? + ur_cases_fma_expl_bcast : + ur_cases_fma_embd_bcast; + const int *ur_cases = jcp.ver == ver_4fma ? ur_cases_4fma : ur_cases_fma; + const int num_ur_cases = + (jcp.ver == ver_4fma ? size_ur_cases_4fma : size_ur_cases_fma) + / sizeof(*ur_cases); + + for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { + int label_idx = num_ur_cases - ur_idx - 1; + if (jcp.ur <= ur_cases[ur_idx]) { + cmp(reg_load_loop_work, simd_w * (label_idx + 1)); + jle(load_loop_blk[label_idx], T_NEAR); + } + } + + for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { + if (jcp.ur <= ur_cases[ur_idx]) { + int label_idx = num_ur_cases - ur_idx - 1; + L(load_loop_blk[label_idx]); + { + if (label_idx == 0) { + cmp(reg_load_loop_work, 0); + je(load_loop_blk[num_ur_cases], T_NEAR); + } + load_loop_body(label_idx + 1); + if (label_idx - 1 > 0) { + cmp(reg_load_loop_work, 2 * label_idx * simd_w); + je(load_loop_blk[label_idx - 1], T_NEAR); + } + cmp(reg_load_loop_work, (label_idx + 1) * simd_w); + jge(load_loop_blk[label_idx]); + } + for (int idx = label_idx - 1; idx > 0; --idx) { + cmp(reg_load_loop_work, simd_w * (idx + 1)); + je(load_loop_blk[idx], T_NEAR); + } + if (ur_idx < num_ur_cases - 2) { + cmp(reg_load_loop_work, simd_w); + jle(load_loop_blk[0], T_NEAR); + } + } + } + L(load_loop_blk[num_ur_cases]); + + add(rsp, stack_space_needed); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx512_common_1x1_conv_kernel::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr, int nthreads, bool reduce_src) { + if (!mayiuse(avx512_common)) return status::unimplemented; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int ndims = src_d.ndims(); + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && src_d.data_type() == data_type::f32; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind, + format_kind::undef, cd.diff_bias_desc.format_kind) + != format_kind::undef; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + jcp.tr_is = rnd_up(jcp.is, 4); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (dst_d.data_type() == data_type::s32) return status::unimplemented; + } + + auto dat_tag = pick(ndims - 3, nCw16c, nChw16c); + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + args_ok = true + && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + jcp.ic_block = jcp.oc_block = simd_w; + jcp.transpose_src = false; + + if (everyone_is(data_type::f32, src_d.data_type(), + weights_d.data_type(), dst_d.data_type())) + { + const int is_bwd_d = jcp.prop_kind == backward_data; + format_tag_t wei_tag = with_groups + ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i, + gOIhw16i16o, gIOhw16o16i) + : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i, + OIhw16i16o, IOhw16o16i); + + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + if (jcp.prop_kind != backward_weights && mayiuse(avx512_mic_4ops) && + ((jcp.prop_kind == backward_data) ? jcp.oc_block : jcp.ic_block) % 4 + == 0) { + jcp.ver = ver_4fma; + jcp.fma_step = 4; + } else if (jcp.prop_kind == backward_weights && mayiuse(avx512_mic_4ops) + && !reduce_src + /* Heuristic condition for relation of src size to oc. Otherwise + the src transposition overhead exceed the benefit from 4fma + */ + && ((jcp.is * jcp.ic) / jcp.oc <= 2048) + && mkldnn_thr_syncable() + ) + { + jcp.transpose_src = true; + jcp.ver = ver_4fma; + jcp.fma_step = 4; + } else { + jcp.ver = (mayiuse(avx512_core)) ? ver_avx512_core : ver_fma; + jcp.fma_step = 1; + } + jcp.typesize_in = sizeof(prec_traits::type); + jcp.typesize_out = sizeof(prec_traits::type); + } else { + return status::unimplemented; + } + + /* once all the formats are set, check the padding consistency */ + args_ok = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) return status::unimplemented; + + const int SMALL_SPATIAL = 10; + const int BIG_SPATIAL = 28; + const int BIG_REDUCE_DIM = 1024; + const int BIG_LOAD_DIM = 256; + + int load_blocking{ 0 }; + int load_blocking_max{ 0 }; + int bcast_blocking{ 0 }; + int bcast_blocking_max{ 0 }; + int reduce_blocking{ 0 }; + int reduce_blocking_max{ 0 }; + + jcp.load_grp_count = 1; + + const int L1_capacity = get_cache_size(1, true) / sizeof(float); + const int L2_size = get_cache_size(2, true) / sizeof(float); + const int L2_capacity = (L2_size * 3) / 4; + + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) { + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + } else { + jcp.reduce_dim = jcp.oc; + jcp.reduce_block = jcp.oc_block; + + jcp.load_dim = jcp.ic; + jcp.load_block = jcp.ic_block; + + jcp.bcast_dim = jcp.os; + } + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in; + + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; + jcp.load_loop_load_step + = jcp.reduce_dim * jcp.load_block * jcp.typesize_in; + + // adjusting registry blocking + int max_regs, min_regs, size_treshold, ur_step; + const int spatial + = (one_of(jcp.prop_kind, forward_training, forward_inference)) ? + jcp.oh : + jcp.ih; + if (jcp.ver == ver_avx512_core && (8 * jcp.mb) / nthreads >= 1) { + max_regs = 9; + min_regs = 6; + size_treshold = 14; + ur_step = 1; + jcp.expl_bcast = true; + + if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM + && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) { + max_regs = 6; + min_regs = 5; + } + } else { + max_regs = jcp.ver == ver_4fma ? 28 : 30; + min_regs = 9; + size_treshold = jcp.ver == ver_4fma ? 28 : 14; + ur_step = jcp.ver == ver_4fma ? 4 : 1; + jcp.expl_bcast = false; + jcp.use_vmovntps = true; + } + jcp.ur = 1; + for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) { + if ((spatial >= size_treshold && spatial % ur_w == 0) + || (spatial < size_treshold && jcp.os % ur_w == 0)) { + jcp.ur = ur_w; + break; + } + } + if (jcp.ur == 1) { + jcp.ur = nstl::min(max_regs, jcp.os); + int os_tail = jcp.os % max_regs; + for (int i = max_regs; i >= min_regs; i -= ur_step) { + int i_tail = jcp.os % i; + if (i_tail > os_tail || i_tail == 0) { + jcp.ur = i; + os_tail = i_tail; + if (i_tail == 0) + break; + } + } + } + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in; + + jcp.bcast_block = jcp.ur; + + jcp.bcast_loop_output_step = jcp.ur * jcp.load_block * jcp.typesize_out; + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.reduce_block * jcp.typesize_in; + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_iter_step = jcp.load_block; + + if (jcp.prop_kind == backward_data) + jcp.loop_order = loop_lbr; + else + jcp.loop_order = reduce_src ? loop_blr : loop_lbr; + + int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + int nb_load = div_up(jcp.load_dim, jcp.load_block); + + if (jcp.ver == ver_avx512_core && jcp.expl_bcast) { + if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL + && spatial < BIG_SPATIAL) + reduce_blocking = nstl::min(jcp.reduce_dim, 80); + else if (spatial > SMALL_SPATIAL) + reduce_blocking = nstl::min(jcp.reduce_dim, 512); + else + reduce_blocking = nstl::min(jcp.reduce_dim, 256); + + if ((jcp.mb > 28 && spatial >= 28) + || (jcp.mb > 112 && spatial >= 17)) + jcp.use_vmovntps = true; + else + jcp.use_vmovntps = false; + } else { + + reduce_blocking = nb_reduce; + if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 16; + else if (spatial > SMALL_SPATIAL + && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 8; + reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); + reduce_blocking *= jcp.reduce_block; + } + + // Check input data cache aliasing. + // For other ISA constants may be updated. + // 64 * 1024 is chosen due to 1MB L2 16-way cache. + // 7 is empirical value. It is about half of 16. + // So we leave about half of the set for other data - weights, dst + int way_size = (64 * 1024) / jcp.typesize_in; + int max_hits = 7; + if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) { + int nrb = reduce_blocking / simd_w; + int sp = jcp.bcast_dim; + int wl = way_size / simd_w; + for (int start_off = 0; start_off < jcp.ur; start_off++) { + for (int off = start_off, hits = 0; off < sp * nrb; off += wl) { + if (off % sp >= jcp.ur || ++hits < max_hits) + continue; + int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp); + reduce_blocking + = nstl::min(reduce_blocking, max_r_blocking); + break; + } + } + } + + if (reduce_blocking < jcp.reduce_dim) { + jcp.use_vmovntps = false; + if (jcp.prop_kind == backward_data) + jcp.loop_order = reduce_src ? loop_lbr : loop_rlb; + else + jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; + } + load_blocking = jcp.load_dim; + + int load_size = jcp.load_dim * jcp.reduce_dim; + int bcast_size = jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim; + + if (jcp.ver == ver_avx512_core && nthreads <= 28 && jcp.mb < nthreads + && nb_load * nb_bcast > nthreads) { + // Some heuristic here + float calc_koef = 0.01, best_cost = FLT_MAX; + int n_lgc = nthreads; + float ratio = (float)load_size / (float)bcast_size; + int best_lgc = ratio > 1 ? n_lgc : 1; + auto calc_job_cost = [&](int lb, int tg, float mem_k) { + int bb_size = jcp.mb * div_up(nb_bcast, tg); + float calc_size = (float)(bb_size * jcp.ur) + * (lb * jcp.load_block) * jcp.reduce_dim; + float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block) + * jcp.reduce_dim; + return calc_koef * calc_size + mem_k * mem_size; + }; + for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) { + lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1; + int min_lb = nb_load / lgc; + int max_lb = div_up(nb_load, lgc); + int min_tg = nthreads / lgc; + int max_tg = div_up(nthreads, lgc); + // Some heuristic here + float mem_koef = (max_tg == 1) ? 1.f : 1.3f; + float job_cost = 0.; + if (nthreads % lgc < nb_load % lgc) { + job_cost = calc_job_cost(max_lb, min_tg, mem_koef); + } else { + auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef); + auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef); + job_cost = nstl::max(job_cost1, job_cost2); + } + + if (job_cost < best_cost) { + best_lgc = lgc; + best_cost = job_cost; + } + } + jcp.load_grp_count = best_lgc; + load_blocking = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; + } else { + jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast); + jcp.load_grp_count = best_divider( + nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false); + } + + if (jcp.ver == ver_avx512_core && jcp.expl_bcast && jcp.bcast_dim <= 64 + && load_size >= L2_size) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); + } else if (jcp.bcast_dim <= 49 && jcp.mb <= nthreads + && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); + load_blocking = jcp.load_block; + } + + if (jcp.ver == ver_4fma && jcp.bcast_dim * jcp.mb < jcp.load_dim + && jcp.oh * jcp.ow > 64 + && IMPLICATION(reduce_src, jcp.load_dim < 1024)) { + /* Looking for best loading dimension blocking + * to get the best thread and data read/write efficiency + * by finding the optimal 'load_chunk' value + * Example: + * for 72 threads and convolution with mb=1, ih=iw=7, oc = 512 + * the 'best' load_chunk value should be 1 + * TODO: remove heuristic constants in above condition + * TODO: check this blocking for other ISA + */ + float best_eff = -1.f; + int best_lgc = 1; + + for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) { + int lgc = div_up(nb_load, load_chunk); + if (lgc > nthreads) + continue; + int thr_per_grp = div_up(nthreads, lgc); + int bcast_per_thr = div_up(jcp.mb * nb_bcast, thr_per_grp) + * jcp.bcast_block; + int load_per_thr = load_chunk * simd_w; + float data_norm = (bcast_per_thr + load_per_thr) / 2.f; + float data_eff = (bcast_per_thr * load_per_thr) + / (data_norm * data_norm); + float thr_eff_over_grp = (float)nstl::max(1, nthreads / lgc) + / div_up(nthreads, lgc); + float thr_eff_in_grp = ((float)jcp.mb * nb_bcast) + / rnd_up(jcp.mb * nb_bcast, thr_per_grp); + float thr_eff = thr_eff_over_grp * thr_eff_in_grp; + float load_eff = (float)nb_load / rnd_up(nb_load, lgc); + float overall_eff = data_eff + thr_eff + load_eff; + if (overall_eff > best_eff) { + best_eff = overall_eff; + best_lgc = lgc; + } + } + jcp.load_grp_count = best_lgc; + load_blocking + = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; + } + bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, + div_up(nthreads, jcp.load_grp_count)) + * jcp.bcast_block; + bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking); + bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); + + int space_for_bcast + = (L2_capacity - /* kernel_size - */ + 2 * jcp.load_block * reduce_blocking + - jcp.ur * reduce_blocking - 3 * 1024); + if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) + space_for_bcast /= 2; + + int bcast_in_cache + = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); + bcast_blocking = nstl::min( + bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); + + load_blocking_max = load_blocking; + bcast_blocking_max = bcast_blocking * 3 / 2; + reduce_blocking_max = reduce_blocking; + + } else if (jcp.prop_kind == backward_weights) { + + jcp.use_vmovntps = false; + if (jcp.is > SMALL_SPATIAL * SMALL_SPATIAL && jcp.ver == ver_4fma) + jcp.use_vmovntps = true; + + if (jcp.transpose_src) + jcp.reduce_dim = jcp.tr_is; + else + jcp.reduce_dim = jcp.is; + + if (jcp.ver == ver_4fma) { + // reduce_block should be divided by fma_step + jcp.reduce_block = best_divider(jcp.reduce_dim, 4, 16, true, 4); + } else { + jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true); + if (jcp.reduce_dim % jcp.reduce_block != 0) + jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false); + if (jcp.reduce_block > 256) { + jcp.reduce_block = 1; + } + + } + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + if (jcp.ver == ver_avx512_core && jcp.reduce_block <= 19) { + // if reduce_block is big then generated JIT code may be big + // for small values of ur because reduce_loop_unroll = reduce_block + jcp.ur = jcp.bcast_block / 2; + jcp.expl_bcast = true; + } else { + jcp.ur = jcp.bcast_block; + jcp.expl_bcast = false; + } + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.ic_block * jcp.typesize_in; + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * jcp.typesize_in; + + jcp.bcast_loop_output_step = + jcp.oc_block * jcp.ic_block * jcp.typesize_out; + jcp.bcast_loop_output_substep = + jcp.oc_block * jcp.ur * jcp.typesize_out; + jcp.bcast_loop_bcast_step = + jcp.ic_block * jcp.reduce_dim * jcp.typesize_in; + jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in; + + jcp.load_loop_load_step = jcp.oc_block * jcp.os * jcp.typesize_in; + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + balance(jcp, nthreads); + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + load_blocking = best_divider(load_blocking, 16, load_blocking, false); + load_blocking *= jcp.load_block; + + load_blocking_max = load_blocking; + assert(jcp.load_dim % load_blocking == 0); + + int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + int min_bcast_blocking = 5; + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + bcast_blocking = best_divider( + bcast_blocking, min_bcast_blocking, max_bcast_blocking, false); + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(jcp.bcast_dim % bcast_blocking == 0); + + // for reduction balance + if (jcp.ver == ver_avx512_core) { + int max_reduce_blocking + = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim); + int min_reduce_blocking = nstl::min( + L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih)); + reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking, + max_reduce_blocking, true); + reduce_blocking + = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block), + jcp.reduce_block); + } else { + int max_reduce_blocking = L2_capacity + / ((bcast_blocking + load_blocking) * jcp.reduce_block); + max_reduce_blocking = nstl::min(max_reduce_blocking, + (L1_capacity / (jcp.bcast_block)) / jcp.reduce_block); + + int num_jobs = div_up(jcp.load_dim, load_blocking) + * div_up(jcp.bcast_dim, bcast_blocking); + int threads_per_job = nstl::max(1, nthreads / num_jobs); + reduce_blocking = div_up(jcp.mb * jcp.reduce_dim, jcp.reduce_block); + reduce_blocking = div_up(reduce_blocking, threads_per_job); + + reduce_blocking = best_divider(reduce_blocking, + max_reduce_blocking - 2, max_reduce_blocking, true); + reduce_blocking *= jcp.reduce_block; + } + + reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block); + } else + return status::unimplemented; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + assert(reduce_blocking_max); + assert(load_blocking % jcp.load_block == 0); + assert(reduce_blocking % jcp.reduce_block == 0); + assert(load_blocking_max % jcp.load_block == 0); + assert(reduce_blocking_max % jcp.reduce_block == 0); + if (jcp.ver == ver_4fma) { + assert(jcp.reduce_loop_unroll % jcp.fma_step == 0); + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + } + + assert(jcp.bcast_block % jcp.ur == 0); + assert(jcp.reduce_dim % jcp.reduce_block == 0); + + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + return status::success; +} + +void jit_avx512_common_1x1_conv_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp) { + using namespace mkldnn::impl::memory_tracking::names; + + if (jcp.prop_kind != backward_data && jcp.with_bias + && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); + + if (jcp.prop_kind == backward_weights) { + const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic; + scratchpad.book(key_conv_wei_reduction, + jcp.typesize_out * wei_size * (jcp.nthr_mb - 1)); + } + + if (jcp.transpose_src) { + const size_t tr_src_size = + (size_t)jcp.nthr_mb * jcp.ngroups * jcp.ic * jcp.tr_is; + scratchpad.book(key_conv_tr_src, jcp.typesize_out * tr_src_size); + scratchpad.book(key_conv_tr_src_bctx, + sizeof(simple_barrier::ctx_t) * jcp.nthr); + } +} + +void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp, + int nthreads) +{ + // initialize jcp reduction threading properties + jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1; + if (nthreads < jcp.ngroups) { + /* simplification... fortunately it doesn't hurt much */ + return; + } + const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + const int nb_load = div_up(jcp.load_dim, jcp.load_block); + const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + jcp.nthr_g = jcp.ngroups; + const int nthr = nthreads / jcp.nthr_g; + + auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { + /* calculate per thread memory cost (read/write). high level + * optimizer tries to minimize memory consumption. few notes: (n1) + * unclear why, but that essentially helps first convolution... + * (n2) assuming the reduction over minibatch is always there: + * - instead of 8 it should be 5 here (write ~= 2 read): + * kernel: temporal workspace 1 write + * reduction: 1 read from workspace and 1 write to the diff_wei + * - but experiments showed 8 works better than 5 or 6... */ + int bcast_koeff = 1; + int load_koeff = 1; + int output_koeff = 12; + if (jcp.transpose_src) { + bcast_koeff = 5; + load_koeff = 1; + output_koeff = 8; + } + return 0 + + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) + * div_up(jcp.ngroups, jcp.nthr_g) + * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.reduce_block + / jcp.stride_h / jcp.stride_w /* (n1) */ + + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) + * div_up(jcp.ngroups, jcp.nthr_g) + * div_up(nb_load, nthr_oc_b) * jcp.oc_block * jcp.reduce_block + + (size_t)output_koeff /* (n2) */ + * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b) + * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block + * jcp.oc_block; + }; + + int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1; + auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + + /* step 1: find the best thread distribution with lowest memory cost */ + const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce); + for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { + const int nthr_par = nthr / nthr_mb; + const int nthr_oc_b_max = nstl::min(nthr_par, nb_load); + for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { + nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast); + auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + if (mem_cost <= best_mem_cost) { + best_mem_cost = mem_cost; + jcp.nthr_mb = nthr_mb; + jcp.nthr_oc_b = nthr_oc_b; + jcp.nthr_ic_b = nthr_ic_b; + } + } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } + } + if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads) + jcp.nthr_mb = nstl::min(jcp.mb, nthreads); + + jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b; + assert(jcp.nthr <= nthreads); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp new file mode 100644 index 000000000000..d2ae01794375 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp @@ -0,0 +1,108 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP +#define JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_common_1x1_conv_kernel : public jit_generator { + jit_avx512_common_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode(); + } + + ~jit_avx512_common_1x1_conv_kernel() { + delete eltwise_injector_; + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_1x1_conv_kernel) + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr, + int nthreads, bool reduce_src); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp); + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + + private: + using reg64_t = const Xbyak::Reg64; + using zmm_t = const Xbyak::Zmm; + + reg64_t reg_bcast_data = r8; + reg64_t reg_load_data = r10; + reg64_t reg_output_data = r9; + reg64_t aux_reg_bcast_data = r14; + reg64_t aux1_reg_bcast_data = rbx; + reg64_t aux_reg_load_data = r15; + reg64_t imm_addr64 = aux_reg_load_data; + reg64_t aux_reg_output_data = abi_not_param1; + reg64_t reg_load_loop_work = rsi; + reg64_t reg_reduce_loop_work = r11; + reg64_t bcast_loop_iter = rdx; + reg64_t reduce_loop_iter = abi_param1; + reg64_t reg_reduce_pos_flag = rax; + reg64_t reg_output_stride = r13; + reg64_t reg_bias_data = r12; + reg64_t reg_relu_ns = r13; + reg64_t reg_bcast_loop_work = aux1_reg_bcast_data; + + Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + int bcast_loop_work_offt = 0; + int stack_space_needed = 16; + + void bcast_loop(int load_loop_blk); + void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); + + void generate(); + static void balance(jit_1x1_conv_conf_t &jcp, int nthreads); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp new file mode 100644 index 000000000000..54d58c8a3917 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp @@ -0,0 +1,816 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "jit_avx512_common_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +#define data_blk_off(f, n, c, h, w) \ + ((ndims == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w)) + + +namespace { +template +void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, + T nx, T &nx_start, T &nx_end, T nx_divider) +{ + const int grp_count = nstl::min(nx_divider, nthr); + const int grp_size_big = nthr / grp_count + 1; + const int grp_size_small = nthr / grp_count; + const int n_grp_big = nthr % grp_count; + const int threads_in_big_groups = n_grp_big * grp_size_big; + + const int ithr_bound_distance = ithr - threads_in_big_groups; + T grp, grp_ithr, grp_nthr; + if (ithr_bound_distance < 0) { // ithr in first groups + grp = ithr / grp_size_big; + grp_ithr = ithr % grp_size_big; + grp_nthr = grp_size_big; + } else { // ithr in last groups + grp = n_grp_big + ithr_bound_distance / grp_size_small; + grp_ithr = ithr_bound_distance % grp_size_small; + grp_nthr = grp_size_small; + } + + balance211(nx, grp_count, grp, nx_start, nx_end); + balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end); +} +} +/* convolution forward */ + +template +void jit_avx512_common_1x1_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + const auto &jcp = kernel_->jcp; + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.template get( + key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + parallel(0, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad); + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +template +void jit_avx512_common_1x1_convolution_fwd_t:: +execute_forward_thr(const int ithr, const int nthr, const src_data_t *src, + const wei_data_t *weights, const dst_data_t *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad.get(key_conv_rtus_space); + + const int ndims = src_d.ndims(); + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto p = jit_1x1_conv_call_s(); + + auto rp = rtus_driver_t::call_params_t(); + + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + const int os_block = jcp.bcast_block; + + int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, + jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count); + + auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, + int &oh, int &ow, int &ih, int &iw) + { + int osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + oh = os / jcp.ow; + ow = os % jcp.ow; + + ih = nstl::max(oh * stride_h - pad_t, 0); + iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + }; + + auto init_load = [&](int ocb, int &load_step) + { + load_step = step(jcp.nb_load_blocking, ocb_end - ocb, + jcp.nb_load_blocking_max); + p.load_dim = this_block_size(ocb * jcp.oc_block, + ocb_end * jcp.oc_block, load_step * jcp.oc_block); + }; + + auto init_reduce = [&](int icb) + { + const int nb_ic_blocking_step = + nstl::min(icb + nb_ic_blocking, nb_ic) - icb; + p.first_last_flag = 0 + | (icb == 0 ? FLAG_REDUCE_FIRST : 0) + | (icb + nb_ic_blocking_step >= nb_ic + ? FLAG_REDUCE_LAST : 0); + + p.reduce_dim = this_block_size(icb * jcp.ic_block, + jcp.ic, nb_ic_blocking_step * jcp.ic_block); + rp.icb = p.reduce_dim / jcp.reduce_block; + }; + + auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow, + int ih, int iw) + { + + const int _ocb = g * nb_oc + ocb; + const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); + + p.output_data = &dst[dst_off]; + p.bias_data = &bias[_ocb * jcp.oc_block]; + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + const int _icb = g * nb_ic + icb; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ + + _icb * jcp.is * jcp.ic_block; + if (ocb == ocb_start) { + rp.src = src + data_blk_off(src_d, n, _icb, ih, iw); + rtus_driver_->ker_(&rp); + } + p.bcast_data = rp.ws; + } else + p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw); + + kernel_->jit_ker(&p); + }; + + if (jcp.loop_order == loop_rlb) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + iwork += bcast_step; + } + ocb += load_step; + } + } + } else if (jcp.loop_order == loop_lbr) { + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + } + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == loop_rbl) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + ocb += load_step; + } + iwork += bcast_step; + } + } + } else if (jcp.loop_order == loop_blr) { + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + } + ocb += load_step; + } + iwork += bcast_step; + } + } else { + assert(!"unsupported loop order"); + } +} + + +template struct jit_avx512_common_1x1_convolution_fwd_t; +/* convolution backward wtr data */ + +template +void jit_avx512_common_1x1_convolution_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad(ctx).template get( + key_conv_rtus_space); + + const int ndims = diff_src_d.ndims(); + + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + const int nb_ic = jcp.nb_load; + const int nb_oc = jcp.nb_reduce; + const int os_block = jcp.bcast_block; + const int nb_oc_blocking = jcp.nb_reduce_blocking; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + parallel(0, [&](const int ithr, const int nthr) { + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + int bcast_start{0}, bcast_end{0}, icb_start{0}, icb_end{0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, + jcp.nb_load, icb_start, icb_end, jcp.load_grp_count); + + bool reduce_outer = (jcp.loop_order == loop_rbl + || jcp.loop_order == loop_rlb); + int nboc_outer = reduce_outer ? nb_oc : 1; + int ocb_outer_step = reduce_outer ? nb_oc_blocking : 1; + + int nboc_inner = reduce_outer ? 1 : nb_oc; + int ocb_inner_step = reduce_outer ? 1 : nb_oc_blocking; + + for (int ocb_outer = 0; ocb_outer < nboc_outer; + ocb_outer += ocb_outer_step) { + size_t cur_ocb_outer = + nstl::min(ocb_outer + ocb_outer_step, nboc_outer) - ocb_outer; + + int load_step = 0; + for (int icb = icb_start; icb < icb_end; icb += load_step) { + load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, + jcp.nb_load_blocking_max); + + p.load_dim = this_block_size(icb * jcp.ic_block, + icb_end * jcp.ic_block, load_step * jcp.ic_block); + rp.icb = p.load_dim / jcp.ic_block; + + int bcast_step; + for (int iwork = bcast_start; iwork < bcast_end; + iwork += bcast_step) + { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + + const int oh = os / jcp.ow; + const int ow = os % jcp.ow; + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + const int _icb = g * nb_ic + icb; + rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw); + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_; + p.output_data = rp.ws; + } else + p.output_data = rp.src; + + for (int ocb_inner = 0; ocb_inner < nboc_inner; + ocb_inner += ocb_inner_step) { + int cur_ocb_inner = + nstl::min(ocb_inner + ocb_inner_step, nboc_inner) - + ocb_inner; + + int ocb = reduce_outer ? ocb_outer : ocb_inner; + int nb_oc_blocking_step = reduce_outer + ? cur_ocb_outer : cur_ocb_inner; + const int _ocb = g * nb_oc + ocb; + size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, ow); + p.bcast_data = &diff_dst[diff_dst_off]; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; + + p.reduce_dim = this_block_size(ocb * jcp.oc_block, + jcp.oc, nb_oc_blocking_step * jcp.oc_block); + + kernel_->jit_ker(&p); + } + if (pd()->rtus_.reduce_src_) + rtus_driver_->ker_(&rp); + } + } + } + }); +} + +template struct jit_avx512_common_1x1_convolution_bwd_data_t; + +/* convolution backward wtr weights */ + +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() \ + ? (d).blk_off((g), __VA_ARGS__) \ + : (d).blk_off(__VA_ARGS__)) + +jit_avx512_common_1x1_convolution_bwd_weights_t :: + jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr) + , trans_kernel_(nullptr), rtus_driver_(nullptr) +{ + kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr()); + acc_ker_ = new cpu_accumulator_1d_t(); + reducer_bias_ = new cpu_reducer_t(pd()->reducer_bia_conf_); + init_rtus_driver(this); + + const auto &jcp = kernel_->jcp; + + if (jcp.transpose_src) { + auto tp = jit_transpose4x16_src_t(); + tp.src_pf0_distance = 4; + tp.tr_src_pf0_distance = 0; + tp.src_pf1 = true; + tp.tr_src_pf1 = false; + trans_kernel_ = new jit_transpose4x16_src(&jcp, &tp); + } +} + +void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + + const auto scratchpad = this->scratchpad(ctx); + + auto rtus_space = scratchpad.get(key_conv_rtus_space); + data_t *diff_bias = pd()->wants_padded_bias() + ? scratchpad.get(key_conv_padded_bias) : diff_bias_in; + auto wei_reduction = scratchpad.get(key_conv_wei_reduction); + + /* prepare src transposition barriers */ + auto tr_src = scratchpad.get(key_conv_tr_src); + auto tr_src_bctx = scratchpad.get( + key_conv_tr_src_bctx); + if (jcp.transpose_src) { + for (int i = 0; i < jcp.nthr; ++i) + simple_barrier::ctx_init(&tr_src_bctx[i]); + } + + const int ndims = src_d.ndims(); + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic; + + simple_barrier::ctx_t reduction_barrier; + simple_barrier::ctx_init(&reduction_barrier); + + const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); + + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + const int nb_ic = jcp.nb_bcast; + const int nb_ic_blocking = jcp.nb_bcast_blocking; + + const int nb_oc = jcp.nb_load; + const int nb_oc_blocking = jcp.nb_load_blocking; + + const int sp_nb = jcp.nb_reduce; + const int mb_sp_work = jcp.mb * sp_nb; + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + // TODO: use memory descriptor with the same fmt as src + // (or use a macro :)) + auto tr_src_off = [&](int img, int icb, int is) { + const size_t tr_chn_size = jcp.tr_is * jcp.ic_block; + const size_t tr_img_size = tr_chn_size * nb_ic * jcp.ngroups; + return img * tr_img_size + icb * tr_chn_size + is * jcp.ic_block; + }; + + auto uker_trans = [&](int ithr_mb, int img, int sp_b_start, int sp_size, + int g_start, int g_work, int ic_b_start, int ic_b_work, + int ithr, int nthr, int first_ic_b) + { + const int work_amount = g_work * ic_b_work; + + int start{ 0 }, end{ 0 }; + balance211(work_amount, nthr, ithr, start, end); + + int g{ 0 }, ic_b{ 0 }; + nd_iterator_init(start, g, g_work, ic_b, ic_b_work); + g += g_start; + const int ic_b_tr = g * nb_ic + first_ic_b + ic_b; + ic_b += ic_b_start; + + const int _ic = g * nb_ic + ic_b; + + const int is = sp_b_start * jcp.reduce_block; + const int ih = is / jcp.iw; + const int iw = is % jcp.iw; + + const int src1_off = data_blk_off(src_d, img, _ic, ih, iw); + data_t *src1 = (data_t *)&src[src1_off]; + data_t *tr_src1 = &tr_src[tr_src_off(ithr_mb, ic_b_tr, is)]; + + assert(jcp.ic_block == 16); + const int src_stride = jcp.is * jcp.ic_block; + const int tr_src_stride = jcp.tr_is * jcp.ic_block; + + const int my_work = end - start; + for (int iwork = 0; iwork < my_work; iwork++) { + auto par_trans = jit_src_transpose_s(); + assert(sp_size % 4 == 0 || sp_size % 4 == jcp.is % 4); + par_trans.size = sp_size; + par_trans.src = src1; + par_trans.tr_src = tr_src1; + par_trans.src_prf = src1 + 64 * 16; + par_trans.tr_src_prf = tr_src1 + 80 * 16; + trans_kernel_->jit_ker(&par_trans); + + src1 += src_stride; + tr_src1 += tr_src_stride; + } + }; + + auto ker = [&](const int ithr, const int nthr) { + assert(nthr == jcp.nthr); + assert(IMPLICATION(!mkldnn_thr_syncable(), jcp.nthr_mb == 1)); + + const int ithr_ic_b = ithr % jcp.nthr_ic_b; + const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; + const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; + const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / + jcp.nthr_g; + + const int ithr_but_oc + = (ithr_mb * jcp.nthr_g + ithr_g) * jcp.nthr_ic_b + ithr_ic_b; + + /* reduction dimension */ + int mb_sp_b_start{ 0 }, mb_sp_b_end{ 0 }; + if (jcp.transpose_src && jcp.nthr_mb < jcp.mb / 2) { + // it's preferable to parallelize by mb if possible + int img_start{ 0 }, img_end{ 0 }; + balance211(jcp.mb, jcp.nthr_mb, ithr_mb, img_start, img_end); + mb_sp_b_start = img_start * sp_nb; + mb_sp_b_end = img_end * sp_nb; + } + else { + balance211(mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start, + mb_sp_b_end); + } + + /* independent dimensions */ + int g_start{ 0 }, oc_b_start{ 0 }, ic_b_start{ 0 }; + int g_end{ 0 }, oc_b_end{ 0 }, ic_b_end{ 0 }; + + balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); + balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, + oc_b_end); + balance211(jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, + ic_b_end); + + const int g_work = g_end - g_start; + const int oc_b_work = oc_b_end - oc_b_start; + const int ic_b_work = ic_b_end - ic_b_start; + + data_t *diff_wei = ithr_mb == 0 + ? diff_weights : wei_reduction + (ithr_mb - 1) * wei_size; + + int sp_b_step = 0; + for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end; + mb_sp_b += sp_b_step) { + int img{ 0 }, sp_b{ 0 }; + nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb); + sp_b_step = step(jcp.nb_reduce_blocking, + nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b), + jcp.nb_reduce_blocking_max); + + for (int g = g_start; g < g_end; ++g) { + int load_step = 0; + int bcast_step = 0; + for (int ic_b = ic_b_start; ic_b < ic_b_end; + ic_b += bcast_step) { + bcast_step = step(nb_ic_blocking, ic_b_end - ic_b, + jcp.nb_bcast_blocking_max); + if (jcp.transpose_src) { + if (jcp.nthr_oc_b > 1) + simple_barrier::barrier( + &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b); + const int sp_size + = nstl::min(sp_b_step * jcp.reduce_block, + jcp.is - sp_b * jcp.reduce_block); + uker_trans(ithr_mb, img, sp_b, sp_size, g, 1, ic_b, + bcast_step, ithr_oc_b, jcp.nthr_oc_b, ic_b_start); + if (jcp.nthr_oc_b > 1) + simple_barrier::barrier( + &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b); + } + + for (int oc_b = oc_b_start; oc_b < oc_b_end; + oc_b += load_step) { + load_step = step(nb_oc_blocking, oc_b_end - oc_b, + jcp.nb_load_blocking_max); + const int _ic_b = g * nb_ic + ic_b; + const int _ic_b_tr = g * nb_ic + ic_b_start; + const int _oc_b = g * nb_oc + oc_b; + + data_t *store_to; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b); + store_to = diff_wei + off; + + const data_t *diff_src = jcp.transpose_src ? + &tr_src[tr_src_off(ithr_mb, _ic_b_tr, 0)] : + &src[src_d.blk_off(img, _ic_b)]; + + int sp_b_end = sp_b + sp_b_step; + const data_t *pdiff_dst + = &diff_dst[diff_dst_d.blk_off(img, _oc_b)]; + const data_t *local_src = diff_src; + + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + p.output_stride + = jcp.ic * jcp.oc_block * jcp.typesize_out; + + p.load_dim = load_step * jcp.oc_block; + + p.bcast_dim = bcast_step * jcp.ic_block; + rp.icb = bcast_step; + p.output_data = store_to; + + p.reduce_dim = sp_b_step * jcp.reduce_block; + rp.os = p.reduce_dim; + + p.first_last_flag = 0 + | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST : 0) + | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0); + + int sp = sp_b * jcp.reduce_block; + p.load_data = pdiff_dst + sp * jcp.oc_block; + + if (pd()->rtus_.reduce_src_) { + const int oh = sp / jcp.ow; + const int ow = sp % jcp.ow; + + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + sp * jcp.ic_block; + + if (ndims == 3) + rp.src = local_src + iw + * src_d.blocking_desc().strides[2]; + else + rp.src = local_src + ih + * src_d.blocking_desc().strides[2] + + iw * src_d.blocking_desc().strides[3]; + rtus_driver_->ker_(&rp); + + p.bcast_data = rp.ws; + } else + p.bcast_data = local_src + sp * jcp.ic_block; + + kernel_->jit_ker(&p); + } + } + } + } + + /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */ + if (jcp.nthr_mb > 1) { + simple_barrier::barrier(&reduction_barrier, jcp.nthr); + const int work = g_work * oc_b_work * ic_b_work; + int start{ 0 }, end{ 0 }; + balance211(work, jcp.nthr_mb, ithr_mb, start, end); + if (start == end) + return; + + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + int w = start; + int sub_g_start{ 0 }, sub_oc_b_start{ 0 }, + sub_ic_b_start{ 0 }; + nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start, + oc_b_work, sub_ic_b_start, ic_b_work); + while (w < end) { + const int g = g_start + sub_g_start; + const int oc_b = oc_b_start + sub_oc_b_start; + const int ic_b = ic_b_start + sub_ic_b_start; + + const int acc_size + = nstl::min(end - w, ic_b_work - sub_ic_b_start) + * jcp.ic_block * jcp.oc_block; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b); + data_t *d = diff_weights + off; + data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off; + + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, g_work, + sub_oc_b_start, oc_b_work, sub_ic_b_start, + ic_b_work); + } + } + } + }; + + auto ker_bias = [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + + const int b_job_start = rb->balancer().ithr_job_off(ithr); + const int b_njobs = rb->balancer().ithr_njobs(ithr); + + if (b_njobs == 0) + return; + + /* reduction dimension */ + int img_start{ 0 }, img_end{ 0 }; + + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ithr), img_start, img_end); + + /* jobs */ + int g_start{ 0 }, ocb_start{ 0 }; + nd_iterator_init( + b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_load); + + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * jcp.nb_load + ocb; + + const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; + data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, + reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 16; ++o) + d_bias[o] = 0.; + + for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 16; ++o) + d_bias[o] += d_dst[o]; + d_dst += 16; + } + + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load); + } + } + rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); + }; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + ker(ithr, jcp.nthr); + if (pd()->with_bias()) + ker_bias(ithr, jcp.nthr); + }); + + /* TODO: put this in ker_bias */ + if (pd()->wants_padded_bias()) { + assert(jcp.ngroups == 1); + utils::array_copy(diff_bias_in, diff_bias, jcp.oc_without_padding); + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp new file mode 100644 index 000000000000..2e9fda76d628 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp @@ -0,0 +1,344 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP +#define CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_avx512_common_1x1_conv_kernel.hpp" +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_transpose_src_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_common_1x1_convolution_fwd_t : public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), + jit_avx512_common_1x1_convolution_fwd_t); + + status_t init() { + using namespace utils; + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, dst_type, dst_type, + data_type::undef) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md()); + + status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( + jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), *attr(), + mkldnn_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, + jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx512_common_1x1_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = + new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx512_common_1x1_convolution_fwd_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + + private: + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src, const wei_data_t *weights, + const dst_data_t *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_1x1_conv_kernel *kernel_; + rtus_driver_t *rtus_driver_; +}; + +using jit_avx512_common_1x1_convolution_fwd_f32_t + = jit_avx512_common_1x1_convolution_fwd_t; + +template +struct jit_avx512_common_1x1_convolution_bwd_data_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), + jit_avx512_common_1x1_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, wei_type, data_type::undef, + diff_dst_type, data_type::undef) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *diff_src_d = diff_src_md(); + rtus_prepare(this, conv_d, diff_src_d, diff_dst_md()); + + status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( + jcp_, *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), + *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, + jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + // TODO (Roma): structs conf header cleanup + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + IOw16o16i, gIOw16o16i, IOhw16o16i, gIOhw16o16i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, + *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx512_common_1x1_convolution_bwd_data_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_src_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + + private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_1x1_conv_kernel *kernel_; + rtus_driver_t *rtus_driver_; +}; + +using jit_avx512_common_1x1_convolution_bwd_data_f32_t + = jit_avx512_common_1x1_convolution_bwd_data_t; + +struct jit_avx512_common_1x1_convolution_bwd_weights_t : public cpu_primitive_t +{ + struct pd_t : public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), + jit_avx512_common_1x1_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, diff_dst_md()); + + status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( + jcp_, *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), + *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, + jcp_); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + // TODO (Roma): structs conf header cleanup + jit_1x1_conv_conf_t jcp_; + cpu_reducer_t::conf_t reducer_bia_conf_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + + private: + void init_balancers() { + const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; + if (with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, + jcp_.oc_block, jcp_.ngroups * jcp_.nb_load, + jcp_.mb, max_buffer_size)); + } + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd); + + ~jit_avx512_common_1x1_convolution_bwd_weights_t() { + delete kernel_; + delete acc_ker_; + delete reducer_bias_; + delete rtus_driver_; + delete trans_kernel_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + + private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_1x1_conv_kernel *kernel_; + cpu_accumulator_1d_t *acc_ker_; + cpu_reducer_t *reducer_bias_; + jit_transpose4x16_src *trans_kernel_; + rtus_driver_t *rtus_driver_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp new file mode 100644 index 000000000000..235fb02fefde --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp @@ -0,0 +1,4539 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_barrier.hpp" + +#include "jit_avx512_common_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) +#define KNx_L2_EFFECTIVE_CAPACITY ((512-64)*1024) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +namespace { + +constexpr auto small_spatial = 14; +unsigned int L1_cache_size = get_cache_size(1, true); + +inline void pick_loop_order(jit_conv_conf_t &jcp) { + using namespace prop_kind; + assert(one_of(jcp.prop_kind, + forward_training, forward_inference, backward_data)); + auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow; + auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh; + + // ow-threading is currently implemented for forward only + // TODO: single code for fwd and bwd after ow-thr for bwd + // meaningless switch was removed + if (jcp.prop_kind == backward_data) { + jcp.loop_order = (w <= small_spatial && h <= small_spatial) + ? loop_cgn : loop_gnc; + } else { + jcp.loop_order = (w <= small_spatial && h <= small_spatial) + ? loop_cwgn : loop_gncw; + } +} + +inline bool is_1stconv(const jit_conv_conf_t &jcp) { + if (mayiuse(avx512_core)) + return (jcp.ic < 16 && jcp.ngroups == 1); + else + return one_of(jcp.ic, 1, 3); +} + +inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) { + return (jcp.nb_ow > 1); +} + +inline bool is_owb_prefetching(const jit_conv_conf_t &jcp) { + return (jcp.ver == ver_4fma && is_ow_threading_on(jcp)); +} + +} + +template +void _jit_avx512_common_conv_fwd_kernel::prepare_output(int ur_w) +{ + for (int k = 0; k < jcp.nb_oc_blocking; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + vpxord(vmm, vmm, vmm); + if (!is_owb_prefetching(jcp)) { + size_t aux_output_offset = get_output_offset(j, k); + mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf, + aux_output_offset, reg_out_long_offt)); + } + } +} + +template +void _jit_avx512_common_conv_fwd_kernel::store_output(int ur_w) +{ + Label no_update_label, store_label, eltwise_label; + + mov(reg_channel, ptr[param1 + GET_OFF(channel)]); + if (jcp.with_bias) { + mov(reg_bias, ptr[param1 + GET_OFF(bias)]); + } + + if (!jcp.with_sum) { + cmp(reg_channel, 0); + je(no_update_label, T_NEAR); + } + + for (int k = 0; k < jcp.nb_oc_blocking; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + size_t aux_output_offset = get_output_offset(j, k); + vaddps(vmm, + make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt)); + } + + if (!jcp.with_sum) { + jmp(eltwise_label, T_NEAR); + } else { + cmp(reg_channel, 0); + jne(eltwise_label, T_NEAR); + } + + L(no_update_label); + if (jcp.with_bias) { + for (int k = 0; k < jcp.nb_oc_blocking; k++) { + int bias_offset = jcp.typesize_out * k * jcp.oc_block; + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + vaddps(vmm, EVEX_compress_addr(reg_bias, bias_offset)); + } + mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64)); + } + } + + L(eltwise_label); + if (jcp.with_eltwise) { + cmp(reg_channel, jcp.nb_ic - 1); + jl(store_label, T_NEAR); + + if (ur_w == jcp.ur_w) { + eltwise_injector_->compute_vector_range(0, + jcp.nb_oc_blocking * jcp.ur_w); + } else { + for (int k = 0; k < jcp.nb_oc_blocking; k++) + eltwise_injector_->compute_vector_range(k * jcp.ur_w, + k * jcp.ur_w + ur_w); + } + } + + L(store_label); + for (int k = 0; k < jcp.nb_oc_blocking; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + size_t aux_output_offset = (size_t)typesize * + ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block; + vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset, + reg_out_long_offt), vmm); + if (!is_owb_prefetching(jcp)) + mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf, + aux_output_offset, reg_out_long_offt)); + } +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, + int pad_l, int pad_r) +{ +} + +template<> +void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, + int pad_l, int pad_r) +{ + assert(jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0); + + int iw = jcp.iw; + int ih = jcp.ih; + int kw = jcp.kw; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + Label kh_label, kd_label; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_inp_prf, reg_inp_prf); + } + + size_t max_input_offset = (size_t)jcp.typesize_in + * ((size_t)(kw + ur_w * stride_w - pad_l) + + (size_t)ic_block * iw * ih * jcp.id); + assert(reg_inp_prf == reg_long_offt); + if (max_input_offset > INT_MAX) push(reg_inp_prf); + + if (jcp.ndims == 5) { + push(reg_out_prf); + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + mov(aux_reg_inp_d_prf, reg_inp_prf); + + L(kd_label); + } + mov(reg_kj, reg_kh); + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_inp_prf, aux_reg_inp_d_prf); + } + + L(kh_label); + for (int ki = 0; ki < kw; ki += 4) { + for (int ic = 0; ic < ic_block; ic++) { + for (int i = 0; i < 4; i++) { + int aux_ker_offset + = jcp.typesize_in + * ((ki + i) * oc_block + + ic * kw * jcp.kh * jcp.kd * oc_block); + if (ki + i < kw) + vmovups(vmm_ker(i), + EVEX_compress_addr(aux_reg_ker, aux_ker_offset)); + else + vpxord(vmm_ker(i), vmm_ker(i), vmm_ker(i)); + } + + int j_start = get_ow_start(ki, pad_l); + int j_end = get_ow_end(ur_w, ki, pad_r); + + for (int j = j_start, prf_count=0; j < j_end; j++) { + size_t aux_input_offset = (size_t)jcp.typesize_in + * ((size_t)(ki + j * stride_w + - pad_l) + (size_t)ic * iw * ih * jcp.id); + v4fmaddps(vmm_out(j, 0), vmm_ker(0), + EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset, + reg_long_offt)); + if (ki + prf_count < kw && prf_count < 4 + && ((ki < 2 && j % 4) || j % 2)) { + int aux_ker_offset = jcp.typesize_in + * ((ki + prf_count) * oc_block + + ic * kw * jcp.kh * jcp.kd * oc_block + kw * oc_block); + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, + aux_ker_offset)); + prf_count++; + } + if (ki == 0 + && j % (64 / (stride_w * jcp.typesize_in)) == 0) { + mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp_prf, + aux_input_offset, reg_long_offt)); + } + if (ki == 1 + && j % (64 / (stride_w * jcp.typesize_in)) == 0) { + mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset+jcp.typesize_in * iw, reg_long_offt)); + } + } + } + } + add(aux_reg_ker, jcp.typesize_in * kw * oc_block); + add(aux_reg_inp, jcp.typesize_in * iw); + add(aux_reg_inp_prf, jcp.typesize_in * iw); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, typesize * jcp.ih * jcp.iw); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block); + add(aux_reg_inp_d_prf, typesize * jcp.ih * jcp.iw); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + pop(reg_out_prf); + } + + if (max_input_offset > INT_MAX) pop(reg_inp_prf); +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w, + int pad_l, int pad_r) +{ +} + +template<> +void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w, + int pad_l, int pad_r) +{ + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + Label kh_label, last_iter_label, loop_end_label, kd_label; + int ker_load_number = 4; + int shift_kernel_ptr = typesize * jcp.kw * jcp.oc_block * jcp.ic_block; + int shift_input_ptr = typesize * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block; + + bool check_last_kh = (jcp.kh > 3); + bool pref_current_inp = (jcp.iw < 14 || jcp.iw > 28); + + int oi_ipref_t0 = get_ow_start(0, pad_l); + int ow_end_ipref = get_ow_end(ur_w, 0, pad_r); + + assert(jcp.oc % jcp.nb_oc_blocking == 0); + + auto kernel_offset = [=](int ocb, int ic, int ki) { + int blk_idx = ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd + ki; + int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; + int ic_offset = ic * jcp.oc_block; + return typesize * (blk_offset + ic_offset); + }; + auto kernel_loads = [=](int ki, int ic, int kk) { + for (int ii = 0; ii < ker_load_number; ii++) { + int aux_kernel_offset = kernel_offset(kk, ic + ii, ki); + vmovups(vmm_ker(ii), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + }; + auto prefetch_inp_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) { + if (cnt1 >= ker_load_number && cnt0 >= ker_load_number + && ki >= ki_start && oi_ipref_t0 < ow_end_ipref) { + int aux_inp_offset + = typesize + * ((oi_ipref_t0 * stride_w - pad_l) * ic_block + + (jcp.dilate_h + 1) * jcp.iw * ic_block); + prefetcht0(EVEX_compress_addr(aux_reg_inp, + aux_inp_offset)); + oi_ipref_t0++; + } + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_ker_prf, reg_ker_prf); + mov(aux_reg_inp_prf, reg_inp_prf); + } + + if (jcp.ndims == 5) { + push(reg_out_prf); + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + mov(aux_reg_inp_d_prf, reg_inp_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + L(kd_label); + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + mov(aux_reg_inp_prf, aux_reg_inp_d_prf); + } + + align(16); + L(kh_label); + int kw = jcp.kw; + if (check_last_kh) { + for (int ki = 0; ki < kw; ki++) + for (int ic = 0; ic < ic_block; ic += 4) + for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { + bool last_kernel_loads = (kk == jcp.nb_oc_blocking - 1 + && ki == kw - 1 && (ic + 4) == ic_block); + + if (last_kernel_loads) { + cmp(reg_kj, 1); + je(last_iter_label, T_NEAR); + } + + kernel_loads(ki, ic, kk); + for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0, + prf_count_t0 = 0; + oi < get_ow_end(ur_w, ki, pad_r); oi++) { + int aux_input_offset = typesize + * ((ki * (jcp.dilate_w + 1) + oi * stride_w + - pad_l) * ic_block + + ic); + v4fmaddps(vmm_out(oi, kk), vmm_ker(0), + EVEX_compress_addr(aux_reg_inp, aux_input_offset)); + + if (oi % 2) { + if (prf_count_t0 < 4) { + int aux_kernel_prf; + if (last_kernel_loads) + aux_kernel_prf= kernel_offset(0, + prf_count_t0 + ic + 4 + - ic_block, 0) + typesize * kw + * oc_block * ic_block; + else + aux_kernel_prf = kernel_offset(kk, ic + 4 + + prf_count_t0, ki); + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, + aux_kernel_prf)); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, ic + + prf_count_t1, ki))); + prf_count_t1++; + } + } else + prefetch_inp_next_kh(ki, 2, prf_count_t0, + prf_count_t1); + } + + if (last_kernel_loads) { + jmp(loop_end_label, T_NEAR); + + L(last_iter_label); + + kernel_loads(ki, ic, kk); + for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0, + prf_count_t0 = 0; + oi < get_ow_end(ur_w, ki, pad_r); oi++) { + int aux_input_offset = typesize + * ((ki * (jcp.dilate_w + 1) + oi * stride_w + - pad_l) * ic_block + + ic); + v4fmaddps(vmm_out(oi, kk), vmm_ker(0), + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + if (oi % 2) { + if (prf_count_t0 < 4) { + mic_prefetcht0(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(0, + prf_count_t0, 0))); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, + ic + prf_count_t1, ki))); + prf_count_t1++; + } + } + } + L(loop_end_label); + } + } + } else { + for (int ki = 0; ki < kw; ki++) + for (int ic = 0; ic < ic_block; ic += 4) + for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { + kernel_loads(ki, ic, kk); + for (int oi = get_ow_start(ki, pad_l), + prf_count_t1 = 0, prf_count_t0 = 0; + oi < get_ow_end(ur_w, ki, pad_r); oi++) { + int aux_input_offset = typesize + * ((ki * (jcp.dilate_w + 1) + oi * stride_w + - pad_l) * ic_block + ic); + v4fmaddps(vmm_out(oi, kk), vmm_ker(0), + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + + if (!is_owb_prefetching(jcp)) { + if ((oi % 2) && (prf_count_t1 < 4)) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, + ic + prf_count_t1, ki))); + prf_count_t1++; + } + } else { + if (!(ki == 0 && ic == 0) + && !(ki == kw-1 && ic == 0) && + (oi % 2) && (prf_count_t1 < 4) + ) { + mic_prefetcht0(EVEX_compress_addr( + aux_reg_ker, kernel_offset(kk, + ic + 4 + prf_count_t0, ki))); + prf_count_t0++; + } + } + if (!is_owb_prefetching(jcp)) { + if (pref_current_inp) { + if (ki == 0 && ic == 0 && kk == 0) + mic_prefetcht0(EVEX_compress_addr( + aux_reg_inp, + aux_input_offset + shift_input_ptr)); + } else { + if (ki == 1 && ic == 0 && kk == 0) + mic_prefetcht1(EVEX_compress_addr( + aux_reg_inp_prf, aux_input_offset)); + } + } else { + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int inp_shift + = jcp.typesize_in * ur_w * stride_w * inp_mult; + bool kk_pref_slot = kk ? oi % 2 : !(oi % 2); + if (ki == 0 && ic == 0 && kk_pref_slot) + mic_prefetcht1(EVEX_compress_addr( + aux_reg_inp, + aux_input_offset + inp_shift)); + + if (ki == kw - 1 && ic == 0 && kk_pref_slot) + mic_prefetcht0(EVEX_compress_addr( + aux_reg_inp, + aux_input_offset + inp_shift)); + } + } + } + } + + add(aux_reg_ker, shift_kernel_ptr); + add(aux_reg_inp, shift_input_ptr); + add(aux_reg_ker_prf, shift_kernel_ptr); + add(aux_reg_inp_prf, shift_input_ptr); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + add(aux_reg_inp_d_prf, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block); + add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + pop(reg_out_prf); + } +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w, + int pad_l, int pad_r) +{ + bool prf_ker = true; + bool prf_inp = true; + int ih = jcp.ih; + int stride_w = jcp.stride_w; + int id = jcp.id; + int iw = jcp.iw; + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_oc_block = jcp.nb_oc_blocking; + Label kh_label, kd_label; + + int ker_pipeline_depth = 4; + assert(ker_reg_base_idx + ker_pipeline_depth <= 32); + assert(oc_block >= ker_pipeline_depth); + + int num_ker_loads = ic_block * nb_oc_block * kw; + int num_ker_prfs = prf_ker ? num_ker_loads : 0; + int num_inp_prfs = prf_inp ? + ur_w * nstl::min(kw, stride_w) + nstl::max(0, kw - stride_w) : + 0; + if (jcp.is_1stconv && prf_inp) { + num_inp_prfs = div_up(num_inp_prfs, jcp.simd_w) * ic_block; + } + int num_prfs = num_ker_prfs + num_inp_prfs; + int num_fmas = num_ker_loads * ur_w; + int prf_inst_spacing + = (prf_ker || prf_inp) ? nstl::max(1, num_fmas / num_prfs) : 1; + int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2; + int inp_mul = !jcp.is_1stconv ? ic_block : 1; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_inp_prf, reg_inp_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + } + + size_t max_input_offset = (size_t)jcp.typesize_in * ic_block * iw * ih * id; + assert(reg_inp_prf == reg_long_offt); + if (max_input_offset > INT_MAX) push(reg_inp_prf); + + + if (jcp.ndims == 5) { + push(reg_out_prf); + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + mov(aux_reg_inp_d_prf, reg_inp_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + + L(kd_label); + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + mov(aux_reg_inp_prf, aux_reg_inp_d_prf); + } + + align(16); + L(kh_label); + { + int step = 0; + int ker_prfs = 0; + for (int ki = 0; ki < kw; ki++) { + for (int ic = 0; ic < ic_block; ic++) { + int aux_kernel_offset = 0; + if (step == 0) { + for (int i = 0; i < ker_pipeline_depth; i++) { + aux_kernel_offset = get_kernel_offset(ki, ic, 0, i); + vmovups(vmm_ker(i), EVEX_compress_addr( + aux_reg_ker, aux_kernel_offset)); + } + } else if (step < num_ker_loads - ker_pipeline_depth + 1) { + int load_offset = ker_pipeline_depth - 1; + int ker_load_reg_idx + = (step + load_offset) % ker_pipeline_depth; + aux_kernel_offset + = get_kernel_offset(ki, ic, 0, load_offset); + vmovups(vmm_ker(ker_load_reg_idx), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + + bool ker_prf_inserted = false; + Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth); + int j_start = get_ow_start(ki, pad_l); + int j_end = get_ow_end(ur_w, ki, pad_r); + for (int j = j_start; j < j_end; j++) { + size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l); + auto addr = EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt, true); + vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr); + int fma_idx = step * ur_w + j; + int prf_slot_idx = fma_idx / prf_inst_spacing; + if (fma_idx % prf_inst_spacing == prf_inst_trigger) { + if (prf_ker && !ker_prf_inserted + && ker_prfs < num_ker_prfs) { + int ker_prf_offset + = jcp.typesize_in * ker_prfs * jcp.oc_block; + mic_prefetcht2(EVEX_compress_addr( + aux_reg_ker_prf, ker_prf_offset)); + ker_prf_inserted = true; + ker_prfs++; + } else if (prf_inp) { + int inp_prf_idx = prf_slot_idx - ker_prfs; + if (inp_prf_idx < num_inp_prfs) { + size_t inp_prf_stride = nstl::max(kw, stride_w); + size_t inp_prf_offset; + if (!jcp.is_1stconv) { + inp_prf_offset + = ic_block * jcp.typesize_in + * ((inp_prf_idx / kw) + * inp_prf_stride + + (inp_prf_idx % kw)); + } else { + size_t ic_prf_stride = + (size_t)jcp.typesize_in * iw * ih * id; + size_t iw_prf_stride + = jcp.typesize_in * jcp.simd_w; + inp_prf_offset = ((inp_prf_idx / ic_block) + * iw_prf_stride + + (inp_prf_idx % ic_block) + * ic_prf_stride); + } + mic_prefetcht0(EVEX_compress_addr_safe( + aux_reg_inp_prf, inp_prf_offset, + reg_long_offt)); + } + } + } + } + step++; + } + } + add(aux_reg_ker, jcp.typesize_in * kw * oc_block * ic_block); + if (prf_ker) + add(aux_reg_ker_prf, jcp.typesize_in * kw * oc_block * ic_block); + add(aux_reg_inp, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); + if (prf_inp) + add(aux_reg_inp_prf, + jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + add(aux_reg_inp_d_prf, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + pop(reg_out_prf); + } + if (max_input_offset > INT_MAX) pop(reg_inp_prf); +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w, + int pad_l, int pad_r) +{ + int kw = jcp.kw; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_oc_block = jcp.nb_oc_blocking; + Label kh_label, kd_label; + int shift_kernel_ptr = jcp.typesize_in * jcp.kw * jcp.oc_block + * jcp.ic_block; + int inp_mul = !jcp.is_1stconv ? ic_block : 1; + int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw + * inp_mul; + + + auto input_offset = [=](int oi, int ic, int ki) { + return (size_t)jcp.typesize_in + * ((size_t)(ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) + * inp_mul + (size_t)ic + * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id)); + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + } + + if (jcp.ndims == 5) { + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + + L(kd_label); + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + } + + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_ow_start(ki, pad_l); + int jj_end = get_ow_end(ur_w, ki, pad_r); + for (int ic = 0; ic < ic_block; ic++) { + if (jcp.kernel_kind == expl_bcast) { + for (int jj = jj_start; jj < jj_end; jj++) { + size_t aux_input_offset = input_offset(jj, ic, ki); + vbroadcastss(vmm_inp(jj, nb_oc_block), + EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt)); + } + } + for (int ii = 0; ii < nb_oc_block; ii++) { + int aux_kernel_offset = jcp.typesize_in + * (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd * ic_block + * oc_block + ki * ic_block * oc_block + ic * oc_block); + if (jj_end - jj_start > 0) + vmovups(vmm_wei, EVEX_compress_addr(aux_reg_ker, + aux_kernel_offset)); + for (int jj = jj_start; jj < jj_end; jj++) + if (jcp.kernel_kind == expl_bcast) + vfmadd231ps(vmm_out(jj, ii), + vmm_inp(jj, nb_oc_block), vmm_wei); + else { + size_t aux_input_offset = input_offset(jj, ic, ki); + vfmadd231ps(vmm_out(jj, ii), vmm_wei, + EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt, true)); + } + } + } + } + add(aux_reg_ker, shift_kernel_ptr); + add(aux_reg_inp, shift_input_ptr); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + } +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop(int ur_w, + int pad_l, int pad_r) +{ + if (jcp.ndims == 5) push(reg_oi); + + prepare_output(ur_w); + + Label skip_compute_loop; + if (jcp.ndims == 5) { + if ((jcp.dilate_d >= jcp.id) + || (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) { + mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + } + } + if ((jcp.dilate_h >= jcp.ih) + || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + } + + if (jcp.ver == ver_4fma) + if(jcp.is_1stconv) + compute_loop_4fma_1st(ur_w, pad_l, pad_r); + else + compute_loop_4fma(ur_w, pad_l, pad_r); + else if (jcp.ver == ver_fma) + if ((jcp.is_1stconv && jcp.kernel_kind != expl_bcast) + || mayiuse(avx512_mic)) + compute_loop_fma(ur_w, pad_l, pad_r); + else + if (jcp.kernel_kind == embd_bcast && jcp.nb_oc_blocking == 1) + compute_loop_fma(ur_w, pad_l, pad_r); + else + compute_loop_fma_core(ur_w, pad_l, pad_r); + else + assert(!"unknown convolution version"); + + L(skip_compute_loop); + store_output(ur_w); + if (jcp.ndims == 5) pop(reg_oi); +} + +template +void _jit_avx512_common_conv_fwd_kernel::generate() +{ + int iw = jcp.iw; + int ow = jcp.ow; + int ow_block = jcp.ow_block; + int nb_ow = jcp.nb_ow; + int kw = jcp.kw; + int l_pad = jcp.l_pad; + int ur_w = jcp.ur_w; + int ur_w_tail = jcp.ur_w_tail; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult; + int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult; + int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult; + int out_shift = jcp.typesize_out * ur_w * jcp.oc_block; + + preamble(); + mov(reg_inp, ptr[param1 + GET_OFF(src)]); + mov(reg_out, ptr[param1 + GET_OFF(dst)]); + mov(reg_ker, ptr[param1 + GET_OFF(filt)]); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); + + int r_pad = nstl::max( + 0, (ow - 1) * stride_w + (kw - 1) * dilate_w - (iw + l_pad - 1)); + int n_oi = ow / ur_w; + int r_pad1 = (ur_w * n_oi - 1) * stride_w + (kw - 1) * dilate_w + - (iw + l_pad - 1); + + if (!is_ow_threading_on(jcp)) { + // ow is being processed as a whole - with left and right paddings + if (r_pad1 > 0) n_oi--; + + if (ow == ur_w) { + mov(reg_inp_prf, ptr[param1 + GET_OFF(src_prf)]); + mov(reg_out_prf, ptr[param1 + GET_OFF(dst_prf)]); + compute_loop(ur_w, l_pad, r_pad); + } else { + mov(reg_inp_prf, reg_inp); + mov(reg_out_prf, reg_out); + if (n_oi == 0) { + add(reg_inp_prf, inp_shift_pad); + add(reg_out_prf, out_shift); + compute_loop(ur_w, l_pad, r_pad1); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + if (ur_w_tail != 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w_tail, 0, r_pad); + } + } else { + xor_(reg_oi, reg_oi); + if (l_pad > 0) { + add(reg_inp_prf, inp_shift_pad); + add(reg_out_prf, out_shift); + compute_loop(ur_w, l_pad, 0); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + inc(reg_oi); + } + if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) { + Label ow_loop_label; + L(ow_loop_label); + { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + inc(reg_oi); + cmp(reg_oi, n_oi); + jl(ow_loop_label, T_NEAR); + } + } + if (r_pad1 > 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, r_pad1); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + } + if (ur_w_tail != 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w_tail, 0, r_pad); + } + } + } + } else { + // ow block is only processed. + // Number of block is passed as parameter owb, + // and padding processing depends on this number. + + Label end_label, last_oi_label, middle_ow_blocks_label, tail_label; + Label oi_loop_label, oi_loop_start_label, oi_loop_end_label; + + assert(ow_block % ur_w == 0); + int n_oi_not_last_ow_block = ow_block / ur_w; + // to simplify code (and general regs usage), + // size of ow block must be >= 2 * ur_w + assert(n_oi_not_last_ow_block > 1); + int n_oi_next_last_ow_block = n_oi_not_last_ow_block; + int n_oi_first_ow_block = n_oi_not_last_ow_block; + + int n_oi_last_ow_block = (ow - ow_block * (nb_ow-1)) / ur_w; + + // prepare right padding + bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0; + bool first_ow_block_padded = next_last_ow_block_padded && jcp.nb_ow == 2; + bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0; + + if (last_ow_block_padded) n_oi_last_ow_block--; + else if (first_ow_block_padded) n_oi_first_ow_block--; + else if (next_last_ow_block_padded) n_oi_next_last_ow_block--; + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, 0); // is that the first ow-block ? + jg(middle_ow_blocks_label, T_NEAR); + + // the first ow block, compute left padding + + mov(reg_oi, n_oi_first_ow_block); + mov(reg_inp_prf, reg_inp); + mov(reg_out_prf, reg_out); + + if (l_pad > 0) { + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + add(reg_inp_prf, inp_shift_pad); + add(reg_out_prf, out_shift); + compute_loop(ur_w, l_pad, 0); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + dec(reg_oi); + } + jmp(oi_loop_label, T_NEAR); + + // middle or last ow block entry + + L(middle_ow_blocks_label); + + if (l_pad > 0) { + // just to consider left padding, not compute + add(reg_inp, inp_shift_pad_second_block); + add(reg_inp_prf, inp_shift_pad_second_block); + } + + // set number of iteration for oi-loop + cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ? + mov(reg_oi, n_oi_last_ow_block); + je(oi_loop_label, T_NEAR); + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + mov(reg_oi, n_oi_next_last_ow_block); + je(oi_loop_label, T_NEAR); + mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks + + // oi loop w/o padding + L(oi_loop_label); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + L(oi_loop_start_label); + cmp(reg_oi, 0); + jle(oi_loop_end_label, T_NEAR); + + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + dec(reg_oi); + jmp(oi_loop_start_label, T_NEAR); + L(oi_loop_end_label); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + + cmp(reg_owb, 0); // first ow-block ? + if (first_ow_block_padded) { + je(last_oi_label, T_NEAR); + } else { + je(end_label, T_NEAR); + } + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + jl(end_label, T_NEAR); + if (next_last_ow_block_padded) { + je(last_oi_label, T_NEAR); + } else { + je(end_label, T_NEAR); + } + // that is last block + if (!last_ow_block_padded) { + jmp(tail_label, T_NEAR); + } + + // last oi block with right padding + L(last_oi_label); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, r_pad1); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? + jl(end_label, T_NEAR); + + L(tail_label); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + if (ur_w_tail != 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w_tail, 0, r_pad); + } + L(end_label); + } + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx512_common_conv_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx512_common_conv_fwd_kernel::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr, int nthreads) +{ + using namespace prop_kind; + + if (!mayiuse(avx512_common)) + return status::unimplemented; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const int regs = 28; + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + + jcp = zero(); + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2]; + jcp.ow = dst_d.dims()[ndims-1]; + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2]; + jcp.kw = weights_d.dims()[with_groups + ndims-1]; + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + jcp.back_pad = (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1); + + jcp.is_1stconv = is_1stconv(jcp); + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && src_d.data_type() == data_type::f32; + + const int full_simd_w = cpu_isa_traits::vlen / sizeof(float); + jcp.simd_w = full_simd_w; + bool ok_to_try_xmm = true + && mayiuse(avx512_core) + && src_d.data_type() == data_type::f32 + && !jcp.is_1stconv + && !ok_to_pad_channels + && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0) + && (jcp.ic % 8 != 0 || jcp.oc % 8 != 0); + if (ok_to_try_xmm) + jcp.simd_w = 4; + + jcp.oc_block = jcp.simd_w; + jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; + jcp.aligned_threads = 0; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } + bool args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.ic % jcp.ic_block == 0; + if (!args_ok) + return status::unimplemented; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (dst_d.data_type() == data_type::s32) return status::unimplemented; + } + + auto src_tag = jcp.is_1stconv + ? pick(ndims - 3, ncw, nchw, ncdhw) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c) + : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c)); + auto dst_tag = (jcp.simd_w == 4) + ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c) + : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = with_groups + ? ((jcp.simd_w == 4) + ? pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o) + : pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, OIw4i4o, OIhw4i4o, OIdhw4i4o) + : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o)); + + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, src_tag)); + jcp.src_tag = src_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(src_tag); + } + if (jcp.src_tag != src_tag) + return status::unimplemented; + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, dst_tag)); + jcp.dst_tag = dst_tag; + } else { + jcp.dst_tag = dst_d.matches_one_of_tag(dst_tag); + } + if (jcp.dst_tag != dst_tag) + return status::unimplemented; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, x)); + } + + if (mayiuse(avx512_common) && + src_d.data_type() == data_type::f32 + && weights_d.data_type() == data_type::f32 + && dst_d.data_type() == data_type::f32) { + jcp.ver = ver_fma; + jcp.typesize_in = sizeof(float); + jcp.typesize_out = sizeof(float); + if (mayiuse(avx512_mic_4ops)) + jcp.ver = ver_4fma; + + if (jcp.is_1stconv) { + // TODO: fix & remove constraints below + bool not_for_4fma + = IMPLICATION(everyone_is(0, jcp.l_pad, jcp.t_pad), + nstl::max(jcp.kw, jcp.kh) < 7); + bool is_dilated + = !everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w); + if (one_of(true, not_for_4fma, is_dilated)) + jcp.ver = ver_fma; + if (jcp.ver == ver_4fma) { + wei_tag = with_groups + ? ((jcp.simd_w == 4) + ? pick(ndims - 3, gOiw4o, gOihw4o, gOidhw4o) + : pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o)) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, Oiw4o, Oihw4o, Oidhw4o) + : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o)); + } else { + wei_tag = with_groups + ? ((jcp.simd_w == 4) + ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o) + : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, Owi4o, Ohwi4o, Odhwi4o) + : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o)); + } + } + } else { + return status::unimplemented; + } + + if (weights_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md, wei_tag)); + jcp.wei_tag = wei_tag; + } else { + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + } + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + if (jcp.is_1stconv) { + jcp.ur_w = nstl::min(jcp.ow, regs); + } else { + // avx512_core guard - just to avoid possible regression for other archs + if (jcp.ver == ver_fma && mayiuse(avx512_core)) { + jcp.ur_w = nstl::min(jcp.ow, regs); + } else { + for (int ur_w = regs; ur_w > 0; --ur_w) { + if (jcp.ow % ur_w == 0) { + jcp.ur_w = ur_w; + break; + } + } + } + if ((ndims == 5 && jcp.ur_w <= 8) || (jcp.ur_w <= 1)) { + jcp.ur_w = nstl::min(jcp.ow, regs); + } + } + // TODO (Tanya): currently applied to Segnet convolutions only. + // Need to try for other topologies + if (jcp.ow > 150 && jcp.ur_w < regs/2) + jcp.ur_w = regs; + + int n_oi = (jcp.ow / jcp.ur_w); + int r_pad = (jcp.ur_w * n_oi - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1); + if (jcp.l_pad > 0 && r_pad > 0) + n_oi--; + + bool large_code_size = jcp.ur_w != jcp.ow && jcp.l_pad > 0 && r_pad > 0 + && ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1)); + if (large_code_size) { + const int max_code_size = 24 * 1024; + const int num_ops_per_reg = 6 + jcp.ic_block * jcp.kw; + int mult = 1; + if (jcp.l_pad > 0) mult += 1; + if (r_pad > 0) mult += 1; + for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) { + if (ur_w * mult * num_ops_per_reg * 9.0 < max_code_size) { + jcp.ur_w = ur_w; + break; + } + } + } + + /* Grouped channel offset to support 'non-blocked data' format for + * convolution sizes with '(input_channel / ngroups) < simd' */ + jcp.nonblk_group_off + = (jcp.ngroups > 1 && one_of(jcp.src_tag, ncw, nchw, ncdhw)) ? + jcp.ic : + 1; + + jcp.nb_ic = jcp.ic / jcp.ic_block; + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + + auto is_ow_threading_applicable = [=]() { + return (true && !jcp.is_1stconv && one_of(jcp.ndims, 3, 4) + && IMPLICATION(mayiuse(avx512_mic), + jcp.ver == ver_4fma + && IMPLICATION(jcp.mb != 1, + jcp.ih == 1 && jcp.kh == 1))); + }; + + if (jcp.ver == ver_4fma && !jcp.is_1stconv) { + if ((jcp.kw <= 5 && jcp.kh <= 5 && jcp.kw == jcp.kh && jcp.ow <= 8 + && jcp.oh <= 8 && jcp.ow == jcp.oh) + || (jcp.stride_h != 1 && jcp.ur_w < jcp.ow)) { + if (jcp.nb_oc % 2 == 0) { + jcp.nb_oc_blocking = 2; + jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking); + } + } else { + for (int i = jcp.nb_oc; i > 0; i--) + if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) { + jcp.nb_oc_blocking = i; + break; + } + } + if (jcp.ver == ver_4fma && is_ow_threading_applicable()) { + if (jcp.nb_oc % 2 == 0 && jcp.ur_w < jcp.ow + && jcp.ow != 2 * jcp.ur_w) { + jcp.nb_oc_blocking = 2; + jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking); + } + } + } + + jcp.ow_block = jcp.ow; + + auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) { + int nb_ow = div_up(jcp.ow, ow_block); + int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking); + int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow; + float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block); + float thr_eff = disbalance * (float)work_amount + / rnd_up(work_amount, nthreads); + return thr_eff; + }; + + auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) { + int res_ow_block = jcp.ow; + eff = get_thr_eff(nb_oc_blocking, res_ow_block); + if (!is_ow_threading_applicable()) + return res_ow_block; + + int L2_part = (get_cache_size(2) * 7 / 8) / typesize; + if (jcp.ver == ver_4fma) + L2_part /= 2; + int size_src_chunk = jcp.ic_block * ur_w * jcp.kh; + int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w; + int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block + * jcp.kw * jcp.kh; + int nurw_cache = (L2_part - 2 * size_wei_chunk) + / (2 * size_dst_chunk + 2 * size_src_chunk); + // current design of generate() requires ow_block >= 2 * ur_w + int ow_block_cache = ur_w * nstl::max(2, nurw_cache); + + int ow_block_thr = ow_block_cache; + eff = get_thr_eff(nb_oc_blocking, ow_block_thr); + + int max_nb_ow = div_up(jcp.ow, 2 * ur_w); + int start_nb_ow = div_up(jcp.ow, ow_block_thr); + for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) { + int ow_block + = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow); + float eff_threshold = (jcp.ver == ver_4fma) ? 0.8f : 0.9f; + if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold) + break; + if (div_up(jcp.ow, ow_block) != nb_ow) + continue; + float thr_eff = get_thr_eff(nb_oc_blocking, ow_block); + float eff_step = (jcp.ver == ver_4fma) ? 1.1f : 1.f; + if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) { + ow_block_thr = ow_block; + eff = thr_eff; + } + eff_threshold = (jcp.ver == ver_4fma) ? 0.9f : 0.98f; + if (eff > eff_threshold) + break; + } + res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr)); + eff = get_thr_eff(nb_oc_blocking, res_ow_block); + return res_ow_block; + }; + + + if (jcp.ver == ver_fma && mayiuse(avx512_core)) { + int try_nb_oc_blocking = 2; + unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w) + * jcp.ic_block * jcp.kh * jcp.kd; + unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block + * try_nb_oc_blocking; + unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block + * jcp.oc_block * try_nb_oc_blocking * jcp.kd; + unsigned int ker_total_size = ker_inp_size + ker_out_size + + ker_wei_size; + + bool embd_bcast_condition = true + && (jcp.kw == 3 && jcp.ow <= 28 && ker_total_size < L1_cache_size) + && !(jcp.kw == 3 && jcp.ow == 13 && jcp.ic >= 192) + && !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512); + + if (jcp.mb == 1) { + unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h) + * div_up(jcp.iw, jcp.stride_w) * jcp.ic; + unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw; + + // Estimate whether we need to limit the number of threads + // and calculate this number. Includes some heuristic. + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh; + int job_size_min = work_amount / nthreads; + int job_size_max = div_up(work_amount, nthreads); + int ch_max = rnd_up(jcp.oh, job_size_max); + int ch_min = (job_size_min == 0) + ? jcp.oh + : rnd_up(jcp.oh, job_size_min); + bool not_aligned_max = ch_max % jcp.oh != 0 && ch_max / jcp.oh < 2 + && (jcp.oh != 8 || ch_max / jcp.oh > 1); + bool not_aligned_min = ch_min % jcp.oh != 0 && ch_min / jcp.oh < 2 + && (jcp.oh != 8 || ch_min / jcp.oh > 1); + bool eligible_case = (jcp.stride_h == 1 && jcp.stride_w == 1) + || nthreads > oc_chunks; + if (jcp.loop_order == loop_cgn && oc_chunks > 1 && nthreads > 1 + && wei_size / inp_size > 24 + && (not_aligned_max || not_aligned_min) + && eligible_case) { + // Try to find nthreads > mkldnn_get_max_threads() / 2 such + // that oc_chunks is a multiple of nthreads, or nthreads is a + // multiple of oc_chunks. Otherwise, keep default value. + // TODO: implement a task-based alternative without throttling. + jcp.aligned_threads = nthreads; + for (int i = nthreads; i > nthreads / 2; i--) { + if (oc_chunks % i == 0 || i % oc_chunks == 0) { + jcp.aligned_threads = i; + break; + } + } + } + } + + if (jcp.kw > 3 + || (jcp.stride_w == 1 && jcp.stride_h == 1 + && embd_bcast_condition) + || ((jcp.stride_w != 1 || jcp.stride_h != 1) + && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10) + && embd_bcast_condition))) + || (jcp.mb == 1 + && (jcp.ur_w >= jcp.ow || jcp.is_1stconv + || (jcp.ow <= 147 && jcp.oc <= 96)))) { + jcp.kernel_kind = embd_bcast; + jcp.ur_w = nstl::min(jcp.ow, regs); + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3 + && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0 + && IMPLICATION(jcp.is_1stconv, jcp.mb == 1) + && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) { + jcp.nb_oc_blocking = try_nb_oc_blocking; + jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); + } + } else { + jcp.kernel_kind = expl_bcast; + jcp.nb_ic_blocking = 1; + if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)) { + float best_thr_eff = 0.f; + int best_nb_oc_blocking = 1; + for (int i = nstl::min(jcp.nb_oc, 5); i > 0; i--) { + if (jcp.nb_oc % i == 0) { + float thr_eff; + int ur_w = nstl::min(jcp.ow, 31 / (i + 1)); + get_ow_block(i, ur_w, thr_eff); + if (thr_eff > 1.05f * best_thr_eff) { + best_nb_oc_blocking = i; + best_thr_eff = thr_eff; + } + } + } + jcp.nb_oc_blocking = best_nb_oc_blocking; + jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); + } + } + } + + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + args_ok = true + && jcp.l_pad <= jcp.ur_w + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) + return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + if (r_pad_no_tail > jcp.ur_w) + return status::unimplemented; + + pick_loop_order(jcp); + + jcp.nb_ic_L2 = jcp.nb_ic; + + float thr_eff; + jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff); + jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); + + const int L2_size = get_cache_size(2, true) / sizeof(float); + // Source and output data needs to fit in L2, + // leaving some space for weights and prefetching. + int h_L2 = int(((0.6f * L2_size) / jcp.simd_w + - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw) + / (jcp.stride_h * jcp.iw + jcp.ow)); + jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2)); + + if (jcp.ver == ver_4fma) { + if (!is_ow_threading_on(jcp)) { + for (int divf = 2, temp_nb = jcp.nb_ic_L2; divf <= jcp.nb_ic; + divf++) { + size_t l2_src + = (size_t)jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id; + size_t l2_dst = (size_t)jcp.ow * jcp.oc_block * jcp.nb_oc_blocking + * jcp.oh * jcp.od; + size_t l2_filt = (size_t)jcp.kw * jcp.oc_block * jcp.ic_block + * jcp.kh * jcp.nb_oc_blocking * temp_nb * jcp.kd; + if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { + if (jcp.kh == 3 && jcp.oh == 7) { + jcp.nb_ic_L2 = 1; + break; + } + temp_nb = (jcp.nb_ic_L2 % divf == 0 ? jcp.nb_ic_L2 / divf + : jcp.nb_ic_L2); + } else { + jcp.nb_ic_L2 = temp_nb; + break; + } + } + } else if (jcp.ic > 64) { + jcp.nb_ic_L2 = 2; /* according to performance data*/ + } + } + + return status::success; +} + +void jit_avx512_common_conv_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w) +{ + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + for (int j = 0; j < ur_w; j++) { + Zmm zmm = zmm_out(j, k); + vpxord(zmm, zmm, zmm); + size_t aux_src_offset + = (size_t)typesize * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) + * jcp.ic_block; + mic_prefetcht1(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, + reg_long_offt)); + } + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w) +{ + Label no_update_label; + + mov(reg_channel, ptr[param + GET_OFF(channel)]); + cmp(reg_channel, 0); + je(no_update_label, T_NEAR); + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + for (int j = 0; j < ur_w; j++) { + Zmm zmm = zmm_out(j, k); + size_t aux_src_offset = (size_t)typesize + * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; + vaddps(zmm, EVEX_compress_addr_safe(reg_src, aux_src_offset, + reg_long_offt)); + } + } + + L(no_update_label); + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + for (int j = 0; j < ur_w; j++) { + Zmm zmm = zmm_out(j, k); + size_t aux_src_offset = (size_t)typesize + * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; + vmovups(EVEX_compress_addr_safe(reg_src, aux_src_offset, + reg_long_offt), zmm); + mic_prefetcht0(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, + reg_long_offt)); + } + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma( + int ur_w, int l_overflow, int r_overflow) +{ + int ow = jcp.ow; + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + Label kh_label, last_iter_label, loop_end_label, kd_label; + int ker_load_number = 4; + int shift_ker_ptr = typesize * kw * oc_block * ic_block; + int shift_dst_ptr = typesize * ow * oc_block; + int ii_dpref_t0 = get_iw_start(0, l_overflow); + int iw_end_ipref = get_iw_end(ur_w, 0, r_overflow); + + bool check_last_kh = (jcp.kh > 3); + auto kernel_offset = [=](int icb, int oc, int ki) { + int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki; + int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; + int oc_offset = oc * jcp.oc_block; + return typesize * (blk_offset + oc_offset); + }; + auto kernel_loads = [=](int ki, int oc, int kk) { + for (int ii = 0; ii < ker_load_number; ii++) { + int aux_kernel_offset = kernel_offset(kk, oc + ii, ki); + vmovups(zmm_ker(ii), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + }; + auto prefetch_dst_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) { + if (cnt1 >= ker_load_number && cnt0 >= ker_load_number + && ki >= ki_start && ii_dpref_t0 < iw_end_ipref) { + int aux_dst_offset = typesize * ((ii_dpref_t0 + + jcp.l_pad) * oc_block + jcp.ow * oc_block); + prefetcht0(EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + ii_dpref_t0++; + } + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_dst_prf, reg_dst_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + } + + if (jcp.ndims == 5) { + push(reg_src_prf); + push(reg_src); + + mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_dst); + mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); + mov(aux_reg_dst_d_prf, reg_dst_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + + L(kd_label); + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_dst, aux_reg_dst_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_dst_prf, aux_reg_dst_d_prf); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + } + + align(16); + L(kh_label); + if (check_last_kh) { + for (int ki = 0; ki < kw; ki++) + for (int oc = 0; oc < oc_block; oc += 4) + for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) { + bool last_kernel_loads = (kk == jcp.nb_ic_blocking - 1 + && ki == kw - 1 && (oc + 4) == oc_block); + + if (last_kernel_loads) { + cmp(reg_kj, 1); + je(last_iter_label, T_NEAR); + } + + kernel_loads(ki, oc, kk); + for (int ii = get_iw_start(ki, l_overflow), + prf_count_t0 = 0, prf_count_t1 = 0; + ii < get_iw_end(ur_w, ki, r_overflow); ii++) { + int aux_dst_offset = typesize + * ((ii + jcp.l_pad - ki) * oc_block + oc); + v4fmaddps(zmm_out(ii, kk), zmm_ker(0), + EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + + if (ii % 2) { + if (prf_count_t0 < 4) { + int aux_kernel_prf; + if (last_kernel_loads) + aux_kernel_prf= kernel_offset(0, prf_count_t0 + + oc + 4 - oc_block, 0) + typesize * kw + * oc_block * ic_block; + else + aux_kernel_prf = kernel_offset(kk, oc + 4 + + prf_count_t0, ki); + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, + aux_kernel_prf)); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf, + kernel_offset(kk, oc + prf_count_t1, ki))); + prf_count_t1++; + } + } else + prefetch_dst_next_kh(ki, 2, prf_count_t0, prf_count_t1); + } + if (last_kernel_loads) { + jmp(loop_end_label, T_NEAR); + + L(last_iter_label); + + kernel_loads(ki, oc, kk); + for (int ii = get_iw_start(ki, l_overflow), + prf_count_t0 = 0, prf_count_t1 = 0; + ii < get_iw_end(ur_w, ki, r_overflow); ii++) { + int aux_dst_offset = typesize + * ((ii + jcp.l_pad - ki) * oc_block + oc); + v4fmaddps(zmm_out(ii, kk), zmm_ker(0), + EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + if (ii % 2) { + if (prf_count_t0 < 4) { + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker_prf, + kernel_offset(0, prf_count_t0, 0))); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf, + kernel_offset(kk, oc + prf_count_t1, ki))); + prf_count_t1++; + } + } + } + L(loop_end_label); + } + } + } else { + for (int ki = 0; ki < kw; ki++) + for (int oc = 0; oc < oc_block; oc += 4) + for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) { + kernel_loads(ki, oc, kk); + + for (int ii = get_iw_start(ki, l_overflow), prf_count_t1 = 0; + ii < get_iw_end(ur_w, ki, r_overflow); ii++) { + int aux_dst_offset = typesize + * ((ii + jcp.l_pad - ki) * oc_block + oc); + v4fmaddps(zmm_out(ii, kk), zmm_ker(0), + EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + if ((ii % 2) && (prf_count_t1 < 4)) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, + oc + prf_count_t1, ki))); + prf_count_t1++; + } + if ( ki == 1 && oc == 0 && kk == 0) + mic_prefetcht1(EVEX_compress_addr( + aux_reg_dst_prf, aux_dst_offset)); + } + } + } + + add(aux_reg_ker, shift_ker_ptr); + sub(aux_reg_dst, shift_dst_ptr); + add(aux_reg_ker_prf, shift_ker_ptr); + sub(aux_reg_dst_prf, shift_dst_ptr); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, typesize * (jcp.oh * ow) * ic_block); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block); + sub(aux_reg_dst_d_prf, typesize * (jcp.oh * ow) * ic_block); + add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh *oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_src); + pop(reg_src_prf); + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma( + int ur_w, int l_overflow, int r_overflow) +{ + Label kh_label, kd_label; + int kw = jcp.kw; + int ow = jcp.ow; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int l_pad = jcp.l_pad; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + int stride_h = jcp.stride_h; + + int ker_pipeline_depth = 4; + assert(ker_reg_base_idx + ker_pipeline_depth <= 32); + assert(oc_block >= ker_pipeline_depth); + + int num_ker_loads = oc_block * kw; + int num_inp_prfs = ur_w * nstl::min(kw, stride_w) + + nstl::max(0, kw - stride_w); + int num_prfs = num_ker_loads + num_inp_prfs; + int num_fmas = num_ker_loads * ur_w / stride_w; + int prf_inst_spacing = nstl::max(1, num_fmas / num_prfs); + int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + + mov(aux_reg_dst_prf, reg_dst_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + } + + if (jcp.ndims == 5) { + push(reg_src_prf); + push(reg_src); + + mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_dst); + mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); + mov(aux_reg_dst_d_prf, reg_dst_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + + L(kd_label); + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_dst, aux_reg_dst_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_dst_prf, aux_reg_dst_d_prf); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + } + + L(kh_label); { + int step = 0; + int ker_prfs = 0; + for (int ki = 0; ki < kw; ki++) { + for (int oc = 0; oc < oc_block; oc++) { + if (step == 0) { + for (int i = 0; i < ker_pipeline_depth; i++) { + int aux_kernel_offset = typesize * ((oc + i) * oc_block + + ki * ic_block * oc_block); + vmovups(zmm_ker(i), EVEX_compress_addr( + aux_reg_ker, aux_kernel_offset)); + } + } else if (step < num_ker_loads - ker_pipeline_depth + 1) { + int load_offset = ker_pipeline_depth - 1; + int ker_load_reg_idx + = (step + load_offset) % ker_pipeline_depth; + int aux_kernel_offset = typesize * ((oc + load_offset) + * oc_block + ki * ic_block * oc_block); + vmovups(zmm_ker(ker_load_reg_idx), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + + bool ker_prf_inserted = false; + auto zmm_kernel = zmm_ker(step % ker_pipeline_depth); + + int jj_start = get_iw_start(ki, l_overflow); + int jj_end = get_iw_end(ur_w, ki, r_overflow); + assert(stride_w != 1 + || jj_start == nstl::max(0, + l_overflow - (kw - 1 - ki) * dilate_w)); + assert(stride_w != 1 + || jj_end == ur_w - nstl::max(0, + r_overflow - ki * dilate_w)); + + for (int jj = jj_start; jj < jj_end; jj += stride_w) { + assert((jj + l_pad - ki * dilate_w) % stride_w == 0); + int aux_dst_offset = typesize * + (((jj + l_pad - ki * dilate_w) + / stride_w) * jcp.oc_block + oc); + vfmadd231ps(zmm_out(jj, 0), zmm_kernel, + EVEX_compress_addr(aux_reg_dst, aux_dst_offset, true)); + + int fma_idx = (step * ur_w + jj) / stride_w; + int prf_slot_idx = fma_idx / prf_inst_spacing; + if (fma_idx % prf_inst_spacing == prf_inst_trigger) { + if (!ker_prf_inserted && ker_prfs < num_ker_loads) { + int ker_prf_offset = typesize + * ker_prfs * jcp.oc_block; + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, ker_prf_offset)); + ker_prf_inserted = true; + ker_prfs++; + } else { + int inp_prf_idx = prf_slot_idx - ker_prfs; + if (inp_prf_idx < num_inp_prfs) { + int inp_prf_offset + = ic_block * typesize + * ((inp_prf_idx / kw) * kw + + (inp_prf_idx % kw)); + mic_prefetcht0(EVEX_compress_addr( + aux_reg_dst_prf, inp_prf_offset)); + } + } + } + } + step++; + } + } + + add(aux_reg_ker, typesize * stride_h * kw * oc_block * ic_block); + sub(aux_reg_dst, typesize * (jcp.dilate_h + 1) * ow * oc_block); + add(aux_reg_ker_prf, typesize * stride_h * kw * oc_block * ic_block); + sub(aux_reg_dst_prf, typesize * (jcp.dilate_h + 1) * ow * oc_block); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, + typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d, typesize * jcp.stride_d * jcp.kw * jcp.kh + * oc_block * ic_block); + sub(aux_reg_dst_d_prf, + typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d_prf, typesize * jcp.stride_d * jcp.kw * jcp.kh + * oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + } + + if (jcp.ndims == 5) + { + pop(reg_src); + pop(reg_src_prf); + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core( + int ur_w, int l_overflow, int r_overflow) +{ + int kw = jcp.kw; + int ow = jcp.ow; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_ic_block = jcp.nb_ic_blocking; + Label kh_label, kd_label; + + int shift_ker_ptr = typesize * kw * oc_block * ic_block; + int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_block; + + auto output_offset = [=](int oi, int oc, int ki) { + return typesize * + (((oi + jcp.l_pad - ki * dilate_w) / stride_w) * oc_block + oc); + }; + auto kernel_offset = [=](int icb, int oc, int ki) { + int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki; + int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; + int oc_offset = oc * jcp.oc_block; + return typesize * (blk_offset + oc_offset); + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + } + + if (jcp.ndims == 5) { + push(reg_src_prf); + push(reg_src); + + mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_dst); + mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); + + L(kd_label); + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_dst, aux_reg_dst_d); + mov(aux_reg_ker, aux_reg_ker_d); + } + + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_iw_start(ki, l_overflow); + int jj_end = get_iw_end(ur_w, ki, r_overflow); + for (int oc = 0; oc < oc_block; oc++) { + if (jcp.kernel_kind == expl_bcast) { + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_output_offset = output_offset(jj, oc, ki); + vbroadcastss(zmm_inp(jj, nb_ic_block), + ptr[aux_reg_dst + aux_output_offset]); + } + } + for (int ii = 0; ii < nb_ic_block; ii++) { + int aux_kernel_offset = kernel_offset(ii, oc, ki); + if (jj_end - jj_start > 0) + vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker, + aux_kernel_offset)); + for (int jj = jj_start; jj < jj_end; jj += stride_w) + if (jcp.kernel_kind == expl_bcast) + vfmadd231ps(zmm_out(jj, ii), + zmm_inp(jj, nb_ic_block), zmm_wei); + else + vfmadd231ps(zmm_out(jj, ii), zmm_wei, + EVEX_compress_addr(aux_reg_dst, + output_offset(jj, oc, ki), true)); + } + } + } + add(aux_reg_ker, shift_ker_ptr); + sub(aux_reg_dst, shift_dst_ptr); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, + typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_src); + pop(reg_src_prf); + } +} + +inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop( + int ur_w, int l_overflow, int r_overflow) +{ + if (jcp.ndims == 5) push(reg_oi); + + prepare_output(ur_w); + + Label skip_compute_loop; + if (jcp.ndims == 5) { + mov(reg_kj, ptr[param + GET_OFF(kd_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + } + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + + if (jcp.ver == ver_4fma) + compute_loop_4fma(ur_w, l_overflow, r_overflow); + else if (jcp.ver == ver_fma) + if (mayiuse(avx512_mic)) + compute_loop_fma(ur_w, l_overflow, r_overflow); + else + if (jcp.kernel_kind == embd_bcast && jcp.nb_ic_blocking == 1) + compute_loop_fma(ur_w, l_overflow, r_overflow); + else + compute_loop_fma_core(ur_w, l_overflow, r_overflow); + else + assert("!unknown convolution version"); + + L(skip_compute_loop); + store_output(ur_w); + if (jcp.ndims == 5) pop(reg_oi); +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::generate() +{ + int iw = jcp.iw; + int kw = jcp.kw; + int ur_w = jcp.ur_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int ur_w_tail = jcp.ur_w_tail; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + int dst_shift = jcp.typesize_in * (ur_w / stride_w) * ic_block; + int src_shift = jcp.typesize_out * ur_w * oc_block; + + preamble(); + + mov(reg_src, ptr[param + GET_OFF(src)]); + mov(reg_dst, ptr[param + GET_OFF(dst)]); + mov(reg_ker, ptr[param + GET_OFF(filt)]); + + mov(reg_kh, ptr[param + GET_OFF(kh_padding)]); + mov(reg_src_prf, ptr[param + GET_OFF(src_prf)]); + mov(reg_dst_prf, ptr[param + GET_OFF(dst_prf)]); + mov(reg_ker_prf, ptr[param + GET_OFF(filt_prf)]); + + int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w); + int r_overflow = nstl::max(0, ((kw - 1) * dilate_w + - nstl::max(0, jcp.r_pad)) / stride_w); + int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w + - nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w); + + int n_oi = iw / ur_w; + if (r_overflow1 > 0) n_oi--; + + if (ur_w == iw) { + compute_loop(ur_w, l_overflow, r_overflow); + } else if (n_oi == 0) { + compute_loop(ur_w, l_overflow, r_overflow1); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + if (ur_w_tail != 0) + compute_loop(ur_w_tail, 0, r_overflow); + } else { + xor_(reg_oi, reg_oi); + if (l_overflow > 0) { + compute_loop(ur_w, l_overflow, 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + + inc(reg_oi); + } + if ((l_overflow <= 0 && n_oi > 0) + || (l_overflow > 0 && n_oi > 1)) { + Label ow_loop_label; + L(ow_loop_label); { + compute_loop(ur_w, 0, 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + + inc(reg_oi); + cmp(reg_oi, n_oi); + jl(ow_loop_label, T_NEAR); + } + } + if (r_overflow1 > 0) { + compute_loop(ur_w, 0, r_overflow1); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + } + if (ur_w_tail != 0) { + compute_loop(ur_w_tail, 0, r_overflow); + } + } + + postamble(); +} + +status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf( + jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + if (!mayiuse(avx512_common)) return status::unimplemented; + + jcp = zero(); + + jcp.simd_w = cpu_isa_traits::vlen / sizeof(float); + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; + int ndims = diff_src_d.ndims(); + + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = diff_src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2]; + jcp.iw = diff_src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + if ((jcp.dilate_w != 0 && jcp.stride_w != 1) + || (jcp.dilate_d != 0 && jcp.stride_d != 1) + || (jcp.dilate_h != 0 && jcp.stride_h != 1)) + return status::unimplemented; + + jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1); + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + jcp.back_pad = (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1); + + jcp.aligned_threads = 0; + + jcp.is_1stconv = false; + + jcp.oc_block = jcp.simd_w; + jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && diff_src_d.data_type() == data_type::f32; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } + + auto dat_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = with_groups + ? pick(ndims - 3, gOIw16o16i, gOIhw16o16i, gOIdhw16o16i) + : pick(ndims - 3, OIw16o16i, OIhw16o16i, OIdhw16o16i); + jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.ic % jcp.ic_block == 0 + && jcp.src_tag == dat_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) + return status::unimplemented; + + jcp.nb_ic = jcp.ic / jcp.ic_block; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + jcp.ur_w = jcp.stride_w; + + int regs = 28; + if (jcp.iw <= regs) + jcp.ur_w = jcp.iw; + else { + for (int ur_w = regs; ur_w > 0; --ur_w) + if (ur_w % jcp.stride_w == 0) { + jcp.ur_w = ur_w; + break; + } + } + int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - jcp.l_pad) / jcp.stride_w); + int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w); + int n_oi = jcp.iw / jcp.ur_w; + if (r_overflow1 > 0) n_oi--; + + if (mayiuse(avx512_common) + && diff_dst_d.data_type() == data_type::f32 + && weights_d.data_type() == data_type::f32 + && diff_src_d.data_type() == data_type::f32) { + jcp.ver = ver_fma; + jcp.typesize_in = sizeof(float); + jcp.typesize_out = sizeof(float); + if (mayiuse(avx512_mic_4ops) + && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1) { + jcp.ver = ver_4fma; + } + } else { + return status::unimplemented; + } + + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + if (!utils::everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) + && jcp.ver != ver_fma) + return status::unimplemented; + + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + if (jcp.ver == ver_4fma) { + if (jcp.kw == 3 && jcp.kh == 3 && jcp.iw == 7 && jcp.ih == 7) { + jcp.nb_ic_blocking = 2; + } else { + for (int i = jcp.nb_ic; i > 0; i--) + if (i * jcp.ur_w <= regs && jcp.nb_ic % i == 0) { + jcp.nb_ic_blocking = i; + break; + } + } + } + + jcp.loop_order = loop_gnc; + + bool large_code_size = (jcp.ur_w != jcp.ow) + && ((l_overflow <= 0 && n_oi > 0) ||(l_overflow > 0 && n_oi > 1)) + && (r_overflow1 > 0) && (l_overflow > 0); + if (large_code_size) { + const int max_code_size = 24 * 1024; + const int num_ops_per_reg = 6 + jcp.oc_block * jcp.kw; + int mult = 1; + if (l_overflow > 0) mult += 1; + if (r_overflow1 > 0) mult += 1; + for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) { + if ((ur_w / jcp.stride_w) * mult * num_ops_per_reg * 9.2 + < max_code_size) { + if (ur_w % jcp.stride_w == 0) { + jcp.ur_w = ur_w; + break; + } + } + } + } + + if (jcp.ver == ver_fma && mayiuse(avx512_core)) { + int try_nb_ic_blocking = 2; + unsigned int ker_inp_size = typesize * jcp.iw * jcp.ic_block + * try_nb_ic_blocking * jcp.kh; + unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block; + unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block + * jcp.oc_block * try_nb_ic_blocking; + unsigned int ker_total_size = ker_inp_size + ker_out_size + + ker_wei_size; + if (!(jcp.kw == 1 || (jcp.kw == 5 && jcp.iw < 8) + || (jcp.kw < 5 && ((jcp.iw <= 5 || (jcp.iw > 8 && jcp.iw <= 13)) + || ker_total_size > L1_cache_size ))) + || jcp.stride_h > 1 || jcp.stride_d > 1) { + jcp.kernel_kind = embd_bcast; + jcp.ur_w = nstl::min(jcp.iw, regs); + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + if (!(jcp.kw > 3 || (jcp.kw == 3 && ker_total_size < L1_cache_size + && jcp.ow > 8)) && jcp.stride_h == 1) + if (jcp.nb_ic % try_nb_ic_blocking == 0) { + jcp.nb_ic_blocking = try_nb_ic_blocking; + jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1); + if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw; + } + } else { + jcp.kernel_kind = expl_bcast; + jcp.nb_oc_blocking = 1; + jcp.nb_ic_blocking = 4; + if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic; + if (jcp.nb_ic % jcp.nb_ic_blocking != 0) + for (int i = jcp.nb_ic_blocking; i > 0; i--) + if (jcp.nb_ic % i == 0) { + jcp.nb_ic_blocking = i; + break; + } + jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1); + if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw; + } + } + jcp.ur_w_tail = jcp.iw % jcp.ur_w; + + if (l_overflow * jcp.stride_w > jcp.ur_w) + return status::unimplemented; + int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); + if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w) + return status::unimplemented; + if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0)) + return status::unimplemented; + + pick_loop_order(jcp); + + jcp.nb_oc_L2 = jcp.nb_oc; + if (jcp.ver == ver_4fma && (jcp.kh < 5 && jcp.kw < 5)) { + for (int divf = 2, temp_nb = jcp.nb_oc_L2; divf <= jcp.nb_oc; + divf++) { + size_t l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih + * jcp.id; + size_t l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od; + size_t l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh + * jcp.kd * jcp.nb_ic_blocking * temp_nb; + if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { + if (jcp.kh == 3 && jcp.ih == 7) { + jcp.nb_oc_L2 = 1; + break; + } + temp_nb = (jcp.nb_oc_L2 % divf == 0 ? jcp.nb_oc_L2 / divf + : jcp.nb_oc_L2); + } else { + jcp.nb_oc_L2 = temp_nb; + break; + } + } + } + + args_ok = true + && jcp.ic <= diff_src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) return status::unimplemented; + + return status::success; +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + UNUSED(scratchpad); + UNUSED(jcp); +} + +const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28; + +void jit_avx512_common_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() +{ + Label kd_comeback_label; + + /* 'depth' loop count bound by 'kd_work_size' */ + mov(kj, reg_kd_count); + L(kd_comeback_label); { + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + sub(reg_input, + jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult); + sub(reg_kernel, + jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kd_comeback_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() +{ + Label kh_comeback_label, kd_comeback_label; + mov(kj, reg_kh); + L(kh_comeback_label); { + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult); + sub(reg_kernel, + jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_comeback_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma( + int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset, bool input_wraparound) +{ + + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) + vmovups(Zmm(i_kw * ic_block_step + i_ic), + EVEX_compress_addr(reg_kernel, typesize * (i_kw * ic_block + + i_ic) * jcp.oc_block + kernel_offset)); + + for (int i_ur = 0; i_ur < ur_w; i_ur++) { + if (i_ur == 0) { + vmovups(Zmm(kw * ic_block_step + (i_ur + 0) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 0) + * oc_block + output_offset)); + if (ur_w > 1) vmovups(Zmm(kw * ic_block_step + (i_ur + 1) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 1) * oc_block + + output_offset)); + if (ur_w > 2) vmovups(Zmm(kw * ic_block_step + (i_ur + 2) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 2) * oc_block + + output_offset)); + if (ur_w > 3) vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block + + output_offset)); + } else if (i_ur + 3 < ur_w) + vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block + + output_offset)); + + for (int i_kw = 0; i_kw < kw; i_kw++) { + int i_iw = i_ur * jcp.stride_w + i_kw * (jcp.dilate_w + 1); + if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w + + (kw - 1) * (jcp.dilate_w + 1) - pad_r) continue; + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + const size_t i_offset = (size_t)input_offset + + (size_t)typesize * (jcp.ver == ver_4fma + ? (i_iw - pad_l + i_ic * jcp.tr_iw) + : (jcp.is_1stconv + ? (i_iw - pad_l) + (size_t)i_ic + * ((size_t)jcp.ih*jcp.iw*jcp.id) + : (i_iw - pad_l) * ic_block + i_ic)); + vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic), + Zmm(kw * ic_block_step + i_ur % 4), + EVEX_compress_addr_safe(reg_input, i_offset, reg_long_offt, + true)); + } + } + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) + vmovups(EVEX_compress_addr(reg_kernel, typesize + * (i_kw * ic_block + i_ic) * jcp.oc_block + kernel_offset), + Zmm(i_kw * ic_block_step + i_ic)); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_4fma( + int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset, bool input_wraparound) +{ + // TODO: add prefetches to fma version as well + + assert(jcp.ver == ver_4fma); + + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + auto zmm_ker = [=](int i_kw, int i_ic) { + return Zmm(i_kw * ic_block_step + i_ic); + }; + + auto ker_addr = [=](int i_kw, int i_ic) { + size_t local_offset + = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block; + return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset); + }; + + auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0) { + int stride = jcp.tr_iw * (jcp.is_1stconv ? jcp.ih : 1); + int local_offset = jcp.typesize_in * (i_iw + i_ic * stride); + return EVEX_compress_addr(reg_input, + local_offset + input_offset + extra_offset); + }; + + auto zmm_out = [=](int i_iw) { + // TODO: move reg calc to global member funcs + const int out_zmm_base_idx = 28; + return Zmm(out_zmm_base_idx + i_iw % 4); + }; + + auto out_addr = [=](int i_ur) { + return EVEX_compress_addr(reg_output, + jcp.typesize_in * i_ur * oc_block + output_offset); + }; + + auto pf_callback = [=](int i_ur, int i_kw, int i_ic) { + assert(i_ur % 4 == 0); + if (i_ur == 0) + prefetcht1(ker_addr(i_kw, i_ic)); + if (i_ur + 4 >= ur_w) + prefetcht0(ker_addr(i_kw, i_ic)); + + const ptrdiff_t next_input_block_offset + = jcp.typesize_in * ic_block_step * jcp.tr_iw; + if (i_ur % 16 == 4 && i_kw == 0) { + if (i_ur + 16 < ur_w) + prefetcht0(inp_addr(i_ur + 16, i_ic)); + else + prefetcht0(inp_addr(0, i_ic, next_input_block_offset)); + } + if (i_ur % 16 == 4 && i_kw == 1) { + if (input_wraparound) + prefetcht1(inp_addr(i_ur, i_ic, -input_offset)); + else + prefetcht1(inp_addr(i_ur, i_ic, next_input_block_offset)); + } + }; + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + auto zmm = zmm_ker(i_kw, i_ic); + vpxord(zmm, zmm, zmm); + } + + for (int i_ur = 0; i_ur < ur_w; i_ur += 4) { + + for (int i = 0; i < 4; i++) { + auto zmm = zmm_out(i_ur + i); + if (i_ur + i < ur_w) + vmovups(zmm, out_addr(i_ur + i)); + else + vpxord(zmm, zmm, zmm); + prefetcht0(out_addr(i_ur + i + 4)); + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + int i_iw = i_ur + i_kw; + v4fmaddps(zmm_ker(i_kw, i_ic), + zmm_out(i_ur), inp_addr(i_iw, i_ic)); + pf_callback(i_ur, i_kw, i_ic); + } + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + auto addr = ker_addr(i_kw, i_ic); + auto zmm = zmm_ker(i_kw, i_ic); + vaddps(zmm, zmm, addr); + vmovups(addr, zmm); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step( + int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset, bool input_wraparound) +{ + if (jcp.ver == ver_4fma) + compute_ic_block_step_4fma(ur_w, pad_l, pad_r, + ic_block_step, input_offset, kernel_offset, output_offset, + input_wraparound); + else if (jcp.ver == ver_fma) + compute_ic_block_step_fma(ur_w, pad_l, pad_r, + ic_block_step, input_offset, kernel_offset, output_offset, + input_wraparound); + else + assert(!"unknown convolution version"); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_unroll_ow_icblock( + int ic_block_step, int max_ur_w) +{ + UNUSED(max_ur_w); + + Label kh_label, kd_label; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int inp_mul = !jcp.is_1stconv ? ic_block : 1; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + int ow = jcp.ow; + + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + int l_pad = jcp.l_pad; + + if (jcp.ndims == 5) { + L(kd_label); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + L(kh_label); + { + for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) { + const int input_offset = jcp.typesize_in + * (jcp.ver == ver_4fma ? i_b_ic * iw : i_b_ic); + compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step, + input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0, + i_b_ic + ic_block_step >= jcp.ic_block); + } + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); + add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_label, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_input, + jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul); + add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_unroll_ow( + int ic_block_step, int max_ur_w) +{ + Label kh_label, ic_block_label, kd_label; + + UNUSED(max_ur_w); + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + int ow = jcp.ow; + + int r_pad = nstl::max(0, + (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + int l_pad = jcp.l_pad; + + if (jcp.ndims == 5) { + L(kd_label); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + L(kh_label); + { + xor_(b_ic, b_ic); + L(ic_block_label); { + compute_ic_block_step(ow, l_pad, r_pad, ic_block_step, + 0, 0, 0); + size_t inp_icblk_stride = jcp.is_1stconv + ? (size_t)jcp.ih * jcp.iw * jcp.id + : (jcp.ver == ver_4fma ? jcp.tr_iw : 1); + size_t input_offset + = inp_icblk_stride * jcp.typesize_in * ic_block_step; + safe_add(reg_input, input_offset, reg_long_offt); + add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); + add(b_ic, ic_block_step); + cmp(b_ic, jcp.ic_block); + jl(ic_block_label, T_NEAR); + } + + if (jcp.is_1stconv) { + size_t input_offset + = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, input_offset, reg_long_offt); + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); + } else if (jcp.ver != ver_4fma) { + add(reg_input, jcp.typesize_in + * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block); + } + add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_label, T_NEAR); + } + if (jcp.ndims == 5) { + add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih + * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); + add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_common( + int ic_block_step, int max_ur_w) +{ + Label kh_label, ic_block_label, ow_block_label, kd_label; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + int ow = jcp.ow; + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + int l_pad = jcp.ver == ver_4fma ? 0 : jcp.l_pad; + + int ur_w = nstl::min(ow, max_ur_w); + int ur_w_trips = ow / ur_w; + int ur_w_tail = ow % ur_w; + if ((ur_w_tail == 0 && r_pad != 0) + || r_pad >= ur_w_tail) { + if (ur_w_trips > 1) { + ur_w_tail += ur_w; + ur_w_trips--; + } else { + ur_w_tail += (ur_w - ur_w / 2); + ur_w = ur_w / 2; + } + } + + int inp_mult = (jcp.is_1stconv || jcp.ver == ver_4fma) ? 1 : ic_block; + int input_comeback = (ur_w_trips * ur_w * jcp.stride_w - l_pad) * inp_mult; + int output_comeback = ur_w_trips * ur_w * oc_block; + + if (jcp.ndims == 5) { + L(kd_label); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + L(kh_label); { + xor_(b_ic, b_ic); + L(ic_block_label); { + if (l_pad != 0) { + ur_w_trips--; + compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0); + add(reg_input, jcp.typesize_in * (ur_w * jcp.stride_w - l_pad) + * inp_mult); + add(reg_output, jcp.typesize_in * ur_w * oc_block); + } + + if (ur_w_trips > 0) { + xor_(reg_ur_w_trips, reg_ur_w_trips); + L(ow_block_label); { + compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0); + add(reg_input, jcp.typesize_in * ur_w * jcp.stride_w + * inp_mult); + add(reg_output, jcp.typesize_in * ur_w * oc_block); + + inc(reg_ur_w_trips); + cmp(reg_ur_w_trips, ur_w_trips); + jl(ow_block_label, T_NEAR); + } + } + + if (ur_w_tail > 0) compute_ic_block_step(ur_w_tail, 0, r_pad, + ic_block_step, 0, 0, 0); + + sub(reg_input, jcp.typesize_in * input_comeback); + sub(reg_output, jcp.typesize_in * output_comeback); + int inp_icblk_stride = jcp.is_1stconv + ? jcp.ih * jcp.iw * jcp.id + : (jcp.ver == ver_4fma ? jcp.tr_iw : 1); + size_t input_offset + = inp_icblk_stride * jcp.typesize_in * ic_block_step; + safe_add(reg_input, input_offset, reg_long_offt); + add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); + + add(b_ic, ic_block_step); + cmp(b_ic, jcp.ic_block); + jl(ic_block_label, T_NEAR); + } + if (jcp.is_1stconv) { + size_t input_offset + = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, input_offset, reg_long_offt); + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); + } else if (jcp.ver != ver_4fma) { + add(reg_input, jcp.typesize_in + * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block); + } + add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_label, T_NEAR); + } + if (jcp.ndims == 5) { + add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih + * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); + add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_disp() +{ + int ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw <= 7 ? 4 : 2); + if (jcp.is_1stconv) { + bool large_code = jcp.kw >= 7 && (jcp.l_pad > 0 || jcp.t_pad > 0); + ic_block_step + = (jcp.kw * jcp.ic_block <= 28 && !large_code) ? jcp.ic_block : 1; + } + + bool too_large_to_unroll + = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1) + && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1); + + int ow = jcp.ow; + if (jcp.ndims == 5) { + /* NOTE: reg_kd_count = aux_reg_input = r12. The following order of + * 'movs' must be guaranteed. */ + mov(ki, reg_kd_count); + push(reg_kd_count); + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + } + + if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll) + compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w); + else if (ow <= max_ur_w) + compute_oh_step_unroll_ow(ic_block_step, max_ur_w); + else + compute_oh_step_common(ic_block_step, max_ur_w); + + if (jcp.ndims == 5) { + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + pop(reg_kd_count); + od_step_comeback_pointers(); + } else { + oh_step_comeback_pointers(); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::maybe_zero_kernel() +{ + Label skip_zeroing, zeroing_loop; + + mov(reg_tmp, ptr[param + GET_OFF(channel)]); + cmp(reg_tmp, 0); + jz(skip_zeroing, T_NEAR); + + Zmm zero = Zmm(0); + vpxord(zero, zero, zero); + xor_(reg_tmp, reg_tmp); + L(zeroing_loop); { + assert(jcp.oc_block * jcp.typesize_out + == cpu_isa_traits::vlen); + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) + vmovups(ptr[reg_kernel + reg_tmp + ic1 * jcp.oc_block + * jcp.typesize_out], zero); + add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out); + cmp(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.kw * jcp.kh * jcp.kd + * jcp.typesize_out); + jnz(zeroing_loop); + } + + L(skip_zeroing); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel() +{ + Label skip_bias, bias_loop, skip_load_bias; + + mov(reg_tmp, ptr[param + GET_OFF(flags)]); + test(reg_tmp,reg_tmp); + jne(skip_bias, T_NEAR); + + mov(reg_bias, ptr[param + GET_OFF(bias)]); + mov(reg_output, ptr[param + GET_OFF(dst)]); + vpxord(Zmm(1), Zmm(1), Zmm(1)); + + mov(reg_tmp, ptr[param + GET_OFF(channel)]); + cmp(reg_tmp, 0); + jne(skip_load_bias, T_NEAR); + vmovups(Zmm(1), ptr[reg_bias]); + + L(skip_load_bias); + + mov(reg_oi, ptr[param + GET_OFF(d_worksize)]); + sub(reg_oi, ptr[param + GET_OFF(d_index)]); + mov(reg_tmp, jcp.oc_block * jcp.ow * jcp.oh * jcp.typesize_out); + imul(reg_oi, reg_tmp); + + xor_(reg_tmp, reg_tmp); + L(bias_loop); { + vmovups(Zmm(0), ptr[reg_output + reg_tmp]); + vaddps(Zmm(1), Zmm(1), Zmm(0)); + add(reg_tmp, jcp.oc_block * jcp.typesize_out); + cmp(reg_tmp, reg_oi); + jl(bias_loop); + } + vmovups(EVEX_compress_addr(reg_bias,0), Zmm(1)); + + L(skip_bias); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_loop_common() +{ + int b_pad = jcp.b_pad; + int t_pad = jcp.t_pad; + bool is_dilated = jcp.dilate_h != 0; + int dilate_h = jcp.dilate_h + 1; + int stride_h = jcp.stride_h; + const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label, + oh_bpad_label, oh_bpad_label_end, od_label, od_label_end, + oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end; + + int ow = jcp.ow; + + mov(reg_kh, jcp.kh); + xor_(reg_ih_count, reg_ih_count); + xor_(reg_oj, reg_oj); + /* Compute 'top' edge */ + if (t_pad > 0) { + const int kh_range = 1 + (jcp.kh - 1) * dilate_h; + const int overflow + = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h)); + const int underflow = div_up(t_pad, dilate_h); + const int initial_inp_ker_overlap = jcp.kh - overflow - underflow; + mov(reg_kh, initial_inp_ker_overlap); + add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block + * jcp.oc_block); + // generate loop to process kernel while it remains within t_pad + ih + if (kh_range < t_pad + jcp.ih) { + if (is_dilated) { + const int tail = t_pad % dilate_h; + const int shift = tail == 0 ? 0 : dilate_h - tail; + mov(reg_tmp, shift); + if (tail != 0) + add(reg_input, jcp.typesize_in * shift * iw * inp_mult); + } + L(oh_tpad_label); { + compute_oh_step_disp(); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + if (is_dilated) { + inc(reg_tmp); + cmp(reg_tmp, dilate_h); + jl(oh_dilate_label_shift, T_NEAR); + // unshift input as new kernel element enters + sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult); + xor_(reg_tmp, reg_tmp); + } + // kernel overlap only changes when (t_pad + oj) % dilate_h == 0 + sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw + * jcp.ic_block * jcp.oc_block); + add(reg_kh, stride_h); + if (is_dilated) { + jmp(oh_dilate_label_noshift, T_NEAR); + L(oh_dilate_label_shift); + // shift input as old kernel element progresses + add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); + L(oh_dilate_label_noshift); + } + inc(reg_oj); + add(reg_ih_count, stride_h); + + // final number of kernel elements that overlap with input + const int final_inp_ker_overlap + = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h)); + cmp(reg_kh, final_inp_ker_overlap); + jl(oh_tpad_label, T_NEAR); + } + } + // need second loop to process kernel if it is larger than the input + // (does not apply to dilations as they must have unit stride) + if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h : + t_pad % stride_h)) { + assert(!is_dilated); + mov(reg_kh, jcp.ih); + L(oh_tpad_tail_label); { + compute_oh_step_disp(); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw + * jcp.ic_block * jcp.oc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + + cmp(reg_ih_count, nstl::min(t_pad, jcp.oh * stride_h)); + jl(oh_tpad_tail_label, T_NEAR); + } + } + // correct any excess shifts to kernel and input + // (does not apply to dilations as they must have unit stride, + // kernel must fit inside input, and padding is smaller than input) + if (t_pad <= jcp.oh * stride_h) { + // kernel has moved beyond padding (adjust for stride effects) + if (t_pad % stride_h != 0) { + assert(!is_dilated); + int inp_corr = stride_h - t_pad % stride_h; + add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw + * jcp.ic_block * jcp.oc_block); + add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult); + } + } else { + // kernel still overlaps padding (complete reset) + assert(!is_dilated); + sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h) + * jcp.kw * jcp.ic_block * jcp.oc_block); + } + } + + cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); + jge(oh_label_end, T_NEAR); + cmp(reg_oj, jcp.oh); + jge(oh_label, T_NEAR); + + /* Compute middle block(s) */ + mov(reg_kh, jcp.kh); + L(oh_label); { + compute_oh_step_disp(); + add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + + cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); + jge(oh_label_end, T_NEAR); + + cmp(reg_oj, jcp.oh); + jl(oh_label, T_NEAR); + } + L(oh_label_end); + + /* Compute bottom edge */ + if (b_pad > 0) { + cmp(reg_oj, jcp.oh); + jge(oh_bpad_label_end, T_NEAR); + + if (is_dilated) { + mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations + mov(reg_tmp, 0); + } else { + mov(reg_kh, jcp.ihp - b_pad); + sub(reg_kh, reg_ih_count); + } + L(oh_bpad_label); + { + compute_oh_step_disp(); + add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + if (is_dilated) { + inc(reg_tmp); + cmp(reg_tmp, dilate_h); + jl(oh_dilate_label_end, T_NEAR); + xor_(reg_tmp, reg_tmp); + } + sub(reg_kh, stride_h); + cmp(reg_kh, 0); + jle(oh_bpad_label_end, T_NEAR); + if (is_dilated) + L(oh_dilate_label_end); + + inc(reg_oj); + cmp(reg_oj, jcp.oh); + jl(oh_bpad_label, T_NEAR); + } + L(oh_bpad_label_end); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_d_loop_common() { + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + int ow = jcp.ow; + const int input_backpad_overlap + = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d); + + const size_t filter_shift + = jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block; + const size_t input_shift = jcp.typesize_in * jcp.ih * iw * inp_mult; + const size_t output_shift = jcp.typesize_in * jcp.oh * ow * jcp.oc_block; + + Label d_loop_label, loop_end_label, common_block_label, fpad_end_label, + backpad_end_label, backpad_label; + + if (jcp.with_bias) bias_kernel(); + + /* initially offset 'kd' by f_pad */ + add(reg_kernel, ptr[param + GET_OFF(kd_offset)]); + + mov(reg_input_d, ptr[param + GET_OFF(src)]); + mov(reg_output_d, ptr[param + GET_OFF(dst)]); + mov(reg_d_index, ptr[param + GET_OFF(d_index)]); + mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]); + + cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]); + jge(loop_end_label, T_NEAR); + + L(d_loop_label); + + mov(reg_input, reg_input_d); + mov(reg_output, reg_output_d); + + push(reg_input_d); + push(reg_output_d); + push(reg_d_index); + + compute_oh_loop_common(); + + pop(reg_d_index); + pop(reg_output_d); + pop(reg_input_d); + + /* Compute 'front' edge */ + if (jcp.f_pad > 0) { + + /* Check if within fpad region */ + cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d)); + jge(fpad_end_label, T_NEAR); + + /* Fpad steps */ + sub(reg_kernel, filter_shift * jcp.stride_d); + add(reg_kd_count, jcp.stride_d); + + /* Final number of kernel elements that overlap with input */ + const int inp_ker_overlap = nstl::min(jcp.kd, jcp.id); + cmp(reg_kd_count, inp_ker_overlap); + jl(common_block_label, T_NEAR); + + /* Correct any excess shifts to kernel and input */ + if (jcp.f_pad <= jcp.od * jcp.stride_d) { + /* Filter has moved beyond padding (adjust for stride effects) */ + if (jcp.f_pad % jcp.stride_d != 0) { + int inp_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d; + add(reg_kernel, filter_shift * inp_corr); + add(reg_input_d, input_shift * inp_corr); + } + } else { + /* Filter still overlaps padding (complete reset) */ + sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift); + } + + /* Apply correction */ + mov(reg_kd_count, jcp.kd); + jmp(common_block_label); + + L(fpad_end_label); + } + + /* Compute bottom edge */ + if (jcp.back_pad > 0) { + + /* Check if within back_pad region */ + cmp(reg_d_index, input_backpad_overlap - 1); + jl(backpad_end_label, T_NEAR); + jg(backpad_label, T_NEAR); + + /* Execute overlap correction between the filter and the initial + * back_pad region. */ + mov(reg_kd_count, + jcp.id + jcp.f_pad - input_backpad_overlap * jcp.stride_d); + jmp(backpad_end_label, T_NEAR); + + L(backpad_label); + sub(reg_kd_count, jcp.stride_d); + cmp(reg_kd_count, 0); + jle(loop_end_label, T_NEAR); + + L(backpad_end_label); + } + + /* Compute middle block */ + add(reg_input_d, input_shift * jcp.stride_d); + + /* Execute common block and loop */ + L(common_block_label); + add(reg_output_d, output_shift); + inc(reg_d_index); + cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]); + jl(d_loop_label, T_NEAR); + + L(loop_end_label); +} + +bool jit_avx512_common_conv_bwd_weights_kernel_f32::compute_full_spat_loop() { + // FIXME: use register mapping from the class declaration + bool ok = jcp.ver == ver_4fma + && everyone_is(0, jcp.dilate_h, jcp.dilate_w) + && everyone_is(1, jcp.stride_h, jcp.stride_w); + if (!ok) return false; + if (jcp.l_pad != jcp.kw / 2 || jcp.t_pad != jcp.kh / 2) + return false; + + // General code layout: + // + // Blocking over OH -- top level + // (Reduces L2 pressure; not very useful right now) + // Loop over all KHxKW kernel -- emit_kh_kw_loop() + // Loop over OH block -- emit_h_loop() + // Loop over OW blocks -- emit_fma_block() + // (Supports both fully unrolled and partially unrolled versions to + // reduce code size) + // Loop over OW block -- emit_fma_step() + + int max_working_set_size = 128 * 1024; + int pad_ow = jcp.ow; + + int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in; + int out_row_size = jcp.oc_block * pad_ow * jcp.typesize_in; + int row_size = inp_row_size + out_row_size; + + int h_block_size = jcp.oh; + int working_set_size = row_size * h_block_size; + + if (working_set_size > max_working_set_size) { + int opt_working_set_size = 48 * 1024; + assert(opt_working_set_size < max_working_set_size); + + while (working_set_size > opt_working_set_size) { + for (int i = 2; i <= h_block_size; i++) + if (i == h_block_size) + h_block_size = h_block_size / 2; + else if (h_block_size % i == 0) { + h_block_size = h_block_size / i; + break; + } + working_set_size = row_size * h_block_size; + + if (h_block_size == 1 && working_set_size > opt_working_set_size) + return false; + } + } + + // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size (see below) + if (h_block_size < nstl::max(1, jcp.t_pad) + || jcp.b_pad > (jcp.oh % h_block_size == 0 ? h_block_size + : jcp.oh % h_block_size)) + return false; + + // check that we can use simple arithmetic for prefetch address + // calculations + // TODO: we need some traits for this check (Roma) + int cache_line_size = 64; + assert(jcp.ic_block * typesize == 64); + assert(jcp.oc_block * typesize == 64); + + int num_inp_l2_pfs = jcp.tr_iw * h_block_size; + int avg_h_loop_len = h_block_size; + int num_inp_l2_pfs_per_fma_block + = div_up(num_inp_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh); + int num_out_l2_pfs = pad_ow * h_block_size; + int num_out_l2_pfs_per_fma_block + = div_up(num_out_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh); + + Opmask reg_h_block = k1; // 32-bit only on Intel(R) Xeon Phi(TM) processors + Reg64 reg_kh = rax; + Reg64 reg_kw = rbx; + Reg64 reg_tmp = abi_not_param1; + Reg32 reg_tmp_w = reg_tmp.cvt32(); + Reg64 reg_ohs = rdx; + Reg64 reg_ihs = rsi; + Reg64 reg_h = r8; + Reg64 reg_i = r9; + Reg64 reg_j = r10; + + Reg64 reg_inp = r13; + Reg64 reg_out = r14; + Reg64 reg_ker = r15; + + Reg64 reg_inp_pf_l1 = rbp; + + Reg64 reg_inp_pf_l2 = r11; + Reg64 reg_out_pf_l2 = r12; + + Xmm reg_inp_pf_save = xmm17; + Xmm reg_out_pf_save = xmm18; + + Reg64 reg_inp_save = abi_param1; + Reg64 reg_out_save = reg_tmp; + + auto zmm_out = [&](int oi) { return Zmm(24 + oi % 8); }; + auto zmm_ker = [&](int ic1) { return Zmm(ic1); }; + auto inp_addr = [&](int oi, int ic1) { + return ptr[reg_inp + (ic1 * jcp.tr_iw + oi) * jcp.typesize_in]; + }; + auto out_addr = [&](int oi, int oj = 0) { + assert(jcp.ver == ver_4fma); + return ptr[reg_out + + ((oi + oj * jcp.ow) * jcp.oc_block) * jcp.typesize_in]; + }; + auto ker_addr = [&](int ic1) { + return ptr[reg_ker + ic1 * jcp.oc_block * jcp.typesize_out]; + }; + + auto emit_block = [&](int h_block_size, + bool is_last_block, bool is_last_kh_kw_iter, bool is_last_row) + { + // TODO: add an fma version (Roma) + auto pad_ow = jcp.ow; + + int ow4u = rnd_up(pad_ow, 4); + int def_step_size = 16; + + bool has_w_tail = (pad_ow % def_step_size != 0 + || pad_ow % 4 != 0); + bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail; + + auto emit_step = [&](int ur_ow, + int num_inp_l1_pfs_per_fma_step, + int num_inp_l2_pfs_per_fma_step, + int num_out_l2_pfs_per_fma_step, bool is_w_tail) + { + bool block_wraparound = is_w_tail && is_last_row; + + assert(ur_ow % 4 == 0); + int tail_size = ow4u % ur_ow; + int this_ur_ow + = (is_w_tail && tail_size) ? tail_size : ur_ow; + int ow_last_chunk4 = pad_ow % 4; + int ow_zero_tail4 = ow_last_chunk4 + ? 4 - ow_last_chunk4 : 0; + + auto emit_out_pf = [&](int oi) { +#if 1 + if (oi + def_step_size < ur_ow || !block_wraparound) + mic_prefetcht0(ptr[reg_out + + ((def_step_size + oi) + * jcp.oc_block * jcp.typesize_in)]); + else { + assert(block_wraparound); + assert(oi + def_step_size >= ur_ow); + mic_prefetcht0(ptr[reg_out_save + + ((oi + def_step_size - ur_ow) + * jcp.oc_block * jcp.typesize_in)]); + } +#else + // XXX: This is an alternative prefetching strategy that + // always prefetches the next row. Keeping it here for + // future experiments (Roma) + if (!block_wraparound) + mic_prefetcht0(ptr[reg_out + + (jcp.ow + oi) * jcp.oc_block * jcp.typesize_in]); + else + mic_prefetcht0(ptr[reg_out + reg_ohs + - ((h_block_size - 1) * jcp.ow + - oi) * jcp.oc_block * jcp.typesize_in]); +#endif + if (oi < num_out_l2_pfs_per_fma_step) + mic_prefetcht1(ptr[reg_out_pf_l2 + + oi * jcp.oc_block * jcp.typesize_in]); + }; + + auto emit_inp_pf = [&](int oi4, int ic1) { + int pf_slot_idx = ic1 + oi4 / 4 * jcp.ic_block; + int num_pf_slots = jcp.ic_block * ur_ow / 4; + + int num_pfs = num_inp_l1_pfs_per_fma_step + + num_inp_l2_pfs_per_fma_step; + int pf_freq = nstl::max(1, num_pf_slots / num_pfs); + + if (pf_slot_idx % pf_freq) + return; + + int pf_idx = pf_slot_idx / pf_freq; + + if (pf_idx < num_inp_l2_pfs_per_fma_step) + mic_prefetcht1(ptr[reg_inp_pf_l2 + + pf_idx * jcp.ic_block * jcp.typesize_in]); + else { + pf_idx -= num_inp_l2_pfs_per_fma_step; + // prefetch the 'tail' of the cache line because most of + // the accesses are not aligned + mic_prefetcht0(ptr[reg_inp_pf_l1 + + pf_idx * jcp.ic_block * jcp.typesize_in + + cache_line_size - jcp.typesize_in]); + } + }; + + auto numloads = 4; + + int steps = this_ur_ow; + for (int oi4 = 0; oi4 < steps; oi4 += numloads) { + for (int oi1 = 0; oi1 < numloads; oi1++) { + int oi = oi4 + oi1; + if (!is_w_tail || oi < (this_ur_ow - ow_zero_tail4)) { + vmovups(zmm_out(oi), out_addr(oi)); + emit_out_pf(oi); + } else { + auto zmm = zmm_out(oi); + vpxord(zmm, zmm, zmm); + } + } + + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + if (jcp.ver == ver_4fma) { + v4fmaddps(zmm_ker(ic1), + zmm_out(oi4), inp_addr(oi4, ic1)); + } else { + assert(!"unknown convolution version"); + } + emit_inp_pf(oi4, ic1); + } + } + }; + + // Input is transposed and padded but we only access about jcp.iw + // elements so use that to compute the # of cache lines in each 'row' + int num_inp_l1_pfs + = div_up(jcp.iw * jcp.typesize_in, cache_line_size) * jcp.ic_block; + + if (full_w_unroll) { + emit_step(ow4u, num_inp_l1_pfs, + num_inp_l2_pfs_per_fma_block, + num_out_l2_pfs_per_fma_block, true); + add(reg_inp_pf_l2, num_inp_l2_pfs_per_fma_block * cache_line_size); + add(reg_out_pf_l2, num_out_l2_pfs_per_fma_block * cache_line_size); + } else { + Label w_loop; + int num_w_iters = pad_ow / def_step_size; + int num_w_iters_full = num_w_iters + has_w_tail; + int num_inp_l1_pfs_per_fma_step + = div_up(num_inp_l1_pfs, num_w_iters_full); + int num_inp_l2_pfs_per_fma_step + = div_up(num_inp_l2_pfs_per_fma_block, num_w_iters_full); + int num_out_l2_pfs_per_fma_step + = div_up(num_out_l2_pfs_per_fma_block, num_w_iters_full); + mov(reg_i, num_w_iters); + L(w_loop); { + emit_step(def_step_size, num_inp_l1_pfs_per_fma_step, + num_inp_l2_pfs_per_fma_step, + num_out_l2_pfs_per_fma_step, false); + add(reg_inp, def_step_size * jcp.typesize_in); + add(reg_out, def_step_size * jcp.oc_block * jcp.typesize_in); + add(reg_inp_pf_l1, + num_inp_l1_pfs_per_fma_step * cache_line_size); + add(reg_inp_pf_l2, + num_inp_l2_pfs_per_fma_step * cache_line_size); + add(reg_out_pf_l2, + num_out_l2_pfs_per_fma_step * cache_line_size); + sub(reg_i, 1); + jnz(w_loop); + } + if (has_w_tail) { + emit_step(def_step_size, num_inp_l1_pfs_per_fma_step, + num_inp_l2_pfs_per_fma_step, + num_out_l2_pfs_per_fma_step, true); + add(reg_inp_pf_l2, + num_inp_l2_pfs_per_fma_step * cache_line_size); + add(reg_out_pf_l2, + num_out_l2_pfs_per_fma_step * cache_line_size); + } + // reset reg_inp and reg_out because emit_h_loop expects + // unmodified pointers + int w_offset = num_w_iters * def_step_size; + sub(reg_inp, w_offset * jcp.typesize_in); + sub(reg_out, w_offset * jcp.oc_block * jcp.typesize_in); + } + }; + + auto emit_h_loop = [&](int h_block_size, + bool is_last_block, bool is_last_kh_kw_iter) + { + Label h_loop, skip_h_loop; + mov(reg_j, 1); + cmp(reg_j, reg_h); + je(skip_h_loop, T_NEAR); + L(h_loop); { + + lea(reg_inp_pf_l1, + ptr[reg_inp + jcp.tr_iw * jcp.ic_block * jcp.typesize_in]); + emit_block(h_block_size, + is_last_block, is_last_kh_kw_iter, false); + + add(reg_inp, jcp.tr_iw * jcp.ic_block * jcp.typesize_in); + add(reg_out, pad_ow * jcp.oc_block * jcp.typesize_in); + add(reg_j, 1); + cmp(reg_j, reg_h); + jb(h_loop); + } + + L(skip_h_loop); + + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) + mic_prefetcht0(ker_addr(ic1)); + + lea(reg_inp_pf_l1, ptr[reg_inp_save + reg_kw * jcp.typesize_in]); + emit_block(h_block_size, is_last_block, is_last_kh_kw_iter, true); + }; + + auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block, + int h_block_size) + { + xor_(reg_kh, reg_kh); + Label kh_loop, kh_loop_end; + + int last_oh_block_size + = jcp.oh - rnd_up(jcp.oh - h_block_size, h_block_size); + int oh_block_size = (is_last_block) ? last_oh_block_size : h_block_size; + // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size + int ih_block_size = oh_block_size - 1 + jcp.kh + - is_first_block * jcp.t_pad - is_last_block * jcp.b_pad; + + L(kh_loop); { + // determine starting indices for this block + if (is_first_block) { + xor_(reg_tmp, reg_tmp); + mov(reg_ohs, jcp.t_pad); + sub(reg_ohs, reg_kh); + cmovb(reg_ohs, reg_tmp); + + mov(reg_ihs, reg_ohs); + sub(reg_ihs, jcp.t_pad); + add(reg_ihs, reg_kh); + } else { + xor_(reg_ohs, reg_ohs); + mov(reg_ihs, reg_kh); + } + + // determine effective size of block based on padding + mov(reg_tmp, oh_block_size); + sub(reg_tmp, reg_ohs); + mov(reg_h, ih_block_size); + sub(reg_h, reg_ihs); + cmp(reg_tmp, reg_h); + cmovb(reg_h, reg_tmp); + + Label kh_loop_work; + cmp(reg_h, 0); + jg(kh_loop_work, T_NEAR); + + // empty h loop for this jcp.kh: + // - set the output to 0 if necessary + // - move ker pt + // - jump to the end + sub(reg_h, 1); + Label skip_ker_zeroing; + + // The reg_ker ptr has highest bit set if the output needs to be + // zeroed. Those who have byte-aligned their data will suffer the + // consiquences :( + // TODO: move the flag to a mask register? (Roma) + test(reg_ker, 1); + jz(skip_ker_zeroing, T_NEAR); + + Label zeroing_loop; + vpxord(zmm0, zmm0, zmm0); + and_(reg_ker, ~1); // temporarily clear the zeroing flag + mov(reg_tmp, jcp.kw); + L(zeroing_loop); { + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) + vmovups(ker_addr(ic1), zmm0); + add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.typesize_out); + sub(reg_tmp, 1); + jnz(zeroing_loop, T_NEAR); + } + // restore the zeroing flag (it will be cleared after the end of + // emit_kh_kw_loop, but we may need it until then) + or_(reg_ker, 1); + jmp(kh_loop_end, T_NEAR); + + L(skip_ker_zeroing); + add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.kw + * jcp.typesize_out); + jmp(kh_loop_end, T_NEAR); + + L(kh_loop_work); + + mul_by_const(reg_ihs, reg_tmp, + jcp.tr_iw * jcp.ic_block * jcp.typesize_in); + mul_by_const(reg_ohs, reg_tmp, + pad_ow * jcp.oc_block * jcp.typesize_in); + + add(reg_inp, reg_ihs); + add(reg_out, reg_ohs); + + Label kw_loop; + xor_(reg_kw, reg_kw); + L(kw_loop); { + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + auto zmm = zmm_ker(ic1); + vpxord(zmm, zmm, zmm); + mic_prefetcht1(ker_addr(ic1)); + } + + mov(reg_out_save, reg_out); + mov(reg_inp_save, reg_inp); + lea(reg_inp, ptr[reg_inp + reg_kw * jcp.typesize_in]); + +#if 0 + // XXX: Generate code with special prefetches when switching + // blocks or at the end of the last block. Disabled to reduce + // code size and because there's no performance benefit (Roma) + Label regular_h_loop, end_h_loop; + cmp(reg_kw, jcp.kw - 1); + jne(regular_h_loop, T_NEAR); + cmp(reg_kh, jcp.kh - 1); + jne(regular_h_loop, T_NEAR); + + emit_h_loop(oh_block_size, is_last_block, true); + jmp(end_h_loop, T_NEAR); + + L(regular_h_loop); + emit_h_loop(oh_block_size, is_last_block, false); + + L(end_h_loop); +#else + emit_h_loop(oh_block_size, is_last_block, false); +#endif + + mov(reg_out, reg_out_save); + mov(reg_inp, reg_inp_save); + + Label do_store; + // The reg_ker ptr has highest bit set if the output needs to + // be zeroed. Those who have byte-aligned their data will + // suffer the consiquences :( + mov(reg_tmp, reg_ker); + and_(reg_ker, ~1); + test(reg_tmp, 1); + jnz(do_store, T_NEAR); + + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + auto zmm = zmm_ker(ic1); + if (jcp.ver == ver_4fma) { + vaddps(zmm, ker_addr(ic1)); + } else { + assert(!"unknown convolution version"); + } + } + + L(do_store); + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + auto zmm = zmm_ker(ic1); + vmovups(ker_addr(ic1), zmm); + } + + mov(reg_ker, reg_tmp); + add(reg_ker, jcp.ic_block * jcp.oc_block * jcp.typesize_out); + add(reg_kw, 1); + cmp(reg_kw, jcp.kw); + jl(kw_loop); + } + + sub(reg_inp, reg_ihs); + sub(reg_out, reg_ohs); + + + L(kh_loop_end); + add(reg_kh, 1); + cmp(reg_kh, jcp.kh); + jl(kh_loop); + } + }; + + mov(reg_inp, ptr[param + GET_OFF(src)]); + mov(reg_out, ptr[param + GET_OFF(dst)]); + mov(reg_ker, ptr[param + GET_OFF(filt)]); + mov(reg_inp_pf_l2, ptr[param + GET_OFF(src_prf)]); + mov(reg_out_pf_l2, ptr[param + GET_OFF(dst_prf)]); + mov(reg_tmp, ptr[param + GET_OFF(channel)]); + or_(reg_ker, reg_tmp); + + bool single_kh_kw_loop = (h_block_size == jcp.oh); + + size_t inp_row_step = jcp.tr_iw * jcp.ic_block * jcp.typesize_in; + size_t first_inp_block_step = inp_row_step * (h_block_size - jcp.t_pad); + size_t inp_block_step = inp_row_step * h_block_size; + size_t out_block_step = pad_ow * jcp.oc_block * jcp.typesize_in + * h_block_size; + + if (!single_kh_kw_loop) { + // Save the original prefetch pointers from the OpenMP driver + vmovq(reg_inp_pf_save, reg_inp_pf_l2); + vmovq(reg_out_pf_save, reg_out_pf_l2); + mov(reg_inp_pf_l2, reg_inp); + add(reg_inp_pf_l2, first_inp_block_step); + mov(reg_out_pf_l2, reg_out); + add(reg_out_pf_l2, out_block_step); + } + emit_kh_kw_loop(true, single_kh_kw_loop, h_block_size); + + if (!single_kh_kw_loop) { + size_t ker_reset_offset + = jcp.oc_block * jcp.ic_block * jcp.typesize_out * jcp.kw * jcp.kh; + sub(reg_ker, ker_reset_offset); + and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates + + add(reg_inp, first_inp_block_step); + add(reg_out, out_block_step); + mov(reg_inp_pf_l2, reg_inp); + add(reg_inp_pf_l2, inp_block_step); + mov(reg_out_pf_l2, reg_out); + add(reg_out_pf_l2, out_block_step); + + int num_innermost_iters = div_up(jcp.oh, h_block_size) - 2; + if (num_innermost_iters > 0) { + Label h_block_loop; + + mov(reg_tmp_w, num_innermost_iters); + kmovw(reg_h_block, reg_tmp_w); + L(h_block_loop); { + emit_kh_kw_loop(false, false, h_block_size); + sub(reg_ker, ker_reset_offset); + add(reg_inp, inp_row_step * h_block_size); + add(reg_out, out_block_step); + mov(reg_inp_pf_l2, reg_inp); + add(reg_inp_pf_l2, inp_block_step); + mov(reg_out_pf_l2, reg_out); + add(reg_out_pf_l2, out_block_step); + kmovw(reg_tmp_w, reg_h_block); + sub(reg_tmp_w, 1); + kmovw(reg_h_block, reg_tmp_w); + jnz(h_block_loop); + } + } + + // Restore the original prefetch pointers that came from the OpenMP + // driver + vmovq(reg_inp_pf_l2, reg_inp_pf_save); + vmovq(reg_out_pf_l2, reg_out_pf_save); + emit_kh_kw_loop(false, true, h_block_size); + } + + return true; +} + +bool jit_avx512_common_conv_bwd_weights_kernel_f32 + ::flat_4ops_compute() { + const auto &j = jcp; + const bool ok = j.ver == ver_4fma && j.is_1stconv + && everyone_is(0, j.dilate_h, j.dilate_w); + if (!ok) return false; + + Reg64 reg_ptr_tr_src = r8; + Reg64 reg_ptr_dst = r9; + Reg64 reg_ptr_wei = r10; + Reg64 reg_ptr_bia = r11; + + Reg64 reg_kh_step = rax; + Reg64 reg_oh = abi_not_param1; + Reg64 reg_kh = rdx; + + Reg32 reg_flag_save = ebx; + Reg32 reg_flag = esi; + + Zmm vbia(31); + + auto zmm_wei = [&](int kh, int kw) { + return Zmm(8 + kh * j.kw + kw); + }; + auto zmm_dst = [&](int ow) { + return Zmm(ow % 8); + }; + + auto addr_tr_src = [&](int kh, int iw) { + return ptr[reg_ptr_tr_src + + (kh * j.stride_w * j.tr_ld + iw) * jcp.typesize_in]; + }; + auto addr_dst = [&](int ow) { + return ptr[reg_ptr_dst + ow * jcp.oc_block * jcp.typesize_in]; + }; + auto addr_wei = [&](int kh, int kw) { + return ptr[reg_ptr_wei + (kh * j.kw + kw) * j.oc_block + * jcp.typesize_out]; + }; + + auto emit_fma_block = [&](int kh_step) { + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) { + auto vwei = zmm_wei(kh, kw); + vpxord(vwei, vwei, vwei); + } + } + + for (int ow = 0; ow < j.ow; ow += 4) { + for (int _ow = ow; _ow < ow + 4; ++_ow) { + auto vdst = zmm_dst(_ow); + if (_ow < j.ow) + vmovups(vdst, addr_dst(_ow)); + else + vpxord(vdst, vdst, vdst); + } + + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) { + const int iw = ow + (kw % j.stride_w) * j.tr_ld + + (kw / j.stride_w); + v4fmaddps(zmm_wei(kh, kw), zmm_dst(ow), + addr_tr_src(kh, iw)); + if (1 && kh == 0 && kw < 4) { + prefetcht1(ptr[reg_ptr_dst + + (j.ow + ow + kw) * jcp.oc_block + * jcp.typesize_in]); + } + if (j.with_bias && kh_step == 1) { /* [bwd_w:b:r1] */ + const int off = kw + 4 - j.kw; + if (off >= 0 && ow + off < j.ow) + vaddps(vbia, vbia, zmm_dst(ow + off)); + } + } + } + } + + Label l_store; + test(reg_flag, FLAG_MB_FIRST); + jnz(l_store, T_NEAR); + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) + vaddps(zmm_wei(kh, kw), addr_wei(kh, kw)); + } + L(l_store); + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) + vmovups(addr_wei(kh, kw), zmm_wei(kh, kw)); + } + }; + + auto emit_kh_loop = [&]() { + const int kh_step_rem = j.kh % j.kh_step; + xor_(reg_kh, reg_kh); + mov(reg_kh_step, j.kh_step); + + Label l_kh_loop; + L(l_kh_loop); { + Label l_done; + + if (kh_step_rem != 0) { + Label l_keep_kh_step; + cmp(reg_kh, j.kh - j.kh_step); + jle(l_keep_kh_step, T_NEAR); + + mov(reg_kh_step, kh_step_rem); + emit_fma_block(kh_step_rem); + jmp(l_done, T_NEAR); + + L(l_keep_kh_step); + } + + emit_fma_block(j.kh_step); + + L(l_done); + + add(reg_ptr_tr_src, j.kh_step * j.stride_w * j.tr_ld + * jcp.typesize_in); + add(reg_ptr_wei, j.kh_step * j.kw * j.oc_block * jcp.typesize_out); + add(reg_kh, j.kh_step); + + cmp(reg_kh, j.kh); + jl(l_kh_loop, T_NEAR); + } + + const int kh_steps = rnd_up(j.kh, j.kh_step); + sub(reg_ptr_tr_src, kh_steps * j.stride_w * j.tr_ld * jcp.typesize_in); + sub(reg_ptr_wei, kh_steps * j.kw * j.oc_block * jcp.typesize_out); + }; + + auto emit_oh_loop = [&]() { + mov(reg_oh, j.oh); + + Label l_oh_loop; + L(l_oh_loop); { + Label l_restore_mb_flag, l_jump; + + cmp(reg_oh, j.oh); + je(l_restore_mb_flag, T_NEAR); + + and_(reg_flag, ~FLAG_MB_FIRST); + jmp(l_jump, T_NEAR); + + L(l_restore_mb_flag); + mov(reg_flag, reg_flag_save); + + L(l_jump); + + emit_kh_loop(); + + add(reg_ptr_tr_src, j.stride_h * j.stride_w * j.tr_ld + * jcp.typesize_in); + add(reg_ptr_dst, j.ow * j.oc_block * jcp.typesize_in); + + dec(reg_oh); + jnz(l_oh_loop, T_NEAR); + } + }; + + auto emit_bia_store = [&]() { + if (!j.with_bias) return; + + Label l_bia_store, l_bia_skip; + test(reg_flag, FLAG_IC_FIRST); + jz(l_bia_skip); + + test(reg_flag, FLAG_MB_FIRST); + jnz(l_bia_store, T_NEAR); + vaddps(vbia, ptr[reg_ptr_bia]); + L(l_bia_store); + vmovups(ptr[reg_ptr_bia], vbia); + L(l_bia_skip); + }; + + mov(reg_ptr_tr_src, ptr[param + GET_OFF(src)]); + mov(reg_ptr_dst, ptr[param + GET_OFF(dst)]); + mov(reg_ptr_wei, ptr[param + GET_OFF(filt)]); + mov(reg_ptr_bia, ptr[param + GET_OFF(bias)]); + mov(reg_flag_save, ptr[param + GET_OFF(flags)]); + + vpxord(vbia, vbia, vbia); + emit_oh_loop(); + emit_bia_store(); + + return true; +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_loop() +{ + if (flat_4ops_compute()) + return; + if (compute_full_spat_loop()) + return; + + maybe_zero_kernel(); + + if (jcp.ndims == 5) compute_d_loop_common(); + else compute_oh_loop_common(); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::generate() +{ + preamble(); + + mov(reg_input, ptr[param + GET_OFF(src)]); + mov(reg_output, ptr[param + GET_OFF(dst)]); + mov(reg_kernel, ptr[param + GET_OFF(filt)]); + + compute_loop(); + + postamble(); +} + +status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &diff_weights_md, + memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md) { + if (!mayiuse(avx512_common)) + return status::unimplemented; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper diff_weights_d(&diff_weights_md); + const memory_desc_wrapper diff_bias_d(&diff_bias_md); + const memory_desc_wrapper diff_dst_d(&diff_dst_md); + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + + jcp = zero(); + + jcp.simd_w = cpu_isa_traits::vlen / sizeof(float); + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2]; + jcp.kw = diff_weights_d.dims()[with_groups + ndims-1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1); + bool ok = true + // general condition to simplify dilations + && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1) + && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1) + && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1) + // special condition to simplify dilations in compute_oh_loop_common + && IMPLICATION(jcp.dilate_h != 0, kh_range <= jcp.ih); + if (!ok) + return status::unimplemented; + + jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h + + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1)); + jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1)); + + /* XXX: currently, does not support dilation_d > 0 */ + if (ndims == 5) + if (jcp.dilate_d > 0) + return status::unimplemented; + + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + jcp.aligned_threads = 0; + + /* check for the 1st convolution */ + jcp.is_1stconv = is_1stconv(jcp); + + jcp.oc_block = jcp.simd_w; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && src_d.data_type() == data_type::f32; + + if (ok_to_pad_channels) + jcp.oc = rnd_up(jcp.oc, jcp.simd_w); + + if (jcp.oc % jcp.oc_block) + return status::unimplemented; + + auto dst_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = with_groups + ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o) + : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o); + + if (diff_dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag)); + jcp.dst_tag = dst_tag; + } else { + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dst_tag); + } + if (jcp.dst_tag != dst_tag) + return status::unimplemented; + + /* conditions on bias memory */ + jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; + if (jcp.with_bias) { + if (diff_bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_bias_md, x)); + } + + jcp.nb_oc = jcp.oc / jcp.oc_block; + + /* kernel applicability check wrt boundaries + * the conditions are quite general across the kernels we have, + * but ideally the check should belong to a specific kernel... */ + const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2; + const bool boundaries_ok = true + && jcp.t_pad <= max_pad + && jcp.b_pad <= max_pad + && IMPLICATION(jcp.f_pad > 0, jcp.kd < jcp.id + jcp.f_pad) + && jcp.f_pad < jcp.kd; + if (!boundaries_ok) + return status::unimplemented; + + /* yet another common check */ + if (jcp.kw > 14) + return status::unimplemented; + + /* setting register strategy */ + for (int ur_w = nstl::min(max_ur_w, jcp.ow); ur_w > 0; --ur_w) { + if (jcp.ow % ur_w == 0) { jcp.ur_w = ur_w; break; } + } + + if (jcp.is_1stconv) { + auto src_tag = pick(ndims - 3, ncw, nchw, ncdhw); + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, src_tag)); + jcp.src_tag = src_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(src_tag); + if (jcp.ic == 1 && jcp.src_tag != src_tag) + jcp.src_tag = src_d.matches_one_of_tag( + pick(ndims - 3, nwc, nhwc, ndhwc)); + } + if (jcp.src_tag == format_tag::undef) + return status::unimplemented; + + const bool src_ok = true + && utils::everyone_is(data_type::f32, + src_d.data_type(), diff_weights_d.data_type(), + diff_dst_d.data_type()) + && one_of(jcp.ic, 1, 2, 3) + && jcp.ngroups == 1; + if (!src_ok) + return status::unimplemented; + + const int tr_ld = rnd_up(div_up(jcp.iw + jcp.l_pad + jcp.r_pad, + jcp.stride_w), 16); + const int kh_step = nstl::max((28 - jcp.with_bias) / jcp.kw, 1); + const int kh_step_rem = jcp.kh % kh_step; + + const auto wei_4fma_tag = with_groups + ? pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o) + : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o); + + auto current_wei_tag = format_tag::undef; + if (diff_weights_d.format_kind() != format_kind::any) + current_wei_tag = diff_weights_d.matches_one_of_tag(wei_4fma_tag); + + const bool use_4fma = true + && one_of(ndims, 3, 4) + && mayiuse(avx512_mic_4ops) + && mkldnn_thr_syncable() + && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) + && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad) + && jcp.kw <= 28 - jcp.with_bias + && jcp.stride_w == 4 + && tr_ld / jcp.simd_w <= 4 /* [bwd_w:tr_src:r1] */ + && IMPLICATION(jcp.with_bias, kh_step_rem == 1) /* [bwd_w:b:r1] */ + && IMPLICATION(diff_weights_d.format_kind() != format_kind::any, + current_wei_tag == wei_4fma_tag); + + if (use_4fma) { + jcp.ver = ver_4fma; + jcp.kh_step = kh_step; + jcp.tr_ld = tr_ld; + jcp.ic_block = 1; + if (diff_weights_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_weights_md, wei_4fma_tag)); + jcp.wei_tag = wei_4fma_tag; + } else { + jcp.ver = ver_fma; + jcp.ic_block = jcp.ic; + + wei_tag = with_groups + ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o) + : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o); + + if (diff_weights_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); + jcp.wei_tag = wei_tag; + } else { + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + } + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + } + + jcp.nb_ic = jcp.ic / jcp.ic_block; + } else { + auto src_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, src_tag)); + jcp.src_tag = src_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(src_tag); + } + if (jcp.src_tag != src_tag) + return status::unimplemented; + + if (diff_weights_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); + jcp.wei_tag = wei_tag; + } else { + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + } + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + jcp.ic_block = jcp.simd_w; + if (ok_to_pad_channels) + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + jcp.nb_ic = jcp.ic / jcp.ic_block; + if ((mayiuse(avx512_mic) || mayiuse(avx512_core)) + && utils::everyone_is(data_type::f32, + src_d.data_type(), diff_weights_d.data_type(), + diff_dst_d.data_type())) { + jcp.ver = ver_fma; + if (one_of(ndims, 3, 4) && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1 && + everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) && + mkldnn_thr_syncable()) { + jcp.ver = ver_4fma; + } + } else { + return status::unimplemented; + } + if (jcp.ver == ver_4fma) { + jcp.ur_w = jcp.ow; + // XXX, BUGBUGBUG, but not a FIXME: this assumes that it's OK to + // cross the right boundary. The only requirement is not to have + // NaNs there because another multiplicand is always guaranteed to + // be zero. This also may require the top-level driver to allocate + // four extra guarding elements at the very end of the buffer. + // I'm not proud of this hack, but it improves performance by + // about 5-10% depending on the dimensions (Roma) + + const int tr_round = 4; + + jcp.tr_iw = rnd_up(jcp.iw + jcp.kw - 1, tr_round); + jcp.tr_src_num_guard_elems = tr_round; // upper bound + } + } + + if (utils::one_of(jcp.ver, ver_4fma, ver_fma)) { + jcp.typesize_in = sizeof(float); + jcp.typesize_out = sizeof(float); + } else + return status::unimplemented; + + bool args_ok = true + && jcp.ic % jcp.ic_block == 0 + && jcp.oc % jcp.oc_block == 0 + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) return status::unimplemented; + + { // balancing + int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; + balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b); + jcp.nthr = nthr; + jcp.nthr_mb = nthr_mb; + jcp.nthr_g = nthr_g; + jcp.nthr_oc_b = nthr_oc_b; + jcp.nthr_ic_b = nthr_ic_b; + } + + return status::success; +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.ver == ver_4fma) { + if (jcp.is_1stconv) { + const size_t tr_src_size = + jcp.nthr / jcp.nthr_oc_b * jcp.ih * jcp.stride_w * jcp.tr_ld; + scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size); + } else { + // XXX: See the comment about tr_iw and guarding elements in + // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf() + const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic; + const size_t min_tr_src_size_per_thr + = jcp.ih * jcp.ic_block * jcp.tr_iw; + const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr + + jcp.tr_src_num_guard_elems; + scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size); + } + + /* prepare synchronization contexts */ + if (jcp.nthr_oc_b > 1) { + const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b; + scratchpad.book(key_conv_tr_src_bctx, + sizeof(simple_barrier::ctx_t) * tr_src_bctx_size); + } + } + + if (jcp.nthr_mb > 1) { + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic + * jcp.kh * jcp.kw * jcp.kd; + const int bia_size = jcp.ngroups * jcp.oc; + const size_t wei_bia_reduction_size = wei_size + bia_size; + + scratchpad.book(key_conv_wei_bia_reduction, + jcp.typesize_out * wei_bia_reduction_size * (jcp.nthr_mb - 1)); + scratchpad.book(key_conv_wei_bia_reduction_bctx, + sizeof(simple_barrier::ctx_t)); + } + + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::balance( + const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_, + int &nthr_oc_b_, int &nthr_ic_b_) +{ + nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1; + + const int max_threads = mkldnn_get_max_threads(); + + if (max_threads < j.ngroups) { + /* simplification... fortunately it doesn't hurt much */ + return; + } + + if (!mkldnn_thr_syncable() && j.ver == ver_4fma) { + // should not happen -- the driver is not ready + // for TBB-like non-synchronous threading yet + return; + } + + if (j.ver == ver_4fma && j.is_1stconv) { + nthr_g_ = 1; + nthr_oc_b_ = 1; + nthr_ic_b_ = nstl::min(j.nb_ic, max_threads); + nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb); + nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_; + return; + } + + nthr_g_ = j.ngroups; + const int nthr = max_threads / nthr_g_; + + auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { + /* calculate per thread memory cost (read/write). high level optimizer + * tries to minimize memory consumption. few notes: + * (n1) unclear why, but that essentially helps first convolution... + * (n2) assuming the reduction over minibatch is always there: + * - instead of 8 it should be 5 here (write ~= 2 read): + * kernel: temporal workspace 1 write + * reduction: 1 read from workspace and 1 write to the diff_wei + * - but experiments showed 8 works better than 5 or 6... */ + + const int src_coef = j.ver == ver_4fma ? 4 : 1; + const int dst_coef = 1; + const int wei_coef = 8; + + return 0 + + src_coef + * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id + / j.stride_d / j.stride_h / j.stride_w /* (n1) */ + + dst_coef + * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od + + wei_coef /* (n2) */ + * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b) + * j.kh * j.kw * j.kd * j.ic_block * j.oc_block; + }; + + int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); + + /* step 1: find the best thread distribution with lowest memory cost */ + const int nthr_mb_max = nstl::min(nthr, j.mb * j.od); + for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { + const int nthr_par = nthr / nthr_mb; + const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); + for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { + int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); + + int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + if (mem_cost <= best_mem_cost) { + best_mem_cost = mem_cost; + nthr_mb_ = nthr_mb; + nthr_oc_b_ = nthr_oc_b; + nthr_ic_b_ = nthr_ic_b; + } + } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } + } + + if (!mayiuse(avx512_mic)) { + auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { + return 1 + * div_up(j.mb, nthr_mb) + * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_oc, nthr_oc_b) + * div_up(j.nb_ic, nthr_ic_b); + }; + + /* step 2: search for a thread distribution with lower compute cost. + * the constrains: + * - memory cost cannot exceed 110% of the best found in the step 1 + * - unless compute cost is 133% lower than the current best case + * note: both constants were found empirically */ + int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); + for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { + const int nthr_par = nthr / nthr_mb; + const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); + for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { + int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); + int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + + const bool opt1 = comp_cost <= best_comp_cost + && mem_cost < 1.1 * best_mem_cost; + const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost; + + if (opt1 || opt2) { + best_comp_cost = comp_cost; + nthr_mb_ = nthr_mb; + nthr_oc_b_ = nthr_oc_b; + nthr_ic_b_ = nthr_ic_b; + } + } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } + } + } + + if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads) + nthr_mb_ = nstl::min(j.mb * j.od, max_threads); + nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_; + + assert(nthr_ <= max_threads); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1)); +} + +template struct _jit_avx512_common_conv_fwd_kernel; +template struct _jit_avx512_common_conv_fwd_kernel; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp new file mode 100644 index 000000000000..f76770797ae8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp @@ -0,0 +1,423 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP +#define JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct _jit_avx512_common_conv_fwd_kernel : public jit_generator { + + _jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + + generate(); + jit_ker_ = (void (*)(jit_conv_call_s *))getCode(); + } + + ~_jit_avx512_common_conv_fwd_kernel() { + delete eltwise_injector_; + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel) + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker_)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + enum { + typesize = sizeof(float), + ker_reg_base_idx = 28, + }; + + reg64_t param = abi_param1; + reg64_t reg_inp = r8; + reg64_t reg_ker = r9; + reg64_t reg_out = r10; + + reg64_t reg_inp_prf = r11; + reg64_t reg_ker_prf = r12; + reg64_t reg_out_prf = r13; + reg64_t reg_owb = r12; + + reg64_t aux_reg_inp = r14; + reg64_t aux_reg_ker = r15; + + reg64_t aux_reg_inp_prf = rsi; + reg64_t aux_reg_ker_prf = rdx; + + reg64_t reg_channel = rsi; + reg64_t reg_bias = rdx; + + reg64_t aux_reg_ker_d = r9; + reg64_t aux_reg_inp_d = rbx; + reg64_t aux_reg_inp_d_prf = r13; + reg64_t aux_reg_ker_d_prf = abi_not_param1; + reg64_t reg_ki = r10; + + reg64_t reg_kj = rax; + reg64_t reg_relu_ns = rax; + reg64_t reg_oi = rbx; + reg64_t reg_kh = abi_not_param1; + + reg64_t reg_tmp = rbp; + + reg64_t reg_ic_loop = rdx; + reg64_t reg_inp_loop = rsi; + + reg64_t reg_init_flag = r13; + reg64_t reg_bias_ptr = param; + + reg64_t aux_reg_ic = r12; + reg64_t reg_binp = rax; + reg64_t reg_bout = r11; + reg64_t aux1_reg_inp = rbx; + reg64_t aux_reg_out = abi_not_param1; + + reg64_t reg_long_offt = r11; + reg64_t reg_out_long_offt = r14; + + inline Vmm vmm_ker(int i_ic) { + assert(i_ic < 4); + return Vmm(ker_reg_base_idx + i_ic); + } + + inline Vmm vmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < ker_reg_base_idx); + return Vmm(idx); + } + + inline Vmm vmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Vmm(idx); + } + + Xbyak::Reg64 imm_addr64 = r15; + Vmm vmm_wei = Vmm(31); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + inline void prepare_output(int ur_w); + inline void store_output(int ur_w); + inline void compute_loop_fma(int ur_w, int pad_l, int pad_r); + inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r); + inline void compute_loop_4fma(int ur_w, int pad_l, int pad_r); + inline void compute_loop_4fma_1st(int ur_w, int pad_l, int pad_r); + inline void compute_loop(int ur_w, int pad_l, int pad_r); + + void generate(); + + inline size_t get_output_offset(int oi, int n_oc_block) { + return (size_t)jcp.typesize_out * ((size_t)n_oc_block * jcp.oh + * jcp.ow * jcp.od + oi) * jcp.oc_block; + } + + inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) { + size_t iw_str = !jcp.is_1stconv ? jcp.ic_block : 1; + size_t ic_str = !jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id; + return (size_t)jcp.typesize_in * ((size_t)(ki * (jcp.dilate_w + 1) + + oi * jcp.stride_w - pad_l) * iw_str + ic * ic_str); + } + + inline int get_kernel_offset(int ki,int ic,int n_oc_block,int ker_number) { + return jcp.typesize_in * jcp.oc_block + * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd + + (ic + ker_number) + ki * jcp.ic_block); + } + + inline int get_ow_start(int ki, int pad_l) { + return nstl::max(0, + utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); + } + + inline int get_ow_end(int ur_w, int ki, int pad_r) { + return ur_w - nstl::max(0, utils::div_up(pad_r + - (jcp.kw - 1 - ki) + * (jcp.dilate_w + 1), + jcp.stride_w)); + } +}; + +struct jit_avx512_common_conv_fwd_kernel { + + jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) : + jit_ker(nullptr), + zmm_kernel_(nullptr), + xmm_kernel_(nullptr) { + int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.oc_block; + switch (ch_block) { + case 16: + zmm_kernel_ = + new _jit_avx512_common_conv_fwd_kernel( + ajcp, attr); + jit_ker = zmm_kernel_->jit_ker_; + return; + case 4: + xmm_kernel_ = + new _jit_avx512_common_conv_fwd_kernel( + ajcp, attr); + jit_ker = xmm_kernel_->jit_ker_; + return; + default: + assert(!"invalid channel blocking"); + } + } + + ~jit_avx512_common_conv_fwd_kernel() { + delete xmm_kernel_; + delete zmm_kernel_; + } + + enum { + typesize = sizeof(float) + }; + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + memory_desc_t &src_pd, + memory_desc_t &weights_pd, + memory_desc_t &dst_pd, + memory_desc_t &bias_pd, + const primitive_attr_t &attr, + int nthreads); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + void(*jit_ker)(jit_conv_call_s *); + _jit_avx512_common_conv_fwd_kernel *zmm_kernel_; + _jit_avx512_common_conv_fwd_kernel *xmm_kernel_; +}; + +struct jit_avx512_common_conv_bwd_data_kernel_f32: public jit_generator { + + jit_avx512_common_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) + { + generate(); + jit_ker = (void (*)(jit_conv_call_s *))getCode(); + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_data_kernel_f32) + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + enum { + typesize = sizeof(float), + ker_reg_base_idx = 28, + }; + + reg64_t param = abi_param1; + reg64_t reg_dst = r8; + reg64_t reg_ker = r9; + reg64_t reg_src = r10; + + reg64_t reg_dst_prf = r11; + reg64_t reg_ker_prf = r12; + reg64_t reg_src_prf = r13; + + reg64_t aux_reg_dst = r14; + reg64_t aux_reg_ker = r15; + + reg64_t aux_reg_dst_prf = rsi; + reg64_t aux_reg_ker_prf = rdx; + + reg64_t aux_reg_dst_d_prf = r13; + reg64_t aux_reg_dst_d = rbx; + reg64_t aux_reg_ker_d_prf = abi_not_param1; + reg64_t aux_reg_ker_d = r9; + reg64_t reg_ki = r10; + + reg64_t reg_kj = rax; + reg64_t reg_oi = rbx; + reg64_t reg_kh = abi_not_param1; + + reg64_t reg_channel = rsi; + + reg64_t reg_tmp = rbp; + reg64_t reg_long_offt = r14; + + inline Xbyak::Zmm zmm_ker(int i_ic) { + assert(i_ic < 4); + return Xbyak::Zmm(ker_reg_base_idx + i_ic); + } + inline Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Xbyak::Zmm(idx); + } + inline Xbyak::Zmm zmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < ker_reg_base_idx); + return Xbyak::Zmm(idx); + } + + Xbyak::Zmm zmm_wei = Xbyak::Zmm(31); + + inline void prepare_output(int ur_w); + inline void store_output(int ur_w); + inline void compute_loop_4fma(int ur_w, int l_overflow, int r_overflow); + inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow); + inline void compute_loop_fma_core(int ur_w, int l_overflow, int r_overflow); + inline void compute_loop(int ur_w, int l_overflow, int r_overflow); + void generate(); + + inline int get_iw_start(int ki, int l_overflow) + { + int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w + + l_overflow * jcp.stride_w + - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return res; + } + + inline int get_iw_end(int ur_w, int ki, int r_overflow) + { + if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) + ur_w += nstl::min(0, jcp.r_pad); // remove negative padding + int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w + + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return ur_w - res; + } +}; + +struct jit_avx512_common_conv_bwd_weights_kernel_f32 : public jit_generator { + + jit_avx512_common_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) + : jcp(ajcp) + { + generate(); + jit_ker = (void (*)(jit_conv_call_s *))getCode(); + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_weights_kernel_f32) + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + memory_desc_t &src_md, + memory_desc_t &diff_weights_md, + memory_desc_t &diff_bias_md, + memory_desc_t &diff_dst_md); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + enum {typesize = sizeof(float)}; + static const int max_ur_w; + + reg64_t param = abi_param1; + reg64_t reg_input = rax; + reg64_t reg_kernel = rdx; + reg64_t reg_output = rsi; + reg64_t b_ic = abi_not_param1; + reg64_t kj = r8; + reg64_t reg_kh = r9; + reg64_t reg_ur_w_trips = r10; + reg64_t reg_oj = r15; + reg64_t reg_ih_count = rbx; + reg64_t reg_tmp = r14; + reg64_t reg_long_offt = r14; + + reg64_t ki = r11; + reg64_t reg_kd_count = r12; + reg64_t reg_oi = r12; + reg64_t reg_d_index = r13; + reg64_t reg_input_d = r15; + reg64_t reg_output_d = rbx; + reg64_t aux_reg_input = r12; + reg64_t aux_reg_kernel = r13; + reg64_t reg_bias = rbx; + + inline void bias_kernel(); + inline void maybe_zero_kernel(); + inline void compute_oh_step_unroll_ow_icblock(int ic_block_step, + int max_ur_w); + inline void od_step_comeback_pointers(); + inline void oh_step_comeback_pointers(); + inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); + inline void compute_ic_block_step(int ur_w, + int pad_l, int pad_r, int ic_block_step, + int input_offset, int kernel_offset, int output_offset, + bool input_wraparound = false); + inline void compute_ic_block_step_fma(int ur_w, + int pad_l, int pad_r, int ic_block_step, + int input_offset, int kernel_offset, int output_offset, + bool input_wraparound); + inline void compute_ic_block_step_4fma(int ur_w, + int pad_l, int pad_r, int ic_block_step, + int input_offset, int kernel_offset, int output_offset, + bool input_wraparound); + inline void compute_oh_step_common(int ic_block_step, int max_ur_w); + inline void compute_oh_step_disp(); + inline void compute_oh_loop_common(); + inline void compute_d_loop_common(); + + inline bool compute_full_spat_loop(); + inline bool flat_4ops_compute(); + + inline void compute_loop(); + + void generate(); + + static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb, + int &nthr_g, int &nthr_oc_b, int &nthr_ic_b); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp new file mode 100644 index 000000000000..1bdcd0d6a8fb --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp @@ -0,0 +1,1163 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include + +#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" + +#ifndef KERNEL_SIZE_THRESHOLD +#define KERNEL_SIZE_THRESHOLD 16 +#endif + +#define MIN_REQUIRED_DIMN_REG_BLOCK 14 + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { + +using namespace mkldnn::impl::utils; + +unsigned int L1_cache_size = get_cache_size(1, true); +unsigned int L2_cache_size = get_cache_size(2, true); +unsigned int LLC_data_size = get_cache_size(3, false); + +// the test funtion takes jcp, the candidate and the current best. +// it returns true if the new candidate is better +int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number, + int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int)) +{ + int best_divisor = default_best; + auto test_num + = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) { + if (test(jcp, num, best_divisor)) { + best_divisor = num; + } + }; + + for (int divisor = 1; divisor <= ::sqrt(number); divisor++) { + if (number % divisor == 0) { + test_num(jcp, divisor); + test_num(jcp, number / divisor); + } + } + + return best_divisor; +} + +namespace { +bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) { + if (jcp.ver == ver_4fma) + return jcp.mb >= 32; + else + return jcp.mb >= 16; +} +} + +/* assumes 512 bits registers */ +/* TODO: add support for strides */ +/* TODO: handle the prefetch distance automatically */ +typedef enum cache_t_ { L1, L2, L3 } cache_t; + +template +struct prefetcher_t { + prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr, + cache_t cache_type, size_t block_size, /* in number of elements*/ + int nb_instructions_in_block, int fma_ipc) + : cg_(generator) + , reg_base_addr_(reg_base_addr) + , cache_type_(cache_type) + , cache_block_size_(block_size) + { + nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t)); + prefetch_spread_ + = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_); + prefetch_blk_ + = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block); + + /* assumption: when fetch in Li, data is already in L(i+1) */ + int cache_latency; + switch (cache_type_) { + case L1: cache_latency = 14; break; + case L2: + case L3: + default: cache_latency = 250; break; + } + + prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_); + } + + void prefetch(int instruction_number) + { + if (instruction_number % prefetch_spread_ == 0) { + for (int i = 0; (i < prefetch_blk_) + && (prefetches_issued_ < nb_cache_lines_to_prefetch_); + i++, prefetches_issued_++) { + prefetch_inst_(cg_->EVEX_compress_addr( + reg_base_addr_, (cache_block_size_ * prefetch_distance_) + * sizeof(data_t) + + (prefetches_issued_ * 64))); + } + } + } + +private: + void prefetch_inst_(const Xbyak::Address &addr) + { + switch (cache_type_) { + case L1: cg_->prefetcht0(addr); break; + case L2: cg_->prefetcht1(addr); break; + case L3: cg_->prefetcht2(addr); break; + default: + break; // TODO: raise an exception or put an assert + } + } + + jit_generator *cg_; + Xbyak::Reg64 reg_base_addr_; + cache_t cache_type_; + int cache_block_size_ = 0; + int nb_cache_lines_to_prefetch_ = 0; + int prefetches_issued_ = 0; + int prefetch_spread_ = 0; + int prefetch_blk_ = 0; + int prefetch_distance_ = 0; +}; + +// utilities to support kernel parameter selection +bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimN_reg_block * dimM_simd_block + + dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block + + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} + +bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimK_block * dimK_reg_block * dimM_simd_block + + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} + +bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block, + int dimK_block, int dimK_reg_block, int dimM_block, int dimM_simd_block, + float C) +{ + float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block * dimM_simd_block + + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block + + nb_dimN_reg_block * dimK_nb_block * dimK_block + * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L2_cache_size; + return (lhs < rhs); +} +} + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +void _jit_avx512_common_conv_winograd_data_kernel_f32::gemm_loop_generate( + bool is_beta_zero) +{ + // const int dimK_simd_block = jcp.dimK_reg_block; + + // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++) + // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++) + // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block; + // dimK_reg_block++) + // for (int tile =0; tile < jcp.dimN_reg_block; tile++) + // C[dimM_block][tile] += + // A[dimM_block][dimK_block][dimK_reg_block] * + // broadcast(B[dimK_block][tile][dimK_reg_block]); + // 1) We do register blocking on A[dimM_block][dimK_block][dimK_reg_block], + // so we load it before the loop on tile + // 2) the loop on tile must be fully unrolled. Don't know about the one on + // dimK_reg_block. I think it should be + + auto inner_loops = [=]() { + Label dimM_block_loop, dimK_block_loop; + const int inc_dimK_reg_block = jcp.ver == ver_4fma ? 4 : 1; + const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2; + + prefetcher_t L1_pf(this, reg_srcB, L1, + jcp.dimN_reg_block * jcp.dimK_reg_block, + jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block, + fma_ipc); + prefetcher_t L2_pf(this, reg_srcB, L2, + jcp.dimN_reg_block * jcp.dimK_reg_block, + jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block, + fma_ipc); + + if (jcp.dimM_block > 1) { + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + } + { + // First, we zero the accumulators if first nb_ic iteration, + // otherwise we load them + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm(jcp.zmm_start + tile); + if (is_beta_zero) + vpxord(zmm, zmm, zmm); + else + vmovups(zmm, zword[reg_dstC + 64 * tile]); + } + + if (jcp.dimK_block > 1) { + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + } + { + auto load_A = [=](int reg_idx, int offset) { + for (int i = 0; i < inc_dimK_reg_block; i++) + vmovups(Zmm(reg_idx + i), + zword[reg_srcA + 64 * (offset + i)]); + }; + + // Used when doing double buffering + int next = 0; + if (jcp.double_buffering) { + load_A(next, 0); + } + for (int dimK_reg_block = 0; + dimK_reg_block < jcp.dimK_reg_block; + dimK_reg_block += inc_dimK_reg_block) { + int current; + /* Loading the next vector from A */ + current = next; + if (jcp.double_buffering) { + next = (dimK_reg_block + inc_dimK_reg_block) + % (2 * inc_dimK_reg_block); + load_A(next, dimK_reg_block + inc_dimK_reg_block); + } else { + next = 0; + load_A(next, dimK_reg_block); + } + /* Performing the fmas */ + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm(jcp.zmm_start + tile); + if (jcp.ver != ver_avx512_core) + L1_pf.prefetch( + dimK_reg_block * jcp.dimN_reg_block + tile); + if (jcp.ver == ver_4fma) + v4fmaddps(zmm, Zmm(current), + EVEX_compress_addr(reg_srcB, + 64 * tile + dimK_reg_block * 4)); + else + vfmadd231ps(zmm, Zmm(current), + EVEX_compress_addr(reg_srcB, + 64 * tile + dimK_reg_block * 4, + true)); + if (jcp.ver != ver_avx512_core) + L2_pf.prefetch( + dimK_reg_block * jcp.dimN_reg_block + tile); + } + } + + add(reg_srcA, jcp.dimK_reg_block * 64); + add(reg_srcB, jcp.dimN_reg_block * 64); + if (jcp.dimK_block > 1) { + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + } + + + auto store_output = [=](bool output_is_aligned) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm(jcp.zmm_start + tile); + if (output_is_aligned + && jcp.dimK_nb_block == 1 + && (jcp.dimN * jcp.dimM * alpha * alpha + * sizeof(float) > 2 * LLC_data_size)) + vmovntps(zword[reg_dstC + 64 * tile], zmm); + else + vmovups(zword[reg_dstC + 64 * tile], zmm); + } + }; + + Label unaligned_store, end_store; + test(reg_dstC, cpu_isa_traits::vlen - 1); + jnz(unaligned_store, T_NEAR); + store_output(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + store_output(false); + } + L(end_store); + + if (jcp.dimM_block > 1) { + sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64); + add(reg_dstC, jcp.dimN_reg_block * 64); + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + } + }; + + /* Preamble */ + preamble(); + + /* kernel */ + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) +{ + + if (mayiuse(avx512_core)) + return status::unimplemented; + else if (!mayiuse(avx512_common)) + return status::unimplemented; + else if (mayiuse(avx512_mic_4ops)) + jcp.ver = ver_4fma; + else + jcp.ver = ver_fma; + + jcp.nthr = mkldnn_get_max_threads(); + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + // Checking conditions not supported by these kernels + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.wei_tag != wei_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!layout_consistency) return status::unimplemented; + + return status::success; +} + + +status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) { + + auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_reg_block, int current_best) { + return (dimN_reg_block >= MIN_REQUIRED_DIMN_REG_BLOCK) + && (dimN_reg_block < jcp.nb_reg) + && (dimN_reg_block < current_best); + }; + jcp.dimN_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimN, jcp.dimN, test_cond_dimN_reg_block); + + if (jcp.dimN_reg_block >= jcp.nb_reg) { + auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_reg_block, int current_best) { + return (dimN_reg_block < jcp.nb_reg) + && (dimN_reg_block > current_best); + }; + + jcp.dimN_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimN, 1, test_cond_dimN_reg_block); + } + + //********************* Choosing dimK_block **********************// + auto test_cond1_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block, + 1, jcp.dimM_simd_block, .75f) + && (dimK_block > current_best); + }; + + auto test_cond1_bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, dimK_block, + jcp.dimK_reg_block, 1, jcp.dimM_simd_block, .9f) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block); + // If we are not able to use streams, we fall back to condition [1] + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block); + jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block; + + //********************* Choosing dimM_block **********************// + jcp.dimM_simd_block = 16; + /*XXX: Why C=0.5 here but C=0.75 for dimK_block?*/ + auto test_cond1_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .5f) + && (dimM_block > current_best); + }; + + auto test_cond1_bis_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .3f) + && (dimM_block > current_best); + }; + + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimM_block = get_divisor_satisfying_cond( + jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block); + else + jcp.dimM_block = get_divisor_satisfying_cond(jcp, + jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_bis_dimM_block); + jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block; + + //******************* Choosing dimN_block *******************// + auto test_cond2_dimN_block = []( + jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { + return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block, + jcp.dimM_simd_block, .5f) + && (dimN_block > current_best); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); + jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block); + jcp.sched_policy = WSCHED_DATA_W_S_G_D; + return status::success; +} + +status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK) +{ + jcp.dimK_reg_block = 16; + jcp.dimM_simd_block = 16; + + // TODO: replace double buffering with nuple buffering to maximize register + // usage. + // the choice of the number of buffers will then come after choosing + // dimN_reg_block + jcp.double_buffering = true; + if (jcp.double_buffering) + jcp.zmm_start = 2 * ((jcp.ver == ver_4fma) ? 4 : 2); + else + jcp.zmm_start = 1; + jcp.nb_reg = 32 - jcp.zmm_start; + + jcp.dimN = dimN; + jcp.dimK = dimK; + jcp.dimM = dimM; + + jcp.sched_policy = WSCHED_INVALID; + set_wsched_DATA_W_S_G_D_avx512_common(jcp); + + assert(jcp.sched_policy == WSCHED_DATA_W_S_G_D); + return status::success; +} + +bool jit_avx512_common_conv_winograd_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_relu(0) || is_sum(0); // relu or sum + case 2: return (is_sum(0) && is_relu(1)) || + (is_relu(0) && is_sum(1)); // sum->relu or relu->sum + case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu + default: return false; + } + + return false; +} + +status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) { + status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d); + + if (st != status::success) + return st; + + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise; + jcp.with_sum = p.find(primitive_kind::sum, 0) != -1; + + status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic); + jcp.ic_simd_block = jcp.dimK_reg_block; + jcp.ic_block = jcp.dimK_block; + jcp.nb_ic = jcp.dimK_nb_block; + jcp.oc_simd_block = jcp.dimM_simd_block; + jcp.oc_block = jcp.dimM_block; + jcp.nb_oc = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + jcp.tile_4fma_padding = 0; // only relevant for backward weights + + return res; +} + +status_t jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d); + + if (st != status::success) + return st; + + jcp.itiles = (jcp.iw + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc); + jcp.oc_simd_block = jcp.dimK_reg_block; + jcp.oc_block = jcp.dimK_block; + jcp.nb_oc = jcp.dimK_nb_block; + jcp.ic_simd_block = jcp.dimM_simd_block; + jcp.ic_block = jcp.dimM_block; + jcp.nb_ic = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + jcp.tile_4fma_padding = 0; // only relevant for backward weights + + return res; +} + +void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::transpose_ker_generate() +{ + auto load_B = [=](int reg_idx, int offset) { + for (int i = 0; i < 4; i++) { + vmovups(Zmm(reg_idx + i), zword[reg_origB + (offset + i) * jcp.dimN_reg_block * sizeof(float)]); + } + }; + + preamble(); + int curr = 0; + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + int origB_offset = (j * alpha + i) * jcp.dimK_4fma; + size_t transB_offset = (size_t)(j * alpha + i) * jcp.dimK_nb_block * + jcp.dimN_block * jcp.dimK_block * jcp.dimK_reg_block * + jcp.dimK_4fma * jcp.dimN_reg_block * sizeof(float); + mov(reg_transB_idx, transB_offset); + for (int tb = 0; tb < jcp.dimK_4fma; tb+=4) { + /*double buffering to hide load latencies*/ + int next = (curr + 4) % 8; + if (i == 0 && tb == 0) { + load_B(0, origB_offset); + } + if (tb + 4 < (jcp.dimK_4fma -1)) { + load_B(next, origB_offset + 4); + } else if (i < alpha - 1) { + load_B(next, origB_offset + jcp.dimK_4fma); + } + + vunpcklps(Zmm(8), Zmm(curr), Zmm(curr + 1)); + vunpcklps(Zmm(9), Zmm(curr + 2), Zmm(curr + 3)); + vunpckhps(Zmm(curr), Zmm(curr), Zmm(curr + 1)); + vunpckhps(Zmm(curr + 1), Zmm(curr + 2), Zmm(curr + 3)); + + vunpcklpd(Zmm(curr + 2), Zmm(8), Zmm(9)); + vunpckhpd(Zmm(curr + 3), Zmm(8), Zmm(9)); + + vunpcklpd(Zmm(8), Zmm(curr), Zmm(curr + 1)); + vunpckhpd(Zmm(9), Zmm(curr), Zmm(curr + 1)); + + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * tb * jcp.dimN_reg_block], + Zmm(curr+2)); + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * (tb + 1) * jcp.dimN_reg_block], + Zmm(curr+3)); + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * (tb + 2) * jcp.dimN_reg_block], + Zmm(8)); + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * (tb + 3) * jcp.dimN_reg_block], + Zmm(9)); + curr = next; + + } + } + } + postamble(); + ret(); +} +void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::gemm_loop_generate( + bool is_first_tile) +{ + // for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) + // for (int ifm2 = 0; ifm2 < jcp.ic_block; ifm2++) + // for (int nb_tile_block_ur = 0; nb_tile_block_ur < + // jcp.nb_tile_block_ur; nb_tile_block_ur++) + // for (int tile_block_ur = 0; tile_block_ur < + // jcp.tile_block_ur; tile_block_ur++) + // for (int ifm3 = 0; ifm3 < jcp.ic_reg_block; ++ifm3) + // U[ofm2][ifm2][ofm3][ifm3][0:oc_simd_block] += + // M[ofm2][ofm3][nb_tile_block_ur][tile_block_ur][0:oc_simd_block] + // * + // broadcast(V[ifm2][nb_tile_block_ur][ifm3][tile_block_ur]) + auto inner_loops = [=]() { + int inc_fma = jcp.ver == ver_4fma ? 4 : 1; + const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2; + prefetcher_t L1_pf(this, reg_srcB, L1, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma + / inc_fma, + fma_ipc); + prefetcher_t L2_pf(this, reg_srcB, L2, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma + / inc_fma, + fma_ipc); + + auto load_A = [=](int reg_idx, int offset) { + for (int i = 0; i < inc_fma; i++) { + vmovups(Zmm(reg_idx + i), + zword[reg_srcA + + sizeof(float) * jcp.dimM_simd_block * (offset + i)]); + } + }; + + Label dimM_block_loop, dimK_block_loop, dimN_block_loop; + if (jcp.dimM_block > 1) { + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + } + { /************* OC_block (M) loop ***********/ + if (jcp.dimN_block > 1) { + mov(reg_dimN_block_loop_cnt, jcp.dimN_block); + L(dimN_block_loop); + } + { /*************** IC_block (N) loop *********/ + for (int dimN_reg_block = 0; + dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) { + Zmm zmm(jcp.zmm_start + dimN_reg_block); + if (is_first_tile) + vpxord(zmm, zmm, zmm); + else + vmovups(zmm, zword[reg_dstC + + dimN_reg_block * jcp.dimM_simd_block * + sizeof(float)]); + } + + if (jcp.dimK_block > 1) { + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + } + { /************* nb_tile_ur(K) loop ********/ + int next = 0; + if (jcp.double_buffering) { + load_A(next, 0); + } + for (int dimK_reg_block = 0; + dimK_reg_block < jcp.dimK_reg_block; + dimK_reg_block++) { + int srcB_offset = dimK_reg_block * jcp.dimK_4fma + * jcp.dimN_reg_block; + for (int dimK_4fma = 0; dimK_4fma < jcp.dimK_4fma; + dimK_4fma += inc_fma) { + int current = next; + if (jcp.double_buffering) { + next = (dimK_reg_block * jcp.dimK_4fma + + dimK_4fma + inc_fma) + % (2 * inc_fma); + load_A(next, dimK_reg_block * jcp.dimK_4fma + + dimK_4fma + inc_fma); + } else { + next = 0; + load_A(next, dimK_reg_block * jcp.dimK_4fma + + dimK_4fma); + } + for (int dimN_reg_block = 0; + dimN_reg_block < jcp.dimN_reg_block; + ++dimN_reg_block) { + L1_pf.prefetch(srcB_offset / inc_fma + + dimK_4fma / inc_fma + * jcp.dimN_reg_block + + dimN_reg_block); + L2_pf.prefetch(srcB_offset / inc_fma + + dimK_4fma / inc_fma + * jcp.dimN_reg_block + + dimN_reg_block); + if (jcp.ver == ver_4fma) { + int srcB_trans_offset = (dimK_4fma / 4) * 64 + + dimK_4fma % 4; + v4fmaddps( + Zmm(jcp.zmm_start + dimN_reg_block), + Zmm(current), + EVEX_compress_addr(reg_srcB, + sizeof(float) * ( + srcB_offset + + srcB_trans_offset + + (dimN_reg_block % 4) * 16 + + (dimN_reg_block / 4) * 4))); + } else { + vfmadd231ps( + Zmm(jcp.zmm_start + dimN_reg_block), + Zmm(current), + EVEX_compress_addr(reg_srcB, + sizeof(float) * (srcB_offset + dimN_reg_block), + true)); + } + } + } + } + } + + add(reg_srcA, jcp.dimK_reg_block * jcp.dimK_4fma + * jcp.dimM_simd_block * sizeof(float)); + add(reg_srcB, jcp.dimK_reg_block * jcp.dimN_reg_block + * jcp.dimK_4fma * sizeof(float)); + if (jcp.dimK_block > 1) { + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + + /******** Write C back to memory *******/ + for (int dimN_reg_block = 0; + dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) { + Zmm zmm(jcp.zmm_start + dimN_reg_block); + vmovups(zword[reg_dstC + + dimN_reg_block * jcp.dimM_simd_block * sizeof(float)], + zmm); + } + + sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block * + jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float)); + add(reg_dstC, jcp.dimN_reg_block * jcp.dimM_simd_block + * sizeof(float)); + if (jcp.dimN_block > 1) { + sub(reg_dimN_block_loop_cnt, 1); + jnz(dimN_block_loop); + } + } + + if (jcp.dimM_block > 1) { + sub(reg_srcB, jcp.dimN_block * jcp.dimK_block + * jcp.dimK_reg_block * jcp.dimN_reg_block + * jcp.dimK_4fma * sizeof(float)); + add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float)); + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + } + }; + + /* Preamble */ + // register used to handle long fma encoding + preamble(); + mov(reg_srcA, reg_srcA_const); + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +namespace { +bool check_cond1_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C) +{ + float lhs = 1.0f * dimM_block * dimN_reg_block * dimM_simdw; + lhs += dimM_block * dimK_block * dimK_reg_block * dimK_4fma * dimM_simdw; + lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma; + lhs *= sizeof(float); + float rhs = C * L1_cache_size; + return (lhs <= rhs); +} + +bool check_cond1bis_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C) +{ + float lhs = 1.0f * dimM_block * dimK_block * dimK_reg_block * dimK_4fma + * dimM_simdw; + lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma; + lhs *= sizeof(float); + float rhs = C * L1_cache_size; + return (lhs <= rhs); +} + +bool check_cond2bis_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block, + float C) +{ + float lhs = 1.0f * dimM_block * dimM_simdw * dimK_block * dimK_reg_block + * dimK_4fma; + lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block + * dimN_reg_block; + lhs *= sizeof(float); + float rhs = C * L2_cache_size; + return (lhs <= rhs); +} + +bool check_cond2_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block, + float C) +{ + float lhs = 1.0f * dimM_block * dimM_simdw * dimN_block * dimN_reg_block; + lhs += dimM_block * dimM_simdw * dimK_block * dimK_reg_block * dimK_4fma; + lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block + * dimN_reg_block; + lhs *= sizeof(float); + float rhs = C * L2_cache_size; + return (lhs <= rhs); +} +} // namespace + +status_t set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp) +{ + /*************** Choose dimN_reg_block (ic_simd_block) + * *******************************/ + jcp.dimN = jcp.ic; + /*Hardcoded to 16 because N = ic for bwd weights and + innermost dimension for ic is assumed 16 in src transforms. This + choice covers load latencies while maintaining simplicity of kernel + for POR topologies. FIXME in future??: Will not work for future topologies + when ic%16 != 0*/ + jcp.dimN_reg_block = jcp.ic_simd_block; + + /****************************** Choose dimK_block + * **************************/ + // No freedom for choosing dimM_simd_block because ic_simd_block + // is determined by input data format + jcp.dimM_simd_block = jcp.oc_simd_block; + + auto test_cond1bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1bis_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f) + && (dimK_block > current_best); + }; + + auto test_cond1_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f) + && (dimK_block > current_best); + }; + + auto test_cond2bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond2bis_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.5f) + && (dimK_block > current_best); + }; + + auto test_cond2_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond2_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.1f) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2bis_dimK_block); + if (jcp.dimK_block < jcp.dimK / jcp.dimK_4fma) + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2_dimK_block); + + jcp.dimK_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimK_block, 1, test_cond1bis_dimK_block); + if (jcp.dimK_reg_block < jcp.dimK_block) { + jcp.dimK_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimK_block, 1, test_cond1_dimK_block); + } + jcp.dimK_block /= jcp.dimK_reg_block; + jcp.dimK_nb_block + = jcp.dimK / jcp.dimK_4fma / jcp.dimK_reg_block / jcp.dimK_block; + jcp.tile_block_ur = jcp.dimK_reg_block; + jcp.nb_tile_block_ur = jcp.dimK_block; + jcp.tile_block = jcp.dimK_nb_block; + + /***************************** Chose dimN block + * ****************************/ + auto test_cond2_dimN_block = []( + jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { + return check_cond2_wu(1, jcp.dimM_simd_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimK_4fma, dimN_block, + jcp.dimN_reg_block, 0.5f) + && (dimN_block > current_best); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); + jcp.ic_block = jcp.dimN_block; + jcp.dimN_nb_block = jcp.dimN / jcp.dimN_reg_block / jcp.dimN_block; + jcp.nb_ic = jcp.dimN_nb_block; + + /********************************* Choose dimM block + * ************************/ + jcp.dimM = jcp.oc; + + auto test_cond1_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1_wu(dimM_block, jcp.dimM_simd_block, 1, + jcp.dimK_reg_block, jcp.dimK_4fma, jcp.dimN_reg_block, + 1.0f) + && (dimM_block > current_best) + && (jcp.dimM / jcp.dimM_simd_block / dimM_block) >= 2; + }; + + jcp.dimM_block = get_divisor_satisfying_cond( + jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block); + jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block; + + jcp.sched_policy = WSCHED_WEI_S_D_G_W; + return status::success; +} + +status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d) +{ + jcp.nthr = mkldnn_get_max_threads(); + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + jcp.kh = diff_weights_d.dims()[with_groups + 2]; + jcp.kw = diff_weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef); + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + if (mayiuse(avx512_core)) + return status::unimplemented; + if (!mayiuse(avx512_common)) + return status::unimplemented; + else if (mayiuse(avx512_mic_4ops)) + jcp.ver = ver_4fma; + else + jcp.ver = ver_fma; + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + // Winograd kernel works only for 3x3 convolution with stride 1 + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.wei_tag != wei_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; + if (!layout_consistency) return status::unimplemented; + + /*************************** New Kernel Parameters + * *****************************/ + jcp.ic_simd_block = simd_w; + jcp.oc_simd_block = simd_w; + jcp.dimK_4fma = 1; + jcp.tile_4fma_padding = 0; + +#define MAX_4FMA_UR 8 + if (jcp.ver == ver_4fma) { + auto test_cond_4fma = [](jit_conv_winograd_conf_t &jcp, int dimK_4fma, + int current_best) { + return (dimK_4fma % 4 == 0) && (dimK_4fma <= MAX_4FMA_UR) + && (dimK_4fma > current_best); + }; + jcp.dimK_4fma = get_divisor_satisfying_cond( + jcp, jcp.itiles * jcp.jtiles, 4, test_cond_4fma); + if (jcp.dimK_4fma == 1) + jcp.dimK_4fma = 4; + if ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma != 0) + jcp.tile_4fma_padding = jcp.dimK_4fma + - ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma); + } + + jcp.tile_4fma = jcp.dimK_4fma; + /*NOTE: When (itiles * jtiles) % dimK_4fma != 0, transpose in diff_src + * transform + * will not work correctly, this is solved by applying padding.*/ + jcp.dimK = jcp.mb * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); + jcp.dimN = jcp.ic; + jcp.dimM = jcp.oc; + + jcp.double_buffering = true; + if (jcp.double_buffering) + jcp.zmm_start = jcp.ver == ver_4fma ? 8 : 2; + else + jcp.zmm_start = jcp.ver == ver_4fma ? 4 : 1; + jcp.nb_reg = 32 - jcp.zmm_start; + + jcp.sched_policy = WSCHED_INVALID; + status_t res = set_wsched_WEI_S_D_G_W_avx512_common(jcp); + assert(jcp.sched_policy == WSCHED_WEI_S_D_G_W); + + jcp.tile_block_ur = jcp.dimK_reg_block; + jcp.nb_tile_block_ur = jcp.dimK_block; + jcp.tile_block = jcp.dimK_nb_block; + + jcp.ic_block = jcp.dimN_block; + jcp.nb_ic = jcp.dimN_nb_block; + + jcp.oc_block = jcp.dimM_block; + jcp.nb_oc = jcp.dimM_nb_block; + + return res; + +} +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp new file mode 100644 index 000000000000..6c117143f593 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp @@ -0,0 +1,179 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP +#define JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "cpu_memory.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +//alpha determines the output tile_size +constexpr int alpha = 6; +constexpr int tile_size = 4; +//simd length used for vectorization +constexpr int simd_w = 16; + +struct _jit_avx512_common_conv_winograd_data_kernel_f32 : public jit_generator { + _jit_avx512_common_conv_winograd_data_kernel_f32( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) + { + //******************* First iter kernel ********************// + this->gemm_loop_generate(true); + gemm_loop_ker_first_iter + = (decltype(gemm_loop_ker_first_iter)) this->getCode(); + + //************** Subsequent iterations kernel **************// + if (jcp.dimK_nb_block > 1) { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(false); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_data_kernel_f32) + + static status_t init_conf_common(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d); + + static status_t init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *); + void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); + +protected: + using reg64_t = const Xbyak::Reg64; + enum { typesize = sizeof(float) }; + + void gemm_loop_generate(bool is_beta_zero); + + /* registers used for GEMM */ + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA = abi_param2; + reg64_t reg_srcB = abi_param3; + + reg64_t reg_dimM_block_loop_cnt = r10; + reg64_t reg_dimK_block_loop_cnt = r11; +}; + +struct jit_avx512_common_conv_winograd_fwd_kernel_f32 + : _jit_avx512_common_conv_winograd_data_kernel_f32 { + using _jit_avx512_common_conv_winograd_data_kernel_f32:: + _jit_avx512_common_conv_winograd_data_kernel_f32; + + static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); +}; + +struct jit_avx512_common_conv_winograd_bwd_data_kernel_f32 + : public _jit_avx512_common_conv_winograd_data_kernel_f32 { + using _jit_avx512_common_conv_winograd_data_kernel_f32:: + _jit_avx512_common_conv_winograd_data_kernel_f32; + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); +}; + +struct jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 + : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_bwd_weights_kernel_f32) + + jit_avx512_common_conv_winograd_bwd_weights_kernel_f32( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) + { + + //******************* First iter kernel ********************// + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(true); + gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))addr; + } + + if (jcp.tile_block > 1) { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(false); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + + if (jcp.ver == ver_4fma) { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->transpose_ker_generate(); + transpose_4fma_ker = (decltype(transpose_4fma_ker))addr; + } + } + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *); + void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); + void (*transpose_4fma_ker)(float *, float *); + +private: + using reg64_t = const Xbyak::Reg64; + enum { typesize = sizeof(float) }; + + void gemm_loop_generate(bool is_first_tile); + void transpose_ker_generate(); + + reg64_t reg_origB = abi_param2; + reg64_t reg_transB = abi_param1; + + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA_const = abi_param2; + reg64_t reg_srcB = abi_param3; + + reg64_t reg_sp = rsp; + reg64_t reg_srcA = r9; + reg64_t reg_nb_ic = r10; + reg64_t reg_loop_cpt = r11; + reg64_t reg_transB_idx = r13; + + /* Registers used by new kernel */ + reg64_t reg_dimM_block_loop_cnt = r10; + reg64_t reg_dimK_block_loop_cnt = r12; + reg64_t reg_dimN_block_loop_cnt = r11; +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp new file mode 100644 index 000000000000..abddc19221b5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp @@ -0,0 +1,1526 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_common_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace nstl; + +using jit_conv_ker_t = void (*)(jit_conv_call_s *); + +#define PIPELINE(field) \ + do { \ + p.field = p.field ## _prf; \ + p.field ## _prf = field; \ + } while (0) + +inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + + if (p.src) + ker(&p); +} +// The special case for the driver with ow-parallelization (FWD) +// TODO: implement it for BWD_D and BWD_W too +inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding, int owb) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + PIPELINE(owb); + + if (p.src) + ker(&p); +} + +inline void jit_conv_3d_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding, int kd_padding) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + PIPELINE(kd_padding); + + if (p.src) + ker(&p); +} +// The special case for the driver with ow-parallelization (FWD) +// TODO: implement it for BWD_D and BWD_W too +inline void jit_conv_3d_ker_pipeline_ow_thr(jit_conv_ker_t ker, + jit_conv_call_s &p, const void *src, const void *dst, const void *filt, + const void *bias, int channel, int kh_padding, int kd_padding, int owb) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + PIPELINE(kd_padding); + PIPELINE(owb); + + if (p.src) + ker(&p); +} + +void jit_conv_3d_ker_bwd_w_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int d_index, int d_worksize, + int kd_padding /* kd_work_size */, size_t kd_offset) { + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kd_padding); + PIPELINE(d_worksize); + PIPELINE(d_index); + PIPELINE(kd_offset); + + if (p.src) + ker(&p); +} +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() \ + ? (d).blk_off((g), __VA_ARGS__) \ + : (d).blk_off(__VA_ARGS__)) + +template +void jit_avx512_common_convolution_fwd_t::prepare_padded_bias(const dst_data_t *&bias, + const memory_tracking::grantor_t &scratchpad) const { + if (!pd()->wants_padded_bias()) return; + + auto padded_bias = scratchpad.template get( + key_conv_padded_bias); + utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding); + utils::array_set(padded_bias + pd()->jcp_.oc_without_padding, + (dst_data_t)0, pd()->jcp_.oc - pd()->jcp_.oc_without_padding); + bias = padded_bias; +} + +template +void jit_avx512_common_convolution_fwd_t:: +execute_forward_1d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + prepare_padded_bias(bias, this->scratchpad(ctx)); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.nb_ow; + + int nthr; + if (jcp.aligned_threads) + nthr = jcp.aligned_threads; + else + nthr = mkldnn_get_max_threads(); + + parallel(nthr, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t src_c_stride = src_d.blk_off(0, 1); + size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); + + for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { + start = start_copy; + int n{0}, g{0}, occ{0}, owb{0}; + + if (jcp.loop_order == loop_cwgn) { + int dummy{0}; + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gncw) { + int dummy{0}; + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, occ, + oc_chunks, owb, jcp.nb_ow, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int g_ocb = g * jcp.nb_oc + ocb; + int g_oc = g_ocb * jcp.oc_block; + int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; + + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + auto bias_w = bias ? bias + g_oc : nullptr; + auto dst_w = dst + dst_d.blk_off(n, g_ocb, ow_s); + auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2); + + for (int icb = icb_l2; + icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) { + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, + src_w, dst_w, wht_w, bias_w, icb, 1, owb); + + src_w += src_c_stride; + wht_w += wht_ic_stride; + } + if (jcp.loop_order == loop_cwgn) { + int dummy{0}; + nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gncw) { + int dummy{0}; + nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, + occ, oc_chunks, owb, jcp.nb_ow, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + } + } + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, + src, dst, weights, bias, 0, 0, 0); + }); +} + +template +void jit_avx512_common_convolution_fwd_t:: +execute_forward_2d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + prepare_padded_bias(bias, this->scratchpad(ctx)); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh * jcp.nb_ow; + + int nthr; + if (jcp.aligned_threads) + nthr = jcp.aligned_threads; + else + nthr = mkldnn_get_max_threads(); + + parallel(nthr, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t src_c_stride = src_d.blk_off(0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); + + for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { + start = start_copy; + int n{0}, g{0}, occ{0}, oh_s{0}, owb{0}; + + if (jcp.loop_order == loop_cwgn) + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, + occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int g_ocb = g * jcp.nb_oc + ocb; + int g_oc = g_ocb * jcp.oc_block; + int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; + + int work_rem = end - start; + + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + auto bias_w = bias ? bias + g_oc : nullptr; + + for (int oh_b = oh_s; oh_b < oh_e; oh_b += jcp.h_blocking) { + int ih_b = -jcp.t_pad + oh_b * jcp.stride_h; + + auto dst_w = dst + dst_d.blk_off(n, g_ocb, oh_b, ow_s); + auto src_w + = src + src_d.blk_off(n, g_icb + icb_l2, ih_b, iw_s); + auto wht_w + = weights + wht_blk_off(weights_d, g, ocb, icb_l2); + + for (int icb = icb_l2; + icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); + ++icb) { + auto src_c = src_w; + auto dst_c = dst_w; + for (int oj = oh_b, ij = ih_b; + oj < min(oh_e, oh_b + jcp.h_blocking); + ++oj, ij += jcp.stride_h) { + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = div_up(max(0, -ij), dilate_h); + int i_b_overflow = div_up(max(0, ij - jcp.ih + + (jcp.kh - 1) * dilate_h + 1), dilate_h); + int kh_padding = nstl::max( + 0, jcp.kh - i_t_overflow - i_b_overflow); + + auto aux_src = src_c + + i_t_overflow * dilate_h * src_h_stride; + auto aux_wht = wht_w + i_t_overflow * wht_h_stride; + + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, + par_conv, aux_src, dst_c, aux_wht, bias_w, icb, + kh_padding, owb); + + src_c += src_h_stride * jcp.stride_h; + dst_c += dst_h_stride; + } + src_w += src_c_stride; + wht_w += wht_ic_stride; + } + } + + if (jcp.loop_order == loop_cwgn) + nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, occ, + oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + } + } + + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, + src, dst, weights, bias, 0, 0, 0); + }); +} + +template +void jit_avx512_common_convolution_fwd_t:: +execute_forward_3d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + prepare_padded_bias(bias, this->scratchpad(ctx)); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + + parallel(0, [&](const int ithr, const int nthr) { + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int start{0}, end{0}, start_copy; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.od * jcp.oh + * jcp.nb_ow; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t src_d_stride = src_d.blk_off(0, 0, 1); + size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); + size_t src_c_stride = src_d.blk_off(0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1); + size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); + size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); + + for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { + start = start_copy; + int n{0}, g{0}, occ{0}, oh_s{0}, od_s{0}, owb{0}; + + if (jcp.loop_order == loop_cwgn) + nd_iterator_init(start, + occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb, + od_s, jcp.od, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_init(start, + g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow, + od_s, jcp.od, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int g_ocb = g * jcp.nb_oc + ocb; + int g_oc = g_ocb * jcp.oc_block; + int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; + + int work_rem = end - start; + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + + int id_s = -jcp.f_pad + od_s * jcp.stride_d; + + int dilate_d = jcp.dilate_d + 1; + int d_t_overflow = div_up(max(0, -id_s), dilate_d); + int d_b_overflow = div_up( + max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1), + dilate_d); + int kd_padding = nstl::max(0, + jcp.kd - d_t_overflow - d_b_overflow); + + auto bias_w = bias ? bias + bias_d.blk_off(g_oc) : 0; + auto dst_w = dst + dst_d.blk_off(n, g_ocb, od_s, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, id_s, ih_s, + iw_s) + d_t_overflow * dilate_d * src_d_stride; + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2) + + d_t_overflow * wht_d_stride; + + for (int icb = icb_l2; + icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) { + auto src_c = src_w; + auto dst_c = dst_w; + for (int oj = oh_s, ij = ih_s; + oj < oh_e; ++oj, ij += jcp.stride_h) + { + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = div_up(max(0, -ij), dilate_h); + int i_b_overflow = div_up( + max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + + 1), + dilate_h); + int kh_padding = nstl::max(0, + jcp.kh - i_t_overflow - i_b_overflow); + jit_conv_3d_ker_pipeline_ow_thr(kernel_->jit_ker, + par_conv, + src_c + i_t_overflow * dilate_h * src_h_stride, + dst_c, wht_w + i_t_overflow * wht_h_stride, + bias_w, icb, kh_padding, kd_padding, owb); + + src_c += src_h_stride * jcp.stride_h; + dst_c += dst_h_stride; + } + src_w += src_c_stride; + wht_w += wht_ic_stride; + } + + if (jcp.loop_order == loop_cwgn) + nd_iterator_jump(start, end, + occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb, + od_s, jcp.od, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_jump(start, end, + g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow, + od_s, jcp.od, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + } + } + jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, + src, dst, weights, bias, 0, 0, 0); + }); +} + +template struct jit_avx512_common_convolution_fwd_t; + +template +void jit_avx512_common_convolution_bwd_data_t::execute_backward_data_1d(const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + parallel(0, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); + size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); + + for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { + start = start_copy; + int n{0}, g{0}, icc{0}; + if (jcp.loop_order == loop_cgn) { + int dummy{0}; + nd_iterator_init(start, icc, ic_chunks, g, jcp.ngroups, n, + jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gnc) { + int dummy{0}; + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, icc, + ic_chunks, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + + while (start < end) { + int icb = icc * jcp.nb_ic_blocking; + int g_icb = g * jcp.nb_ic + icb; + int g_ocb = g * jcp.nb_oc; + + auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb); + auto diff_dst_w = diff_dst + + diff_dst_d.blk_off(n, g_ocb + ocb_l2); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb); + + for (int ocb = ocb_l2; + ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src_w, diff_dst_w, wht_w, 0, ocb, 1); + diff_dst_w += diff_dst_c_stride; + wht_w += wht_oc_stride; + } + + if (jcp.loop_order == loop_cgn) { + int dummy{0}; + nd_iterator_jump(start, end, icc, ic_chunks, g, jcp.ngroups, + n, jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gnc) { + int dummy{0}; + nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, icc, + ic_chunks, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + } + } + + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src, diff_dst, weights, 0, 0, 1); + }); +} + +template +void jit_avx512_common_convolution_bwd_data_t::execute_backward_data_2d(const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + parallel(0, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1); + size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1); + size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); + + bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1; + + for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { + start = start_copy; + int n{0}, g{0}, icc{0}, ih_s{0}; + if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_init(start, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + + while (start < end) { + int icb = icc * jcp.nb_ic_blocking; + int g_icb = g * jcp.nb_ic + icb; + int g_ocb = g * jcp.nb_oc; + + int work_rem = end - start; + int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; + + auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb); + auto diff_dst_w = diff_dst + + diff_dst_d.blk_off(n, g_ocb + ocb_l2); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb); + + for (int ocb = ocb_l2; + ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { + for (int ij = ih_s; ij < ih_e; ++ij) { + int oj, k_len, k_lo; + if (is_fast_path) { // dilate == 0 && stride == 1 + int i_t_overflow = max(0, jcp.kh - 1 - ij + - jcp.t_pad); + int i_b_overflow = max(0, jcp.kh - jcp.ih + ij + - jcp.b_pad); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow; + } else if (jcp.dilate_h != 0) { // stride == 1 + int dilate_h = jcp.dilate_h + 1; + // Note: use div_up to account for "holes" in filter + int i_t_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + - ij - jcp.t_pad), dilate_h); + int i_b_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 + - jcp.ih + ij - jcp.b_pad), dilate_h); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow * dilate_h; + } else { // dilate == 0 + int i_t_overflow = max(0, (jcp.kh - 1 - ij + - jcp.t_pad) / jcp.stride_h); + int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij + - jcp.b_pad) / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 + + jcp.b_pad - ij) % jcp.stride_h); + int overflow_kh_lo = (ij + jcp.t_pad) + % jcp.stride_h; + + k_len = (overflow_kh_hi - overflow_kh_lo) + / jcp.stride_h + 1 - i_t_overflow + - i_b_overflow; + k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; + oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h; + } + assert(k_len >= 0); + + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src_w + ij * diff_src_h_stride, + diff_dst_w + oj * diff_dst_h_stride, + wht_w + k_lo * wht_h_stride, + 0, ocb, k_len); + } + diff_dst_w += diff_dst_c_stride; + wht_w += wht_oc_stride; + } + + if (jcp.loop_order == loop_cgn) + nd_iterator_jump(start, end, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_jump(start, end, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + } + } + + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src, diff_dst, weights, 0, 0, 1); + }); +} + +template +void jit_avx512_common_convolution_bwd_data_t::execute_backward_data_3d(const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + parallel(0, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.id * jcp.ih; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1); + size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1); + size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1); + size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1); + size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); + size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); + + bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1; + bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1; + + for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { + start = start_copy; + int n{0}, g{0}, icc{0}, ih_s{0}, id_s{0}; + if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id, + ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_init(start, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id, + ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + + while (start < end) { + int icb = icc * jcp.nb_ic_blocking; + int g_icb = g * jcp.nb_ic + icb; + int g_ocb = g * jcp.nb_oc; + + int work_rem = end - start; + int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; + int d_len = 0, d_lo = 0, d_oj = 0; + if (is_fast_path_d) { // dilate == 0 && stride == 1 + int d_t_overflow = max(0, jcp.kd - 1 - id_s + - jcp.f_pad); + int d_b_overflow = max(0, jcp.kd - jcp.id + id_s + - jcp.back_pad); + d_len = jcp.kd - d_t_overflow - d_b_overflow; + d_lo = d_b_overflow; + d_oj = id_s + jcp.f_pad - d_b_overflow; + } else if (jcp.dilate_d != 0) { // stride == 1 + int dilate_d = jcp.dilate_d + 1; + // Note: use div_up to account for "holes" in filter + int d_t_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + - id_s - jcp.f_pad), dilate_d); + int d_b_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + 1 + - jcp.id + id_s - jcp.back_pad), dilate_d); + d_len = jcp.kd - d_t_overflow - d_b_overflow; + d_lo = d_b_overflow; + d_oj = id_s + jcp.f_pad - d_b_overflow * dilate_d; + } else { // dilate == 0 + int d_t_overflow = max(0, (jcp.kd - 1 - id_s + - jcp.f_pad) / jcp.stride_d); + int d_b_overflow = max(0, (jcp.kd - jcp.id + id_s + - jcp.back_pad) / jcp.stride_d); + int overflow_kd_hi = jcp.kd - 1 - abs((jcp.id - 1 + + jcp.back_pad - id_s) % jcp.stride_d); + int overflow_kd_lo = (id_s + jcp.f_pad) + % jcp.stride_d; + + d_len = (overflow_kd_hi - overflow_kd_lo) + / jcp.stride_d + 1 - d_t_overflow + - d_b_overflow; + d_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d; + d_oj = (id_s + jcp.f_pad - d_lo) / jcp.stride_d; + } + assert(d_len >= 0); + + auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb) + + id_s * diff_src_d_stride; + auto diff_dst_w = diff_dst + + diff_dst_d.blk_off(n, g_ocb + ocb_l2) + + d_oj * diff_dst_d_stride; + auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb) + + d_lo * wht_d_stride; + + for (int ocb = ocb_l2; + ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { + for (int ij = ih_s; ij < ih_e; ++ij) { + int oj, k_len, k_lo; + if (is_fast_path_h) { // dilate == 0 && stride == 1 + int i_t_overflow = max(0, jcp.kh - 1 - ij + - jcp.t_pad); + int i_b_overflow = max(0, jcp.kh - jcp.ih + ij + - jcp.b_pad); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow; + } else if (jcp.dilate_h != 0) { // stride == 1 + int dilate_h = jcp.dilate_h + 1; + // Note: use div_up to account for "holes" in filter + int i_t_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + - ij - jcp.t_pad), dilate_h); + int i_b_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 + - jcp.ih + ij - jcp.b_pad), dilate_h); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow * dilate_h; + } else { // dilate == 0 + int i_t_overflow = max(0, (jcp.kh - 1 - ij + - jcp.t_pad) / jcp.stride_h); + int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij + - jcp.b_pad) / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 + + jcp.b_pad - ij) % jcp.stride_h); + int overflow_kh_lo = (ij + jcp.t_pad) + % jcp.stride_h; + + k_len = (overflow_kh_hi - overflow_kh_lo) + / jcp.stride_h + 1 - i_t_overflow + - i_b_overflow; + k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; + oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h; + } + assert(k_len >= 0); + + jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src_w + ij * diff_src_h_stride, + diff_dst_w + oj * diff_dst_h_stride, + wht_w + k_lo * wht_h_stride, + 0, ocb, k_len, d_len); + } + diff_dst_w += diff_dst_c_stride; + wht_w += wht_oc_stride; + } + + if (jcp.loop_order == loop_cgn) + nd_iterator_jump(start, end, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id, + ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_jump(start, end, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id, + ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + } + } + + jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src, diff_dst, weights, 0, 0, 1, 1); + }); +} + +template struct jit_avx512_common_convolution_bwd_data_t; + +template +jit_avx512_common_convolution_bwd_weights_t:: +jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd), kernel_(nullptr) + , trans_kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr) +{ + const auto &j = pd()->jcp_; + + nthr_ = j.nthr; + nthr_mb_ = j.nthr_mb; + nthr_g_ = j.nthr_g; + nthr_oc_b_ = j.nthr_oc_b; + nthr_ic_b_ = j.nthr_ic_b; + + kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j); + + if (j.ver == ver_4fma) + trans_kernel_ = create_trans_src(&j); + + if (nthr_mb_ > 1) + acc_ker_ = new cpu_accumulator_1d_t(); + + reducer_bias_ = + new cpu_reducer_t(pd()->reducer_bia_conf_); +} + +template +struct jit_avx512_common_convolution_bwd_weights_t::thread_info_t { + const src_data_t *src; + const diff_dst_data_t *diff_dst; + const diff_weights_data_t *diff_weights; + diff_weights_data_t *diff_bias; + + const memory_tracking::grantor_t scratchpad; + + src_data_t *tr_src; + simple_barrier::ctx_t *tr_src_bctx; + + diff_dst_data_t *tr_diff_dst; + simple_barrier::ctx_t *tr_diff_dst_bctx; + + diff_weights_data_t *wei_bia_reduction; + simple_barrier::ctx_t *wei_bia_reduction_bctx; + + int ithr; + int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb; + int ithr_but_oc; + int ithr_but_ic; + + int img_start = 0, img_end = 0, img_work; + int g_start = 0, g_end = 0, g_work; + int oc_b_start = 0, oc_b_end = 0, oc_b_work; + int ic_b_start = 0, ic_b_end = 0, ic_b_work; + + thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self, + const exec_ctx_t &ctx, int ithr) + : scratchpad(self->scratchpad(ctx)), ithr(ithr) + { + diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + diff_weights = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + diff_bias = self->pd()->wants_padded_bias() + ? scratchpad.template get( + key_conv_padded_bias) + : CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS); + + tr_src = scratchpad.template get(key_conv_tr_src); + tr_src_bctx = scratchpad.template get( + key_conv_tr_src_bctx); + + tr_diff_dst = scratchpad.template get( + key_conv_tr_diff_dst); + tr_diff_dst_bctx = scratchpad.template get( + key_conv_tr_diff_dst_bctx); + + wei_bia_reduction = scratchpad.template get( + key_conv_wei_bia_reduction); + wei_bia_reduction_bctx = scratchpad.template get( + key_conv_wei_bia_reduction_bctx); + + ithr_ic_b = ithr % self->nthr_ic_b_; + ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_; + ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_; + ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_; + + ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_ + + ithr_ic_b; + + ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_ + + ithr_oc_b; + + const auto &jcp = self->kernel_->jcp; + + /* reduction dimension */ + balance211(jcp.mb*jcp.od, self->nthr_mb_, ithr_mb, img_start, img_end); + img_work = img_end - img_start; + + /* independent dimensions */ + balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end); + g_work = g_end - g_start; + + balance211(jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start, + oc_b_end); + oc_b_work = oc_b_end - oc_b_start; + + balance211(jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start, + ic_b_end); + ic_b_work = ic_b_end - ic_b_start; + } +}; + +template +void jit_avx512_common_convolution_bwd_weights_t::compute_diff_weights(const thread_info_t *ti) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh*jcp.kw*jcp.kd; + + diff_weights_data_t *diff_wei = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_weights + : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; + diff_weights_data_t *diff_bia = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_bias + : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size + + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc; + + // TODO: use memory descriptor with the same fmt as src (or use a macro :)) + auto tr_src_off = [&](int ithr_mb, int ic, int ij) { + const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; + const size_t tr_chn_size = tr_row_size * jcp.ih; + const size_t tr_img_size = tr_chn_size * jcp.nb_ic * jcp.ngroups; + + return ti->ithr_mb * tr_img_size + ic * tr_chn_size + ij * tr_row_size; + }; + + auto uker_trans = [&](int img) { + const int work_amount = ti->g_work * ti->ic_b_work * jcp.ih; + + int start{0}, end{0}; + balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end); + const int my_work = end - start; + + int g{0}, ic_b{0}, j{0}; + nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work, j, jcp.ih); + g += ti->g_start; + ic_b += ti->ic_b_start; + + const int _ic = g * jcp.nb_ic + ic_b; + src_data_t *src1 = (src_data_t*)&ti->src[src_d.blk_off(img, _ic, j)]; + src_data_t *tr_src1 = &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, j)]; + + assert(jcp.ic_block == 16); + const int src_stride = jcp.iw * jcp.ic_block; + const int tr_src_stride = jcp.tr_iw * jcp.ic_block; + + const int pf_depth = 2; + struct { src_data_t *src, *tr_src; } pf_circ_buf[pf_depth]; + + for (int iwork = 0; iwork < my_work + pf_depth - 1; iwork++) { + pf_circ_buf[iwork % pf_depth] = {src1, tr_src1}; + + if (iwork >= pf_depth - 1) { + int old_idx = (iwork - pf_depth + 1) % pf_depth; + auto ctx = jit_trans_src_t::ctx_t(); + ctx.src = pf_circ_buf[old_idx].src; + ctx.tr_src = pf_circ_buf[old_idx].tr_src; + ctx.src_prf = src1; + ctx.tr_src_prf = tr_src1; + (*trans_kernel_)(&ctx); + } + src1 += src_stride; + tr_src1 += tr_src_stride; + } +#if 0 + // reference transposition + const int l_pad = jcp.l_pad; + const int iwlp = l_pad + jcp.iw; + const int tr_iw = jcp.tr_iw; + + for (size_t iwork = start; iwork < end; iwork++) { + PRAGMA_OMP_SIMD() +# pragma unroll + for (int i = 0; i < l_pad; i++) + for (int j = 0; j < jcp.ic_block; j++) + tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0; + + PRAGMA_OMP_SIMD() +# pragma unroll + for (int i = l_pad; i < iwlp; i++) + for (int j = 0; j < jcp.ic_block; j++) + tr_src1[j * jcp.tr_iw + i] + = (src_data_t)src1[(i - l_pad) * 16 + j]; + + PRAGMA_OMP_SIMD() +# pragma unroll + for (int i = iwlp; i < tr_iw; i++) + for (int j = 0; j < jcp.ic_block; j++) + tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0; + + src1 += src_stride; + tr_src1 += tr_src_stride; + } +#endif + }; + + if (jcp.is_1stconv && jcp.ver == ver_4fma) { + /* prepare contexts */ + auto tr_ctx = jit_trans_src_t::ctx_t(); + tr_ctx.tr_src = ti->tr_src + + ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld; + + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_oc_b_ == 1)); + tr_ctx.nthr_oc_b = nthr_oc_b_; + int ih_start{0}, ih_end{0}; + balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end); + tr_ctx.tr_src_ih_start = ih_start; + tr_ctx.tr_src_ih_end = ih_end; + tr_ctx.tr_src_bctx = ti->tr_src_bctx + ti->ithr_but_oc; + + auto p = jit_conv_call_s(); + p.src = tr_ctx.tr_src; + + /* zero diff_bias if applicable */ + if (jcp.with_bias && ti->ithr_ic_b == 0) { + assert(jcp.oc_block == 16); + for (int oc_b = ti->ic_b_start; oc_b < ti->oc_b_end; ++oc_b) { + diff_weights_data_t *db = &diff_bia[oc_b * 16]; + for (int o = 0; o < 16; ++o) + db[o] = 0; + } + } + + for (int img = ti->img_start; img < ti->img_end; ++img) { + p.flags = (img == ti->img_start) * FLAG_MB_FIRST; + + for (int g = ti->g_start; g < ti->g_end; ++g) { + for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { + const int _ic = g * jcp.nb_ic + ic_b; + tr_ctx.src = &ti->src[src_d.blk_off(img, _ic)]; + + (*trans_kernel_)(&tr_ctx); + + if (ic_b == 0) + p.flags |= FLAG_IC_FIRST; + else + p.flags &= ~FLAG_IC_FIRST; + + for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { + const int _oc = g * jcp.nb_oc + oc_b; + p.dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)]; + + const size_t off = + wht_blk_off(diff_weights_d, g, oc_b, ic_b); + p.filt = diff_wei + off; + p.bias = diff_bia + _oc * jcp.oc_block; + + kernel_->jit_ker(&p); + } + } + } + } + } else { + for (int img = ti->img_start; img < ti->img_end; ++img) { + auto p = jit_conv_call_s(); + + if (jcp.ver == ver_4fma) { + /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */ + using simple_barrier::barrier; + if (nthr_oc_b_ > 1) + barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); + uker_trans(img); + if (nthr_oc_b_ > 1) + barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); + } + + for (int g = ti->g_start; g < ti->g_end; ++g) { + for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { + for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { + const int _oc = g * jcp.nb_oc + oc_b; + const int _ic = g * jcp.nb_ic + ic_b; + + jit_conv_ker_pipeline(kernel_->jit_ker, p, + jcp.ver == ver_4fma + ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)] + : &ti->src[src_d.blk_off(img, _ic)], + &ti->diff_dst[diff_dst_d.blk_off(img, _oc)], + diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), + 0, (img == ti->img_start), 0); + + } + } + } + + const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start; + const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start; + jit_conv_ker_pipeline(kernel_->jit_ker, p, + jcp.ver == ver_4fma + ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)] + : &ti->src[src_d.blk_off(img + 1, _ic)], + &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)], + diff_wei + wht_blk_off( + diff_weights_d, ti->g_start, + ti->oc_b_start, ti->ic_b_start), + 0, 0, 0); + } + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::compute_diff_weights_3d(const thread_info_t *ti) const +{ + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size + = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw * jcp.kd; + + diff_weights_data_t *diff_wei = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_weights + : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; + diff_weights_data_t *diff_bia = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_bias + : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size + + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc; + + const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + const int input_step = jcp.ih * jcp.iw * inp_mult; + const int output_step = jcp.ow * jcp.oh * jcp.oc_block; + int img{0}, od_s{0}; + int img_start = ti->img_start, img_end = ti->img_end; + nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od); + const int img_first = img; + + while (img_start < img_end) { + auto p = jit_conv_call_s(); + + int work_rem = img_end - img_start; + const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; + const int id_s = od_s * jcp.stride_d; + const int ik_overlap = nstl::max(0, id_s - jcp.f_pad); + const int kd_front_pad = nstl::max(0, jcp.f_pad - id_s); + const int kd_back_pad + = nstl::max(0, id_s - jcp.f_pad - jcp.id + jcp.kd); + int kd_pad_off = nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw + * jcp.ic_block * jcp.oc_block * jcp.typesize_out; + + for (int g = ti->g_start; g < ti->g_end; ++g) { + for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { + for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { + const int _oc = g * jcp.nb_oc + oc_b; + const int _ic = g * jcp.nb_ic + ic_b; + + auto src = &ti->src[src_d.blk_off(img, _ic) + + ik_overlap * input_step]; + auto dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc) + + od_s * output_step]; + + jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, src, dst, + diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), + diff_bia + _oc * 16, (img == img_first), od_s, od_e, + jcp.kd - kd_front_pad - kd_back_pad, kd_pad_off); + + if (ic_b == 0) p.flags = 0; + else p.flags = 1; + } + } + } + + const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start; + const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start; + jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, + &ti->src[src_d.blk_off(img + 1, _ic)], + &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)], + diff_wei + wht_blk_off(diff_weights_d, ti->g_start, + ti->oc_b_start, ti->ic_b_start), + diff_bia, 0, 0, 0, 0, 0); + nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od); + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::reduce_diff_weights(const thread_info_t *ti) const { + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; + const int bia_size = jcp.ngroups * jcp.oc; + const diff_weights_data_t *diff_bias_ws + = ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size; + + /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ + simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); + + const int ic_b_kh_work = ti->ic_b_work * jcp.kh; + const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; + + int start{0}, end{0}; + balance211(work, nthr_mb_, ti->ithr_mb, start, end); + if (start == end) return; + + for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { + int w = start; + int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0}; + nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + while (w < end) { + const int g = ti->g_start + sub_g_start; + const int oc_b = ti->oc_b_start + sub_oc_b_start; + const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kh; + const int kh = sub_ic_b_kh_start % jcp.kh; + + const int acc_size + = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start) + * jcp.kw * jcp.ic_block * jcp.oc_block; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kh); + + diff_weights_data_t *d + = (diff_weights_data_t *)ti->diff_weights + off; + diff_weights_data_t *s + = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off; + + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + } + + if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) { + if (ti->ithr == 0) + acc_ker_->accumulate((diff_weights_data_t *)ti->diff_bias, + diff_bias_ws, bia_size); + diff_bias_ws += bia_size; + } + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::reduce_diff_weights_3d(const thread_info_t *ti) const { + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw + * jcp.kd; + + /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ + simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); + + const int ic_b_kh_work = ti->ic_b_work * jcp.kd; + const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; + + int start{0}, end{0}; + balance211(work, nthr_mb_, ti->ithr_mb, start, end); + if (start == end) return; + + for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { + int w = start; + int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0}; + nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + while (w < end) { + const int g = ti->g_start + sub_g_start; + const int oc_b = ti->oc_b_start + sub_oc_b_start; + const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kd; + const int kd = sub_ic_b_kh_start % jcp.kd; + + const int acc_size + = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start) + * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.kh; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kd); + diff_weights_data_t *d + = (diff_weights_data_t *)ti->diff_weights + off; + diff_weights_data_t *s + = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off; + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + } + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::compute_diff_bias(const thread_info_t *ti) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + auto rb = this->reducer_bias_; + assert(nthr_ == rb->balancer().nthr_); + + const auto reducer_bia_scratchpad = memory_tracking::grantor_t( + ti->scratchpad, prefix_reducer_bia); + + const auto &jcp = kernel_->jcp; + + if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) return; + + const int b_job_start = rb->balancer().ithr_job_off(ti->ithr); + const int b_njobs = rb->balancer().ithr_njobs(ti->ithr); + + if (b_njobs == 0) return; + + /* reduction dimension */ + int img_start{0}, img_end{0}; + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ti->ithr), img_start, img_end); + + /* jobs */ + int g_start{0}, ocb_start{0}; + nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc); + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * jcp.nb_oc + ocb; + + const diff_dst_data_t *d_dst + = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)]; + diff_weights_data_t *d_bias = rb->get_local_ptr(ti->ithr, + ti->diff_bias, reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 16; ++o) + d_bias[o] = 0; + for (int hw = 0; hw < jcp.oh * jcp.ow * jcp.od; ++hw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 16; ++o) + d_bias[o] += d_dst[o]; + d_dst += 16; + } + + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc); + } + } + + rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad); +} + +template +void jit_avx512_common_convolution_bwd_weights_t::compute_diff_bias_3d(const thread_info_t *ti) const { + + const auto &jcp = kernel_->jcp; + + const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh + * jcp.kw * jcp.kd; + const int bia_size = jcp.ngroups * jcp.oc; + const diff_weights_data_t *diff_bias_ws + = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size; + + if (nthr_mb_ > 1) mkldnn_thr_barrier(); + + if (ti->ithr == 0) + { + for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { + acc_ker_->accumulate(ti->diff_bias, diff_bias_ws, bia_size); + diff_bias_ws += bia_size; + } + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::prepare_scratchpad_data(const exec_ctx_t &ctx) const +{ + const auto &j = pd()->jcp_; + auto scratchpad = this->scratchpad(ctx); + + if (j.ver == ver_4fma) { + if (!j.is_1stconv) { + // XXX: See the comment about tr_iw and guarding elements in + // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf() + const int max_nthr = j.nthr_mb * j.ngroups * j.nb_ic; + const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw; + + auto tr_src = scratchpad.template get(key_conv_tr_src); + /* to avoid NaNs in computations we zero tail num_guard_elems for + * each possible thread group */ + + for (int ithr = 1; ithr <= max_nthr; ++ithr) { + src_data_t *ts = &tr_src[ithr * min_tr_src_size_per_thr]; + for (int i = 0; i < j.tr_src_num_guard_elems; ++i) + ts[i] = 0; + } + } + + if (j.nthr_oc_b > 1) { + const int tr_src_bctx_size = j.nthr / j.nthr_oc_b; + auto tr_src_bctx = scratchpad.template get( + key_conv_tr_src_bctx); + for (int i = 0; i < tr_src_bctx_size; ++i) + simple_barrier::ctx_init(&tr_src_bctx[i]); + } + } + + if (nthr_mb_ > 1) { + simple_barrier::ctx_init(scratchpad.template get( + key_conv_wei_bia_reduction_bctx)); + } + + const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); +} + +template +void jit_avx512_common_convolution_bwd_weights_t::execute_backward_weights(const exec_ctx_t &ctx) const { + prepare_scratchpad_data(ctx); + + parallel(nthr_, [&](const int ithr, const int nthr) { + assert(nthr_ == nthr); + + thread_info_t thread_info(this, ctx, ithr); + + if (utils::one_of(pd()->ndims(), 3, 4)) { + compute_diff_weights(&thread_info); + if (nthr_mb_ > 1) reduce_diff_weights(&thread_info); + if (pd()->with_bias()) compute_diff_bias(&thread_info); + } else if (pd()->ndims() == 5) { + compute_diff_weights_3d(&thread_info); + if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info); + if (pd()->with_bias()) compute_diff_bias_3d(&thread_info); + } else { + assert(false); + } + }); + + /* TODO: put that into compute_diff_bias() */ + if (pd()->wants_padded_bias()) { + auto diff_bias = scratchpad(ctx).template get( + key_conv_padded_bias); + auto diff_bias_in = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS); + for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc) + diff_bias_in[oc] = diff_bias[oc]; + } +} + +template struct jit_avx512_common_convolution_bwd_weights_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp new file mode 100644 index 000000000000..3341c3ebe08d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp @@ -0,0 +1,302 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP +#define CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_barrier.hpp" +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_transpose_src_utils.hpp" +#include "jit_avx512_common_conv_kernel.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_common_convolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, dst_type, dst_type, + data_type::undef) + && !has_zero_dim_memory(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_common_conv_fwd_kernel::init_conf( + jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, + *attr(), mkldnn_get_max_threads()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_conv_fwd_kernel::init_scratchpad(scratchpad, + jcp_); + + return status; + } + + jit_conv_conf_t jcp_; + }; + + jit_avx512_common_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { + kernel_ = new jit_avx512_common_conv_fwd_kernel(pd()->jcp_, + *pd()->attr()); + } + ~jit_avx512_common_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->ndims() == 3) + execute_forward_1d(ctx); + else if (pd()->ndims() == 4) + execute_forward_2d(ctx); + else if (pd()->ndims() == 5) + execute_forward_3d(ctx); + else + assert(false); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); + + return status::success; + } + +private: + void prepare_padded_bias(const dst_data_t *&bias, + const memory_tracking::grantor_t &scratchpad) const; + void execute_forward_1d(const exec_ctx_t &ctx) const; + void execute_forward_2d(const exec_ctx_t &ctx) const; + void execute_forward_3d(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_conv_fwd_kernel *kernel_; +}; + +template +struct jit_avx512_common_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, wei_type, + data_type::undef, diff_dst_type, data_type::undef) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = + jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(jcp_, + *desc(), *diff_src_md(), *weights_md(), *diff_dst_md()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16o16i, gOIw16o16i, OIhw16o16i, gOIhw16o16i, + OIdhw16o16i, gOIdhw16o16i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + jit_avx512_common_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd) + { kernel_ = new jit_avx512_common_conv_bwd_data_kernel_f32(pd()->jcp_); } + ~jit_avx512_common_convolution_bwd_data_t() { delete kernel_; }; + + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_src_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->ndims() == 3) + execute_backward_data_1d(ctx); + else if (pd()->ndims() == 4) + execute_backward_data_2d(ctx); + else if (pd()->ndims() == 5) + execute_backward_data_3d(ctx); + else + assert(false); + return status::success; + } + +private: + void execute_backward_data_1d(const exec_ctx_t &ctx) const; + void execute_backward_data_2d(const exec_ctx_t &ctx) const; + void execute_backward_data_3d(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_conv_bwd_data_kernel_f32 *kernel_; +}; + +template +struct jit_avx512_common_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, diff_weights_type, + diff_weights_type, diff_dst_type, data_type::undef) + && !has_zero_dim_memory(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_common_conv_bwd_weights_kernel_f32:: + init_conf(jcp_, *desc(), src_md_, diff_weights_md_, + diff_bias_md_, diff_dst_md_); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + return status; + } + + jit_conv_conf_t jcp_; + typename cpu_reducer_t::conf_t reducer_bia_conf_; + + private: + void init_balancers() { + const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; + if (with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, + jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb, + max_buffer_size)); + } + } + }; + + jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd); + ~jit_avx512_common_convolution_bwd_weights_t() { + delete kernel_; + if (trans_kernel_) + delete trans_kernel_; + if (acc_ker_) + delete acc_ker_; + delete reducer_bias_; + } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type diff_weights_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + void prepare_scratchpad_data(const exec_ctx_t &ctx) const; + struct thread_info_t; + void compute_diff_weights(const thread_info_t *) const; + void compute_diff_weights_3d(const thread_info_t *) const; + void reduce_diff_weights(const thread_info_t *) const; + void reduce_diff_weights_3d(const thread_info_t *) const; + void compute_diff_bias(const thread_info_t *) const; + void compute_diff_bias_3d(const thread_info_t *) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + int nthr_, nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_; + + jit_avx512_common_conv_bwd_weights_kernel_f32 *kernel_; + jit_trans_src_t *trans_kernel_; + cpu_accumulator_1d_t *acc_ker_; + cpu_reducer_t *reducer_bias_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp new file mode 100644 index 000000000000..62247c0264db --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp @@ -0,0 +1,1215 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifdef __INTEL_COMPILER +#include +#endif + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_common_convolution_winograd.hpp" + +#ifndef _MSC_VER +#define pragma_unroll _Pragma("unroll") +#else +#define pragma_unroll +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +namespace { + +unsigned int LLC_cache_size = get_cache_size(3, false); + +void inline load_ps(float *dest, const float *src_mem) { +#ifdef __INTEL_COMPILER + __m512 *Iv512 = (__m512 *)dest; + Iv512[0] = _mm512_load_ps(src_mem); +#else + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) dest[v] = src_mem[v]; +#endif +} + +void inline store_output(float *dest, const float *data, bool streamout) { +#ifdef __INTEL_COMPILER + if (streamout) + _mm512_stream_ps(dest, *((__m512 *)data)); + else + _mm512_store_ps(dest, *((__m512 *)data)); +#else + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + dest[v] = data[v]; +#endif +} + +void inline accum_output( + float *dest, float *data, bool streamout, bool with_relu_postsum) { +#ifdef __INTEL_COMPILER + __m512 _data = _mm512_loadu_ps(data); + __m512 _dest = _mm512_loadu_ps(dest); + _data = _mm512_add_ps(_data, _dest); + if (with_relu_postsum) + _data = _mm512_max_ps(_data, _mm512_setzero_ps()); + if (streamout) + _mm512_stream_ps(dest, _data); + else + _mm512_store_ps(dest, _data); +#else + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + data[v] += dest[v]; + + if (with_relu_postsum) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + if (data[v] < 0.f) + data[v] = 0.f; + } + + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + dest[v] = data[v]; +#endif +} +} + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]) { + float Fw[6][16]; + float T[6][3][16]; + float t0[16]; + float t1[16]; + float t2[16]; + + for (int j = 0; j < 16; j++) { +#pragma unroll + for (int i = 0; i < 3; i++) { + PRAGMA_OMP_SIMD() + for (int k = 0; k < 16; k++) { + t0[k] = 0.26890756302521f * F[2][i][j][k]; + t1[k] = -t0[k] - 0.688403361344538f * F[0][i][j][k]; + t2[k] = t0[k] + 0.119514472455649f * F[0][i][j][k]; + + T[0][i][k] = 1.13777777777778f * F[0][i][j][k]; + T[1][i][k] = t1[k] - 0.430252100840336f * F[1][i][j][k]; + T[2][i][k] = t1[k] + 0.430252100840336f * F[1][i][j][k]; + T[3][i][k] = t2[k] + 0.179271708683473f * F[1][i][j][k]; + T[4][i][k] = t2[k] - 0.179271708683473f * F[1][i][j][k]; + T[5][i][k] = F[2][i][j][k]; + } + } +#pragma unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int k = 0; k < 16; k++) { + t0[k] = 0.26890756302521f * T[i][2][k]; + t1[k] = -t0[k] - 0.688403361344538f * T[i][0][k]; + t2[k] = t0[k] + 0.119514472455649f * T[i][0][k]; + + Fw[0][k] = 1.13777777777778f * T[i][0][k]; + Fw[1][k] = t1[k] - 0.430252100840336f * T[i][1][k]; + Fw[2][k] = t1[k] + 0.430252100840336f * T[i][1][k]; + Fw[3][k] = t2[k] + 0.179271708683473f * T[i][1][k]; + Fw[4][k] = t2[k] - 0.179271708683473f * T[i][1][k]; + Fw[5][k] = T[i][2][k]; +#pragma unroll + for (int l = 0; l < 6; l++) { + Fw_[i][l][j][k] = Fw[l][k]; + } + } + } + } +} + +void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]) { + float T[4][6][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + +#pragma unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = Mw[1][i][v] + Mw[2][i][v]; + t1[v] = Mw[3][i][v] + Mw[4][i][v]; + t2[v] = Mw[1][i][v] - Mw[2][i][v]; + t3[v] = Mw[3][i][v] - Mw[4][i][v]; + + T[0][i][v] = t0[v] + t1[v] + Mw[0][i][v]; + T[1][i][v] = t2[v] * 0.625f + t3[v] * 1.5f; + T[2][i][v] = t0[v] * 0.390625f + t1[v] * 2.25f; + T[3][i][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + Mw[5][i][v]; + } + } +#pragma unroll + for (int i = 0; i < 4; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = T[i][1][v] + T[i][2][v]; + t1[v] = T[i][3][v] + T[i][4][v]; + t2[v] = T[i][1][v] - T[i][2][v]; + t3[v] = T[i][3][v] - T[i][4][v]; + + O[i][0][v] = t0[v] + t1[v] + T[i][0][v]; + O[i][1][v] = t2[v] * 0.625f + t3[v] * 1.5f; + O[i][2][v] = t0[v] * 0.390625f + t1[v] * 2.25f; + O[i][3][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + T[i][5][v]; + } + } +} + + +void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]) +{ + const float rcp3 = 1.0f / 3.0f; + const float rcp4 = 1.0f / 4.0f; + const float rcp6 = 1.0f / 6.0f; + const float rcp12 = 1.0f / 12.0f; + const float rcp24 = 1.0f / 24.0f; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + float t4[16]; + float T[6][4][16]; + +pragma_unroll + for (int i = 0; i < 4; i++) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < 16; j++) { + t0[j] = F[2][i][j] * rcp6; + t1[j] = F[0][i][j] * -rcp6 - t0[j]; + t2[j] = F[0][i][j] * rcp24 + t0[j]; + t3[j] = (F[1][i][j] + F[3][i][j]) * rcp6; + t4[j] = F[1][i][j] * rcp12 + F[3][i][j] * rcp3; + + T[0][i][j] = F[0][i][j] * rcp4; + T[1][i][j] = t1[j] - t3[j]; + T[2][i][j] = t1[j] + t3[j]; + T[3][i][j] = t2[j] + t4[j]; + T[4][i][j] = t2[j] - t4[j]; + T[5][i][j] = F[3][i][j]; + } + } +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < 16; j++) { + t0[j] = T[i][2][j] * rcp6; + t1[j] = T[i][0][j] * -rcp6 - t0[j]; + t2[j] = T[i][0][j] * rcp24 + t0[j]; + t3[j] = (T[i][1][j] + T[i][3][j]) * rcp6; + t4[j] = T[i][1][j] * rcp12 + T[i][3][j] * rcp3; + + Fw[i][0][j] = T[i][0][j] * rcp4; + Fw[i][1][j] = t1[j] - t3[j]; + Fw[i][2][j] = t1[j] + t3[j]; + Fw[i][3][j] = t2[j] + t4[j]; + Fw[i][4][j] = t2[j] - t4[j]; + Fw[i][5][j] = T[i][3][j]; + } + } +} + +void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]) +{ + float T[4][6][16]; + float M_[3][16]; + float t0[16]; + float t1[16]; + float t2[16]; + + for (int j = 0; j < 16; j++) { +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int l = 0; l < 16; l++) { + t0[l] = Mw[1][i][j][l] + Mw[2][i][j][l]; + t1[l] = Mw[3][i][j][l] + Mw[4][i][j][l]; + t2[l] = t1[l] * 4.0f + Mw[5][i][j][l]; + + T[0][i][l] = Mw[0][i][j][l] + t0[l] + t1[l]; + T[1][i][l] = (Mw[1][i][j][l] - Mw[2][i][j][l]) + + 2.0f * (Mw[3][i][j][l] - Mw[4][i][j][l]); + T[2][i][l] = t0[l] + t2[l]; + } + } +pragma_unroll + for (int i = 0; i < 3; i++) { + PRAGMA_OMP_SIMD() + for (int l = 0; l < 16; l++) { + t0[l] = T[i][1][l] + T[i][2][l]; + t1[l] = T[i][3][l] + T[i][4][l]; + t2[l] = t1[l] * 4.0f + T[i][5][l]; + + M_[0][l] = T[i][0][l] + t0[l] + t1[l]; + M_[1][l] = (T[i][1][l] - T[i][2][l]) + + 2.0f * (T[i][3][l] - T[i][4][l]); + M_[2][l] = t0[l] + t2[l]; + + for (int k = 0; k < 3; k++) { + M[i][k][j][l] = M_[k][l]; + } + } + } + } +} + +void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]) +{ + float T[6][6][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + float t4[16]; + float t5[16]; + +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = I[2][i][v] * -2.25f + I[4][i][v]; + t1[v] = I[1][i][v] * -2.25f + I[3][i][v]; + t2[v] = I[2][i][v] * -0.390625f + I[4][i][v]; + t3[v] = I[1][i][v] * -0.390625f + I[3][i][v]; + t4[v] = I[0][i][v] * 0.87890625f + I[4][i][v]; + t5[v] = I[1][i][v] * 0.87890625f + I[5][i][v]; + + T[0][i][v] = I[2][i][v] * -2.640625f + t4[v]; + T[1][i][v] = t1[v] * 0.625f + t0[v]; + T[2][i][v] = t1[v] * -0.625f + t0[v]; + T[3][i][v] = t3[v] * 1.5f + t2[v]; + T[4][i][v] = t3[v] * -1.5f + t2[v]; + T[5][i][v] = I[3][i][v] * -2.640625f + t5[v]; + } + } + +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = T[i][2][v] * -2.25f + T[i][4][v]; + t1[v] = T[i][1][v] * -2.25f + T[i][3][v]; + t2[v] = T[i][2][v] * -0.390625f + T[i][4][v]; + t3[v] = T[i][1][v] * -0.390625f + T[i][3][v]; + t4[v] = T[i][0][v] * 0.87890625f + T[i][4][v]; + t5[v] = T[i][1][v] * 0.87890625f + T[i][5][v]; + + Iw[i][0][v] = T[i][2][v] * -2.640625f + t4[v]; + Iw[i][1][v] = t1[v] * 0.625f + t0[v]; + Iw[i][2][v] = t1[v] * -0.625f + t0[v]; + Iw[i][3][v] = t3[v] * 1.5f + t2[v]; + Iw[i][4][v] = t3[v] * -1.5f + t2[v]; + Iw[i][5][v] = T[i][3][v] * -2.640625f + t5[v]; + } + } +} + +void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]) +{ + float T[6][4][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + float t4[16]; + +pragma_unroll + for (int i = 0; i < 4; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = F[2][i][v] * 0.26890756302521f; + t1[v] = F[0][i][v] * -0.688403361344538f - t0[v]; + t2[v] = F[0][i][v] * 0.119514472455649f + t0[v]; + t3[v] = F[1][i][v] * 0.430252100840336f + + F[3][i][v] * 0.168067226890756f; + t4[v] = F[1][i][v] * 0.179271708683473f + + F[3][i][v] * 0.403361344537815f; + + T[0][i][v] = F[0][i][v] * 1.13777777777778f; + T[1][i][v] = t1[v] - t3[v]; + T[2][i][v] = t1[v] + t3[v]; + T[3][i][v] = t2[v] + t4[v]; + T[4][i][v] = t2[v] - t4[v]; + T[5][i][v] = F[3][i][v]; + } + } +pragma_unroll + for (int i = 0; i < 6; i++) { + for (int v = 0; v < 16; v++) { + t0[v] = T[i][2][v] * 0.26890756302521f; + t1[v] = T[i][0][v] * -0.688403361344538f - t0[v]; + t2[v] = T[i][0][v] * 0.119514472455649f + t0[v]; + t3[v] = T[i][1][v] * 0.430252100840336f + + T[i][3][v] * 0.168067226890756f; + t4[v] = T[i][1][v] * 0.179271708683473f + + T[i][3][v] * 0.403361344537815f; + + Fw[i][0][v] = T[i][0][v] * 1.13777777777778f; + Fw[i][1][v] = t1[v] - t3[v]; + Fw[i][2][v] = t1[v] + t3[v]; + Fw[i][3][v] = t2[v] + t4[v]; + Fw[i][4][v] = t2[v] - t4[v]; + Fw[i][5][v] = T[i][3][v]; + } + } +} + +void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]) +{ + float T[3][6][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float M_[3][16]; + + for (int j = 0; j < 16; j++) { +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = Mw[1][i][j][v] + Mw[2][i][j][v]; + t1[v] = Mw[3][i][j][v] + Mw[4][i][j][v]; + t2[v] = t1[v] * 2.25f + Mw[5][i][j][v]; + + T[0][i][v] = Mw[0][i][j][v] + t0[v] + t1[v]; + T[1][i][v] = 0.625f * (Mw[1][i][j][v] - Mw[2][i][j][v]) + + 1.5f * (Mw[3][i][j][v] - Mw[4][i][j][v]); + T[2][i][v] = t0[v] * 0.390625f + t2[v]; + } + } +pragma_unroll + for (int i = 0; i < 3; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = T[i][1][v] + T[i][2][v]; + t1[v] = T[i][3][v] + T[i][4][v]; + t2[v] = t1[v] * 2.25f + T[i][5][v]; + + M_[0][v] = T[i][0][v] + t0[v] + t1[v]; + M_[1][v] = 0.625f * (T[i][1][v] - T[i][2][v]) + + 1.5f * (T[i][3][v] - T[i][4][v]); + M_[2][v] = t0[v] * 0.390625f + t2[v]; + } + +pragma_unroll + for (int k = 0; k < 3; k++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + M[i][k][j][v] = M_[k][v]; + } + } + } + } +} + +template +void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp, bool streamout = true) +{ + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow; + const int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh; + const int wp_max = inpw + l_pad; + const int hp_max = inph + t_pad; + float Iw[alpha][alpha][simd_w]; + float I[alpha][alpha][simd_w]; + + array_offset_calculator input(inp, + jcp.mb, jcp.dimK/simd_w, inph, inpw, + simd_w); + array_offset_calculator output(tinp, + jcp.dimN_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block, + jcp.dimN_reg_block, jcp.dimK_reg_block); + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + int ydim = tj * tile_size + j; + if ((t_pad <= ydim) && (ydim < hp_max)) { + float *pinp_j = inp + (ydim - t_pad) * inpw * 16 ; + for (int i = 0; i < alpha; i++) { + int xdim = ti * tile_size + i; + if ((l_pad <= xdim) && (xdim < wp_max)) { + float *pinp_i = pinp_j + (xdim - l_pad) * 16; + load_ps(I[j][i], pinp_i); + } else { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } else { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } + + trans_I_4x4_3x3(Iw, I); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + store_output(&(output(tile_block, j, i, + nb_tile_block_ur, 0, 0, + tile_block_ur, 0)), + Iw[j][i], streamout); + } + } + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template +void weight_transform_data(const jit_conv_winograd_conf_t &jcp, + float *wp, float *twp) +{ + const int kh = 3; + const int kw = 3; + array_offset_calculator input(wp, + jcp.oc/jcp.oc_simd_block, + jcp.ic/jcp.ic_simd_block, + jcp.kh, jcp.kw, + simd_w, simd_w); + array_offset_calculator output(twp, + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block, jcp.dimK_block, + simd_w, simd_w); + float Fw[alpha][alpha][simd_w][simd_w]; + float F[kh][kw][simd_w][simd_w]; + + for (int j = 0; j < kh; j++) { + for (int i = 0; i < kw; i++) { + for (int v1 = 0; v1 < simd_w; v1++) { + float *base_inp = is_fwd + ? &(input(0, 0, j, i, v1, 0)) + : &(input(0, 0, 2 - j, 2 - i, v1, 0)); + PRAGMA_OMP_SIMD() + for (int v2 = 0; v2 < simd_w; v2++) { + if (is_fwd) + F[j][i][v1][v2] = *(base_inp + v2); + else + F[j][i][v2][v1] = *(base_inp + v2); + } + } + } + } + + trans_W_4x4_3x3(Fw, F); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + for (int v1 = 0; v1 < simd_w; v1++) { + PRAGMA_OMP_SIMD() + for (int v2 = 0; v2 < simd_w; v2++) { + output(0, j, i, 0, 0, 0, v1, v2) = Fw[j][i][v1][v2]; + } + } + } + } +} + +template +void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp, + const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias, + bool streamout = true) { + float Ow[alpha][alpha][simd_w]; + float O[tile_size][tile_size][simd_w]; + int outw = is_fwd ? jcp.ow : jcp.iw; + int outh = is_fwd ? jcp.oh : jcp.ih; + + /* Prepare for PostOps */ + bool with_relu_postsum = p_ops.find(primitive_kind::eltwise, 1) != -1; + + array_offset_calculator input(toutp, + jcp.dimN_nb_block, jcp.dimM_nb_block, + alpha, alpha, + jcp.dimN_block, jcp.dimM_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + Ow[j][i][v] = input(tile_block, 0, + j, i, + nb_tile_block_ur, 0, + tile_block_ur, v); + } + } + } + + trans_O_4x4_3x3(Ow, O); + + for (int j = 0; j < tile_size; j++) { + int ydim = tj * tile_size + j; + if (ydim < outh) { + float *pout_j = pout_b + ydim * outw * simd_w; + for (int i = 0; i < tile_size; i++) { + int xdim = ti * tile_size + i; + if (xdim < outw) { + float *pout_i = pout_j + xdim * simd_w; + if (is_fwd) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + O[j][i][v] += with_bias ? bias[v] : 0.f; + O[j][i][v] = true + && with_relu_presum && O[j][i][v] < 0.f + ? O[j][i][v] + * jcp.eltwise.alpha + : O[j][i][v]; + } + } + if (with_sum) + accum_output(pout_i, O[j][i], streamout, + with_relu_postsum); + else + store_output(pout_i, O[j][i], streamout); + } + } + } + } + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template +void diff_src_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv, + float *inp, float *tinp, float *Iw_temp, + void (*transpose_4fma_ker)(float *, float *)) +{ + + const int ifwp = conv.iw + conv.l_pad; + const int ifhp = conv.ih + conv.t_pad; + float I[alpha][alpha][simd_w]; + float Iw[alpha][alpha][simd_w]; + + array_offset_calculator Iw_trans_temp(Iw_temp, + alpha, alpha, conv.tile_4fma, simd_w); + array_offset_calculator input(inp, + conv.mb, conv.ic/simd_w, conv.ih, conv.iw, simd_w); + array_offset_calculator output(tinp, + conv.nb_ic, alpha, alpha, + conv.tile_block, conv.ic_block, + conv.nb_tile_block_ur, conv.tile_block_ur, + conv.ic_simd_block * conv.tile_4fma); + + int tile_base_index = + image * (conv.itiles * conv.jtiles + conv.tile_4fma_padding); + int tile_4fma = 0; + int tile_block_ur = (tile_base_index / conv.tile_4fma) % conv.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / conv.tile_4fma / conv.tile_block_ur) + % conv.nb_tile_block_ur; + int tile_block = (tile_base_index / conv.tile_4fma / conv.tile_block_ur) + / conv.nb_tile_block_ur; + + for (int tj = 0; tj < conv.jtiles; tj++) { + for (int ti = 0; ti < conv.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + int ydim = tj * tile_size + j; + if ((conv.t_pad <= ydim) && ydim < ifhp) { + for (int i = 0; i < alpha; i++) { + int xdim = ti * tile_size + i; + if ((conv.l_pad <= xdim) && xdim < ifwp) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = input(0, 0, + ydim - conv.t_pad, + xdim - conv.l_pad, v); + } + } else { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } else { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } + trans_I_4x4_3x3(Iw, I); + + if (ver_4fma) { + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + float *Iw_temp_base = &(Iw_trans_temp(j, i, + tile_4fma, 0)); + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + Iw_temp_base[v] = Iw[j][i][v]; + } + } + } + tile_4fma++; + if (tile_4fma == conv.tile_4fma) { + float *outp = &(output(0, 0, 0, + tile_block, 0, + nb_tile_block_ur, tile_block_ur, 0)); + transpose_4fma_ker(outp, (float *)Iw_temp); + tile_4fma = 0; + tile_block_ur++; + } + } else { + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + store_output(&(output(0, j, i, + tile_block, 0, + nb_tile_block_ur, tile_block_ur, 0)), + Iw[j][i], true); + } + } + tile_block_ur++; + } + + if (tile_block_ur == conv.tile_block_ur) { + tile_block_ur = 0; + ++nb_tile_block_ur; + } + if (nb_tile_block_ur == conv.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } + + if (ver_4fma && tile_4fma < conv.tile_4fma && conv.tile_4fma_padding != 0) { + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + for (int tb = tile_4fma; tb < conv.tile_4fma; tb++) { + float *Iw_temp_base = &(Iw_trans_temp(j, i, tb, 0)); + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + Iw_temp_base[v] = 0; + } + } + } + } + float *outp = &(output(0, 0, 0, + tile_block, 0, + nb_tile_block_ur, tile_block_ur, 0)); + transpose_4fma_ker(outp, (float *)Iw_temp); + } +} + +template +void diff_dst_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv, + float *inp, float *tinp, float *dbias) +{ + + const int total_tiles = conv.itiles * conv.jtiles + conv.tile_4fma_padding; + float I[alpha][alpha][simd_w]; + float Iw[alpha][alpha][simd_w]; + + array_offset_calculator input(inp, + conv.mb, conv.oc/simd_w, conv.oh, conv.ow, conv.oc_simd_block); + array_offset_calculator output(tinp, + conv.nb_oc, alpha, alpha, + conv.tile_block, conv.oc_block, + conv.nb_tile_block_ur, + conv.tile_block_ur * conv.tile_4fma, conv.oc_simd_block); + + int tile_base_index = image * total_tiles; + int tile_block_ur = tile_base_index % (conv.tile_block_ur * conv.tile_4fma); + int nb_tile_block_ur = + (tile_base_index / conv.tile_block_ur / conv.tile_4fma) + % conv.nb_tile_block_ur; + int tile_block = (tile_base_index / conv.tile_block_ur / conv.tile_4fma) + / conv.nb_tile_block_ur; + + for (int tj = 0; tj < conv.jtiles; tj++) { + for (int ti = 0; ti < conv.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + int ydim = tj * tile_size + j; + if (ydim < conv.oh) { + for (int i = 0; i < alpha; i++) { + int xdim = ti * tile_size + i; + if (xdim < conv.ow) { + float *input_base = &(input(0, 0, ydim, xdim, 0)); + + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = input_base[v]; + } + if (with_bias && j < tile_size && i < tile_size) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + dbias[v] += input_base[v]; + } + } + } else { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } else { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } + + trans_W_3x3_4x4_wu(Iw, I); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + store_output(&(output(0, j, i, + tile_block, 0, + nb_tile_block_ur, + tile_block_ur, 0)), + Iw[j][i], true); + } + } + tile_block_ur++; + if (tile_block_ur >= conv.tile_block_ur * conv.tile_4fma) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= conv.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +void diff_weights_transform_bwd_weights(jit_conv_winograd_conf_t conv, + float *wp, float *twp) +{ + const int kh = 3; + const int kw = 3; + float Fw[alpha][alpha][simd_w][simd_w]; + float F[kh][kw][simd_w][simd_w]; + + array_offset_calculator input(twp, + conv.nb_ic, conv.nb_oc, + alpha, alpha, + conv.oc_block, conv.ic_block, + conv.ic_simd_block, conv.oc_simd_block); + array_offset_calculator output(wp, + conv.oc/simd_w, conv.ic/simd_w, + conv.kh, conv.kw, + conv.ic_simd_block, conv.oc_simd_block); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + for (int v = 0; v < conv.ic_simd_block; v++) { + PRAGMA_OMP_SIMD() + for (int k = 0; k < conv.oc_simd_block; k++) { + Fw[j][i][v][k] = input(0, 0, j, i, 0, 0, v, k); + } + } + } + } + + trans_O_3x3_4x4_wu(Fw, F); + + for (int j = 0; j < kh; j++) { + for (int i = 0; i < kw; i++) { + for (int v = 0; v < conv.ic_simd_block; v++) { + store_output(&(output(0, 0, j, i, v, 0)), + F[j][i][v], true); + } + } + } +} + +template +void _jit_avx512_common_convolution_winograd_t::_execute_data_W_S_G_D( + float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const auto &p_ops = attr_->post_ops_; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int outh = is_fwd ? jcp.oh : jcp.ih; + const int outw = is_fwd ? jcp.ow : jcp.iw; + + /* Note that jcp.with_eltwise is true for both fused conv+relu primitive + * and conv primitive with PostOps with relu before sum + * (PostOps relu after sum is handled later) */ + auto output_transform = jcp.with_bias + ? (jcp.with_eltwise + ? (jcp.with_sum + ? output_transform_data + : output_transform_data) + : (jcp.with_sum + ? output_transform_data + : output_transform_data)) + : (jcp.with_eltwise + ? (jcp.with_sum + ? output_transform_data + : output_transform_data) + : (jcp.with_sum + ? output_transform_data + : output_transform_data)); + + /* Notation: + FWD: dimM:oc, dimN:ntiles, dimK:ic, + BWD: dimM:ic, dimN:ntiles, dimK:oc, + FWD/BWD: V: src/diff_dst transform, U:weight transform, + M:dst/diff_src transform */ + array_offset_calculator input(inp_ptr, + jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, + jcp.dimK_reg_block); + array_offset_calculator output(out_ptr, + jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, + jcp.dimM_simd_block); + array_offset_calculator weights(wei_ptr, + jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, + jcp.ic_simd_block, jcp.oc_simd_block); + array_offset_calculator bias(bias_ptr, + jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block); + + array_offset_calculator M(is_fwd + ? scratchpad.template get(key_wino_M) + : scratchpad.template get(key_wino_V), + jcp.dimN_nb_block, jcp.dimM_nb_block, + alpha, alpha, + jcp.dimN_block, jcp.dimM_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + array_offset_calculator U( + scratchpad.template get(key_wino_U), + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimM_simd_block); + array_offset_calculator V(is_fwd + ? scratchpad.template get(key_wino_V) + : scratchpad.template get(key_wino_M), + jcp.dimN_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block); + + bool V_streamout = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float) + > 2 * LLC_cache_size ? true : false; + + const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0; + + const bool wants_padded_bias = jcp.with_bias + && jcp.oc_without_padding != jcp.oc; + float last_slice_bias[simd_w] = {0}; + if (wants_padded_bias) { + for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) + last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); + } + + { + parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block, + [&](int img, int K_blk1, int K_blk2) { + input_transform_data(img, jcp, + &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)), + &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)), V_streamout); + }); + + parallel_nd(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block, + [&](int ofm1, int ifm1, int ofm2, int ifm2) { + float *U_base_ptr = is_fwd + ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) + : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); + weight_transform_data(jcp, + &(weights(ofm1 * jcp.oc_block + ofm2, + ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), U_base_ptr); + }); + + parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, jcp.dimN_block, + [&](int N_blk1, int oj, int oi, int M_blk1, int N_blk2) { + + kernel_->gemm_loop_ker_first_iter( + (float *)&(M(N_blk1, M_blk1, oj, oi, + N_blk2, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, + 0, 0, 0, 0, 0)), + (const float *)&(V(N_blk1, oj, oi, + N_blk2, 0, 0, 0, 0))); + for (int K_blk1 = 1; K_blk1 < jcp.dimK_nb_block; K_blk1++) { + kernel_->gemm_loop_ker( + (float *)&(M(N_blk1, M_blk1, oj, oi, + N_blk2, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, + K_blk1, 0, 0, 0, 0)), + (const float *)&(V(N_blk1, oj, oi, + N_blk2, K_blk1, + 0, 0, 0))); + } + + }); + + parallel_nd(jcp.mb, jcp.dimM_nb_block, jcp.dimM_block, + [&](int img, int M_blk1, int M_blk2) { + + const int M_blk = M_blk1 * jcp.dimM_block + M_blk2; + + float *bias_ptr = wants_padded_bias + && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 + ? last_slice_bias : &bias(M_blk, 0); + + output_transform(img, jcp, p_ops, + &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)), + &(output(img, M_blk, 0, 0, 0)), + bias_ptr, output_is_aligned); + + }); + + } +} + +template struct _jit_avx512_common_convolution_winograd_t; +template struct _jit_avx512_common_convolution_winograd_t; + +void jit_avx512_common_convolution_winograd_bwd_weights_t:: +_maybe_execute_diff_bias_copy(float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const { + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get(key_conv_padded_bias); + for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc) + diff_bias[oc] = padded_bias[oc]; + } +} + +void jit_avx512_common_convolution_winograd_bwd_weights_t:: +_execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx, + const memory_tracking::grantor_t &scratchpad) const { + auto ptr_diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto ptr_src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto ptr_diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS); + auto ptr_diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); + + const auto &jcp = kernel_->jcp; + const int nthreads = jcp.nthr; + + auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ? + diff_src_transform_bwd_weights : + diff_src_transform_bwd_weights; + auto diff_dst_transform_bwd_weights_ver = jcp.with_bias + ? diff_dst_transform_bwd_weights + : diff_dst_transform_bwd_weights; + + array_offset_calculator src((float *)ptr_src, + jcp.mb, jcp.ic/simd_w, jcp.ih, jcp.iw, simd_w); + array_offset_calculator diff_dst((float *)ptr_diff_dst, + jcp.mb, jcp.oc/simd_w, jcp.oh, jcp.ow, simd_w); + array_offset_calculator diff_weights(ptr_diff_weights, + jcp.oc/simd_w, jcp.ic/simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + array_offset_calculator diff_bias(pd()->wants_padded_bias() + ? scratchpad.get(key_conv_padded_bias) : ptr_diff_bias, + jcp.oc/simd_w, simd_w); + + array_offset_calculator U( + scratchpad.get(key_wino_U), + jcp.nb_ic, jcp.nb_oc, + alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, jcp.oc_simd_block); + + array_offset_calculator M( + scratchpad.get(key_wino_M), + jcp.nb_oc, alpha, alpha, + jcp.tile_block, jcp.oc_block, + jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma, + jcp.oc_simd_block); + array_offset_calculator V( + scratchpad.get(key_wino_V), + jcp.nb_ic, alpha, alpha, + jcp.tile_block, jcp.ic_block, + jcp.nb_tile_block_ur, jcp.tile_block_ur, + jcp.ic_simd_block * jcp.tile_4fma); + + const int trans_buffer_size = alpha * alpha * jcp.tile_4fma + * jcp.ic_simd_block; + array_offset_calculator trans_buffer( + scratchpad.get(key_conv_tr_src), + nthreads, + trans_buffer_size); + + array_offset_calculator diff_bias_prv( + scratchpad.get(key_conv_bia_reduction), + nthreads, + jcp.oc); + +PRAGMA_OMP(parallel num_threads(nthreads)) + { + if (jcp.with_bias) { + parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) { + diff_bias_prv(ithr, ofm) = 0.0f; + }); + +PRAGMA_OMP(for nowait) + for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + diff_bias(bofm, v) = 0.0f; + } + } + + const int ithread = mkldnn_get_thread_num(); + + parallel_nd_in_omp(jcp.mb, jcp.nb_ic, jcp.ic_block, + [&](int img, int ifm1, int ifm2) { + float *transb = jcp.ver == ver_4fma + ? &(trans_buffer(ithread, 0)) + : NULL; + diff_src_transform_bwd_weights_ver(img, jcp, + &(src(img, ifm1 * jcp.ic_block + ifm2, + 0, 0, 0)), + &(V(ifm1, 0, 0, 0, ifm2, 0, 0, 0)), + transb, + kernel_->transpose_4fma_ker); + }); + + parallel_nd_in_omp(jcp.mb, jcp.nb_oc, jcp.oc_block, + [&](int img, int ofm1, int ofm2) { + float *dbias = jcp.with_bias + ? &(diff_bias_prv(ithread, + simd_w * (ofm1 * jcp.oc_block + ofm2))) + : NULL; + diff_dst_transform_bwd_weights_ver(img, jcp, + &(diff_dst(img, ofm1 * jcp.oc_block + ofm2, + 0, 0, 0)), + &(M(ofm1, 0, 0, 0, ofm2, 0, 0, 0)), + dbias); + }); + +PRAGMA_OMP(barrier) + + for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) { + parallel_nd_in_omp(alpha, alpha, jcp.nb_oc, + [&](int oj, int oi, int ofm1) { + kernel_->gemm_loop_ker_first_iter( + (float *)&(U(ifm1, ofm1, oj, oi, + 0, 0, 0, 0)), + (const float *)&(M(ofm1, oj, oi, + 0, 0, 0, 0, 0)), + (const float *)&(V(ifm1, oj, oi, + 0, 0, 0, 0, 0))); + for (int tile_block = 1; tile_block < jcp.tile_block; + tile_block++) { + kernel_->gemm_loop_ker((float *)&(U(ifm1, ofm1, + oj, oi, + 0, 0, 0, 0)), + (const float *)&(M(ofm1, oj, oi, tile_block, + 0, 0, 0, 0)), + (const float *)&(V(ifm1, oj, oi, tile_block, + 0, 0, 0, 0))); + } + }); + } + +PRAGMA_OMP(barrier) + + parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, + [&](int ifm1, int ofm1, int ofm2, int ifm2) { + diff_weights_transform_bwd_weights(jcp, + &(diff_weights(ofm1 * jcp.oc_block + ofm2, + ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), + &(U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, 0))); + }); + + if (jcp.with_bias) { +PRAGMA_OMP(for) + for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) { + for (int ithr = 0; ithr < nthreads; ithr++) { + float* base_bias_ptr = &(diff_bias(ofm1, 0)); + float* base_bias_prv_ptr = &(diff_bias_prv( + ithr * jcp.oc + ofm1 * simd_w)); + PRAGMA_OMP_SIMD() + for (int ofm2 = 0; ofm2 < simd_w; ofm2++) { + base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2]; + } + } + } + } + } + + _maybe_execute_diff_bias_copy(ptr_diff_bias, scratchpad); +} + +} +} +} +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp new file mode 100644 index 000000000000..6c76f37c727d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp @@ -0,0 +1,318 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP +#define CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace winograd_avx512_common { +inline void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_winograd_conf_t &jcp) { + using namespace memory_tracking::names; + + size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc; + size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic + * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); + size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc + * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); + + scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M); + scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M); + scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M); + + if (jcp.sched_policy == WSCHED_WEI_S_D_G_W) { + const int nthr = mkldnn_get_max_threads(); + + size_t tr_src_sz = jcp.ver != ver_4fma ? 0 : (size_t)nthr + * alpha * alpha * jcp.tile_4fma * jcp.ic_simd_block; + scratchpad.book(key_conv_tr_src, sizeof(float) * tr_src_sz, PAGE_2M); + + size_t br_sz = jcp.with_bias ? nthr * jcp.oc : 0; + scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M); + + size_t padded_bias_sz = + jcp.with_bias && jcp.oc_without_padding != jcp.oc ? jcp.oc : 0; + scratchpad.book(key_conv_padded_bias, sizeof(float) * padded_bias_sz); + } +} +} + +template +struct _jit_avx512_common_convolution_winograd_t { + _jit_avx512_common_convolution_winograd_t( + const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr) + : kernel_(nullptr), attr_(attr) { + kernel_ = new _jit_avx512_common_conv_winograd_data_kernel_f32(jcp); + } + + ~_jit_avx512_common_convolution_winograd_t() { delete kernel_; } + + protected: + void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, + float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const; + _jit_avx512_common_conv_winograd_data_kernel_f32 *kernel_; + const primitive_attr_t *attr_; +}; + +struct jit_avx512_common_convolution_winograd_fwd_t + : _jit_avx512_common_convolution_winograd_t + , public cpu_primitive_t + { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), + jit_avx512_common_convolution_winograd_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_common_conv_winograd_fwd_kernel_f32:: + init_conf(jcp_, *desc(), *src_md(), *weights_md(), *dst_md(), + *attr()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_common::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_tag, nChw16c); + } + }; + + jit_avx512_common_convolution_winograd_fwd_t(const pd_t *apd) + : _jit_avx512_common_convolution_winograd_t(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) {} + + ~jit_avx512_common_convolution_winograd_fwd_t(){}; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override + { + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); + this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, + (float *)bias, this->scratchpad(ctx)); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_common_convolution_winograd_bwd_data_t + : _jit_avx512_common_convolution_winograd_t, + public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), + jit_avx512_common_convolution_winograd_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && !has_zero_dim_memory() + && set_default_formats() + && mkldnn_thr_syncable(); + if (!ok) return status::unimplemented; + + status_t status = + jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf( + jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_common::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_tag, nChw16c); + } + }; + + jit_avx512_common_convolution_winograd_bwd_data_t(const pd_t *apd) + : _jit_avx512_common_convolution_winograd_t(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) {} + + ~jit_avx512_common_convolution_winograd_bwd_data_t(){}; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC); + this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, + (float *)weights, nullptr, this->scratchpad(ctx)); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_common_convolution_winograd_bwd_weights_t + : public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, + hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), + jit_avx512_common_convolution_winograd_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats() + && mkldnn_thr_syncable(); + if (!ok) return status::unimplemented; + + status_t status = + jit_avx512_common_conv_winograd_bwd_weights_kernel_f32:: + init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(), + *diff_weights_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_common::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_tag, nChw16c); + } + }; + + jit_avx512_common_convolution_winograd_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd, true), kernel_(nullptr) + { + kernel_ = new jit_avx512_common_conv_winograd_bwd_weights_kernel_f32( + pd()->jcp_); + } + + ~jit_avx512_common_convolution_winograd_bwd_weights_t() + { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override + { + _execute_backward_weights_S_D_G_W(ctx, scratchpad(ctx)); + return status::success; + } + +private: + void _execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx, + const memory_tracking::grantor_t &scratchpad) const; + void _maybe_execute_diff_bias_copy(float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 *kernel_; +}; + +void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]); +void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]); +void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]); +void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]); +void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]); +void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]); +void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]); + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp new file mode 100644 index 000000000000..d4a451c02134 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp @@ -0,0 +1,853 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_common_lrn.hpp" + +#include "jit_generator.hpp" + +#define FWD_RBC 4 +#define BWD_RBC 3 + +#define XMM_SIZE (4*sizeof(float)) +#define ZMM_SIZE (vlen) +#define BUFFER_BLOCK (XMM_SIZE + ZMM_SIZE + XMM_SIZE) +#define BUFFER_NEXT_OFFSET (XMM_SIZE + ZMM_SIZE) +#define SRC_PREV_OFFSET (vlen - XMM_SIZE) + +#define IRB_LOOP(statement) for(int irb = 0; irb < loop_size; irb++) { \ + statement;\ +} + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +enum params { vsize = 16, vlen = 64}; + +typedef struct { + const float *src; + float *dst, *ws0, *ws1; +} jit_args_fwd_t; + +typedef struct { + const float *src, *diff_dst, *ws0, *ws1; + float *diff_src; +} jit_args_bwd_t; + +struct nChw16c_across { +/* version: + * -1: channels 0..15, + * 1: channels C-16 .. C-1, + * 0: other channels + * 3: channels only for this kernel(without prev and next) + */ + int H, W, version; + nChw16c_across(int h, int w, int v) : H(h), W(w), version(v) {} +}; + +struct jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_kernel_f32: + public jit_generator { + int HW, W; + bool is_first; + bool is_last; + bool is_single; + + Reg64 src = rax; + Reg64 dst = r8; + Reg64 scratch0 = rdx; + Reg64 scratch1 = rsi; + Reg64 imm_addr64 = rbx; + + Zmm zalpha = zmm0; + Xmm xalpha = xmm0; + Zmm zk = zmm1; + Xmm xk = xmm1; + + Reg64 param = abi_param1; + Reg64 t = rsp; + Reg64 hw = r9; + + int xsrc_prev = 2; + int zsrc = 7; + int xsrc_next = 3; + int zc = 7; + + int za = 2; + int zb = 3; + int zd = 5; + int ze = 6; + int zsum = 4; + int zdst = 2; + int zbase = 3; + int zsum2 = 5; + + prop_kind_t pk; + int use_h_parallelism; + + float alpha, k; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32) + + void (*ker)(jit_args_fwd_t *); + void operator()(jit_args_fwd_t *arg) { ker(arg); } + + enum { + prf0_offt = 1*FWD_RBC, + prf2_offt = 8*FWD_RBC + }; + + inline void compute_loop(int loop_size_param) + { + // loop_size - param for IRB_LOOP macro + int loop_size = FWD_RBC; + + auto xreg = [=](int irb, int i) { + return Xmm(irb*3 + i); + }; + + auto zreg = [=](int irb, int i) { + return Zmm(irb*7 + i); + }; + + if (!is_first && !is_single) { + IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt - HW)*vlen])); + IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt - HW)*vlen])); + } + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(src, (irb + prf0_offt)*vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(src, (irb + prf2_offt)*vlen))); + if (!is_last && !is_single) { + IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt + HW)*vlen])); + IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt + HW)*vlen])); + } + if (pk != prop_kind::forward_inference) { + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch0, + (irb + prf0_offt)*vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch0, + (irb + prf2_offt)*vlen))); + } + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(dst, (irb + prf0_offt)*vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(dst, (irb + prf2_offt)*vlen))); + if (pk != prop_kind::forward_inference) { + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch1, + (irb + prf0_offt) * vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch1, + (irb + prf2_offt) * vlen))); + } + + loop_size = loop_size_param; + if (loop_size == 0) + return; + if (!is_first && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xsrc_prev), + ptr[src + (irb - HW) * vlen + SRC_PREV_OFFSET])); + } + IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src,irb*vlen))); + if (!is_last && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xsrc_next), + ptr[src + (irb + HW) * vlen])); + } + + if (!is_first && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK], + xreg(irb, xsrc_prev))); + } + IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE), + zreg(irb, zsrc))); + if (!is_last && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], + xreg(irb, xsrc_next))); + } + + IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - 2*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + 2*sizeof(float)))); + + assert(zc == zsrc); + IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zc), zreg(irb, zc))); + + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, za), zreg(irb, za))); + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zb), zreg(irb, zb))); + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zd), zreg(irb, zd))); + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, ze), zreg(irb, ze))); + + IRB_LOOP(vfmadd132ps(zreg(irb, zsum), zk, zalpha)); + + IRB_LOOP(vmovaps(zreg(irb, zbase), zreg(irb, zsum))); + + IRB_LOOP(vmulps(zreg(irb, zsum2), zreg(irb, zsum), zreg(irb, zsum))); + IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zsum), zreg(irb, zsum2))); + + IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum))); + IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum))); + + if (pk != prop_kind::forward_inference) { + IRB_LOOP(vmovups(EVEX_compress_addr(scratch0, irb*vlen), + zreg(irb, zsum))); + } + IRB_LOOP(vdivps(zreg(irb, zdst), zreg(irb, zsrc), zreg(irb, zsum))); + IRB_LOOP(vmovups(EVEX_compress_addr(dst, irb*vlen), zreg(irb, zdst))); + if (pk != prop_kind::forward_inference) { + /* ws1 = zdst / zbase = zsrc / (zbase^1.75) */ + IRB_LOOP(vdivps(zreg(irb, zsum), zreg(irb, zdst), zreg(irb, zbase))); + IRB_LOOP(vmovups(EVEX_compress_addr(scratch1, irb*vlen), + zreg(irb, zsum))); + } + } + + jit_avx512_common_lrn_kernel_f32( + const struct nChw16c_across &J, + prop_kind_t prop_kind, + int use_h_parallel, + float A, + float K, + void *code_ptr = nullptr, + size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + , pk(prop_kind) + , use_h_parallelism(use_h_parallel) + , alpha(A) + , k(K) + { + this->preamble(); + + mov(src, ptr[param + 0]); + mov(dst, ptr[param + 8]); + if (pk != prop_kind::forward_inference) + { + mov(scratch0, ptr[param + 16]); + mov(scratch1, ptr[param + 24]); + } + is_first = J.version == -1 || J.version == -2; + is_last = J.version == +1 || J.version == -2; + is_single = J.version == 3; + + W = J.W; + HW = J.W*J.H; + int LSB = use_h_parallelism ? W : HW; + + sub(t, FWD_RBC*BUFFER_BLOCK); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(zalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(zk, xk); + + if (is_first || is_single) { + vxorps(xmm2, xmm2, xmm2); + for(int irb = 0; irb < FWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK], xmm2); + } + } + if (is_last || is_single) { + vxorps(xmm2, xmm2, xmm2); + for(int irb = 0; irb < FWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], + xmm2); + } + } + + int LSREST = LSB % FWD_RBC; + int LS = LSB - LSREST; + + Label lrn_loop; + + if (LS > 0) { + mov(hw, LS); + + L(lrn_loop); + { + compute_loop(FWD_RBC); + + add(src, FWD_RBC*vlen); + add(dst, FWD_RBC*vlen); + if (pk != prop_kind::forward_inference) + { + add(scratch0, FWD_RBC*vlen); + add(scratch1, FWD_RBC*vlen); + } + + for(int irb = 0; irb < FWD_RBC; irb++) + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + } + } + + compute_loop(LSREST); + + add(t, FWD_RBC*BUFFER_BLOCK); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); + } +}; + +status_t jit_avx512_common_lrn_fwd_t::pd_t::init() { + using namespace prop_kind; + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(avx512_common) + && is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type::f32, data_d.data_type()) + && data_d.ndims() == 4 + && data_d.dims()[1] % vsize == 0 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + if (desc()->prop_kind == forward_training) { + dims_t ws_dims = { MB(), C(), H(), 2*W() }; + mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32, + format_tag::nChw16c); + } + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && desc()->lrn_beta == 0.75 + && data_d.matches_tag(format_tag::nChw16c); + + return args_ok_across ? success : unimplemented; +} + +jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr) + , ker_last_(nullptr) { + using namespace alg_kind; + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + const float alpha = pd()->desc()->lrn_alpha / ls; + const float k = pd()->desc()->lrn_k; + + auto pk = pd()->desc()->prop_kind; + + use_h_parallelism = H > 28 ? 1 : 0; + + if (C / vsize == 1) { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), pk, + use_h_parallelism, alpha, k); + } else { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), pk, + use_h_parallelism, alpha, k); + ker_first_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, -1), pk, use_h_parallelism, alpha, k); + ker_last_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, +1), pk, use_h_parallelism, alpha, k); + } +} + +jit_avx512_common_lrn_fwd_t::~jit_avx512_common_lrn_fwd_t() +{ delete ker_; delete ker_first_; delete ker_last_; } + +void jit_avx512_common_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const +{ + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + const int C16 = C / vsize; + const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16; + + balance211(work_amount, nthr, ithr, start, end); + if (use_h_parallelism) { + int n{0}, c16{0}, h{0}; + nd_iterator_init(start, n, N, c16, C16, h, H); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize + + h*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize + + h*2*W*vsize; + auto ws_offset1 = ws_offset0 + W*vsize; + + jit_args_fwd_t args; + args.src = &src[offset]; + args.dst = &dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + nd_iterator_step(n, N, c16, C16, h, H); + } + } else { + int n{0}, c16{0}; + nd_iterator_init(start, n, N, c16, C16); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize; + auto ws_offset1 = ws_offset0 + H*W*vsize; + + jit_args_fwd_t args; + args.src = &src[offset]; + args.dst = &dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + + nd_iterator_step(n, N, c16, C16); + } + } + }); +} + +struct jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_kernel_f32: + public jit_generator { + int HW, W; + bool is_first; + bool is_last; + bool is_single; + + Reg64 src = rax; + Reg64 diffsrc = r8; + Reg64 diffdst = r9; + Reg64 workspace0 = rdx; + Reg64 workspace1 = rsi; + Reg64 imm_addr64 = rbx; + + Zmm znalphabeta = zmm0; + Xmm xnalphabeta = xmm0; + + Reg64 param = abi_param1; + Reg64 t = rsp; + Reg64 hw = r10; + + int xws1_prev = 1; + int xdiffdst_prev = 2; + int zws1 = 1; + + int zsrc = 1; + int zdiffdst = 5; + int zdiffsrc = 6; + + int xws1_next = 1; + int xdiffdst_next = 3; + + int za = 1; + int zb = 2; + int zd = 3; + int ze = 4; + int zws0 = 2; + + float nalphabeta; + + int use_h_parallelism; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32) + + void (*ker)(jit_args_bwd_t *); + void operator()(jit_args_bwd_t *arg) { ker(arg); } + + enum { + prf0_offt = 1*BWD_RBC, + prf2_offt = 8*BWD_RBC + }; + + inline void compute_loop(int loop_size_param, int prefetchL1, + int prefetchL2) + { + // loop_size - param for IRB_LOOP macro + int loop_size = loop_size_param; + + auto xreg = [=](int irb, int i) { + return Xmm(irb*6 + i); + }; + + auto zreg = [=](int irb, int i) { + return Zmm(irb*6 + i); + }; + +// ---- prefetching ------------------------------------------- + if (!is_first && !is_single) { + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt + - 2 * HW) * vlen])); + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt + - HW) * vlen])); + } + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt)*vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt)*vlen])); + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt)*vlen])); + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt)*vlen])); + + if (!is_last && !is_single) { + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt + + 2 * HW) * vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[workspace1 + (irb + prf2_offt + + 2 * HW) * vlen])); + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt + + HW) * vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[diffdst + (irb + prf2_offt + + HW) * vlen])); + } + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace0 + (irb + prf0_offt)*vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[workspace0 + (irb + prf2_offt)*vlen])); +// ----------------------------------------------------------- + + if (loop_size_param == 0) + return; + + if (!is_first && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xws1_prev), ptr[workspace1 + (irb + - 2 * HW) * vlen + SRC_PREV_OFFSET])); + IRB_LOOP(vmovups(xreg(irb, xdiffdst_prev), ptr[diffdst + (irb + - HW) * vlen + SRC_PREV_OFFSET])); + IRB_LOOP(vmulps(xreg(irb, xdiffdst_prev), xreg(irb, xdiffdst_prev), + xreg(irb, xws1_prev))); + } + + IRB_LOOP(vmovups(zreg(irb, zws1), + EVEX_compress_addr(workspace1, irb*vlen))); + IRB_LOOP(vmovups(zreg(irb, zdiffdst), + EVEX_compress_addr(diffdst, irb*vlen))); + IRB_LOOP(vmulps(zreg(irb, zdiffsrc), zreg(irb, zdiffdst), + zreg(irb, zws1))); + + if (!is_last && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xws1_next), ptr[workspace1 + (irb + + 2 * HW) * vlen])); + IRB_LOOP(vmovups(xreg(irb, xdiffdst_next), ptr[diffdst + (irb + + HW) * vlen])); + IRB_LOOP(vmulps(xreg(irb, xdiffdst_next), xreg(irb, xdiffdst_next), + xreg(irb, xws1_next))); + } + + if (!is_first && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK], + xreg(irb, xdiffdst_prev))); + } + IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE), + zreg(irb, zdiffsrc))); + if (!is_last && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], + xreg(irb, xdiffdst_next))); + } + + IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - 2*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - 1*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + 1*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + 2*sizeof(float)))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, za))); + assert(zsrc == za); + IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src, irb*vlen))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, zb))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, zd))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, ze))); + IRB_LOOP(vmulps(zreg(irb, zsrc), zreg(irb, zsrc), znalphabeta)); + + IRB_LOOP(vmovups(zreg(irb, zws0), + EVEX_compress_addr(workspace0, irb*vlen))); + IRB_LOOP(vdivps(zreg(irb, zdiffdst), zreg(irb, zdiffdst), + zreg(irb, zws0))); + IRB_LOOP(vfmadd213ps(zreg(irb, zdiffsrc), zreg(irb, zsrc), + zreg(irb, zdiffdst))); + + Label unaligned_store, end_store; + test(diffsrc, vlen - 1); + jnz(unaligned_store, T_NEAR); + IRB_LOOP(uni_vmovntps(EVEX_compress_addr(diffsrc, irb*vlen), + zreg(irb, zdiffsrc))); + jmp(end_store, T_NEAR); + L(unaligned_store); { + IRB_LOOP(uni_vmovups(EVEX_compress_addr(diffsrc, irb*vlen), + zreg(irb, zdiffsrc))); + } + L(end_store); + } + + jit_avx512_common_lrn_kernel_f32( + const struct nChw16c_across &J, + float A, + float B, + int use_h_parallel, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + , nalphabeta(-2*A*B) + , use_h_parallelism(use_h_parallel) + { + this->preamble(); + + mov(src, ptr[param + 0]); + mov(diffdst, ptr[param + 8]); + mov(workspace0, ptr[param + 16]); + mov(workspace1, ptr[param + 24]); + mov(diffsrc, ptr[param + 32]); + + W = J.W; + HW = J.H*J.W; + int LSB = this->use_h_parallelism ? W : HW; + + sub(t, BWD_RBC*BUFFER_BLOCK); + mov(imm_addr64, float2int(this->nalphabeta)); + movq(xnalphabeta, imm_addr64); + vbroadcastss(znalphabeta, xnalphabeta); + + is_first = J.version == -1 || J.version == -2; + is_last = J.version == +1 || J.version == +2; + is_single = J.version == 3; + + if (is_first || is_single) { + vxorps(xmm1, xmm1, xmm1); + for(int irb = 0; irb < BWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK], xmm1); + } + } + if (is_last || is_single) { + vxorps(xmm1, xmm1, xmm1); + for(int irb = 0; irb < BWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], xmm1); + } + } + + int LSREST = LSB % BWD_RBC; + int LS = LSB - LSREST; + + Label lrn_loop; + + if (LS > 0) { + mov(hw, LS); + + L(lrn_loop); + { + compute_loop(BWD_RBC, 1, 1); + + add(src, BWD_RBC*vlen); + add(diffsrc, BWD_RBC*vlen); + add(diffdst, BWD_RBC*vlen); + add(workspace0, BWD_RBC*vlen); + add(workspace1, BWD_RBC*vlen); + + for(int irb = 0; irb < BWD_RBC; irb++) + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + } + } + + compute_loop(LSREST, 1, this->use_h_parallelism ? 0 : 1); + + add(t, BWD_RBC*BUFFER_BLOCK); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); + } + +}; + +status_t jit_avx512_common_lrn_bwd_t::pd_t::init() { + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(avx512_common) + && !is_fwd() + && utils::everyone_is(data_type::f32, data_d.data_type()) + && !has_zero_dim_memory() + && data_d.ndims() == 4 + && data_d.dims()[1] % vsize == 0 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + dims_t ws_dims = { MB(), C(), H(), 2*W() }; + mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32, + format_tag::nChw16c); + + if (!compare_ws(hint_fwd_pd_)) return unimplemented; + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && desc()->lrn_beta == 0.75 + && data_d.matches_tag(format_tag::nChw16c); + + return args_ok_across ? success : unimplemented; +} + +jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_bwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr) + , ker_last_(nullptr) { + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + const float alpha = pd()->desc()->lrn_alpha / ls; + const float beta = pd()->desc()->lrn_beta; + + use_h_parallelism = H > 28 ? 1 : 0; + + if (C / vsize == 1) { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), + alpha, beta, use_h_parallelism); + } else { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), + alpha, beta, use_h_parallelism); + ker_first_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, -1), alpha, beta, use_h_parallelism); + ker_last_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, +1), alpha, beta, use_h_parallelism); + } +} + +jit_avx512_common_lrn_bwd_t::~jit_avx512_common_lrn_bwd_t() +{ delete ker_; delete ker_first_; delete ker_last_; } + +void jit_avx512_common_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const +{ + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + const int C16 = C / vsize; + const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16; + + balance211(work_amount, nthr, ithr, start, end); + if (use_h_parallelism) { + int n{0}, c16{0}, h{0}; + nd_iterator_init(start, n, N, h, H, c16, C16); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize + + h*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize + + h*2*W*vsize; + auto ws_offset1 = ws_offset0 + W*vsize; + + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + args.diff_src = &diff_src[offset]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + nd_iterator_step(n, N, h, H, c16, C16); + } + } else { + int n{0}, c16{0}; + nd_iterator_init(start, n, N, c16, C16); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize; + auto ws_offset1 = ws_offset0 + H*W*vsize; + + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + args.diff_src = &diff_src[offset]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + + nd_iterator_step(n, N, c16, C16); + } + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp new file mode 100644 index 000000000000..37fbb9b3e54e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp @@ -0,0 +1,96 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_LRN_HPP +#define CPU_JIT_AVX512_COMMON_LRN_HPP + +#include "c_types_map.hpp" + +#include "cpu_isa_traits.hpp" +#include "cpu_lrn_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_common_lrn_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_fwd_pd_t { + using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_lrn_fwd_t); + + status_t init(); + }; + + jit_avx512_common_lrn_fwd_t(const pd_t *apd); + ~jit_avx512_common_lrn_fwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + int use_h_parallelism; + struct jit_avx512_common_lrn_kernel_f32; + jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +struct jit_avx512_common_lrn_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_bwd_pd_t { + using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_lrn_bwd_t); + + status_t init(); + }; + + jit_avx512_common_lrn_bwd_t(const pd_t *apd); + ~jit_avx512_common_lrn_bwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + int use_h_parallelism; + struct jit_avx512_common_lrn_kernel_f32; + jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp new file mode 100644 index 000000000000..c58d3fa0a68b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp @@ -0,0 +1,1103 @@ +/******************************************************************************* + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_fp32_wino_conv_2x3.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_kind; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +/// SRC TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + + struct call_params_t { + const void *src; + const void *wino_src; + const void *v_y_masks; + const void *v_x_masks; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp) { + generate(); + ker_ = + reinterpret_cast(const_cast(getCode())); + } + + void generate(); + + Zmm vreg_inp(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + + Zmm vreg_tmp(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Zmm(15 - i); + } + + Zmm vreg_out(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert (id < 4); + return Opmask(3 + id); + } + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_ic_block = r8; + +}; + +void jit_avx512_core_fp32_wino_conv_2x3_src_trans_t::generate() { + Label ic_block_label; + + const int load_block = 16; + int out_offset = 0, inp_offset = 0; + preamble(); + +#define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_aux_ptr_src, src); + READ_PARAM(reg_aux_ptr_dst, wino_src); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); +#undef READ_PARAM + + for (int i = 0; i < jcp.alpha; i++) { + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]); + } + mov(reg_ic_block, jcp.ic / load_block); + L(ic_block_label); + { + for (int y = 0; y < jcp.alpha; y++) { + kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]); + for (int x = 0; x < jcp.alpha; x++) { + Zmm zmm = vreg_inp(y * jcp.alpha + x); + + vxorps(zmm, zmm, zmm); + kandw(r_mask, y_mask, x_mask(x)); + inp_offset = sizeof(float) + * ((-jcp.t_pad + y) * jcp.iw * load_block + + (-jcp.l_pad + x) * load_block); + vmovups(zmm | r_mask, + EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + } + } + for (int y = 0; y < jcp.alpha; y++) { + vsubps(vreg_tmp(y * jcp.alpha + 0), vreg_inp(y * jcp.alpha + 0), + vreg_inp(y * jcp.alpha + 2)); + vaddps(vreg_tmp(y * jcp.alpha + 1), vreg_inp(y * jcp.alpha + 1), + vreg_inp(y * jcp.alpha + 2)); + vsubps(vreg_tmp(y * jcp.alpha + 2), vreg_inp(y * jcp.alpha + 2), + vreg_inp(y * jcp.alpha + 1)); + vsubps(vreg_tmp(y * jcp.alpha + 3), vreg_inp(y * jcp.alpha + 1), + vreg_inp(y * jcp.alpha + 3)); + } + for (int x = 0; x < jcp.alpha; x++) { + vsubps(vreg_out(x + 0 * jcp.alpha), vreg_tmp(x + jcp.alpha * 0), + vreg_tmp(x + jcp.alpha * 2)); + vaddps(vreg_out(x + 1 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1), + vreg_tmp(x + jcp.alpha * 2)); + vsubps(vreg_out(x + 2 * jcp.alpha), vreg_tmp(x + jcp.alpha * 2), + vreg_tmp(x + jcp.alpha * 1)); + vsubps(vreg_out(x + 3 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1), + vreg_tmp(x + jcp.alpha * 3)); + } + + for (int i = 0; i < 16; i++) { + out_offset = sizeof(float) * (jcp.inp_stride * i); + vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), + vreg_out(i)); + } + + add(reg_aux_ptr_src, sizeof(float) * jcp.ih * jcp.iw * load_block); + add(reg_aux_ptr_dst, sizeof(float) * load_block); + } + dec(reg_ic_block); + cmp(reg_ic_block, 0); + jg(ic_block_label, T_NEAR); + postamble(); +} + +/// DST TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *wino_dst; + const void *dst; + const void *v_y_masks; + const void *v_x_masks; + + const void *bias; + const void *scales; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) { + generate(); + ker_ = reinterpret_cast( + const_cast(getCode())); + } + + void generate(); + bool maybe_relu(int position); + + Zmm vreg_inp(int i) { // 16 + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + + Zmm vreg_stg(int id) { // 8 + const int id_reg_stg = jcp.alpha * jcp.alpha + id; + assert(id_reg_stg < jcp.alpha * jcp.alpha + 8); + return Zmm(31 - id_reg_stg); + } + + Zmm vreg_out(int id) { // 4 + const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; + assert(id_reg_out < jcp.alpha * jcp.alpha + 12); + return Zmm(31 - id_reg_out); + } + + Zmm vreg_tmp(int id) { // 2 + const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id; + assert(id_reg_tmp < jcp.alpha * jcp.alpha + 14); + return Zmm(31 - id_reg_tmp); + } + + Zmm vreg_zero = Zmm(0); + Zmm vreg_prev_dst = Zmm(0); + Zmm vreg_bias = Zmm(2); + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert (id < 4); + return Opmask(3 + id); + } + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_oc_block = r8; + + Reg64 reg_ptr_bias = rbx; + Reg64 reg_ptr_scales = abi_not_param1; + Reg64 reg_ptr_sum_scale = rdx; +}; + +bool jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::maybe_relu(int position) { + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* relu before sum */ + return false + || p.contain(eltwise, 0); + } else if (position == 1) { + /* relu after sum */ + const int sum_idx = p.contain(sum, 0) + ? 0 : (p.contain(sum, 1) ? 1 : -1); + if (sum_idx == -1) + return false; + + return false + || p.contain(eltwise, sum_idx + 1); + } + + return false; +} + +void jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::generate() { + Label oc_block_label; + + const int load_block = 16; + + auto loop_body = [=]() { + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = (sum_idx != -1) + ? &p.entry_[sum_idx].sum.scale + : nullptr; + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + for (int i = 0; i < 16; i++) { + int internal_offset = sizeof(float) * jcp.out_stride * i; + vmovups(vreg_inp(i), + EVEX_compress_addr(reg_aux_ptr_src, internal_offset)); + } + for (int y = 0; y < jcp.alpha; y++) { + vaddps(vreg_tmp(0), vreg_inp(y * 4 + 0), vreg_inp(y * 4 + 1)); + vaddps(vreg_stg(y * 2), vreg_tmp(0), vreg_inp(y * 4 + 2)); + + vsubps(vreg_tmp(1), vreg_inp(y * 4 + 1), vreg_inp(y * 4 + 2)); + vsubps(vreg_stg(y * 2+1), vreg_tmp(1), vreg_inp(y * 4 + 3)); + } + for (int x = 0; x < jcp.m; x++) { + vaddps(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2 * 1)); + vaddps(vreg_out(x), vreg_tmp(0), vreg_stg(x+2 * 2)); + + vsubps(vreg_tmp(1), vreg_stg(x+2 * 1), vreg_stg(x+2 * 2)); + vsubps(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2 * 3)); + } + + + if (jcp.with_bias) { + auto bias_addr = ptr [ reg_ptr_bias ]; + vmovups(vreg_bias, bias_addr); + } + for (int y = 0; y < jcp.m; y++) { + kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(int16_t) * y ]); + for (int x = 0; x < jcp.m; x++) { + kandw(r_mask, y_mask, x_mask(x)); + + int i = y * jcp.m + x; + int offset = sizeof(float) * + (y * jcp.ow * jcp.oc_block + x * jcp.oc_block); + Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset); + + Zmm zmm = vreg_out(i); + if (jcp.with_bias) + vaddps(zmm, zmm, vreg_bias); + vmulps(zmm, zmm, ptr [reg_ptr_scales]); + + if (maybe_relu(0)) { + vxorps(vreg_zero, vreg_zero, vreg_zero); + vmaxps(zmm, vreg_zero, zmm); + } + if (p_sum_scale) { // post_op: sum + vxorps(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst); + vmovups(vreg_prev_dst | r_mask, addr); + if (*p_sum_scale == 1.f) + vaddps(zmm, vreg_prev_dst); + else + vfmadd231ps(zmm, vreg_prev_dst, + zword_b[reg_ptr_sum_scale]); + } + if (maybe_relu(1)) { + vxorps(vreg_zero, vreg_zero, vreg_zero); + vmaxps(zmm, vreg_zero, zmm); + } + + vmovups(addr, zmm | r_mask); + } + } + }; + + preamble(); + +#define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_aux_ptr_src, wino_dst); + READ_PARAM(reg_aux_ptr_dst, dst); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); + READ_PARAM(reg_ptr_bias, bias); + READ_PARAM(reg_ptr_scales, scales); +#undef READ_PARAM + + for (int i = 0; i < jcp.alpha * jcp.alpha; i++) + vxorps(vreg_inp(i), vreg_inp(i), vreg_inp(i)); + + for (int i = 0; i < jcp.alpha; i++) + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]); + + int oc_blocks = 1; + oc_blocks = jcp.oc / load_block; + mov(reg_oc_block, oc_blocks); + L(oc_block_label); + { + loop_body(); + add(reg_aux_ptr_src, sizeof(float) * load_block); + add(reg_aux_ptr_dst, sizeof(float) * jcp.oh * jcp.ow * load_block); + + add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); + add(reg_ptr_bias, jcp.typesize_bia * load_block); + } + dec(reg_oc_block); + cmp(reg_oc_block, 0); + jg(oc_block_label, T_NEAR); + + sub(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); + sub(reg_ptr_bias, oc_blocks * jcp.typesize_bia * load_block); + + postamble(); + +} + +/// GEMM kernel //////////////////////////////////////////////////////////////// +struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t) + jit_conv_conf_2x3_wino_t jcp; + + struct call_params_t { + const void *src; + const void *dst; + const void *wei; + const void *dst_b; + }; + void (*ker_)(const call_params_t *); + + void generate(); + static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp, + const primitive_attr_t &attr); + + jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp) { + generate(); + ker_ = reinterpret_cast( + const_cast(getCode())); + } + + static status_t init_conf( + jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr, + memory_desc_t& expect_wei_md); + + Zmm vreg_out(int n, int m) { + const int id_reg_out = n * jcp.m_block + m; + assert(id_reg_out < jcp.n2_block * jcp.m_block); + return Zmm(31 - id_reg_out); + } + Zmm vreg_wei(int i) { + assert (31 - jcp.n2_block * jcp.m_block - i > 1); + return Zmm(31 - jcp.n2_block * jcp.m_block - i); + } + + Zmm vreg_src = Zmm(0); + Zmm vreg_one = Zmm(1); + Zmm vreg_tmp = Zmm(2); + + Reg64 reg_ptr_src = r15; + + Reg64 reg_aux_dst = r12; + Reg64 reg_aux_dst2 = r11; + Reg64 reg_aux_wei = r10; + Reg64 reg_aux_wei2 = r9; + Reg64 reg_aux_src = r8; + Reg64 reg_aux_src2 = rax; + + Reg64 reg_mb = rbx; + Reg64 reg_nnb = rdx; + Reg64 reg_K = rsi; + +}; + +bool jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::post_ops_ok( + jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_relu(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_relu(1)) || + (p.contain(sum, 1) && is_relu(0)); + case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2); + default: return false; + } + + return false; +} + +void jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::generate() { + Label nnb_loop_label, K_loop_label, mb_loop_label; + + preamble(); +#define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, src); + READ_PARAM(reg_aux_dst, dst); + READ_PARAM(reg_aux_wei, wei); +#undef READ_PARAM + + if (!jcp.small_mb) { + mov(reg_nnb, jcp.n_chunks); + L(nnb_loop_label); + } + mov(reg_aux_dst2, reg_aux_dst); + mov(reg_aux_src, reg_ptr_src); + mov(reg_mb, jcp.M / jcp.m_block); + L(mb_loop_label); + { + int nb2 = 0; + for (nb2 = 0; nb2 < jcp.n2_block; nb2++) { + for (int m = 0; m < jcp.m_block; m++) { + vxorps(vreg_out(nb2, m), vreg_out(nb2, m), vreg_out(nb2, m)); + } + } + mov(reg_aux_src2, reg_aux_src); + mov(reg_aux_wei2, reg_aux_wei); + + mov(reg_K, jcp.k_chunks); + L(K_loop_label); { + int wei_offset = 0; + for (int _i = 0; _i < jcp.k2_block; _i++) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + if (jcp.small_mb) { + int wei_offset = sizeof(float) + * ((nb2 * jcp.nb_ic * jcp.ic_block + * jcp.oc_block) + + _i * jcp.oc_block); + vmovups(vreg_wei(nb2), + EVEX_compress_addr(reg_aux_wei2, wei_offset)); + } else { + vmovups(vreg_wei(nb2), + EVEX_compress_addr(reg_aux_wei2, + sizeof(float) * wei_offset)); + wei_offset += jcp.oc_block; + } + } + for (int m = 0; m < jcp.m_block; m++) { + int inp_offset = sizeof(float) * (m * jcp.K + _i); + if (jcp.n2_block > 1) { + vbroadcastss(vreg_src, + EVEX_compress_addr(reg_aux_src2, inp_offset)); + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) + vfmadd231ps(vreg_out(nb2, m), vreg_wei(nb2), + vreg_src); + } else { + vfmadd231ps(vreg_out(0, m), vreg_wei(0), + EVEX_compress_addr(reg_aux_src2, inp_offset, true)); + } + } + } + add(reg_aux_src2, sizeof(float) * jcp.ic_block); + if (jcp.small_mb) + add(reg_aux_wei2, sizeof(float) * jcp.oc_block * jcp.ic_block); + else + add(reg_aux_wei2, + sizeof(float) * jcp.k2_block * jcp.n2_block + * jcp.oc_block); + } + dec(reg_K); + cmp(reg_K, 0); + jg(K_loop_label, T_NEAR); + + for (int m = 0; m < jcp.m_block; m++) { + int nb2 = 0; + for (nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int offset = sizeof(float) * + (m * jcp.N + nb2 * jcp.oc_block); + vmovups(EVEX_compress_addr(reg_aux_dst2,offset), + vreg_out(nb2, m)); + } + } + add(reg_aux_src, sizeof(float) * jcp.m_block * jcp.K); + add(reg_aux_dst2, sizeof(float) * jcp.m_block * jcp.N); + } + dec(reg_mb); + cmp(reg_mb, 0); + jg(mb_loop_label, T_NEAR); + + if (!jcp.small_mb) { + add(reg_aux_dst, sizeof(float) * jcp.n2_block * jcp.oc_block); + add(reg_aux_wei, + sizeof(float) * jcp.k_chunks * jcp.ic_block * jcp.n2_block + * jcp.oc_block); + + dec(reg_nnb); + cmp(reg_nnb, 0); + jg(nnb_loop_label, T_NEAR); + } + postamble(); +} + +namespace { +bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) { + return jcp.mb >= 4; +} +} + +status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf( + jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &wei_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr, memory_desc_t &expect_wei_md) { + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper wei_d(&wei_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const bool with_groups = wei_d.ndims() == src_d.ndims() + 1; + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.ngroups = with_groups ? wei_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = wei_d.dims()[with_groups + 2]; + jcp.kw = wei_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.b_pad = cd.padding[1][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.r_pad = cd.padding[1][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.m = 2; + jcp.r = 3; + jcp.alpha = jcp.m + jcp.r - 1; + int simdw = 16; + + format_tag_t dat_tag = format_tag::nChw16c; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simdw); + jcp.ic = rnd_up(jcp.ic, simdw); + } + + jcp.ver = ver_avx512_core; + if (!(mayiuse(avx512_core))) + return status::unimplemented; + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + if (src_d.data_type() != data_type::f32) + return status::unimplemented; + if (wei_d.data_type() != data_type::f32) + return status::unimplemented; + if (dst_d.data_type() != data_type::f32) + return status::unimplemented; + + jcp.ic_block = simdw; + jcp.oc_block = simdw; + + bool ok = true && jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1 + && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 + && jcp.stride_h == 1 && jcp.stride_w == 1 && jcp.dilate_h == 0 + && jcp.dilate_w == 0 && jcp.t_pad == jcp.b_pad + && jcp.l_pad == jcp.r_pad && jcp.t_pad < 2 && jcp.t_pad >= 0 + && jcp.l_pad < 2 && jcp.l_pad >= 0; + if (!ok) + return status::unimplemented; + + const int L2_cap = get_cache_size(2, true) / sizeof(float); + const int L3_capacity = get_cache_size(3, false) / sizeof(float); + int a = jcp.alpha; + int aa = a * a; + int mb = jcp.mb; + int ic = jcp.ic; + int oc = jcp.oc; + int ih = jcp.ih; + int iw = jcp.iw; + auto wei_sz = (float)aa * ic * oc; + auto inp_sz = (float)mb * ih * iw * ic; + auto sp_sz = (float)mb * ih * iw; + + /* Heuristics here. Numbers '28','196' is an observation from data. */ + if (wei_sz / inp_sz > 5) + jcp.small_mb = true; + else + jcp.small_mb = false; + + if (mb > nstl::min(jcp.nthr, 28) + || (!jcp.small_mb + && (wei_sz >= 0.9f * L2_cap + || inp_sz > L2_cap * jcp.nthr + L3_capacity)) + || (jcp.small_mb && sp_sz > 196)) + return status::unimplemented; + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.typesize_bia + = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; + + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + const int skx_free_regs = 30; + + auto find_m_n2_blocks = [=](int xb, int yb, int &M, int &m_block, + int &n2_block, float ®_eff) { + M = (xb * yb) / jcp.alpha; + int max_m_block = m_block = nstl::min(M, skx_free_regs); + int max_n2_block = n2_block = nstl::min(jcp.nb_oc, skx_free_regs); + reg_eff = 0; + for (int im = max_m_block; im > 0; im--) { + for (int in2 = max_n2_block; in2 > 0; in2--) { + int used_regs = in2 * im + in2; + float cur_reg_eff = ((float)in2 * im) / (im + in2) / 2.5f; + if (M % im || jcp.nb_oc % in2 || used_regs > skx_free_regs + || cur_reg_eff <= reg_eff) + continue; + reg_eff = cur_reg_eff; + m_block = im; + n2_block = in2; + } + } + }; + + int oh = jcp.oh; + int ow = jcp.ow; + int nb_oc = jcp.nb_oc; + int Z = ic + oc; + int Y = ic * oc; + const int L3_cap_per_core = get_cache_size(3, true) / sizeof(float); + + /* Selecting xb and yb blocking */ + int min_yb = jcp.alpha; + int min_xb = jcp.alpha; + int max_yb = nstl::max(min_yb, rnd_up(ih, 2)); + int max_xb = nstl::max(min_xb, rnd_up(iw, 2)); + float best_eff = 0.f; + for (int ix = max_xb; ix >= min_xb; ix -= 2) { + if (rnd_up(ow, ix) < iw - 2) + continue; + for (int iy = max_yb; iy >= min_yb; iy -= 2) { + if (rnd_up(oh, iy) < ih - 2) + continue; + int ex_y = rnd_up(oh, iy); + int ex_x = rnd_up(ow, ix); + float work_eff = (float)(ih * iw) / (ex_y * ex_x); + + int M, m_block, n2_b; + float reg_eff, thr_eff, par_eff, mem_eff, req_mem; + + find_m_n2_blocks(ix, iy, M, m_block, n2_b, reg_eff); + + /* outer parallelization */ + int nblocks = mb * div_up(ih, iy) * div_up(iw, ix); + thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr); + + mem_eff = 1.f; + req_mem = (((float)ix + 2) * (iy + 2) + aa * M) * Z + aa * Y; + if (req_mem > L2_cap / 2) { + if (req_mem > ((L2_cap + L3_cap_per_core) * 4) / 7) + mem_eff /= (n2_b + 1) / 2.f; + else + mem_eff /= (n2_b + 1) / 3.f; + } + + float outer_eff = thr_eff + work_eff + reg_eff + mem_eff; + + /* inner parallelization */ + int bsz = iy * ix / a; + int gemmw = aa * (nb_oc / n2_b); + int bsz_r = rnd_up(bsz, jcp.nthr); + int gemmw_r = rnd_up(gemmw, jcp.nthr); + thr_eff = ((float)Z * bsz / bsz_r + Y * gemmw / gemmw_r) / (Z + Y); + + req_mem = (float)ix * iy * (ic + simdw * n2_b) + simdw * n2_b * ic; + mem_eff = nstl::min(1.f, L2_cap / req_mem); + int M_per_thr = nstl::max(2, div_up(aa, jcp.nthr)); + int oc_per_thr = + nstl::min(oc, div_up(aa * (nb_oc / n2_b), jcp.nthr)); + req_mem = (float)aa * oc_per_thr * ic + M_per_thr * M * Z; + if (req_mem > L2_cap) + mem_eff = 0.1f; + par_eff = 1 / (2.f * nblocks); + + float inner_eff = thr_eff + work_eff + mem_eff + par_eff; + + float eff = jcp.small_mb ? inner_eff : outer_eff; + if (eff > best_eff) { + best_eff = eff; + jcp.yb = iy; + jcp.xb = ix; + jcp.M = M; + jcp.m_block = m_block; + jcp.n2_block = n2_b; + } + } + } + + assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0); + + jcp.inp_stride = jcp.M * jcp.ic; + jcp.out_stride = jcp.M * jcp.oc; + jcp.wei_stride = jcp.ic * jcp.oc; + jcp.bia_stride = jcp.oc; + + jcp.N = jcp.oc; + jcp.K = jcp.ic; + + jcp.n_block = jcp.oc_block; + jcp.k_block = jcp.ic_block; + + assert(jcp.M % jcp.m_block == 0); + assert(jcp.nb_oc % jcp.n2_block == 0); + + jcp.n_chunks = jcp.nb_oc / jcp.n2_block; + jcp.k2_block = jcp.ic_block; + jcp.k_chunks = jcp.K / jcp.k2_block; + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + /* re-create weights primitive descriptor + and set weights wino_blocking */ + expect_wei_md.format_kind = format_kind::wino; + expect_wei_md.data_type = data_type::f32; + mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; + wd.wino_format + = jcp.small_mb ? mkldnn_wino_wei_aaOio : mkldnn_wino_wei_aaOBiOo; + wd.r = jcp.r; + wd.alpha = jcp.alpha; + wd.ic = jcp.ic; + wd.oc = jcp.oc; + wd.ic_block = jcp.ic_block; + wd.oc_block = jcp.oc_block; + wd.oc2_block = jcp.n2_block; + wd.ic2_block = 1; + wd.adj_scale = 1.f; + size_t max_size = sizeof(float) * jcp.alpha * jcp.alpha * jcp.ic * jcp.oc; + wd.size = max_size; + + return status::success; +} +//////////////////////////////////////////////////////////////////////////////// + +status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_t + ::pd_t::jit_conf(memory_desc_t& expect_wei_md) { + return jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::init_conf( + jcp_, *this->desc(), this->src_md_, this->weights_md_, + this->dst_md_,this->bias_md_, *this->attr(), expect_wei_md); +} + +jit_avx512_core_fp32_wino_conv_2x3_fwd_t:: + jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) +{ + kernel_ = new jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t( + pd()->jcp_, *pd()->attr()); + src_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_src_trans_t( + pd()->jcp_, *pd()->attr()); + dst_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t( + pd()->jcp_, *pd()->attr()); +} + +jit_avx512_core_fp32_wino_conv_2x3_fwd_t + ::~jit_avx512_core_fp32_wino_conv_2x3_fwd_t() { + delete kernel_; + delete src_trans_; + delete dst_trans_; +} + +void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_mbN( + const float *src, const float *wei, const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const +{ + const auto &jcp = kernel_->jcp; + const auto &oscales = pd()->attr()->output_scales_; + + const size_t wino_size_offset = + (size_t)(pd()->jcp_.yb / 2) * (pd()->jcp_.xb / 2) + (pd()->jcp_.xb); + const size_t size_wino_src = wino_size_offset * pd()->jcp_.ic * 16; + const size_t size_wino_dst = wino_size_offset * pd()->jcp_.oc * 16; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get(key_conv_padded_bias); + utils::array_copy(padded_bias, bia, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bia = padded_bias; + } + + auto ptr_V = scratchpad.get(key_wino_V); + auto ptr_M = scratchpad.get(key_wino_M); + + parallel_nd(jcp.mb, div_up(jcp.oh,jcp.yb), div_up(jcp.ow, jcp.xb), + [&](int mb, int tile_y_b, int tile_x_b) { + int tile_y = tile_y_b * jcp.yb; + int tile_x = tile_x_b * jcp.xb; + + int ithr = mkldnn_get_thread_num(); + auto wino_src = ptr_V + size_wino_src * ithr; + auto wino_dst = ptr_M + size_wino_dst * ithr; + + auto src_trans_p = + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t + ::call_params_t(); + auto dst_trans_p = + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t + ::call_params_t(); + auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t :: + call_params_t(); + + /* transformation of input tensor to winograd domain */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; + x_in_block += 2) { + + unsigned short v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min(jcp.alpha, + nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min(jcp.alpha, + nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; + v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; + } + auto local_s = src + + mb * jcp.nb_ic * jcp.ih * jcp.iw + * jcp.ic_block + + y * jcp.iw * jcp.ic_block + x * jcp.ic_block; + auto local_w = wino_src + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + } + } + /* gemms */ + for (int tile_ij = 0; tile_ij < 16; tile_ij++) { + int offset = (tile_ij + ithr) % 16; + gemm_p.src = wino_src + jcp.inp_stride * offset; + gemm_p.dst = wino_dst + jcp.out_stride * offset; + gemm_p.wei = wei + jcp.wei_stride * offset; + + kernel_->ker_(&gemm_p); + } + + /* transformation from winograd domain to output tensor */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; + x_in_block += 2) { + unsigned short v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; + v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; + } + auto local_d = dst + + mb * jcp.nb_oc * jcp.oh * jcp.ow + * jcp.oc_block + + y * jcp.ow * jcp.oc_block + x * jcp.oc_block; + auto local_w = wino_dst + m * jcp.oc; + + auto scales = oscales.scales_; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + } + } + }); +} + +void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_small_mb( + const float *src, const float *wei, const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const +{ + const auto &jcp = kernel_->jcp; + const auto &oscales = pd()->attr()->output_scales_; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get(key_conv_padded_bias); + utils::array_copy(padded_bias, bia, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bia = padded_bias; + } + + auto ptr_V = scratchpad.get(key_wino_V); + auto ptr_M = scratchpad.get(key_wino_M); + + for (int mb = 0; mb < jcp.mb; mb++) { + for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) { + for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) { + /* transformation of input tensor to winograd domain */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), + [&](int y_in_block_b, int x_in_block_b) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto src_trans_p = jit_avx512_core_fp32_wino_conv_2x3_src_trans_t :: + call_params_t(); + + unsigned short v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min( + jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min( + jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; + v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; + } + auto local_s = src + + mb * jcp.nb_ic * jcp.ih * jcp.iw * jcp.ic_block + + y * jcp.iw * jcp.ic_block + x * jcp.ic_block; + auto local_w = ptr_V + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + }); + + /* gemms */ + parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) { + auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t :: + call_params_t(); + + gemm_p.src = ptr_V + jcp.inp_stride * tile_ij; + gemm_p.dst = ptr_M + jcp.out_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + gemm_p.wei = wei + jcp.wei_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block * jcp.K; + + kernel_->ker_(&gemm_p); + }); + + /* transformation from winograd domain to output tensor */ + + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), + [&](int y_in_block_b, int x_in_block_b) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto dst_trans_p = jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t :: + call_params_t(); + + unsigned short v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; + v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; + } + auto local_d = dst + + mb * jcp.nb_oc * jcp.oh * jcp.ow * jcp.oc_block + + y * jcp.ow * jcp.oc_block + x * jcp.oc_block; + auto local_w = ptr_M + m * jcp.oc; + + auto scales = oscales.scales_; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + }); + }}} +} + +} // namespace cpu +} // namespace impl +} // namespace mkldnn diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp new file mode 100644 index 000000000000..7e38b07f5af5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp @@ -0,0 +1,144 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP +#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP + +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t; +struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t; +struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t; + +struct jit_avx512_core_fp32_wino_conv_2x3_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_fp32_wino_2x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_2x3_fwd_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::forward_inference + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) return status::unimplemented; + + memory_desc_t expect_wei_md = *weights_md(); + status_t jit_conf_result = jit_conf(expect_wei_md); + if (jit_conf_result != status::success) return jit_conf_result; + set_default_alg_kind(alg_kind::convolution_winograd); + + if (weights_md_.format_kind == format_kind::any) + weights_md_ = expect_wei_md; + if (weights_md_ != expect_wei_md) + return status::unimplemented; + + init_scratchpad(); + + return status::success; + } + + jit_conv_conf_2x3_wino_t jcp_; + + protected: + status_t jit_conf(memory_desc_t& expect_wei_md); + + void init_scratchpad() { + using namespace memory_tracking::names; + + auto scratchpad = scratchpad_registry().registrar(); + + int wino_size_offset = (jcp_.yb / 2) * (jcp_.xb / 2) + jcp_.xb; + + size_t V_sz = (size_t)jcp_.ic * 16 * wino_size_offset * jcp_.nthr; + scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_4K); + + size_t M_sz = (size_t)jcp_.oc * 16 * wino_size_offset * jcp_.nthr; + scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_4K); + + if (wants_padded_bias()) { + assert(jcp_.ngroups == 1); + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp_.oc); + } + } + + bool set_default_formats() { + using namespace format_tag; + return set_default_formats_common(nChw16c, any, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd); + ~jit_avx512_core_fp32_wino_conv_2x3_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto wei = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto bia = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); + + if (pd()->jcp_.small_mb) + execute_forward_small_mb(src, wei, bia, dst, this->scratchpad(ctx)); + else + execute_forward_mbN(src, wei, bia, dst, this->scratchpad(ctx)); + + return status::success; + } + +private: + void execute_forward_small_mb(const float *src, const float *wei, + const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const; + void execute_forward_mbN(const float *src, const float *wei, + const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t *kernel_; + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t *src_trans_; + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t *dst_trans_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp new file mode 100644 index 000000000000..96325e3ade24 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp @@ -0,0 +1,1020 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifdef __INTEL_COMPILER +#include +#endif + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_fp32_wino_conv_4x3.hpp" + +#ifndef _MSC_VER +#define pragma_unroll _Pragma("unroll") +#else +#define pragma_unroll +#endif + + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t +::weight_transform_data(const jit_conv_winograd_conf_t &jcp, + float *wp, float *twp) const +{ + float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f, + 1.13777777777778f, 0.430252100840336f, 0.179271708683473f}; + const int kh = 3; + const int kw = 3; + float Fw[alpha][alpha][simd_w][simd_w]; + float F[kh][kw][simd_w][simd_w]; + float T[alpha][3][simd_w]; + auto p = jit_wino_transform_call_s(); + + p.src = wp; + p.dst = twp; + p.G = G; + p.M = F; + p.Mw = Fw; + p.T = T; + + kernel_->weights_transform_data_ker(&p); +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t::output_transform_data +(int image, const jit_conv_winograd_conf_t &jcp, + const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) const { + + float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f}; + float Ow[alpha][alpha][simd_w]; + float O[tile_size][tile_size][simd_w]; + float T[tile_size][alpha][simd_w]; + + auto p = jit_wino_transform_call_s(); + p.src = toutp; + p.dst = pout_b; + p.G = G; + p.M = O; + p.Mw = Ow; + p.T = T; + p.bias = bias; + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tile_block = tile_block; + p.tj = tj; + p.ti = ti; + + kernel_->output_transform_data_ker(&p); + + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t +::output_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops, + float *toutp, float *outp, float *bias) const { + + float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f}; + float Ow[alpha][alpha][simd_w]; + float O[tile_size][tile_size][simd_w]; + float T[tile_size][alpha][simd_w]; + + auto p = jit_wino_transform_call_s(); + p.src = toutp; + p.dst = outp; + p.G = G; + p.M = O; + p.Mw = Ow; + p.T = T; + p.bias = bias; + + int outw = is_fwd ? jcp.ow : jcp.iw; + int outh = is_fwd ? jcp.oh : jcp.ih; + + int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur; + + for (int nb_tile_block_ur = 0; + nb_tile_block_ur < jcp.nb_tile_block_ur; + nb_tile_block_ur++) { + + for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur; + tile_block_ur++) { + int img = tile_index / (jcp.jtiles * jcp.itiles); + int ti = tile_index % jcp.itiles; + int tj = (tile_index / jcp.itiles) % jcp.jtiles; + + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tile_block = tile_block; + p.tj = tj; + p.ti = ti; + p.dst = outp + img * (jcp.dimM / jcp.dimM_simd_block) + * outh * outw * jcp.dimM_simd_block; + + kernel_->output_transform_data_ker(&p); + + tile_index++; + } + } +} + + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t + ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const +{ + float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + + float Iw[alpha][alpha][simd_w]; + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + + auto p = jit_wino_transform_call_s(); + + p.src = inp; + p.dst = tinp; + p.G = G; + p.M = I; + p.Mw = Iw; + p.T = T; + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tile_block = tile_block; + p.tj = tj; + p.ti = ti; + + kernel_->input_transform_data_ker(&p); + + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t + ::input_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const +{ + float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + float Iw[alpha][alpha][simd_w]; + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + + array_offset_calculator input(inp, + jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w); + array_offset_calculator output(tinp, + alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block, + jcp.dimN_reg_block, jcp.dimK_reg_block); + + auto p = jit_wino_transform_call_s(); + + p.dst = tinp; + p.G = G; + p.M = I; + p.Mw = Iw; + p.T = T; + + + int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur; + + for (int nb_tile_block_ur = 0; + nb_tile_block_ur < jcp.nb_tile_block_ur; + nb_tile_block_ur++) { + + for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur; + tile_block_ur++) { + + int img = tile_index / (jcp.jtiles * jcp.itiles); + int ti = tile_index % jcp.itiles; + int tj = (tile_index / jcp.itiles) % jcp.jtiles; + float *pinp_b = &(input(img, 0, 0, 0, 0)); + + p.src = pinp_b; + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tj = tj; + p.ti = ti; + + kernel_->input_transform_data_ker(&p); + + tile_index++; + } + } +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t::_execute_data_W_S_G_D( + float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const auto &p_ops = attr_->post_ops_; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int outh = is_fwd ? jcp.oh : jcp.ih; + const int outw = is_fwd ? jcp.ow : jcp.iw; + + /* Notation: + FWD: dimM:oc, dimN:ntiles, dimK:ic, + BWD: dimM:ic, dimN:ntiles, dimK:oc, + FWD/BWD: V: src/diff_dst transform, U:weight transform, + M:dst/diff_src transform */ + array_offset_calculator input(inp_ptr, + jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, + jcp.dimK_reg_block); + array_offset_calculator output(out_ptr, + jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, + jcp.dimM_simd_block); + array_offset_calculator weights(wei_ptr, + jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, + jcp.ic_simd_block, jcp.oc_simd_block); + array_offset_calculator bias(bias_ptr, + jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block); + + array_offset_calculator M(is_fwd + ? scratchpad.template get(key_wino_M) + : scratchpad.template get(key_wino_V), + jcp.dimN_nb_block, jcp.dimM_nb_block, + alpha, alpha, + jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + + auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference) + ? wei_ptr + : scratchpad.template get(key_wino_U); + + array_offset_calculator U(wino_wei, + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimM_simd_block); + array_offset_calculator V(is_fwd + ? scratchpad.template get(key_wino_V) + : scratchpad.template get(key_wino_M), + jcp.dimN_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block); + + const bool wants_padded_bias = jcp.with_bias + && jcp.oc_without_padding != jcp.oc; + float last_slice_bias[simd_w] = {0}; + if (wants_padded_bias) { + for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) + last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); + } + + { + + parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block, + [&](int img, int K_blk1, int K_blk2) { + input_transform_data(img, jcp, + &(input(img, K_blk1 * jcp.dimK_block + K_blk2, + 0, 0, 0)), + &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0))); + }); + + if (jcp.prop_kind != prop_kind::forward_inference) { + parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), + (jcp.ic_block * jcp.ic_reg_block), + [&](int ofm1, int ifm1, int ofm2, int ifm2) { + float *U_base_ptr = is_fwd + ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) + : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); + weight_transform_data(jcp, + &(weights( + ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2, + ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2, + 0, 0, 0, 0)), + U_base_ptr); + }); + } + + parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, + [&](int N_blk1, int oj, int oi, int M_blk1) { + for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; + K_blk1++) + for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++) + kernel_->gemm_loop_ker( + (float *)&(M(N_blk1, M_blk1, oj, oi, + N_blk2, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, + K_blk1, 0, 0, 0, 0)), + (const float *)&(V(N_blk1, oj, oi, + N_blk2, K_blk1, 0, 0, 0)), K_blk1); + }); + + parallel_nd(jcp.mb, jcp.dimM_nb_block, (jcp.dimM_block * jcp.dimM_reg_block), + [&](int img, int M_blk1, int M_blk2) { + const int M_blk = + M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2; + + float *bias_ptr = wants_padded_bias + && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 + ? last_slice_bias : &bias(M_blk, 0); + output_transform_data(img, jcp, p_ops, + &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)), + &(output(img, M_blk, 0, 0, 0)), bias_ptr); + }); + + } +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t::_execute_data_W_SGD( + float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const auto &p_ops = attr_->post_ops_; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int outh = is_fwd ? jcp.oh : jcp.ih; + const int outw = is_fwd ? jcp.ow : jcp.iw; + + array_offset_calculator input(inp_ptr, + jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block); + array_offset_calculator output(out_ptr, + jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block); + array_offset_calculator weights(wei_ptr, + jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, + jcp.ic_simd_block, jcp.oc_simd_block); + array_offset_calculator bias(bias_ptr, + jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block); + + auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference) + ? wei_ptr + : scratchpad.template get(key_wino_U); + + array_offset_calculator U(wino_wei, + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimM_simd_block); + + array_offset_calculator M(is_fwd + ? scratchpad.template get(key_wino_M) + : scratchpad.template get(key_wino_V), + 0, jcp.dimM_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + array_offset_calculator V(is_fwd + ? scratchpad.template get(key_wino_V) + : scratchpad.template get(key_wino_M), + 0, alpha, alpha, jcp.dimN_block, + jcp.dimK_nb_block, jcp.dimK_block, + jcp.dimN_reg_block, jcp.dimK_reg_block); + + const bool wants_padded_bias = jcp.with_bias + && jcp.oc_without_padding != jcp.oc; + float last_slice_bias[simd_w] = {0}; + if (wants_padded_bias) { + for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) + last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); + } + + if (jcp.prop_kind != prop_kind::forward_inference) { + + parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), (jcp.ic_block * jcp.ic_reg_block), + [&](int ofm1, int ifm1, int ofm2, int ifm2) { + float *U_base_ptr = is_fwd + ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) + : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); + weight_transform_data(jcp, + &(weights( + ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2, + ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2, + 0, 0, 0, 0)), + U_base_ptr); + }); + } + + parallel_nd(jcp.tile_block, [&](int tile_block) { + int ithr = mkldnn_get_thread_num(); + + for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) { + for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) { + + input_transform_tileblock_data( + tile_block, jcp, + &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)), + &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0))); + } + } + + for (int oj = 0; oj < alpha; oj++) { + for (int oi = 0; oi < alpha; oi++) { + for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) + for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) + for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++) + kernel_->gemm_loop_ker( + (float *)&(M(ithr, M_blk1, oj, oi, + N_blk, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, K_blk1, + 0, 0, 0, 0)), + (const float *)&(V(ithr, oj, oi, + N_blk, K_blk1, 0, 0, 0)), K_blk1); + } + } + + for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) { + for (int M_blk2 = 0; M_blk2 < jcp.dimM_block * jcp.dimM_reg_block; + M_blk2++) { + const int M_blk = + M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2; + + float *bias_ptr = wants_padded_bias + && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 + ? last_slice_bias : &bias(M_blk, 0); + + output_transform_tileblock_data(tile_block, jcp, p_ops, + &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)), + &(output(0, M_blk, 0, 0, 0)), bias_ptr); + } + } + }); +} + +template struct _jit_avx512_core_fp32_wino_conv_4x3_t; +template struct _jit_avx512_core_fp32_wino_conv_4x3_t; + +namespace { + +void subarray_sum(size_t num_arrs, float *output, size_t nelems, + float *input_ptrs[], size_t input_starts[], size_t input_ends[]) { + using namespace nstl; + const size_t block_size = 16 * 1024 / sizeof(float); + const size_t blocks_number = nelems / block_size; + const size_t tail = nelems % block_size; + +PRAGMA_OMP(parallel) + { + const int ithr = mkldnn_get_thread_num(); + const int nthr = mkldnn_get_num_threads(); + size_t start{ 0 }, end{ 0 }; + balance211(blocks_number, nthr, ithr, start, end); + + for (size_t nb = start; nb < end; ++nb) { + size_t start_e = nb * block_size; + size_t end_e = start_e + block_size; + size_t input_start = max(start_e, min(input_starts[0], end_e)); + size_t input_end = max(start_e, min(input_ends[0], end_e)); + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < input_start; e++) { + output[e] = 0.f; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] = input_ptrs[0][e]; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_end; e < end_e; e++) { + output[e] = 0.f; + } + + for (size_t a = 1; a < num_arrs; a++) { + input_start = max(start_e, input_starts[a]); + input_end = min(input_ends[a], end_e); + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + + if (tail != 0 && ithr == nthr - 1) { + size_t start_e = nelems - tail; + size_t end_e = nelems; + size_t input_start = max(start_e, min(input_starts[0], end_e)); + size_t input_end = max(start_e, min(input_ends[0], end_e)); + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < input_start; e++) { + output[e] = 0.f; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] = input_ptrs[0][e]; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_end; e < end_e; e++) { + output[e] = 0.f; + } + + for (size_t a = 1; a < num_arrs; a++) { + input_start = max(start_e, input_starts[a]); + input_end = min(input_ends[a], end_e); + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + } +} + +const int max_threads_number = 1024; + +// Sum to the first buffer array +void array_sum(size_t num_arrs, float *output, + size_t nelems, float *input_ptrs[], bool reduce_to_first = true) { + const size_t block_size = 16 * 1024 / sizeof(float); + const size_t blocks_number = nelems / block_size; + const size_t tail = nelems % block_size; + +PRAGMA_OMP(parallel) + { + const size_t ithr = mkldnn_get_thread_num(); + const size_t nthr = mkldnn_get_num_threads(); + size_t start{ 0 }, end{ 0 }; + balance211(blocks_number, nthr, ithr, start, end); + + for (size_t nb = start; nb < end; ++nb) { + size_t start_e = nb * block_size; + size_t end_e = start_e + block_size; + if (!reduce_to_first) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = input_ptrs[0][e]; + } + } + for (size_t a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + + if (tail != 0 && ithr == nthr - 1) { + size_t start_e = nelems - tail; + size_t end_e = nelems; + if (!reduce_to_first) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = input_ptrs[0][e]; + } + } + for (size_t a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + } +} +} //bwdw namespace + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t:: +_execute_backward_weights_SDGtWo(const float *ptr_src, + const float *ptr_diff_dst, float *ptr_diff_weights, + float *ptr_diff_bias, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const int nthreads = jcp.nthr; + + array_offset_calculator src((float *)ptr_src, + jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w); + array_offset_calculator diff_dst((float *)ptr_diff_dst, + jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w); + array_offset_calculator diff_weights(ptr_diff_weights, + jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + + array_offset_calculator Us(scratchpad.get(key_wino_U), + 0, alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, + jcp.oc_reg_block, + jcp.oc_simd_block); + + const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc + * jcp.ic / jcp.nb_ic; + array_offset_calculatordiff_weights_prv( + scratchpad.get(key_wino_U) + U_sz, + 0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + + array_offset_calculator M(scratchpad.get(key_wino_M), + 0, alpha, alpha, + jcp.oc_block, + jcp.nb_tile_block_ur, + jcp.tile_block_ur, + jcp.oc_reg_block, + jcp.oc_simd_block); + + array_offset_calculator V(scratchpad.get(key_wino_V), + 0, alpha, alpha, + jcp.ic_block, + jcp.nb_tile_block_ur, + jcp.tile_block_ur, + jcp.ic_simd_block); + + array_offset_calculator diff_bias_prv( + scratchpad.get(key_conv_bia_reduction), nthreads, jcp.oc); + + auto trans_ker_p = jit_wino_transform_call_s(); + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, 0.119514472455649f, + 0.430252100840336f, 0.168067226890756f, 0.179271708683473f, 0.403361344537815f, + 1.13777777777778f}; + float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f}; + +PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T)) +{ + if (jcp.with_bias) { + parallel_nd_in_omp(nthreads, jcp.oc / simd_w, + [&](int ithr, int ofm){ + float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w)); + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + pdbias[v] = 0.0f; + } + }); + } + + int ithr = mkldnn_get_thread_num(); + for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) { + int first_tblk = 0; +PRAGMA_OMP(for) + for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) { + int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur; + int img = tile_index / (jcp.itiles * jcp.jtiles); + trans_ker_p.ti = tile_index % jcp.itiles; + trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles; + trans_ker_p.M = I; + trans_ker_p.T = T; + trans_ker_p.G = G_I_3x3_4x4; + for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) { + int ifm = ifm1 * jcp.ic_block + ifm2; + trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0)); + kernel_->src_transform(&trans_ker_p); + } + + for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) { + trans_ker_p.G = G_W_3x3_4x4; + for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) { + int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block; + trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0)); + if (jcp.with_bias && ifm1 == 0) { + trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w)); + kernel_->diff_dst_transform_wbias(&trans_ker_p); + } else { + kernel_->diff_dst_transform(&trans_ker_p); + } + } + + for (int oj = 0; oj < alpha; ++oj) { + for (int oi = 0; oi < alpha; ++oi) { + kernel_->gemm_loop_ker_first_iter( + &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)), + &(M(ithr, oj, oi, 0, 0, 0, 0, 0)), + &(V(ithr, oj, oi, 0, 0, 0, 0))); + } + } + trans_ker_p.G = G_O_3x3_4x4; + for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) { + for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) { + int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block + + ofm3; + for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) { + int ifm = ifm1 * jcp.ic_block + ifm2; + trans_ker_p.src = (float *)&(Us(ithr, 0, 0, + ofm2, ifm2, 0, ofm3, 0)); + trans_ker_p.dst = (float *)&(diff_weights_prv(ithr, + ofm, ifm, 0, 0, 0, 0)); + if (first_tblk == 0) { + kernel_->diff_weights_transform(&trans_ker_p); + } else { + kernel_->diff_weights_transform_accum(&trans_ker_p); + } + } + } + } + } + ++first_tblk; + } + } +} + + // Reduce diff-weights + { + float *output = ptr_diff_weights; + float *input_base = scratchpad.get(key_wino_U) + U_sz; + int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw; + float *input_ptrs[max_threads_number]; + for (int i = 0; i < nthreads; ++i) { + input_ptrs[i] = input_base + nelems * i; + } + array_sum(nthreads, output, nelems, input_ptrs, false); + + if (jcp.with_bias) { + output = ptr_diff_bias; + input_base = scratchpad.get(key_conv_bia_reduction); + for (int i = 0; i < nthreads; ++i) { + input_ptrs[i] = input_base + jcp.oc * i; + } + array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs, + false); + } + } +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t:: +_execute_backward_weights_S_D_Giot_W(const float *ptr_src, + const float *ptr_diff_dst, float *ptr_diff_weights, + float *ptr_diff_bias, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const int nthreads = jcp.nthr; + + array_offset_calculator src((float *)ptr_src, + jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w); + array_offset_calculator diff_dst((float *)ptr_diff_dst, + jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w); + array_offset_calculator diff_weights((float *)ptr_diff_weights, + jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + array_offset_calculator diff_bias((float *)ptr_diff_bias, jcp.oc); + + array_offset_calculator U(scratchpad.get(key_wino_U), + jcp.nb_ic, jcp.nb_oc, + alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, + jcp.oc_reg_block, + jcp.oc_simd_block); + + const int U_size = jcp.oc * jcp.ic * alpha * alpha; + array_offset_calculator Us( + scratchpad.get(key_wino_U) + U_size, + 0, jcp.nb_ic, jcp.nb_oc, + alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, + jcp.oc_reg_block, + jcp.oc_simd_block); + + array_offset_calculator M(scratchpad.get(key_wino_M), + jcp.nb_oc, + jcp.tile_block, + alpha, alpha, + jcp.oc_block, + jcp.nb_tile_block_ur, + jcp.tile_block_ur , + jcp.oc_reg_block, + jcp.oc_simd_block); + + array_offset_calculator V(scratchpad.get(key_wino_V), + jcp.nb_ic, + jcp.tile_block, + alpha, alpha, + jcp.ic_block, + jcp.nb_tile_block_ur, jcp.tile_block_ur, + jcp.ic_simd_block); + + array_offset_calculator diff_bias_prv( + scratchpad.get(key_conv_bia_reduction), nthreads, jcp.oc); + + size_t input_starts[max_threads_number] = {0}; + size_t input_ends[max_threads_number] = {0}; + size_t first_tblk = 0; + + auto trans_ker_p = jit_wino_transform_call_s(); + float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, + 0.119514472455649f, 0.430252100840336f, 0.168067226890756f, + 0.179271708683473f, 0.403361344537815f, 1.13777777777778f}; + float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f}; + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + +PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T)) +{ + if (jcp.with_bias) { + parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) { + diff_bias_prv(ithr, ofm) = 0.0f; + }); + } + + trans_ker_p.G = G_I_3x3_4x4; + trans_ker_p.M = I; + trans_ker_p.T = T; + + parallel_nd_in_omp(jcp.nb_ic, jcp.ic_block, jcp.mb, + [&](int ifm1, int ifm2, int img){ + size_t ifm = ifm1 * jcp.ic_block + ifm2; + size_t tile_base_index = img * (jcp.itiles * jcp.jtiles); + size_t tblk3 = tile_base_index % jcp.tile_block_ur; + size_t tblk2 = (tile_base_index / jcp.tile_block_ur) + % jcp.nb_tile_block_ur; + size_t tblk1 = (tile_base_index / jcp.tile_block_ur) + / jcp.nb_tile_block_ur; + trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3; + trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0)); + kernel_->src_transform(&trans_ker_p); + }); + + int ithr = mkldnn_get_thread_num(); + trans_ker_p.G = G_W_3x3_4x4; + parallel_nd_in_omp(jcp.nb_oc, jcp.oc_block, jcp.mb, + [&](int ofm1, int ofm2, int img){ + int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block; + size_t tile_base_index = img * (jcp.itiles * jcp.jtiles); + size_t tblk3 = tile_base_index % jcp.tile_block_ur; + size_t tblk2 = (tile_base_index / jcp.tile_block_ur) + % jcp.nb_tile_block_ur; + size_t tblk1 = (tile_base_index / jcp.tile_block_ur) + / jcp.nb_tile_block_ur; + trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3; + trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0)); + if (jcp.with_bias) { + trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w)); + kernel_->diff_dst_transform_wbias(&trans_ker_p); + } else { + kernel_->diff_dst_transform(&trans_ker_p); + } + }); + + PRAGMA_OMP(barrier) + + parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block, + [&](int ifm1, int ofm1, int oj, int oi, int tblk1){ + if (first_tblk == 0) { + input_starts[ithr] = + (float *)&(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0, + 0, 0)) + - (float *)&(Us(ithr, 0, 0, 0, 0, 0, 0, + 0, 0, 0)); + input_ends[ithr] = input_starts[ithr] + + jcp.oc_block * jcp.ic_block + * jcp.ic_simd_block * jcp.oc_reg_block + * jcp.oc_simd_block; + } + else if (tblk1 == 0) { + input_ends[ithr] += jcp.oc_block * jcp.ic_block + * jcp.ic_simd_block * jcp.oc_reg_block + * jcp.oc_simd_block; + } + + if (first_tblk == 0 || tblk1 == 0) { + kernel_->gemm_loop_ker_first_iter( + &(Us(ithr, ifm1, ofm1, oj, oi, + 0, 0, 0, 0, 0)), + &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)), + &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0))); + } else { + kernel_->gemm_loop_ker( + &(Us(ithr, ifm1, ofm1, oj, oi, + 0, 0, 0, 0, 0)), + &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)), + &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0))); + } + ++first_tblk; + }); +} + + // Reduce diff-weights + { + float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0)); + size_t nelems = jcp.ic * jcp.oc * alpha * alpha; + float *input_ptrs[max_threads_number]; + for (int i = 0; i < nthreads; ++i) + input_ptrs[i] = output + nelems * (i + 1); + subarray_sum(nthreads, output, nelems, input_ptrs, + input_starts, input_ends); + } + + trans_ker_p.G = G_O_3x3_4x4; +PRAGMA_OMP(parallel firstprivate(trans_ker_p)) + { + parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, jcp.oc_reg_block, + [&](int ifm1, int ofm1, int ofm2, int ifm2, int ofm3){ + int ofm = (ofm1 * jcp.oc_block + ofm2) + * jcp.oc_reg_block + ofm3; + int ifm = ifm1 * jcp.ic_block + ifm2; + trans_ker_p.src = (float *)&(U(ifm1, ofm1, 0, 0, + ofm2, ifm2, 0, ofm3, 0)); + trans_ker_p.dst = (float *)&(diff_weights(ofm, ifm, + 0, 0, 0, 0)); + kernel_->diff_weights_transform(&trans_ker_p); + }); + } + + if (jcp.with_bias) { + parallel_nd(jcp.oc / simd_w, [&](int ofm1) { + float* pbias = &(diff_bias(ofm1 * simd_w)); + float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w)); + + const int blk_sz = ofm1 == jcp.oc / simd_w - 1 + ? jcp.oc_without_padding - ofm1 * simd_w : simd_w; + + PRAGMA_OMP_SIMD() + for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) { + pbias[ofm2] = pbias_prv[ofm2]; + } + + for (int ithr = 1; ithr < nthreads; ++ithr) { + pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w)); + PRAGMA_OMP_SIMD() + for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) { + pbias[ofm2] += pbias_prv[ofm2]; + } + } + }); + } +} + +} +} +} +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp new file mode 100644 index 000000000000..f1a56aac70ce --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp @@ -0,0 +1,386 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP +#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace winograd_avx512_core { +inline void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_winograd_conf_t &jcp) { + using namespace utils; + using namespace memory_tracking::names; + + size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc; + size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic * jcp.itiles + * jcp.jtiles; + size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc * jcp.itiles + * jcp.jtiles; + + switch (jcp.sched_policy) { + case WSCHED_DATA_W_SGD: + V_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur + * jcp.tile_block_ur * jcp.ic; + M_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur + * jcp.tile_block_ur * jcp.oc; + break; + case WSCHED_WEI_SDGtWo: + U_sz = (size_t)jcp.nthr * (alpha * alpha * jcp.oc + * (jcp.ic / jcp.nb_ic) + jcp.ic * jcp.oc * jcp.kh * jcp.kw); + M_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block) + * (jcp.oc / jcp.nb_oc); + V_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block) + * (jcp.ic / jcp.nb_ic); + break; + case WSCHED_WEI_S_D_Giot_W: + U_sz = (size_t)(jcp.nthr + 1) * alpha * alpha * jcp.ic * jcp.oc; + M_sz = (size_t)alpha * alpha * jcp.oc * jcp.ntiles; + V_sz = (size_t)alpha * alpha * jcp.ic * jcp.ntiles; + break; + default: break; + } + + scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M); + scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M); + scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M); + + if (one_of(jcp.sched_policy, WSCHED_WEI_SDGtWo, WSCHED_WEI_S_D_Giot_W)) { + size_t br_sz = (size_t)jcp.nthr * jcp.oc; + scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M); + } +} +} + +template +struct _jit_avx512_core_fp32_wino_conv_4x3_t { + + _jit_avx512_core_fp32_wino_conv_4x3_t( + const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr) + : kernel_(nullptr), attr_(attr) { + kernel_ = new _jit_avx512_core_fp32_wino_conv_4x3_data_kernel(jcp); + } + + ~_jit_avx512_core_fp32_wino_conv_4x3_t() { delete kernel_; } + + protected: + void weight_transform_data(const jit_conv_winograd_conf_t &jcp, + float *wp, float *twp) const; + void input_transform_data(int image, + const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const; + void input_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const; + void output_transform_data(int image, + const jit_conv_winograd_conf_t &jcp, + const post_ops_t &p_ops, float *toutp, float *pout_b, + float *bias) const; + void output_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops, + float *toutp, float *outp, float *bias) const; + void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, + float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const; + void _execute_data_W_SGD(float *inp_ptr, float *out_ptr, + float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const; + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel *kernel_; + const primitive_attr_t *attr_; +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_fwd_t + : _jit_avx512_core_fp32_wino_conv_4x3_t + , public cpu_primitive_t + { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_4x3_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel:: + init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, + *attr()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_core::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_fmt = desc()->prop_kind == prop_kind::forward_training + ? (with_groups() ? gOIhw16i16o : OIhw16i16o) : any; + return set_default_formats_common(nChw16c, wei_fmt, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_4x3_fwd_t(const pd_t *apd) + : _jit_avx512_core_fp32_wino_conv_4x3_t(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) + {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + switch ((pd()->jcp_).sched_policy) { + case WSCHED_DATA_W_S_G_D: + this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, + (float *)bias, scratchpad); + break; + case WSCHED_DATA_W_SGD: + this->_execute_data_W_SGD((float *)src, dst, (float *)weights, + (float *)bias, scratchpad); + break; + default: + break; + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t + : _jit_avx512_core_fp32_wino_conv_4x3_t, + public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t); + + status_t init() { + bool ok = true + && mkldnn_thr_syncable() + && desc()->prop_kind == prop_kind::backward_data + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel + ::init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_core::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_fmt, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t(const pd_t *apd) + : _jit_avx512_core_fp32_wino_conv_4x3_t(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) + {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC); + + auto scratchpad = this->scratchpad(ctx); + + switch ((pd()->jcp_).sched_policy) { + case WSCHED_DATA_W_S_G_D: + this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, + (float *)weights, NULL, scratchpad); + break; + + case WSCHED_DATA_W_SGD: + this->_execute_data_W_SGD((float *)diff_dst, diff_src, + (float *)weights, NULL, scratchpad); + break; + + default: + break; + } + + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t + : public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t); + + status_t init() { + bool ok = true + && mkldnn_thr_syncable() + && desc()->prop_kind == prop_kind::backward_weights + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) + return status::unimplemented; + + status_t status = + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: + init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(), + *diff_weights_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_core::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_fmt, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd, true) + , kernel_(nullptr) + { + kernel_ = new jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel( + pd()->jcp_); + } + + ~jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t() + { + delete kernel_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); + + switch (kernel_->jcp.sched_policy) { + case WSCHED_WEI_SDGtWo: + _execute_backward_weights_SDGtWo(src, diff_dst, diff_weights, + diff_bias, scratchpad(ctx)); + break; + case WSCHED_WEI_S_D_Giot_W: + _execute_backward_weights_S_D_Giot_W(src, diff_dst, diff_weights, + diff_bias, scratchpad(ctx)); + break; + default: + assert(kernel_->jcp.sched_policy != WSCHED_INVALID); + break; + } + return status::success; + } + +private: + void _execute_backward_weights_SDGtWo(const float *src, + const float *diff_dst, float *diff_weights, float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const; + void _execute_backward_weights_S_D_Giot_W(const float *src, + const float *diff_dst, float *diff_weights, float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp new file mode 100644 index 000000000000..0d64a2d13a46 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp @@ -0,0 +1,2596 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include + +#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_wino_transform_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { + +using namespace mkldnn::impl::utils; + +unsigned int L1_cache_size = get_cache_size(1, true); +unsigned int L2_cache_size = get_cache_size(2, true); +unsigned int LLC_data_size = get_cache_size(3, false); + +// the test funtion takes jcp, the candidate and the current best. +// it returns true if the new candidate is better +int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number, + int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int)) +{ + int best_divisor = default_best; + auto test_num + = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) { + if (test(jcp, num, best_divisor)) { + best_divisor = num; + } + }; + + for (int divisor = 1; divisor <= ::sqrt(number); divisor++) { + if (number % divisor == 0) { + test_num(jcp, divisor); + test_num(jcp, number / divisor); + } + } + + return best_divisor; +} + +namespace { +bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) { + /* Determines if current winograd implementation is faster than direct. + Following conditions are empirical and based on performance data */ + unsigned int ncores_per_socket = + cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); + unsigned int nthreads = mkldnn_get_max_threads(); + + if (jcp.prop_kind == prop_kind::forward_inference) { + return jcp.mb >= 4; + } else if (nthreads > ncores_per_socket) { + double src_dst_transforms_per_core = alpha * alpha + * (jcp.ic + jcp.oc) + * jcp.mb * ((jcp.oh + tile_size - 1) / tile_size) + * ((jcp.ow + tile_size - 1) / tile_size) + * sizeof(float) / 1024. / 1024. / nthreads; + double wei_transform = alpha * alpha + * jcp.ic * jcp.oc * sizeof(float) /1024. / 1024.; + + if (jcp.prop_kind == prop_kind::backward_weights) { + if (src_dst_transforms_per_core < 0.3 + || (src_dst_transforms_per_core <= 28 && wei_transform < 4)) + return false; + else + return true; + } else { + if (src_dst_transforms_per_core < 2.0 || wei_transform < 0.02) + return false; + } + } + + return jcp.mb > 8; +} +} + +/* assumes 512 bits registers */ +/* TODO: add support for strides */ +/* TODO: handle the prefetch distance automatically */ +typedef enum cache_t_ { L1, L2, L3 } cache_t; + +template +struct prefetcher_t { + prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr, + cache_t cache_type, size_t block_size, /* in number of elements*/ + int nb_instructions_in_block, int fma_ipc) + : cg_(generator) + , reg_base_addr_(reg_base_addr) + , cache_type_(cache_type) + , cache_block_size_(block_size) + { + nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t)); + prefetch_spread_ + = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_); + prefetch_blk_ + = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block); + + /* assumption: when fetch in Li, data is already in L(i+1) */ + int cache_latency; + switch (cache_type_) { + case L1: cache_latency = 14; break; + case L2: cache_latency = 250; break; + case L3: cache_latency = 250; break; + } + + prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_); + } + + void prefetch(int instruction_number) + { + if (instruction_number % prefetch_spread_ == 0) { + for (int i = 0; (i < prefetch_blk_) + && (prefetches_issued_ < nb_cache_lines_to_prefetch_); + i++, prefetches_issued_++) { + prefetch_inst_(cg_->EVEX_compress_addr( + reg_base_addr_, (cache_block_size_ * prefetch_distance_) + * sizeof(data_t) + + (prefetches_issued_ * 64))); + } + } + } + +private: + void prefetch_inst_(const Xbyak::Address &addr) + { + switch (cache_type_) { + case L1: cg_->prefetcht0(addr); break; + case L2: cg_->prefetcht1(addr); break; + case L3: cg_->prefetcht2(addr); break; + default: + break; // TODO: raise an exception or put an assert + } + } + + jit_generator *cg_; + Xbyak::Reg64 reg_base_addr_; + cache_t cache_type_; + int cache_block_size_ = 0; + int nb_cache_lines_to_prefetch_ = 0; + int prefetches_issued_ = 0; + int prefetch_spread_ = 0; + int prefetch_blk_ = 0; + int prefetch_distance_ = 0; +}; + +// utilities to support kernel parameter selection +bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp, + int dimN_block, float C2_min, float C2_max) { + float block_size = alpha * alpha * (2*(jcp.oc + jcp.ic) + * dimN_block * jcp.dimN_reg_block + + div_up(jcp.ic * jcp.oc,mkldnn_get_max_threads())) * (float)sizeof(float); + float L2_lb = C2_min * L2_cache_size; + float L2_ub = C2_max * L2_cache_size; + return (block_size > L2_lb && block_size < L2_ub); +} + +bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block, + int dimM_block, float C1_min, float C1_max) { + float gemm_block_size = (dimM_block * jcp.dimM_simd_block * dimK_block + * jcp.dimK_reg_block * jcp.dimM_reg_block + + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block + + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block) + * (float)sizeof(float); + float L1_lb = C1_min * L1_cache_size; + float L1_ub = C1_max * L1_cache_size; + return (gemm_block_size > L1_lb && gemm_block_size < L1_ub); +} +bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimN_reg_block * dimM_simd_block * dimM_reg_block + + dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block * dimM_reg_block + + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} +bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimM_reg_block * dimK_block * dimK_reg_block + * dimM_simd_block + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} +bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block, + int dimK_block, int dimK_reg_block, int dimM_block, int dimM_reg_block, + int dimM_simd_block, float C) +{ + float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block + * dimM_simd_block * dimM_reg_block + + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block * dimM_reg_block + + nb_dimN_reg_block * dimK_nb_block * dimK_block + * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L2_cache_size; + return (lhs < rhs); +} + +bool check_kernel_cond(int dimM_block, int dimM_reg_block, int dimM_simd_block, + int dimN_block, int dimN_reg_block, int dimK, float C1, float C2) +{ + float A_size = dimM_block * dimM_reg_block * dimM_simd_block * dimK + * (float)sizeof(float); + float B_size = dimN_block * dimN_reg_block * dimK + * (float)sizeof(float); + return (A_size > C1 * L2_cache_size && B_size > C2 * L2_cache_size); +} +} + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::gemm_loop_generate() +{ + // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++) + // for (int dimM_reg_block =0; dimM_reg_block < jcp.dimM_reg_block; + // dimM_reg_block++) // unrolled + // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++) + // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block; + // dimK_reg_block++) // unrolled + // for (int tile =0; tile < jcp.dimN_reg_block; tile++) + // C[dimM_block][dimM_reg_block][tile] += + // A[dimM_block][dimM_reg_block][dimK_block][dimK_reg_block] + // * broadcast(B[dimK_block][tile][dimK_reg_block]); + // Notes: + // jcp.kernel_kind defines embedded or explicit broadcast + // dimM_reg_block=1 for embedded bcast kernel + + auto zmm_srcA = [=]() { + return Xbyak::Zmm(0); + }; + auto zmm_srcB = [=](int tile) { + int idx = 1 + tile; + assert(idx < 1 + jcp.dimN_reg_block); + return Xbyak::Zmm(idx); + }; + auto zmm_dstC = [=](int dimM_reg_block, int tile) { + int idx{0}; + if (jcp.kernel_kind == embd_bcast) + idx = 1 + tile; + else + idx = 1 + jcp.dimN_reg_block + + dimM_reg_block * jcp.dimN_reg_block + tile; + assert(idx < 32); + return Xbyak::Zmm(idx); + }; + + auto prepare_output = [=]() { + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm = zmm_dstC(dimM_reg_block, tile); + vpxord(zmm, zmm, zmm); + } + } + }; + auto store_output = [=](bool output_is_aligned) { + Label save; + cmp(reg_is_beta_zero, 0); + je(save, T_NEAR); + + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm = zmm_dstC(dimM_reg_block,tile); + int output_offset + = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64; + vaddps(zmm, zmm, EVEX_compress_addr(reg_dstC, output_offset)); + } + } + + L(save); + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm = zmm_dstC(dimM_reg_block,tile); + int output_offset + = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64; + + // In W_SGD, output will be reused. + if (output_is_aligned + && jcp.dimK_nb_block == 1 + && jcp.sched_policy == WSCHED_DATA_W_S_G_D + && (jcp.dimN * jcp.dimM * alpha * alpha + * sizeof(float) > 2 * LLC_data_size)) + vmovntps(EVEX_compress_addr(reg_dstC, output_offset), zmm); + else vmovups(EVEX_compress_addr(reg_dstC, output_offset), zmm); + } + } + }; + + auto inner_loops = [=]() { + Label dimM_block_loop, dimK_block_loop; + + if (jcp.dimM_block > 1) { + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + } + + prepare_output(); + + if (jcp.dimK_block > 1) { + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + } + + for (int dimK_reg_block = 0; + dimK_reg_block < jcp.dimK_reg_block; + dimK_reg_block ++) { + + if (jcp.kernel_kind == expl_bcast) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + vbroadcastss(zmm_srcB(tile), + ptr[reg_srcB + 64 * tile + dimK_reg_block * 4]); + } + } + + /* Performing the fmas */ + + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + + vmovups(zmm_srcA(), + zword[reg_srcA + + jcp.dimK_reg_block * jcp.dimK_block * 64 + * dimM_reg_block + + dimK_reg_block * 64] + ); + + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + if (jcp.kernel_kind == expl_bcast) + vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(), + zmm_srcB(tile)); + else + vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(), + EVEX_compress_addr(reg_srcB, + 64 * tile + dimK_reg_block * 4, true)); + } + } + } + add(reg_srcA, jcp.dimK_reg_block * 64); + add(reg_srcB, jcp.dimN_reg_block * 64); + if (jcp.dimK_block > 1) { + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + + Label unaligned_store, end_store; + test(reg_dstC, cpu_isa_traits::vlen - 1); + jnz(unaligned_store, T_NEAR); + store_output(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + store_output(false); + } + L(end_store); + + if (jcp.dimM_block > 1) { + sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64); + add(reg_dstC, jcp.dimM_reg_block * jcp.dimN_reg_block * 64); + if (jcp.kernel_kind == expl_bcast) { + add(reg_srcA, + (jcp.dimM_reg_block-1) * jcp.dimK_reg_block * 64 + * jcp.dimK_block); + } + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + }; + + /* Preamble */ + preamble(); + + /* kernel */ + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + ::weights_transform_data_ker_generate() +{ + bool is_fwd = one_of(jcp.prop_kind, + mkldnn_forward_training, mkldnn_forward_inference); + int kh = jcp.kh; + int kw = jcp.kw; + + auto zmm_temp = Xbyak::Zmm(31); + auto zmm_zero = Xbyak::Zmm(30); + + auto zmm_M = [=](int i) { + return Xbyak::Zmm(i); + }; + auto zmm_MT = [=](int i) { + return Xbyak::Zmm(i + simd_w); + }; + + auto zmm_G = [=](int i) { + return Xbyak::Zmm(i); + }; + auto zmm_F = [=](int i) { + return Xbyak::Zmm(alpha + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(alpha + 3 + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(2 * alpha + 3 + i); + }; + + auto zmm_load = [=](int i) { + return Xbyak::Zmm(i); + }; + + auto init_G = [=]() { + mov(wreg_temp, ptr[param1 + GET_OFF(G)]); + for (int i = 0; i < alpha; i++) { + vbroadcastss(zmm_G(i), ptr[wreg_temp + i * typesize]); + } + vpxord(zmm_zero, zmm_zero, zmm_zero); + }; + + auto trans16x16 = [=]() { + for (int i = 0; i < simd_w; i+=2 ) { + vmovups(zmm_M(i), ptr[wreg_M + i * simd_w * 4]); + vmovups(zmm_M(i+1), ptr[wreg_M + (i + 1) * simd_w * 4]); + vunpcklps(zmm_MT(i), zmm_M(i), zmm_M(i+1)); + vunpckhps(zmm_MT(i+1), zmm_M(i), zmm_M(i+1)); + } + for (int i = 0; i < simd_w; i+=4 ) { + vunpcklpd(zmm_M(i), zmm_MT(i), zmm_MT(i+2)); + vunpckhpd(zmm_M(i+1), zmm_MT(i), zmm_MT(i+2)); + vunpcklpd(zmm_M(i+2), zmm_MT(i+1), zmm_MT(i+3)); + vunpckhpd(zmm_M(i+3), zmm_MT(i+1), zmm_MT(i+3)); + } + for (int i = 0; i < simd_w; i += 8) { + vshuff32x4(zmm_MT(i), zmm_M(i), zmm_M(i + 4), 0x88); + vshuff32x4(zmm_MT(i+1), zmm_M(i+1), zmm_M(i + 5), 0x88); + vshuff32x4(zmm_MT(i+2), zmm_M(i+2), zmm_M(i + 6), 0x88); + vshuff32x4(zmm_MT(i+3), zmm_M(i+3), zmm_M(i + 7), 0x88); + vshuff32x4(zmm_MT(i+4), zmm_M(i), zmm_M(i + 4), 0xdd); + vshuff32x4(zmm_MT(i+5), zmm_M(i+1), zmm_M(i + 5), 0xdd); + vshuff32x4(zmm_MT(i+6), zmm_M(i+2), zmm_M(i + 6), 0xdd); + vshuff32x4(zmm_MT(i+7), zmm_M(i+3), zmm_M(i + 7), 0xdd); + } + { + int i = 0; + int mask = 0x88; + vshuff32x4(zmm_M(0), zmm_MT(i), zmm_MT(i + 8), mask); + vmovups(ptr[wreg_MT + 0 * 16 * 4], zmm_M(0)); + vshuff32x4(zmm_M(1), zmm_MT(i + 1), zmm_MT(i + 9), mask); + vmovups(ptr[wreg_MT + 1 * 16 * 4], zmm_M(1)); + vshuff32x4(zmm_M(2), zmm_MT(i + 2), zmm_MT(i + 10), mask); + vmovups(ptr[wreg_MT + 2 * 16 * 4], zmm_M(2)); + vshuff32x4(zmm_M(3), zmm_MT(i + 3), zmm_MT(i + 11), mask); + vmovups(ptr[wreg_MT + 3 * 16 * 4], zmm_M(3)); + vshuff32x4(zmm_M(4), zmm_MT(i + 4), zmm_MT(i + 12), mask); + vmovups(ptr[wreg_MT + 4 * 16 * 4], zmm_M(4)); + vshuff32x4(zmm_M(5), zmm_MT(i + 5), zmm_MT(i + 13), mask); + vmovups(ptr[wreg_MT + 5 * 16 * 4], zmm_M(5)); + vshuff32x4(zmm_M(6), zmm_MT(i + 6), zmm_MT(i + 14), mask); + vmovups(ptr[wreg_MT + 6 * 16 * 4], zmm_M(6)); + vshuff32x4(zmm_M(7), zmm_MT(i + 7), zmm_MT(i + 15), mask); + vmovups(ptr[wreg_MT + 7 * 16 * 4], zmm_M(7)); + mask = 0xdd; + vshuff32x4(zmm_M(8), zmm_MT(i), zmm_MT(i + 8), mask); + vmovups(ptr[wreg_MT + 8 * 16 * 4], zmm_M(8)); + vshuff32x4(zmm_M(9), zmm_MT(i + 1), zmm_MT(i + 9), mask); + vmovups(ptr[wreg_MT + 9 * 16 * 4], zmm_M(9)); + vshuff32x4(zmm_M(10), zmm_MT(i + 2), zmm_MT(i + 10), mask); + vmovups(ptr[wreg_MT + 10 * 16 * 4], zmm_M(10)); + vshuff32x4(zmm_M(11), zmm_MT(i + 3), zmm_MT(i + 11), mask); + vmovups(ptr[wreg_MT + 11 * 16 * 4], zmm_M(11)); + vshuff32x4(zmm_M(12), zmm_MT(i + 4), zmm_MT(i + 12), mask); + vmovups(ptr[wreg_MT + 12 * 16 * 4], zmm_M(12)); + vshuff32x4(zmm_M(13), zmm_MT(i + 5), zmm_MT(i + 13), mask); + vmovups(ptr[wreg_MT + 13 * 16 * 4], zmm_M(13)); + vshuff32x4(zmm_M(14), zmm_MT(i + 6), zmm_MT(i + 14), mask); + vmovups(ptr[wreg_MT + 14 * 16 * 4], zmm_M(14)); + vshuff32x4(zmm_M(15), zmm_MT(i + 7), zmm_MT(i + 15), mask); + vmovups(ptr[wreg_MT + 15 * 16 * 4], zmm_M(15)); + } + }; + + auto load_src = [=]() { + mov(wreg_src, ptr[param1 + GET_OFF(src)]); + mov(wreg_F, ptr[param1 + GET_OFF(M)]); + for (int j = 0; j < kh; j++) { + for (int i = 0; i < kw; i++) { + if (is_fwd) { + for (int v1 = 0; v1 < simd_w; v1++) { + int offset_src = (j * kw * simd_w * simd_w + + i * simd_w * simd_w + v1 * simd_w) * typesize; + int offset_F = (j * kw * simd_w * simd_w + + i * simd_w * simd_w + v1 * simd_w) * typesize; + vmovups(zmm_temp, ptr[wreg_src + offset_src]); + vmovups(ptr[wreg_F + offset_F], zmm_temp); + } + } else { + int offset_src = ((2 - j) * kw * simd_w * simd_w + + (2 - i) * simd_w * simd_w) * typesize; + int offset_F = (j * kw * simd_w * simd_w + + i * simd_w * simd_w) * typesize; + lea(wreg_M, ptr[wreg_src + offset_src]); + lea(wreg_MT, ptr[wreg_F + offset_F]); + trans16x16(); + } + } + } + }; + + auto store_dst = [=]() { + mov(wreg_dst, ptr[param1 + GET_OFF(dst)]); + mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]); + + Label Loop_j; + mov(wreg_cnt_j, 0); + mov(wreg_dst_aux, wreg_dst); + mov(wreg_Fw_aux, wreg_Fw); + + int dim5 = jcp.dimK_nb_block * (jcp.dimM_block * jcp.dimM_reg_block) + * jcp.dimK_block * simd_w * simd_w; + + L(Loop_j); + { + for (int i = 0; i < alpha; i++) { + // touch pages + vmovups(zmm_load(0), ptr[wreg_Fw_aux + + (i * simd_w * simd_w) * typesize]); + mov(wreg_dst_idx, i * dim5 * typesize); + vmovntps(ptr[wreg_dst_aux + wreg_dst_idx], zmm_load(0)); + } + for (int i = 0; i < alpha; i++) { + for (int v1 = 1; v1 < simd_w; v1++) { + int offset_Fw = (i * simd_w * simd_w + v1 * simd_w) + * typesize; + vmovups(zmm_load(v1), ptr[wreg_Fw_aux + offset_Fw]); + } + mov(wreg_dst_idx, i * dim5 * typesize); + for (int v1 = 1; v1 < simd_w; v1++) { + int offset_dst = v1 * simd_w * typesize; + vmovntps(ptr[wreg_dst_aux + wreg_dst_idx + offset_dst], + zmm_load(v1)); + } + } + add(wreg_Fw_aux, alpha * simd_w * simd_w * typesize); + add(wreg_dst_aux, alpha * dim5 * typesize); + add(wreg_cnt_j, 1); + cmp(wreg_cnt_j, alpha); + jl(Loop_j, T_NEAR); + } + }; + + auto trans_W_4x4_3x3 = [=]() { + auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vmovups(dst, a); + vfmadd231ps(dst, b, c); + }; + auto fms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vmulps(zmm_temp, b, c); + vsubps(dst, a, zmm_temp); + }; + auto fnms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vsubps(dst, zmm_zero, a); + vfnmadd231ps(dst, b, c); + }; + + mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]); + mov(wreg_F, ptr[param1 + GET_OFF(M)]); + mov(wreg_T, ptr[param1 + GET_OFF(T)]); + + Label Loop_j; + mov(wreg_cnt_j, 0); + L(Loop_j); + mov(wreg_F_aux, wreg_F); + mov(wreg_Fw_aux, wreg_Fw); + mov(wreg_temp, wreg_cnt_j); + shl(wreg_temp, 4 + 2); + lea(wreg_F_aux, ptr[wreg_F + wreg_temp]); + lea(wreg_Fw_aux, ptr[wreg_Fw + wreg_temp]); + + for (int i = 0; i < 3; i++) { + for (int idx = 0; idx < 3; idx ++) { + vmovups(zmm_F(idx), ptr[wreg_F_aux + (idx * 3 * simd_w + * simd_w + i * simd_w * simd_w) * typesize]); + } + vmulps(zmm_t(0), zmm_G(0), zmm_F(2)); + fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_F(0)); + fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_F(0)); + + vmulps(zmm_T(0), zmm_G(3), zmm_F(0)); + fms4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_F(1)); + fma4(zmm_T(2), zmm_t(1), zmm_G(4), zmm_F(1)); + fma4(zmm_T(3), zmm_t(2), zmm_G(5), zmm_F(1)); + fms4(zmm_T(4), zmm_t(2), zmm_G(5), zmm_F(1)); + vmovaps(zmm_T(5), zmm_F(2)); + + for (int idx = 0; idx < 6; idx ++) { + vmovups(ptr[wreg_T + (idx * 3 * simd_w + i * simd_w) + * typesize], zmm_T(idx)); + } + } + for (int i = 0; i < 6; i++) { + + for (int idx = 0; idx < 3; idx ++) { + vmovups(zmm_T(idx), ptr[wreg_T + + (i * 3 * simd_w + idx * simd_w) * typesize]); + } + vmulps(zmm_t(0), zmm_G(0), zmm_T(2)); + fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_T(0)); + fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_T(0)); + + vmulps(zmm_F(0), zmm_G(3), zmm_T(0)); + fms4(zmm_F(1), zmm_t(1), zmm_G(4), zmm_T(1)); + fma4(zmm_F(2), zmm_t(1), zmm_G(4), zmm_T(1)); + fma4(zmm_F(3), zmm_t(2), zmm_G(5), zmm_T(1)); + fms4(zmm_F(4), zmm_t(2), zmm_G(5), zmm_T(1)); + vmovaps(zmm_F(5), zmm_T(2)); + + for (int l = 0; l < 6; l++) { + vmovups(ptr[wreg_Fw_aux + (i * 6 * simd_w * simd_w + + l * simd_w * simd_w) * typesize], zmm_F(l)); + } + } + add(wreg_cnt_j, 1); + cmp(wreg_cnt_j, 16); + jl(Loop_j, T_NEAR); + }; + + auto inner_loops = [=]() { + load_src(); + init_G(); + trans_W_4x4_3x3(); + store_dst(); + }; + + preamble(); + inner_loops(); + postamble(); +} + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + ::output_transform_data_ker_generate() +{ + bool is_fwd = one_of(jcp.prop_kind, + mkldnn_forward_training, mkldnn_forward_inference); + int outw = is_fwd ? jcp.ow : jcp.iw; + int outh = is_fwd ? jcp.oh : jcp.ih; + bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D; + bool with_bias = jcp.with_bias; + bool with_relu = jcp.with_eltwise; + bool with_relu_postsum = jcp.with_relu_postsum; + bool with_sum = jcp.with_sum; + + auto zmm_zero = Xbyak::Zmm(0); + auto zmm_temp = Xbyak::Zmm(31); + auto zmm_G = [=](int i) { + return Xbyak::Zmm(1 + i); + }; + auto zmm_O = [=](int i) { + return Xbyak::Zmm(1 + alpha + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(1 + 2 * alpha + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(1 + 3 * alpha + i); + }; + + auto init_G = [=]() { + mov(oreg_temp, ptr[param1 + GET_OFF(G)]); + for (int i = 0; i < 6; i++) { + vbroadcastss(zmm_G(i), ptr[oreg_temp + i * typesize]); + } + }; + + auto load_src = [=]() { + mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]); + mov(oreg_src, ptr[param1 + GET_OFF(src)]); + + mov(oreg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]); + imul(oreg_nb_tile_block_ur, oreg_nb_tile_block_ur, + (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block + * jcp.dimM_simd_block * typesize); + add(oreg_src, oreg_nb_tile_block_ur); + + mov(oreg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]); + imul(oreg_tile_block_ur, oreg_tile_block_ur, + jcp.dimM_simd_block * typesize); + add(oreg_src, oreg_tile_block_ur); + + if (not_tiled) { + mov(oreg_tile_block, ptr[param1 + GET_OFF(tile_block)]); + imul(oreg_tile_block, oreg_tile_block, + jcp.dimM_nb_block * alpha * alpha * jcp.dimN_block + * (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block + * jcp.dimM_simd_block * typesize); + add(oreg_src, oreg_tile_block); + } + + int last4dim = jcp.dimN_block * (jcp.dimM_block * jcp.dimM_reg_block) + * jcp.dimN_reg_block * jcp.dimM_simd_block * typesize; + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + int j_base_offset = j * alpha * last4dim; + int i_base_offset = i * last4dim; + vmovups(zmm_temp, ptr[oreg_src + j_base_offset + i_base_offset]); + vmovups(ptr[oreg_Ow + (j * alpha * simd_w + i * simd_w) + * typesize], zmm_temp); + } + } + }; + + auto store_dst = [=]() { + vpxord(zmm_zero, zmm_zero, zmm_zero); + mov(oreg_dst, ptr[param1 + GET_OFF(dst)]); + mov(oreg_O, ptr[param1 + GET_OFF(M)]); + mov(oreg_ydim, ptr[param1 + GET_OFF(tj)]); + shl(oreg_ydim, 2); // tj * tile_size (==4) + mov(oreg_xdim, ptr[param1 + GET_OFF(ti)]); + shl(oreg_xdim, 2); // ti * tilesize (==4) + + if (with_bias) + mov(oreg_bias, ptr[param1 + GET_OFF(bias)]); + + auto store_one = [=](int j, int i, bool is_aligned) { + auto zmm_O = Xbyak::Zmm(31); + auto zmm_relu_ns = Xbyak::Zmm(30); + auto xmm_relu_ns = Xbyak::Xmm(30); + int offset = (j * tile_size * simd_w + i * simd_w) * typesize; + + vmovups(zmm_O, ptr[oreg_O + offset]); + if (is_fwd) { + if (with_bias) { + vaddps(zmm_O, zmm_O, ptr[oreg_bias]); + } + if (with_relu) { + if (jcp.eltwise.alpha == 0) { + vmaxps(zmm_O, zmm_O, zmm_zero); + } else { + Opmask kmask = Opmask(7); + mov(imm_addr64, float2int(jcp.eltwise.alpha)); + vmovq(xmm_relu_ns, imm_addr64); + vbroadcastss(zmm_relu_ns, xmm_relu_ns); + vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os); + vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns); + } + } + } + if (with_sum) { + vaddps(zmm_O, zmm_O, ptr[oreg_out_j + oreg_temp]); + if (with_relu_postsum) // orig: with_relu_postsum + vmaxps(zmm_O, zmm_O, zmm_zero); + } + if (is_aligned) + vmovntps(ptr[oreg_out_j + oreg_temp], zmm_O); + else + vmovups(ptr[oreg_out_j + oreg_temp], zmm_O); + }; + + auto i_loop = [=](int j, bool is_aligned) { + for (int i = 0; i < tile_size; i++) { + Label next; + mov(oreg_temp, oreg_xdim); + add(oreg_temp, i); + cmp(oreg_temp, outw); + jge(next, T_NEAR); + shl(oreg_temp, 4 + 2); // * 16 * 4 + + store_one(j, i, is_aligned); + + L(next); + } + }; + + + for (int j = 0; j < tile_size; j++) { + Label next, unaligned; + mov(oreg_temp, oreg_ydim); + add(oreg_temp, j); + cmp(oreg_temp, outh); + jge(next, T_NEAR); + + mov(oreg_out_j, oreg_dst); + imul(oreg_temp, oreg_temp, outw * simd_w * typesize); + add(oreg_out_j, oreg_temp); + + test(oreg_dst, 63); + jnz(unaligned, T_NEAR); + + i_loop(j, true); + jmp(next, T_NEAR); + + L(unaligned); + i_loop(j, false); + + L(next); + } + }; + + auto trans_O_4x4_3x3 = [=]() { + auto fma2 = [=](Zmm dst, Zmm v1, Zmm u1, Zmm v2, Zmm u2){ + vmulps(dst, v1, u1); + vfmadd231ps(dst, v2, u2); + }; + mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]); + mov(oreg_T, ptr[param1 + GET_OFF(T)]); + mov(oreg_O, ptr[param1 + GET_OFF(M)]); + + for (int i = 0; i < alpha; i++) { + for (int j = 0; j < alpha; j++) { + vmovups(zmm_O(j), ptr[oreg_Ow + (j * alpha * simd_w + + i * simd_w) * typesize]); + } + + vaddps(zmm_t(0), zmm_O(1), zmm_O(2)); + vaddps(zmm_t(1), zmm_O(3), zmm_O(4)); + vsubps(zmm_t(2), zmm_O(1), zmm_O(2)); + vsubps(zmm_t(3), zmm_O(3), zmm_O(4)); + + vaddps(zmm_T(0), zmm_t(0), zmm_t(1)); + vaddps(zmm_T(0), zmm_T(0), zmm_O(0)); + fma2(zmm_T(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1)); + fma2(zmm_T(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3)); + fma2(zmm_T(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5)); + vaddps(zmm_T(3), zmm_T(3), zmm_O(5)); + + for (int j = 0; j < tile_size; j++) { + vmovups(ptr[oreg_T + (j * alpha * simd_w + + i * simd_w) * typesize], zmm_T(j)); + } + } + for (int j = 0; j < tile_size; j++) { + for (int i = 0; i < alpha; i++) { + vmovups(zmm_T(i), ptr[oreg_T + (j * alpha * simd_w + + i * simd_w) * typesize]); + } + vaddps(zmm_t(0), zmm_T(1), zmm_T(2)); + vaddps(zmm_t(1), zmm_T(3), zmm_T(4)); + vsubps(zmm_t(2), zmm_T(1), zmm_T(2)); + vsubps(zmm_t(3), zmm_T(3), zmm_T(4)); + + vaddps(zmm_O(0), zmm_t(0), zmm_t(1)); + vaddps(zmm_O(0), zmm_O(0), zmm_T(0)); + fma2(zmm_O(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1)); + fma2(zmm_O(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3)); + fma2(zmm_O(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5)); + vaddps(zmm_O(3), zmm_O(3), zmm_T(5)); + + for (int i = 0; i < tile_size; i++) { + vmovups(ptr[oreg_O + (j * tile_size * simd_w + + i * simd_w) * typesize], zmm_O(i)); + } + } + }; + + auto inner_loops = [=]() { + init_G(); + load_src(); + trans_O_4x4_3x3(); + store_dst(); + }; + + preamble(); + inner_loops(); + postamble(); +} + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + ::input_transform_data_ker_generate() +{ + bool is_fwd = one_of(jcp.prop_kind, + mkldnn_forward_training, mkldnn_forward_inference); + int inpw = is_fwd ? jcp.iw : jcp.ow; + int inph = is_fwd ? jcp.ih : jcp.oh; + int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow; + int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh; + int wp_max = inpw + l_pad; + int hp_max = inph + t_pad; + bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D; + int G_size = 9; + + auto zmm_zero = Xbyak::Zmm(0); + auto zmm_temp = Xbyak::Zmm(31); + auto zmm_G = [=](int i) { + return Xbyak::Zmm(1 + i); + }; + auto zmm_I = [=](int i) { + return Xbyak::Zmm(1 + G_size + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(1 + G_size + alpha + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(1 + G_size + 2 * alpha + i); + }; + + auto init_G = [=]() { + mov(ireg_temp, ptr[param1 + GET_OFF(G)]); + for (int i = 0; i < G_size; i++) { + vbroadcastss(zmm_G(i), ptr[ireg_temp + i * typesize]); + } + }; + + auto load_src = [=]() { + mov(ireg_src, ptr[param1 + GET_OFF(src)]); // base addr of inp + mov(ireg_I, ptr[param1 + GET_OFF(M)]); + + xor_(ireg_zero, ireg_zero); + vpxord(zmm_zero, zmm_zero, zmm_zero); + + mov(ireg_ydim, ptr[param1 + GET_OFF(tj)]); + shl(ireg_ydim, 2); // tj * tile_size (==4) + mov(ireg_xdim, ptr[param1 + GET_OFF(ti)]); + shl(ireg_xdim, 2); // ti * tilesize (==4) + + for (int j = 0; j < alpha; j++) { + mov(ireg_temp, ireg_ydim); + add(ireg_temp, j); + + mov(ireg_mask_j, 0xffff); + cmp(ireg_temp, t_pad); + cmovl(ireg_mask_j, ireg_zero); + cmp(ireg_temp, hp_max); + cmovge(ireg_mask_j, ireg_zero); + + sub(ireg_temp, t_pad); + imul(ireg_temp, ireg_temp, inpw * simd_w * typesize); + mov(ireg_inp_j, ireg_src); + add(ireg_inp_j, ireg_temp); + + for (int i = 0; i < alpha; i++) { + + mov(ireg_temp, ireg_xdim); + add(ireg_temp, i); + + mov(ireg_mask, 0xffff); + cmp(ireg_temp, l_pad); + cmovl(ireg_mask, ireg_zero); + cmp(ireg_temp, wp_max); + cmovge(ireg_mask, ireg_zero); + and_(ireg_mask, ireg_mask_j); + + sub(ireg_temp, l_pad); + shl(ireg_temp, 4 + 2); + + vpxord(zmm_temp, zmm_temp, zmm_temp); + Opmask kmask = Opmask(7); + kmovw(kmask, ireg_mask_32); + vmovups(zmm_temp | kmask, ptr[ireg_inp_j + ireg_temp]); + vmovups(ptr[ireg_I + (j * alpha * simd_w + i * simd_w) + * typesize], zmm_temp); + } + } + }; + + auto store_Iw = [=]() { + + mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]); + mov(ireg_output, ptr[param1 + GET_OFF(dst)]); + + bool streamout + = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float) + > 2 * LLC_data_size + ? true : false; + + if (not_tiled) { + mov(ireg_tile_block, ptr[param1 + GET_OFF(tile_block)]); + imul(ireg_tile_block, ireg_tile_block, + alpha * alpha * jcp.dimN_block * jcp.dimK_nb_block + * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block + * typesize); + } + + mov(ireg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]); + imul(ireg_nb_tile_block_ur, ireg_nb_tile_block_ur, + jcp.dimK_nb_block * jcp.dimK_block * jcp.dimN_reg_block + * jcp.dimK_reg_block * typesize); + + mov(ireg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]); + imul(ireg_tile_block_ur, ireg_tile_block_ur, + jcp.dimK_reg_block * typesize); + + add(ireg_output, ireg_nb_tile_block_ur); + add(ireg_output, ireg_tile_block_ur); + if (not_tiled) + add(ireg_output, ireg_tile_block); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + vmovups(zmm_temp,ptr[ireg_Iw + (j * alpha * simd_w + + i * simd_w) * typesize]); + + int j_base_offset = + j * alpha * jcp.dimN_block * jcp.dimK_nb_block + * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block + * typesize; + int i_base_offset = + i * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block + * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize; + + if (not_tiled && streamout) + vmovntps(ptr[ireg_output + j_base_offset + i_base_offset], + zmm_temp); + else + vmovups(ptr[ireg_output + j_base_offset + i_base_offset], + zmm_temp); + } + } + }; + + auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vmulps(zmm_temp, a, b); + vaddps(dst, zmm_temp, c); + }; + + auto trans_I_4x4_3x3 = [=]() { + mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]); + mov(ireg_T, ptr[param1 + GET_OFF(T)]); + mov(ireg_I, ptr[param1 + GET_OFF(M)]); + + mov(ireg_output, ptr[param1 + GET_OFF(dst)]); // for prefetch + for (int i = 0; i < alpha; i++) { + for (int idx = 0; idx < alpha; idx++) { + vmovups(zmm_I(idx), ptr[ireg_I + (idx * alpha * simd_w + + i * simd_w) * typesize]); + int j_base_offset = + i * alpha * jcp.dimN_block * jcp.dimK_nb_block + * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block + * typesize; + int idx_base_offset = + idx * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block + * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize; + prefetcht0(ptr[ireg_output + j_base_offset + idx_base_offset]); + } + + fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4)); + fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3)); + fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4)); + fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3)); + fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4)); + fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5)); + + fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4)); + fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5)); + + for (int idx = 0; idx < alpha; idx++) { + vmovups(ptr[ireg_T + (idx * alpha * simd_w + i * simd_w) + * typesize],zmm_T(idx)); + } + } + for (int i = 0; i < alpha; i++) { + for (int idx = 0; idx < alpha; idx++) { + vmovups(zmm_T(idx), ptr[ireg_T + (i * alpha * simd_w + idx + * simd_w) * typesize]); + } + + fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4)); + fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3)); + fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4)); + fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3)); + fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4)); + fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5)); + + fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4)); + fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5)); + + for (int idx = 0; idx < alpha; idx++) { + vmovups(ptr[ireg_Iw + (i * alpha * simd_w + idx * simd_w) + * typesize],zmm_I(idx)); + } + } + }; + + auto inner_loops = [=]() { + init_G(); + load_src(); + trans_I_4x4_3x3(); + store_Iw(); + }; + + preamble(); + inner_loops(); + postamble(); +} + +status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_common( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) +{ + if (!mayiuse(avx512_core)) { + return status::unimplemented; + } + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.ver = ver_avx512_core; + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + // Checking conditions not supported by these kernels + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + if (!one_of(weights_d.format_kind(), format_kind::any, format_kind::wino)) { + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + } + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && (one_of(weights_d.format_kind(), + format_kind::any, format_kind::wino) + || (jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0])); + if (!layout_consistency) + return status::unimplemented; + + return status::success; +} + +void set_kernel_dims_reg_block(jit_conv_winograd_conf_t &jcp) { + + /* ----------- dimM reg block ---------------------*/ + auto test_cond_dimM_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimM_reg_block, int current_best) { + int max_dimM_reg_block = jcp.kernel_kind == embd_bcast ? 1 : 4; + return (dimM_reg_block >= 1) + && (dimM_reg_block <= max_dimM_reg_block ) + && (dimM_reg_block > current_best); + }; + jcp.dimM_reg_block = get_divisor_satisfying_cond(jcp, + jcp.dimM/jcp.dimM_simd_block, 1, test_cond_dimM_reg_block); + + /* ----------- dimN reg block ---------------------*/ + + auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_reg_block, int current_best) { + return jcp.kernel_kind == embd_bcast + ? dimN_reg_block < jcp.nb_reg && dimN_reg_block > current_best + : dimN_reg_block >= 1 + && (dimN_reg_block * jcp.dimM_reg_block + dimN_reg_block) + < jcp.nb_reg + && dimN_reg_block > current_best; + }; + jcp.dimN_reg_block = get_divisor_satisfying_cond(jcp, + jcp.dimN, 1, test_cond_dimN_reg_block); +} + +status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) { + if (jcp.ver != ver_avx512_core) + return status::unimplemented; + + jcp.kernel_kind = embd_bcast; + + set_kernel_dims_reg_block(jcp); + + /*-------------- L2 blocking for dimN block ---------*/ + + auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_block, int current_best) { + return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0) + && (dimN_block > current_best) + && ((jcp.dimN / dimN_block / jcp.dimN_reg_block) + >= 1.5 * mkldnn_get_max_threads()); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block); + jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block; + + if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2) + && (jcp.dimN_nb_block >= 1.5 * mkldnn_get_max_threads())) { + + /* ------------------- L1 blocking for GEMM --------------*/ + /* -------------------- Choose dimK block ----------------*/ + + auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp, + int dimK_block, int current_best) { + return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.5) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block); + + if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 1.0)) { + jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block; + + /* -------------- Choose dimM block -------------------*/ + auto test_cond_dimM_block = [](jit_conv_winograd_conf_t &jcp, + int dimM_block, int current_best) { + return check_L1_block_gemm(jcp, jcp.dimK_block, dimM_block, + 0.2, 0.5) && (dimM_block > current_best); + }; + + jcp.dimM_block = get_divisor_satisfying_cond(jcp, + jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1, + test_cond_dimM_block); + jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block + / jcp.dimM_simd_block; + + jcp.sched_policy = WSCHED_DATA_W_SGD; + return status::success; + } + + } + return status::unimplemented; +} + +void set_kernel_blocking_DATA_W_S_G_D(jit_conv_winograd_conf_t &jcp) { + + set_kernel_dims_reg_block(jcp); + + //********************* Choosing dimK_block **********************// + auto test_cond1_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block, + 1, jcp.dimM_reg_block, jcp.dimM_simd_block, .75f) + && (dimK_block > current_best); + }; + + auto test_cond1_bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, dimK_block, + jcp.dimK_reg_block, 1, jcp.dimM_reg_block, + jcp.dimM_simd_block, .9f) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block); + // If we are not able to use streams, we fall back to condition [1] + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block); + jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block; + + //********************* Choosing dimM_block **********************// + auto test_cond1_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block, + jcp.dimM_simd_block, .5f) + && (dimM_block > current_best); + }; + + auto test_cond1_bis_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block, + jcp.dimM_simd_block, .3f) + && (dimM_block > current_best); + }; + + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimM_block = get_divisor_satisfying_cond( + jcp, jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1, + test_cond1_dimM_block); + else + jcp.dimM_block = get_divisor_satisfying_cond(jcp, + jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1, + test_cond1_bis_dimM_block); + jcp.dimM_nb_block = jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_block + * jcp.dimM_reg_block); + + //******************* Choosing dimN_block *******************// + auto test_cond2_dimN_block = []( + jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { + return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block, + jcp.dimM_reg_block, jcp.dimM_simd_block, .9f) + && (dimN_block > current_best); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); + jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block); +} + +status_t set_wsched_DATA_W_S_G_D_avx512_core(jit_conv_winograd_conf_t &jcp) { + + jcp.kernel_kind = expl_bcast; + set_kernel_blocking_DATA_W_S_G_D(jcp); + if (!(check_kernel_cond(jcp.dimM_block, jcp.dimM_reg_block, + jcp.dimM_simd_block, jcp.dimN_block, jcp.dimN_reg_block, jcp.dimK, + .1f, .35f))) { + jcp.kernel_kind = embd_bcast; + set_kernel_blocking_DATA_W_S_G_D(jcp); + } + jcp.sched_policy = WSCHED_DATA_W_S_G_D; + return status::success; +} + +status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK) +{ + jcp.nb_reg = 32; + jcp.dimN = dimN; + jcp.dimK = dimK; + jcp.dimM = dimM; + jcp.sched_policy = WSCHED_INVALID; + + jcp.dimK_reg_block = 16; + jcp.dimM_simd_block = 16; + + if (jcp.kernel_kind == embd_bcast) { + jcp.dimM_reg_block = 1; + } + + if (!(set_wsched_DATA_W_SGD_avx512_core(jcp) == status::success)) + set_wsched_DATA_W_S_G_D_avx512_core(jcp); + + assert(jcp.sched_policy != WSCHED_INVALID); + return status::success; +} + +bool jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_relu(0) || is_sum(0); // relu or sum + case 2: return (is_sum(0) && is_relu(1)) + || (is_relu(0) && is_sum(1)); // sum->relu or relu->sum + case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu + default: return false; + } + + return false; +} + +status_t jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_t &src_md, memory_desc_t &weights_md, + const memory_desc_t &dst_md, const primitive_attr_t &attr) { + + status_t st = init_conf_common(jcp, cd, src_md, weights_md, dst_md); + + if (st != status::success) + return st; + + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + jcp.with_sum = p.find(primitive_kind::sum, 0) != -1; + jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1; + + status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic); + + jcp.ic_simd_block = jcp.dimK_reg_block; + jcp.ic_block = jcp.dimK_block; + jcp.nb_ic = jcp.dimK_nb_block; + jcp.oc_simd_block = jcp.dimM_simd_block; + jcp.oc_block = jcp.dimM_block; + jcp.oc_reg_block = jcp.dimM_reg_block; + jcp.ic_reg_block = 1; + jcp.nb_oc = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + + /* re-create weights primitive descriptor + and set weights wino_blocking */ + if (cd.prop_kind == mkldnn_forward_inference) { + memory_desc_t expect_wei_md = weights_md; + + expect_wei_md.format_kind = format_kind::wino; + expect_wei_md.data_type = data_type::f32; + mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; + wd.wino_format = mkldnn_wino_wei_OBaaIBOIio; + wd.r = 3; + wd.alpha = 6; + + wd.ic = jcp.ic; + wd.oc = jcp.oc; + wd.ic_block = jcp.dimK_reg_block; + wd.oc_block = jcp.dimM_simd_block; + wd.ic2_block = jcp.dimK_block; + wd.oc2_block = jcp.dimM_block * jcp.dimM_reg_block; + size_t max_size = sizeof(float) * wd.alpha * wd.alpha * jcp.ic * jcp.oc; + wd.size = max_size; + wd.adj_scale = 1.f; + + if (weights_md.format_kind == format_kind::any) + weights_md = expect_wei_md; + if (weights_md != expect_wei_md) + return status::unimplemented; + } + + return res; +} + +status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d); + + if (st != status::success) + return st; + + jcp.itiles = (jcp.iw + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc); + + jcp.oc_simd_block = jcp.dimK_reg_block; + jcp.oc_block = jcp.dimK_block; + jcp.nb_oc = jcp.dimK_nb_block; + jcp.ic_simd_block = jcp.dimM_simd_block; + jcp.ic_block = jcp.dimM_block; + jcp.ic_reg_block = jcp.dimM_reg_block; + jcp.oc_reg_block = 1; + jcp.nb_ic = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + + return res; +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: +src_transform_generate() { + constexpr int G_size = 9; + const size_t ifwp = jcp.iw + jcp.l_pad; + const size_t ifhp = jcp.ih + jcp.t_pad; + + auto zmm_G = [=](int i) { + return Xbyak::Zmm(i); + }; + auto zmm_I = [=](int i) { + return Xbyak::Zmm(G_size + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(G_size + alpha + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(G_size + 2 * alpha + i); + }; + + auto init_G = [=]() { + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + for (int i = 0; i < G_size; i++) { + vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]); + } + }; + + auto load_src = [=]() { + mov(reg_I, ptr[reg_transp + GET_OFF(M)]); + xor_(reg_zero, reg_zero); + + mov(reg_ydim, reg_tj); + shl(reg_ydim, 2); //tj * tile_size(=4) + + for (int j = 0; j < alpha; j++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maskj, 0xffff); + cmp(reg_ydim, jcp.t_pad); + cmovl(reg_maskj, reg_zero); + cmp(reg_ydim, ifhp); + cmovge(reg_maskj, reg_zero); + + /*address offset for tile in src*/ + mov(reg_src_offset, reg_ydim); + sub(reg_src_offset, jcp.t_pad); // tj*tile_size - t_pad + imul(reg_src_offset, reg_src_offset, jcp.iw); + + mov(reg_xdim, reg_ti); + shl(reg_xdim, 2); // xdim = ti * tile_size + + add(reg_src_offset, reg_xdim); + sub(reg_src_offset, jcp.l_pad); + imul(reg_src_offset, reg_src_offset, simd_w * typesize); + for (int i = 0; i < alpha; i++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maski, 0xffff); + cmp(reg_xdim, jcp.l_pad); + cmovl(reg_maski, reg_zero); + cmp(reg_xdim, ifwp); + cmovge(reg_maski, reg_zero); + and_(reg_maski, reg_maskj); + + Opmask kmask_src = Xbyak::Opmask(7); + auto zmm_src = Xbyak::Zmm(31); + kmovw(kmask_src, reg_maski_32); + vpxord(zmm_src, zmm_src, zmm_src); + vmovups(zmm_src | kmask_src, ptr[reg_src + reg_src_offset]); + vmovups(ptr[reg_I], zmm_src); + + add(reg_xdim, 1); //xdim = ti * tile_size + i + add(reg_src_offset, simd_w * typesize); + add(reg_I, simd_w * typesize); + } + add(reg_ydim, 1); + } + }; + + auto fma4 = [=](Xbyak::Zmm dst, Xbyak::Zmm a, Xbyak::Zmm b, Xbyak::Zmm c) { + vmovups(dst, c); + vfmadd231ps(dst, a, b); + }; + + auto trans_I_3x3_4x4 = [=]() { + //Use 24 registers + mov(reg_I, ptr[reg_transp + GET_OFF(M)]); + mov(reg_T, ptr[reg_transp + GET_OFF(T)]); + for (int i = 0; i < alpha; i++) { + for (int j = 0; j < alpha; j++) { + size_t I_off = (j * alpha + i) * simd_w * typesize; + vmovups(zmm_I(j), ptr[reg_I + I_off]); + } + + fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4)); + fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3)); + fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4)); + fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3)); + fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4)); + fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5)); + + fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4)); + fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5)); + + for (int j = 0; j < alpha; j++) { + vmovups(ptr[reg_T + (j * alpha + i) * simd_w * typesize], + zmm_T(j)); + } + + } + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + vmovups(zmm_T(i), ptr[reg_T + (j * alpha + i) * simd_w * typesize]); + } + + fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4)); + fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3)); + fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4)); + fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3)); + fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4)); + fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5)); + + fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4)); + fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5)); + + for (int i = 0; i < alpha; i++) { + size_t dst_off = (j * alpha * jcp.ic_block + * jcp.nb_tile_block_ur * jcp.tile_block_ur + + i * jcp.ic_block * jcp.nb_tile_block_ur * jcp.tile_block_ur) + * simd_w * typesize; + vmovups(ptr[reg_dst + dst_off], zmm_I(i)); + } + } + }; + + auto compute_transform_SDGtWo = [=]() { + mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]); + mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]); + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + xor_(reg_tile_count, reg_tile_count); + Label loop_mb, loop_jtiles, loop_itiles, done; + L(loop_mb); + { + L(loop_jtiles); + { + L(loop_itiles); + { + load_src(); + + trans_I_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(done); + + add(reg_dst, simd_w * typesize); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + xor_(reg_tj, reg_tj); + add(reg_src, jcp.ic * jcp.iw * jcp.ih * typesize); + jmp(loop_mb); + } + L(done); + }; + + auto compute_transform = [=]() { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + xor_(reg_ti, reg_ti); + xor_(reg_tj, reg_tj); + + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); + imul(reg_temp, reg_tile_count, simd_w * typesize); + add(reg_dst, reg_temp); + + Label loop_jtiles, loop_itiles, next_tile_block, next_tile; + L(loop_jtiles); + + { + L(loop_itiles); + { + load_src(); + + trans_I_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(next_tile_block); + add(reg_dst, simd_w * typesize); + jmp(next_tile); + + L(next_tile_block); + sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1) + * simd_w * typesize); + size_t tblk_off = alpha * alpha * jcp.ic_block + * jcp.nb_tile_block_ur * jcp.tile_block_ur + * simd_w * typesize; + add(reg_dst, tblk_off); + xor_(reg_tile_count, reg_tile_count); + + L(next_tile); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + }; + + preamble(); + init_G(); + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) + compute_transform_SDGtWo(); + else + compute_transform(); + postamble(); +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: +diff_dst_transform_generate(bool with_bias) { + + constexpr int G_size = 8; + auto zmm_G = [](int i) { + return Xbyak::Zmm(31); + }; + + auto zmm_src = [=](int j, int i) { + return Xbyak::Zmm(G_size + j * 4 + i); + }; + + auto zmm_bias = Xbyak::Zmm(31); + + auto load_src = [=]() { + if (with_bias) vmovups(zmm_bias, ptr[reg_bias]); + mov(reg_ydim, reg_tj); + shl(reg_ydim, 2); //tj * tile_size(=4) + for (int j = 0; j < tile_size; j++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maskj, 0xffff); + cmp(reg_ydim, jcp.oh); + cmovge(reg_maskj, reg_zero); + + /*address offset for tile in src*/ + mov(reg_src_offset, reg_ydim); + imul(reg_src_offset, reg_src_offset, jcp.ow); + + mov(reg_xdim, reg_ti); + shl(reg_xdim, 2); // xdim = ti * tile_size + + add(reg_src_offset, reg_xdim); + imul(reg_src_offset, reg_src_offset, simd_w * typesize); + for (int i = 0; i < tile_size; i++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maski, 0xffff); + cmp(reg_xdim, jcp.ow); + cmovge(reg_maski, reg_zero); + and_(reg_maski, reg_maskj); + + Opmask kmask_src = Xbyak::Opmask(7); + kmovw(kmask_src, reg_maski_32); + vpxord(zmm_src(j, i), zmm_src(j, i), zmm_src(j, i)); + vmovups(zmm_src(j, i) | kmask_src, ptr[reg_src + reg_src_offset]); + if (with_bias) vaddps(zmm_bias | kmask_src, zmm_bias, + ptr[reg_src + reg_src_offset]); + + add(reg_xdim, 1); //xdim = ti * tile_size + i + add(reg_src_offset, simd_w * typesize); + } + add(reg_ydim, 1); + } + if(with_bias) vmovups(ptr[reg_bias], zmm_bias); + }; + + auto zmm_t = [=](int i) { + return Xbyak::Zmm(G_size + 16 + i); + }; + + auto zmm_T = [=](int j, int i) { + return Xbyak::Zmm(j * 4 + i); + }; + + auto movps = [=](Xbyak::Reg64 reg_dst, size_t dst_off, Xbyak::Zmm a) { + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) + vmovups(ptr[reg_dst + dst_off], a); + else + vmovntps(ptr[reg_dst + dst_off], a); + }; + + auto trans_W_3x3_4x4 = [=]() { + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + for (int i = 0; i < tile_size; i++) { + vbroadcastss(zmm_G(0), ptr[reg_G]); + vmulps(zmm_t(0), zmm_src(2, i), zmm_G(0)); + + vbroadcastss(zmm_G(1), ptr[reg_G + typesize]); + vmovups(zmm_t(1), zmm_t(0)); + vfmsub231ps(zmm_t(1), zmm_src(0, i), zmm_G(1)); + + vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]); + vmovups(zmm_t(2), zmm_t(0)); + vfmadd231ps(zmm_t(2), zmm_src(0, i), zmm_G(2)); + + vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]); + vmulps(zmm_t(3), zmm_src(1, i), zmm_G(3)); + + vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]); + vfmadd231ps(zmm_t(3), zmm_src(3, i), zmm_G(4)); + + vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]); + vmulps(zmm_t(4), zmm_src(1, i), zmm_G(5)); + + vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]); + vfmadd231ps(zmm_t(4), zmm_src(3, i), zmm_G(6)); + + vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]); + vmulps(zmm_T(0, i), zmm_src(0, i), zmm_G(7)); + vsubps(zmm_T(1, i), zmm_t(1), zmm_t(3)); + vaddps(zmm_T(2, i), zmm_t(1), zmm_t(3)); + vaddps(zmm_T(3, i), zmm_t(2), zmm_t(4)); + vsubps(zmm_T(4, i), zmm_t(2), zmm_t(4)); + vmovups(zmm_T(5, i), zmm_src(3, i)); + } + + for (int j = 0; j < alpha; j++) { + vbroadcastss(zmm_G(0), ptr[reg_G]); + vmulps(zmm_t(0), zmm_T(j, 2), zmm_G(0)); + + vbroadcastss(zmm_G(1), ptr[reg_G + typesize]); + vmovups(zmm_t(1), zmm_t(0)); + vfmsub231ps(zmm_t(1), zmm_T(j, 0), zmm_G(1)); + + vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]); + vmovups(zmm_t(2), zmm_t(0)); + vfmadd231ps(zmm_t(2), zmm_T(j, 0), zmm_G(2)); + + vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]); + vmulps(zmm_t(3), zmm_T(j, 1), zmm_G(3)); + + vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]); + vfmadd231ps(zmm_t(3), zmm_T(j, 3), zmm_G(4)); + + vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]); + vmulps(zmm_t(4), zmm_T(j, 1), zmm_G(5)); + + vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]); + vfmadd231ps(zmm_t(4), zmm_T(j, 3), zmm_G(6)); + + vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]); + vmulps(zmm_t(0), zmm_T(j, 0), zmm_G(7)); + vsubps(zmm_t(5), zmm_t(1), zmm_t(3)); + vaddps(zmm_t(1), zmm_t(1), zmm_t(3)); + vaddps(zmm_t(6), zmm_t(2), zmm_t(4)); + vsubps(zmm_t(2), zmm_t(2), zmm_t(4)); + vmovups(zmm_t(3), zmm_T(j, 3)); + + int alpha_offset = (jcp.oc / jcp.nb_oc) + * (jcp.ntiles / jcp.tile_block) * typesize; + int dst_off = j * alpha * alpha_offset; + movps(reg_dst, dst_off, zmm_t(0)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(5)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(1)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(6)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(2)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(3)); + } + + }; + auto compute_transform_SDGtWo = [=]() { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]); + + xor_(reg_zero, reg_zero); + xor_(reg_oc_ur, reg_oc_ur); + Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, tiles_done; + + L(loop_oc_ur); + { + mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]); + mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]); + xor_(reg_tile_count, reg_tile_count); + L(loop_mb); + { + L(loop_jtiles); + { + L(loop_itiles); + { + load_src(); + + trans_W_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(tiles_done); + + add(reg_dst, jcp.oc_reg_block * simd_w * typesize); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + xor_(reg_tj, reg_tj); + add(reg_src, jcp.oc * jcp.ow * jcp.oh * typesize); + jmp(loop_mb); + } + + L(tiles_done); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + add(reg_dst, simd_w * typesize); + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + add(reg_src, jcp.oh * jcp.ow * simd_w * typesize); + + if (with_bias) add(reg_bias, simd_w * typesize); + add(reg_oc_ur, 1); + cmp(reg_oc_ur, jcp.oc_reg_block); + jl(loop_oc_ur); + } + }; + + auto compute_transform = [=]() { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]); + + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); + imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize); + add(reg_dst, reg_temp); + + xor_(reg_zero, reg_zero); + xor_(reg_oc_ur, reg_oc_ur); + Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, next_tile_block, next_tile; + + L(loop_oc_ur); + { + xor_(reg_ti, reg_ti); + xor_(reg_tj, reg_tj); + + L(loop_jtiles); + { + L(loop_itiles); + { + load_src(); + + trans_W_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(next_tile_block); + add(reg_dst, jcp.oc_reg_block * simd_w * typesize); + jmp(next_tile); + + L(next_tile_block); + sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1) + * jcp.oc_reg_block * simd_w * typesize); + int tblk_off = alpha * alpha * (jcp.oc/jcp.nb_oc) + * (jcp.ntiles/jcp.tile_block) * typesize; + add(reg_dst, tblk_off); + xor_(reg_tile_count, reg_tile_count); + + L(next_tile); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); + imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize); + add(reg_dst, reg_temp); + add(reg_dst, simd_w * typesize); + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + add(reg_src, jcp.oh * jcp.ow * simd_w * typesize); + + if (with_bias) add(reg_bias, simd_w * typesize); + add(reg_oc_ur, 1); + cmp(reg_oc_ur, jcp.oc_reg_block); + jl(loop_oc_ur); + } + }; + + preamble(); + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) { + compute_transform_SDGtWo(); + } else { + compute_transform(); + } + postamble(); +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: +diff_weights_transform_generate(bool first_tile) { + int G_size = 4; + + auto zmm_G = [](int i) { + return Xbyak::Zmm(i); + }; + + auto init_G = [=]() { + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + for (int i = 0; i < G_size; i++) + vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]); + }; + + auto zmm_src = [=](int i) { + return Xbyak::Zmm(G_size + i); + }; + + auto load_src = [=](int i) { + for (int j = 0; j < alpha; j++) { + size_t alpha_offset = jcp.oc_block * jcp.oc_reg_block + * jcp.ic_block * simd_w * simd_w * typesize; + size_t src_off = (j * alpha + i) * alpha_offset; + vmovups(zmm_src(j), EVEX_compress_addr(reg_src, src_off)); + } + }; + + auto zmm_t = [=](int i) { + return Xbyak::Zmm(G_size + 6 + i); + }; + + auto zmm_T = [=](int j, int i) { + return Xbyak::Zmm(G_size + 6 + 3 + j * 6 + i); + }; + + auto zmm_dst = [=](int i) { + return Xbyak::Zmm(G_size + i); + }; + + auto zmm_temp = Xbyak::Zmm(31); + + auto store_dst = [=](int j) { + for (int i = 0; i < jcp.kw; i++) { + size_t dst_off = (j * jcp.kw + i) * simd_w * simd_w * typesize; + + if (!first_tile) { + vmovups(zmm_temp, EVEX_compress_addr(reg_dst, dst_off)); + vaddps(zmm_dst(i), zmm_dst(i), zmm_temp); + } + vmovntps(EVEX_compress_addr(reg_dst, dst_off), zmm_dst(i)); + } + }; + + auto compute_transform = [=] () { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + + xor_(reg_ic_simd, reg_ic_simd); + Label loop_ic_simd; + L(loop_ic_simd); + { + for (int i = 0; i < alpha; i++) { + load_src(i); + + vaddps(zmm_t(0), zmm_src(1), zmm_src(2)); + vaddps(zmm_t(1), zmm_src(3), zmm_src(4)); + vmovups(zmm_t(2), zmm_src(5)); + vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0)); + + vaddps(zmm_T(0, i), zmm_src(0), zmm_t(0)); + vaddps(zmm_T(0, i), zmm_T(0, i), zmm_t(1)); + vsubps(zmm_T(1, i), zmm_src(1), zmm_src(2)); + vmulps(zmm_T(1, i), zmm_T(1, i), zmm_G(1)); + vsubps(zmm_temp, zmm_src(3), zmm_src(4)); + vfmadd231ps(zmm_T(1, i), zmm_temp, zmm_G(2)); + vmovups(zmm_T(2, i), zmm_t(2)); + vfmadd231ps(zmm_T(2, i), zmm_t(0), zmm_G(3)); + } + + for (int j = 0; j < jcp.kh; j++) { + vaddps(zmm_t(0), zmm_T(j, 1), zmm_T(j, 2)); + vaddps(zmm_t(1), zmm_T(j, 3), zmm_T(j, 4)); + vmovups(zmm_t(2), zmm_T(j, 5)); + vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0)); + + vaddps(zmm_dst(0), zmm_T(j, 0), zmm_t(0)); + vaddps(zmm_dst(0), zmm_dst(0), zmm_t(1)); + vsubps(zmm_dst(1), zmm_T(j, 1), zmm_T(j, 2)); + vmulps(zmm_dst(1), zmm_dst(1), zmm_G(1)); + vsubps(zmm_temp, zmm_T(j, 3), zmm_T(j, 4)); + vfmadd231ps(zmm_dst(1), zmm_temp, zmm_G(2)); + vmovups(zmm_dst(2), zmm_t(2)); + vfmadd231ps(zmm_dst(2), zmm_t(0), zmm_G(3)); + + store_dst(j); + } + + add(reg_src, jcp.oc_reg_block * simd_w * typesize); + add(reg_dst, simd_w * typesize); + add(reg_ic_simd, 1); + cmp(reg_ic_simd, simd_w); + jl(loop_ic_simd); + } + }; + preamble(); + push(reg_EVEX_max_8b_offt); + mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); + init_G(); + compute_transform(); + pop(reg_EVEX_max_8b_offt); + postamble(); +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::gemm_loop_generate( + bool is_first_tile) +{ + auto zmm_srcA = [=]() { + return Xbyak::Zmm(0); + }; + + auto zmm_srcB = [=] (size_t N_ur){ + return Xbyak::Zmm(N_ur + 1); + }; + + auto broadcastB = [=](size_t K_ur) { + for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) { + size_t srcB_off = (K_ur * jcp.dimN_reg_block + N_bcast) + * sizeof(float); + vbroadcastss(zmm_srcB(N_bcast), EVEX_compress_addr(reg_srcB, srcB_off)); + } + }; + + auto load_srcA = [=] (size_t K_ur, int M_ur) { + size_t srcA_off = (K_ur * jcp.dimM_reg_block * jcp.dimM_simd_block + + M_ur * jcp.dimM_simd_block) * sizeof(float); + vmovups(zmm_srcA(), EVEX_compress_addr(reg_srcA, srcA_off)); + }; + + auto zmm_dstC = [=](size_t M_reg_ur, int N_bcast){ + size_t idx = 1 // zmm_srcA + + jcp.dimN_bcast_ur // zmm_srcB + + M_reg_ur * jcp.dimN_bcast_ur + N_bcast; + assert(idx < 32); + return Xbyak::Zmm(idx); + }; + auto prepare_accumm = [=](){ + for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) { + for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) { + Zmm zmm = zmm_dstC(M_reg_ur, N_bcast); + vpxord(zmm, zmm, zmm); + } + } + }; + + auto store_dstC = [=](){ + /******** Write C back to memory *******/ + for (int M_reg = 0; M_reg < jcp.dimM_reg_block; M_reg++) { + for (int N_ur = 0; N_ur < jcp.dimN_bcast_ur; ++N_ur) { + Zmm zmm = zmm_dstC(M_reg, N_ur); + size_t C_off = (N_ur * jcp.dimM_reg_block * jcp.dimM_simd_block + + M_reg * jcp.dimM_simd_block) * sizeof(float); + if (!is_first_tile) { + vmovups(Xbyak::Zmm(0), EVEX_compress_addr(reg_dstC, C_off)); + vaddps(zmm, zmm, Xbyak::Zmm(0)); + } + vmovups(EVEX_compress_addr(reg_dstC, C_off), zmm); + } + } + }; + + auto inner_loops = [=]() { + Label dimM_block_loop, dimK_block_loop, dimN_block_loop, dimN_bcast_ur; + + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + { /************* OC_block (M) loop ***********/ + mov(reg_dimN_block_loop_cnt, jcp.dimN_block); + L(dimN_block_loop); + { /*************** IC_block (N) loop *********/ + + mov(reg_nb_dimN_bcast_ur, jcp.dimN_reg_block/jcp.dimN_bcast_ur); + L(dimN_bcast_ur); + { + prepare_accumm(); + + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + { + /************* nb_tile_ur(K) loop ********/ + for (int K_ur = 0; K_ur < jcp.dimK_reg_block; K_ur++) { + + broadcastB(K_ur); + + for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) { + load_srcA(K_ur, M_reg_ur); + for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; ++N_bcast) { + vfmadd231ps(zmm_dstC(M_reg_ur, N_bcast), zmm_srcA(), + zmm_srcB(N_bcast)); + } + } + } + add(reg_srcA, jcp.dimK_reg_block + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + add(reg_srcB, jcp.dimK_reg_block + * jcp.dimN_reg_block + * sizeof(float)); + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + + store_dstC(); + + sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + sub(reg_srcB, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimN_reg_block + * sizeof(float)); + add(reg_srcB, jcp.dimN_bcast_ur * sizeof(float)); + add(reg_dstC, jcp.dimN_bcast_ur + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + sub(reg_nb_dimN_bcast_ur, 1); + jnz(dimN_bcast_ur); + } + + sub(reg_srcB, jcp.dimN_reg_block * sizeof(float)); + add(reg_srcB, jcp.dimK_block + * jcp.dimK_reg_block + * jcp.dimN_reg_block * sizeof(float)); + sub(reg_dimN_block_loop_cnt, 1); + jnz(dimN_block_loop); + } + + sub(reg_srcB, jcp.dimN_block + * jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimN_reg_block + * sizeof(float)); + add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + }; + + /* Preamble */ + preamble(); + + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +namespace { + +void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp) { +/*M params*/ + jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block + / jcp.dimM_simd_block; + jcp.oc_reg_block = jcp.dimM_reg_block; + jcp.oc_block = jcp.dimM_block; + jcp.nb_oc = jcp.dimM_nb_block; + /*N params*/ + jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block; + jcp.ic_block = jcp.dimN_block; + jcp.nb_ic = jcp.dimN_nb_block; + + /*K params*/ + jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block; + jcp.tile_block_ur = jcp.dimK_reg_block; + jcp.nb_tile_block_ur = jcp.dimK_block; + jcp.tile_block = jcp.dimK_nb_block; +} + +status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) { + + size_t K_blk_ur, N_blk, M_blk; + /* IS this strategy feasible? */ + auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) { + size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float); + size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float); + size_t nthreads = mkldnn_get_max_threads(); + return (((V_sz + M_sz) / nthreads) >= 2 * L2_cache_size) + && (jcp.dimK / nthreads >= 1.0); + }; + + auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur, + int max_block=1) { + size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * dimK_block_ur * sizeof(float); + size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float); + size_t M_L2_block = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float); + size_t nthreads = mkldnn_get_max_threads(); + bool load_balance=true; + if (!(jcp.dimK % nthreads)) { + load_balance = ((jcp.dimK / dimK_block_ur) % nthreads == 0); + } + return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size) + && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size) + && load_balance + && (M_L2_block < L2_cache_size); + }; + + auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur, + int useless_arg=0) { + return (dimK_ur >= 2) && (dimK_ur <= 8); + }; + + auto blocking_ok = [&](){ + size_t M_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block + * K_blk_ur * sizeof(float); + size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block + * K_blk_ur * sizeof(float); + size_t U_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block + * N_blk * jcp.dimN_reg_block * sizeof(float); + size_t L2_block = M_L2_block + V_L2_block + U_L2_block; + /*Replace 2.375 with L2+L3 cache size*/ + return (L2_block > 0.1 * L2_cache_size) && (L2_block <= 1.2 * L2_cache_size); + }; + + if (test_MV_large_enough(jcp)) { + if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) { + jcp.dimM_reg_block = 2; + } else { + jcp.dimM_reg_block = 1; + } + jcp.dimM_simd_block = jcp.oc_simd_block; + jcp.dimN_reg_block = jcp.ic_simd_block; + jcp.dimN_bcast_ur = 8; + /*dimK_block and dimK_ur*/ + size_t min_dimK_block_ur = get_divisor_satisfying_cond(jcp, jcp.dimK, 1, test_min_dimK_L1); + + jcp.dimM_block = jcp.dimM/jcp.dimM_reg_block/jcp.dimM_simd_block; + jcp.dimN_block = jcp.dimN/jcp.dimN_reg_block; + for (K_blk_ur = min_dimK_block_ur; K_blk_ur >= 1; --K_blk_ur) { + if (test_min_dimK_L1(jcp, K_blk_ur) && !(jcp.dimK % K_blk_ur)) { + for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) { + if (!(jcp.dimN_block % N_blk)) { + for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) { + if (!(jcp.dimM_block % M_blk) && blocking_ok()) { + jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur); + if (!test_dimK_ur(jcp, jcp.dimK_reg_block)) return status::unimplemented; + jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block; + jcp.dimN_block = N_blk; + jcp.dimM_block = M_blk; + jcp.sched_policy = WSCHED_WEI_SDGtWo; + set_jcp_WEI_params(jcp); + jcp.nthr = nstl::min(mkldnn_get_max_threads(), + jcp.tile_block); + return status::success; + } + } + } + } + } + } + } + return status::unimplemented; +} + +status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) { + if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) { + jcp.dimM_reg_block = 2; + } else { + jcp.dimM_reg_block = 1; + } + jcp.dimN_bcast_ur = 8; + jcp.dimN_reg_block = jcp.ic_simd_block; + jcp.dimM_simd_block = jcp.oc_simd_block; + jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block; + jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block; + float C1 = 0.0, C2 = 0.0; + float C1_max = 0.5, C2_max = 1.4; + int N_blk, M_blk, K_blk_ur; + + auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur, + int useless_arg=0) { + return (dimK_ur >= 2) && (dimK_ur <= 8); + }; + + auto blocking_ok = [&]() -> bool { + size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * K_blk_ur * sizeof(float); + size_t L1_block_N = jcp.dimN_reg_block * K_blk_ur * sizeof(float); + bool L1_cond = ((L1_block_N + L1_block_M) >= C1 * L1_cache_size) + && ((L1_block_N + L1_block_M) <= C1_max * L1_cache_size); + + size_t nb_N_blk = jcp.dimN/N_blk/jcp.dimN_reg_block; + size_t nb_M_blk = jcp.dimM/M_blk/jcp.dimM_reg_block/jcp.dimM_simd_block; + size_t nb_K_blk = jcp.dimK / K_blk_ur; + size_t nthreads = mkldnn_get_max_threads(); + bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) >= nthreads; + if (!(nb_K_blk % nthreads)) { + load_balance = load_balance && (nb_K_blk % nthreads == 0); + } + + size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block * K_blk_ur * sizeof(float); + + size_t L2_block = V_L2_block; + /*Replace 2.375 with L2+L3 cache size*/ + bool L2_cond = (L2_block >= C2 * L2_cache_size) && (L2_block <= C2_max * L2_cache_size); + return L1_cond && load_balance && L2_cond; + }; + + for (K_blk_ur = jcp.dimK; K_blk_ur >= 1; --K_blk_ur) { + if (jcp.dimK % K_blk_ur == 0) { + for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) { + if (jcp.dimN_block % N_blk == 0) { + for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) { + if (jcp.dimM_block % M_blk == 0) { + if (blocking_ok()) { + jcp.dimN_block = N_blk; + jcp.dimM_block = M_blk; + jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur); + jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block; + jcp.sched_policy = WSCHED_WEI_S_D_Giot_W; + set_jcp_WEI_params(jcp); + return status::success; + } + } + } + } + } + } + } + jcp.dimK_reg_block = 1; + jcp.dimK_block = 1; + jcp.sched_policy = WSCHED_WEI_S_D_Giot_W; + set_jcp_WEI_params(jcp); + return status::success; +} +} // namespace +status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d) { + if (!mayiuse(avx512_core)) + return status::unimplemented; + else + jcp.ver = ver_avx512_core; + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.prop_kind = cd.prop_kind; + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + jcp.mb = src_d.dims()[0]; + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + jcp.kh = diff_weights_d.dims()[with_groups + 2]; + jcp.kw = diff_weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef); + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + // Winograd kernel works only for 3x3 convolution with stride 1 + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.wei_tag != wei_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; + if (!layout_consistency) return status::unimplemented; + + /******************Kernel blocking Parameters ***********/ + jcp.ic_simd_block = simd_w; + jcp.oc_simd_block = simd_w; + + jcp.dimK = jcp.ntiles; + jcp.dimN = jcp.ic; + jcp.dimM = jcp.oc; + jcp.dimM_simd_block = jcp.oc_simd_block; + jcp.dimN_reg_block = jcp.ic_simd_block; + jcp.sched_policy = WSCHED_INVALID; + status_t res = set_wsched_WEI_SDGtWo(jcp); + if (res == status::unimplemented) { + res = set_wsched_WEI_S_D_Giot_W(jcp); + assert(res == status::success); + } + return res; +} +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp new file mode 100644 index 000000000000..025a554d92a0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp @@ -0,0 +1,291 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP +#define JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP + +#include "c_types_map.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" + +#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + : public jit_generator { + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) { + { + this->weights_transform_data_ker_generate(); + weights_transform_data_ker + = (decltype(weights_transform_data_ker)) this->getCode(); + } + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->input_transform_data_ker_generate(); + input_transform_data_ker = (decltype(input_transform_data_ker))addr; + } + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->output_transform_data_ker_generate(); + output_transform_data_ker + = (decltype(output_transform_data_ker))addr; + } + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_fp32_wino_conv_4x3_data_kernel) + + static status_t init_conf_common(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d); + + static status_t init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *, const int); + void (*input_transform_data_ker)(jit_wino_transform_call_s *); + void (*output_transform_data_ker)(jit_wino_transform_call_s *); + void (*weights_transform_data_ker)(jit_wino_transform_call_s *); + +protected: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + enum { typesize = sizeof(float) }; + + void gemm_loop_generate(); + void input_transform_data_ker_generate(); + void output_transform_data_ker_generate(); + void weights_transform_data_ker_generate(); + + /* registers used for GEMM */ + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA = abi_param2; + reg64_t reg_srcB = abi_param3; + reg64_t reg_is_beta_zero = abi_param4; + + reg64_t reg_dimM_block_loop_cnt = r10; + reg64_t reg_dimK_block_loop_cnt = r11; + + /* registers used for transforms*/ + reg64_t param = abi_param1; + + /* registers used for output_transform_data_ker */ + reg64_t oreg_temp = abi_not_param1; + reg64_t oreg_Ow = r9; + reg64_t oreg_src = r11; + reg64_t oreg_tile_block = r12; + reg64_t oreg_tile_block_ur = r13; + reg64_t oreg_nb_tile_block_ur = r14; + reg64_t oreg_O = r8; + reg64_t oreg_T = r10; + reg64_t oreg_dst = r11; + reg64_t oreg_ydim = r14; + reg64_t oreg_xdim = r15; + reg64_t oreg_out_j = r12; + reg64_t oreg_bias = rbx; + reg64_t imm_addr64 = rax; + + /* registers used for input_transform_data_ker */ + reg64_t ireg_temp = abi_not_param1; + reg64_t ireg_jtiles = rax; + reg64_t ireg_itiles = rbx; + reg64_t ireg_I = r8; + reg64_t ireg_src = r13; + reg64_t ireg_ydim = r14; + reg64_t ireg_xdim = r15; + reg64_t ireg_inp_j = r12; + reg64_t ireg_inp_i = rdx; + reg64_t ireg_mask_j = r11; + reg64_t ireg_mask = rsi; + reg32_t ireg_mask_32 = esi; + reg64_t ireg_zero = r9; + reg64_t ireg_Iw = r9; + reg64_t ireg_T = r10; + reg64_t ireg_tile_block = r12; + reg64_t ireg_tile_block_ur = r13; + reg64_t ireg_nb_tile_block_ur = r14; + reg64_t ireg_output = r15; + + /* registers used for wei transform */ + reg64_t wreg_temp = abi_not_param1; + reg64_t wreg_F = r8; + reg64_t wreg_src = r9; + reg64_t wreg_MT = r15; + reg64_t wreg_M = r14; + reg64_t wreg_dst = r10; + reg64_t wreg_dst_aux = r9; + reg64_t wreg_dst_idx = r8; + reg64_t wreg_Fw = r11; + reg64_t wreg_T = r12; + reg64_t wreg_cnt_j = rdx; + reg64_t wreg_F_aux = r14; + reg64_t wreg_Fw_aux = r15; +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel + : _jit_avx512_core_fp32_wino_conv_4x3_data_kernel { + using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel:: + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel; + + static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_t &src_md, + memory_desc_t &weights_md, const memory_desc_t &dst_md, + const primitive_attr_t &attr); +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel + : public _jit_avx512_core_fp32_wino_conv_4x3_data_kernel { + using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel:: + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel; + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel + : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + _jit_avx512_core_conv_winograd_bwd_weights_kernel_f32) + + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) + { + //******************* First iter kernel ********************// + this->gemm_loop_generate(true); + gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))this->getCode(); + + align(); + const Xbyak::uint8 *addr = getCurr(); + this->src_transform_generate(); + src_transform = (decltype(src_transform))addr; + + if (jcp.with_bias) { + align(); + addr = getCurr(); + this->diff_dst_transform_generate(true); + diff_dst_transform_wbias = (decltype(diff_dst_transform_wbias))addr; + } + + align(); + addr = getCurr(); + this->diff_dst_transform_generate(false); + diff_dst_transform = (decltype(diff_dst_transform))addr; + + if (jcp.sched_policy != WSCHED_WEI_SDGtWo && jcp.tile_block > 1) { + align(); + addr = getCurr(); + this->gemm_loop_generate(false); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + + align(); + addr = getCurr(); + this->diff_weights_transform_generate(true); + diff_weights_transform = (decltype(diff_weights_transform))addr; + + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) { + align(); + addr = getCurr(); + this->diff_weights_transform_generate(false); + diff_weights_transform_accum = + (decltype(diff_weights_transform_accum))addr; + }; + } + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *); + void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); + void (*src_transform)(jit_wino_transform_call_s *); + void (*diff_dst_transform)(jit_wino_transform_call_s *); + void (*diff_dst_transform_wbias)(jit_wino_transform_call_s *); + void (*diff_weights_transform)(jit_wino_transform_call_s *); + void (*diff_weights_transform_accum)(jit_wino_transform_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + enum { typesize = sizeof(float) }; + + void src_transform_generate(); + void diff_dst_transform_generate(bool with_bias); + void diff_weights_transform_generate(bool first_tile); + + /*registers common to transforms*/ + reg64_t reg_transp = abi_param1; + reg64_t reg_ti = rbx; + reg64_t reg_tj = abi_not_param1; + reg64_t reg_src = r8; + reg64_t reg_dst = r9; + reg64_t reg_G = rsi; /*TODO: check if this is ok*/ + reg64_t reg_temp = rsi; + + /*registers common to src/diff_dst transform*/ + reg64_t reg_I = r10; + reg64_t reg_ydim = r11; + reg64_t reg_xdim = r12; + reg64_t reg_src_offset = r13; + reg64_t reg_zero = r14; + reg64_t reg_tile_count = r15; + reg64_t reg_maski = rsi; + reg32_t reg_maski_32 = esi; + reg64_t reg_maskj = rdx; + + reg64_t reg_T = rax; + reg64_t reg_oc_ur = rax; + reg64_t reg_ic_simd = r14; + reg64_t reg_bias = r10; + + void gemm_loop_generate(bool is_first_tile); + + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA = abi_param2; + reg64_t reg_srcB = abi_param3; + + reg64_t reg_dimM_block_loop_cnt = r9; + reg64_t reg_dimN_block_loop_cnt = r10; + reg64_t reg_nb_dimN_bcast_ur = r11; + reg64_t reg_dimK_block_loop_cnt = r12; +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp new file mode 100644 index 000000000000..002010ffa2a3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp @@ -0,0 +1,1284 @@ +/******************************************************************************* + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_u8s8s32x_wino_convolution.hpp" +#include "jit_generator.hpp" + +#include + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +namespace { + // Below scales are applied to source and weights data accordingly + // because this winograd implementation + // transforms source which may increase values up to 4x + // and transforms weights which may increase values up to 9/4x + const float adj_src_scale = 1.f / 4.f; + const float adj_wei_scale = 4.f / 9.f; + // Winograd transforms need ic and oc to be multiples of 16 + const int load_block = 16; +} + +/// SRC TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *src; + const void *wino_src; + const void *v_y_masks; + const void *v_x_masks; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), unsign_val_in_wino_domain(5) { + generate(); + ker_ = reinterpret_cast(const_cast(getCode())); + } + void generate(); + + int reg_inp_ind(int i) { + assert(i < jcp.alpha * jcp.alpha); + return (31 - i); + } + + Xmm vreg_inp(int i) { + return Xmm(reg_inp_ind(i)); + } + + Zmm zmm_inp(int i) { + return Zmm(reg_inp_ind(i)); + } + + Xmm vreg_tmp(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Xmm(15 - i); + } + Xmm vreg_out(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Xmm(31 - i); + } + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert(id < 4); + return Opmask(3 + id); + } + + Reg64 reg_ptr_src = r14; + Reg64 reg_ptr_dst = r13; + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_ic_block = r8; + + int unsign_val_in_wino_domain; + + Reg64 reg_scratch_src_alpha = rdx; + Xmm xmm_src_alpha = Xmm(0); + Zmm zmm_src_alpha = Zmm(0); + + Reg64 reg_shift = rax; + Xmm xmm_shift = Xmm(1); + Xmm xmm_zero = Xmm(0); + + Reg64 reg_maskx = rbx; + Reg64 reg_masky = rsi; + Reg64 reg_nomask = reg_maskx; +}; + +void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() { + Label ic_block_label; + Label end_label; + Label mask_label; + Label nomask_label; + + auto load_src = [=](bool mask) { + for (int y = 0; y < jcp.alpha; y++) { + if (mask) + kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]); + for (int x = 0; x < jcp.alpha; x++) { + Zmm zmm_i = zmm_inp(y * jcp.alpha + x); + Xmm vreg_i = vreg_inp(y * jcp.alpha + x); + int inp_offset = sizeof(uint8_t) + * ((-jcp.t_pad + y) * jcp.iw * jcp.ic + + (-jcp.l_pad + x) * jcp.ic); + if (mask) { + kandw(r_mask, y_mask, x_mask(x)); + vmovdqu8(vreg_i | r_mask | T_z, + EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + } else { + vmovdqu8(vreg_i, + EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + } + vpmovzxbd(zmm_i, vreg_i); // to int32 + vcvtdq2ps(zmm_i, zmm_i); // to fp32 + vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha + vcvtps2dq(zmm_i, zmm_i); // to int32 + vpmovusdb(vreg_i, zmm_i); // to u8 + } + } + }; + + preamble(); + +# define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, src); + READ_PARAM(reg_ptr_dst, wino_src); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); +# undef READ_PARAM + + mov(reg_maskx, ptr[reg_ptr_v_x_masks]); + mov(reg_masky, ptr[reg_ptr_v_y_masks]); + test(reg_maskx, reg_maskx); + jz(end_label, T_NEAR); // skip kernel if x mask is all 0's + test(reg_masky, reg_masky); + jz(end_label, T_NEAR); // skip kernel if y mask is all 0's + and_(reg_maskx, reg_masky); + mov(reg_nomask, reg_maskx); + not_(reg_nomask); // zero if x and y masks are all 1's + + xor_(reg_shift, reg_shift); + mov(reg_shift.cvt8(), (int8_t)-128); + + mov(reg_aux_ptr_src, reg_ptr_src); + mov(reg_aux_ptr_dst, reg_ptr_dst); + + for (int i = 0; i < jcp.alpha; i++) { + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]); + } + + mov(reg_scratch_src_alpha, float2int(adj_src_scale)); + + mov(reg_ic_block, jcp.ic / load_block); + L(ic_block_label); + { + vmovq(xmm_src_alpha, reg_scratch_src_alpha); + vbroadcastss(zmm_src_alpha, xmm_src_alpha); + + test(reg_nomask, reg_nomask); + jz(nomask_label, T_NEAR); + load_src(true); + jmp(mask_label, T_NEAR); + L(nomask_label); + load_src(false); + L(mask_label); + + for(int y = 0; y < 4; y++) { + vpsubb(vreg_tmp(y*4+0), vreg_inp(y*4+0), vreg_inp(y*4+2)); + vpaddb(vreg_tmp(y*4+1), vreg_inp(y*4+1), vreg_inp(y*4+2)); + vpsubb(vreg_tmp(y*4+2), vreg_inp(y*4+2), vreg_inp(y*4+1)); + vpsubb(vreg_tmp(y*4+3), vreg_inp(y*4+1), vreg_inp(y*4+3)); + } + for(int x = 0;x < 4; x++) { + vpsubb(vreg_out(x+0*4), vreg_tmp(x+4*0), vreg_tmp(x+4*2)); + vpaddb(vreg_out(x+1*4), vreg_tmp(x+4*1), vreg_tmp(x+4*2)); + vpsubb(vreg_out(x+2*4), vreg_tmp(x+4*2), vreg_tmp(x+4*1)); + vpsubb(vreg_out(x+3*4), vreg_tmp(x+4*1), vreg_tmp(x+4*3)); + } + + vmovd(xmm_shift, reg_shift.cvt32()); + vpxor(xmm_zero, xmm_zero, xmm_zero); + vpshufb(xmm_shift, xmm_shift, xmm_zero); + + for (int i = 0; i < 16; i++) { + int out_offset = sizeof(uint8_t) * (jcp.inp_stride * i); + if (i != unsign_val_in_wino_domain) + vpsubb(vreg_out(i), vreg_out(i), Xmm(1)); + vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), vreg_out(i)); + } + + add(reg_aux_ptr_src, sizeof(uint8_t) * load_block); + add(reg_aux_ptr_dst, sizeof(uint8_t) * load_block); + } + dec(reg_ic_block); + jnz(ic_block_label, T_NEAR); + + L(end_label); + postamble(); +} + +/// DST TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *wino_dst; + const void *dst; + const void *v_y_masks; + const void *v_x_masks; + + const void *bias; + const void *scales; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) { + generate(); + ker_ = reinterpret_cast(const_cast(getCode())); + } + + void generate(); + bool maybe_relu(int position); + + Zmm vreg_inp(int i) { // 16 + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + Zmm vreg_stg(int id) { // 8 + const int id_reg_stg = jcp.alpha * jcp.alpha + id; + assert(id < 8); + return Zmm(31 - id_reg_stg); + } + Zmm vreg_out(int id) { // 4 + const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; + assert(id < 4); + return Zmm(31 - id_reg_out); + } + Xmm xmm_out(int id) { // 4 + const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; + assert(id < 4); + return Xmm(31 - id_reg_out); + } + Zmm vreg_tmp(int id) { // 2 + const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id; + assert(id < 2); + return Zmm(31 - id_reg_tmp); + } + + Zmm vreg_zero = Zmm(0); + Zmm vreg_bias = Zmm(1); + Zmm vreg_prev_dst = Zmm(2); + Zmm zmm_bias_alpha = Zmm(2); + Xmm xmm_bias_alpha = Xmm(2); + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert(id < 4); + return Opmask(3 + id); + } + + Reg64 reg_scratch_bias_alpha = r15; + + Reg64 reg_ptr_src = r14; + Reg64 reg_ptr_dst = r13; + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_oc_block = r8; + + Reg64 reg_ptr_bias = rbx; + Reg64 reg_ptr_scales = abi_not_param1; + Reg64 reg_ptr_sum_scale = rdx; +}; + +bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) { + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* relu before sum */ + return false + || p.contain(eltwise, 0) + || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0)); + } else if (position == 1) { + /* relu after sum */ + const int sum_idx = p.contain(sum, 0) + ? 0 : (p.contain(sum, 1) ? 1 : -1); + if (sum_idx == -1) + return false; + + return false + || p.contain(eltwise, sum_idx + 1) + || jcp.dst_dt == data_type::u8; + } + + return false; +} + +void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() { + Label oc_block_label; + + auto loop_body = [=]() { + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = (sum_idx != -1) + ? &p.entry_[sum_idx].sum.scale + : nullptr; + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + for(int i = 0; i < 16; i++) { + int internal_offset = sizeof(int32_t) * jcp.out_stride * i; + vmovups(vreg_inp(i), + EVEX_compress_addr(reg_aux_ptr_src, internal_offset)); + } + for(int y = 0; y < jcp.alpha; y++) { + vpaddd(vreg_tmp(0), vreg_inp(y*4 + 0), vreg_inp(y*4 + 1)); + vpaddd(vreg_stg(y*2), vreg_tmp(0), vreg_inp(y*4 + 2)); + + vpsubd(vreg_tmp(1), vreg_inp(y*4 + 1), vreg_inp(y*4 + 2)); + vpsubd(vreg_stg(y*2+1), vreg_tmp(1), vreg_inp(y*4 + 3)); + } + for(int x = 0; x < jcp.m; x++) { + vpaddd(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2*1)); + vpaddd(vreg_out(x), vreg_tmp(0), vreg_stg(x+2*2)); + + vpsubd(vreg_tmp(1), vreg_stg(x+2*1), vreg_stg(x+2*2)); + vpsubd(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2*3)); + } + + + if (jcp.with_bias) { + vmovq(xmm_bias_alpha, reg_scratch_bias_alpha); + vbroadcastss(zmm_bias_alpha, xmm_bias_alpha); + + auto bias_addr = ptr [ reg_ptr_bias ]; + switch (jcp.bia_dt) { + case data_type::f32: + case data_type::s32: vmovups(vreg_bias, bias_addr); break; + case data_type::s8: vpmovsxbd(vreg_bias, bias_addr); break; + case data_type::u8: vpmovzxbd(vreg_bias, bias_addr); break; + default: assert(!"unsupported dst data type"); + } + if (jcp.bia_dt != data_type::f32) + vcvtdq2ps(vreg_bias, vreg_bias); + vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha + } + for(int y = 0; y < jcp.m; y++) { + kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(uint16_t) * y ]); + for(int x = 0; x < jcp.m; x++) { + kandw(r_mask, y_mask, x_mask(x)); + + int i = y * jcp.m + x; + int offset = jcp.typesize_out * + (y * jcp.ow * jcp.oc + x * jcp.oc); + Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset); + + Zmm zmm = vreg_out(i); + Xmm xmm = xmm_out(i); + vcvtdq2ps(zmm, zmm); + if (jcp.with_bias) + vaddps(zmm, zmm, vreg_bias); + vmulps(zmm, zmm, ptr [reg_ptr_scales]); + if (maybe_relu(0)) + vmaxps(zmm, vreg_zero, zmm); + if (p_sum_scale) { // post_op: sum + vpxord(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst); + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: + vmovups(vreg_prev_dst | r_mask, addr); break; + case data_type::s8: + vpmovsxbd(vreg_prev_dst | r_mask, addr); break; + case data_type::u8: + vpmovzxbd(vreg_prev_dst | r_mask, addr); break; + default: assert(!"unknown dst_dt"); + } + if (jcp.dst_dt != data_type::f32) + vcvtdq2ps(vreg_prev_dst, vreg_prev_dst); + if (*p_sum_scale == 1.f) + vaddps(zmm, vreg_prev_dst); + else + vfmadd231ps(zmm, vreg_prev_dst, + zword_b[reg_ptr_sum_scale]); + } + if (maybe_relu(1)) + vmaxps(zmm, vreg_zero, zmm); + if (jcp.dst_dt != data_type::f32) + vcvtps2dq(zmm, zmm); + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: + vmovups(addr, zmm | r_mask); break; + case data_type::s8: + vpmovsdb(xmm, zmm); vmovups(addr, xmm | r_mask); break; + case data_type::u8: + vpmovusdb(xmm, zmm); vmovups(addr, xmm | r_mask); break; + default: assert(!"unknown dst_dt"); + } + } + } + }; + + preamble(); + +# define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, wino_dst); + READ_PARAM(reg_ptr_dst, dst); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); + READ_PARAM(reg_ptr_bias, bias); + READ_PARAM(reg_ptr_scales, scales); +# undef READ_PARAM + + if (jcp.with_bias) + mov(reg_scratch_bias_alpha, float2int(adj_src_scale * adj_wei_scale)); + + mov(reg_aux_ptr_src, reg_ptr_src); + mov(reg_aux_ptr_dst, reg_ptr_dst); + + vpxord(vreg_zero, vreg_zero, vreg_zero); + + for (int i = 0; i < jcp.m; i++) + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]); + + int oc_blocks = jcp.oc / load_block; + mov(reg_oc_block, oc_blocks); + L(oc_block_label); { + loop_body(); + add(reg_aux_ptr_src, sizeof(int32_t) * load_block); + add(reg_aux_ptr_dst, jcp.typesize_out * load_block); + + add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); + add(reg_ptr_bias, sizeof(jcp.typesize_bia) * load_block); + } + dec(reg_oc_block); + jnz(oc_block_label, T_NEAR); + + postamble(); + +} + +/// GEMM kernel //////////////////////////////////////////////////////////////// +struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t) + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *src; + const void *dst; + const void *wei; + const void *dst_b; + }; + void (*ker_)(const call_params_t *); + + void generate(); + static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp, + const primitive_attr_t &attr); + + jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) + { + generate(); + ker_ = reinterpret_cast(const_cast(getCode())); + } + + static status_t init_conf( + jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr); + + Zmm vreg_out(int n, int m) { + const int id_reg_out = n * jcp.m_block + m; + assert(id_reg_out < jcp.n2_block * jcp.m_block); + return Zmm(31 - id_reg_out); + } + Zmm vreg_wei(int i) { + assert(31 - jcp.n2_block * jcp.m_block - i + > (jcp.ver == ver_vnni ? 0 : 2)); + return Zmm(31 - jcp.n2_block * jcp.m_block - i); + } + + Zmm vreg_src = Zmm(0); + Zmm vreg_one = Zmm(1); + Zmm vreg_tmp = Zmm(2); + + Reg64 reg_ptr_src = r15; + + Reg64 reg_aux_dst_b = r13; + Reg64 reg_aux_dst = r12; + Reg64 reg_aux_dst2 = r11; + Reg64 reg_aux_wei = r10; + Reg64 reg_aux_wei2 = r9; + Reg64 reg_aux_src = r8; + Reg64 reg_aux_src2 = rax; + Reg64 reg_mb = rbx; + Reg64 reg_nnb = abi_not_param1; + Reg64 reg_scratch = rdx; + Reg64 reg_K = rsi; +}; + +bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok( + jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_relu(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_relu(1)) || + (p.contain(sum, 1) && is_relu(0)); + case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2); + default: return false; + } + + return false; +} + +void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() { + Label nnb_loop_label, K_loop_label, mb_loop_label; + + auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddubsw(vreg_tmp, vreg_src, vreg_wei); + vpmaddwd(vreg_tmp, vreg_tmp, vreg_one); + vpaddd(vreg_acc, vreg_acc, vreg_tmp); + } + }; + + preamble(); +# define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, src); + READ_PARAM(reg_aux_dst, dst); + READ_PARAM(reg_aux_wei, wei); + READ_PARAM(reg_aux_dst_b, dst_b); +# undef READ_PARAM + + if (jcp.ver != ver_vnni) { + xor_(reg_scratch, reg_scratch); + Reg16 _t = reg_scratch.cvt16(); + mov(_t, 0x1); + vpbroadcastw(vreg_one, _t); + } + + if (!jcp.small_mb) { + mov(reg_nnb, jcp.n_chunks); + L(nnb_loop_label); + } + mov(reg_aux_dst2, reg_aux_dst); + mov(reg_aux_src, reg_ptr_src); + mov(reg_mb, jcp.M / jcp.m_block); + L(mb_loop_label); + { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + for (int m = 0; m < jcp.m_block; m++) { + int offset = jcp.typesize_acc * nb2 * jcp.n_block; + vmovups(vreg_out(nb2, m), + EVEX_compress_addr(reg_aux_dst_b, offset)); + } + } + mov(reg_aux_src2, reg_aux_src); + mov(reg_aux_wei2, reg_aux_wei); + mov(reg_K, jcp.k_chunks); + L(K_loop_label); + { + for (int k = 0; k < jcp.k2_block; k += 4) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int wei_offset + = jcp.typesize_in * (nb2 * jcp.n_block * jcp.K); + vmovups(vreg_wei(nb2), + EVEX_compress_addr(reg_aux_wei2, wei_offset)); + } + for (int m = 0; m < jcp.m_block; m++) { + int inp_offset = jcp.typesize_in * m * jcp.K; + vpbroadcastd(vreg_src, + EVEX_compress_addr(reg_aux_src2, inp_offset)); + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) + compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src); + } + add(reg_aux_src2, jcp.typesize_in * 4); + add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block); + } + } + dec(reg_K); + jnz(K_loop_label, T_NEAR); + + for (int m = 0; m < jcp.m_block; m++) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int offset = jcp.typesize_acc * (m * jcp.N + nb2 * jcp.n_block); + vmovups(EVEX_compress_addr(reg_aux_dst2, offset), + vreg_out(nb2, m)); + } + } + add(reg_aux_src, jcp.typesize_in * jcp.m_block * jcp.K); + add(reg_aux_dst2, jcp.typesize_acc * jcp.m_block * jcp.N); + } + dec(reg_mb); + jnz(mb_loop_label, T_NEAR); + + if (!jcp.small_mb) { + add(reg_aux_dst, jcp.typesize_acc * jcp.n2_block * jcp.n_block); + add(reg_aux_dst_b, jcp.typesize_acc * jcp.n2_block * jcp.n_block); + add(reg_aux_wei, jcp.typesize_in * jcp.n2_block * jcp.n_block * jcp.K); + + dec(reg_nnb); + jnz(nnb_loop_label, T_NEAR); + } + + postamble(); +} +namespace { +bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) { + if (jcp.ver == ver_vnni) { + return (jcp.mb <= mkldnn_get_max_threads() + && (jcp.mb > 4 + && jcp.ic > 64 + && !(jcp.oc > 128 && jcp.ih < 14))) + || jcp.mb > mkldnn_get_max_threads(); + } + return true; +} +} + +status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t +::init_conf(jit_conv_conf_2x3_wino_t &jcp, + const convolution_desc_t &cd, memory_desc_t &src_md, + memory_desc_t &wei_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const primitive_attr_t &attr) { + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper wei_d(&wei_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const bool with_groups = wei_d.ndims() == src_d.ndims() + 1; + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.ngroups = with_groups ? wei_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = wei_d.dims()[with_groups + 2]; + jcp.kw = wei_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.b_pad = cd.padding[1][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.r_pad = cd.padding[1][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.ver = ver_avx512_core; + if (!(mayiuse(avx512_core) && + src_d.data_type() == data_type::u8 + && wei_d.data_type() == data_type::s8 + && one_of(dst_d.data_type(), data_type::f32, data_type::s32, + data_type::s8, data_type::u8))) + return status::unimplemented; + if (mayiuse(avx512_core_vnni)) + jcp.ver = ver_vnni; + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + // block sizes needed for GEMM kernel + jcp.ic_block = 4; + jcp.oc_block = 16; + + bool ok = true + && jcp.ngroups == 1 + && jcp.oc % load_block == 0 && jcp.ic % load_block == 0 + && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 + && everyone_is(3, jcp.kh, jcp.kw) + && everyone_is(1, jcp.stride_h, jcp.stride_w) + && everyone_is(0, jcp.dilate_h, jcp.dilate_w) + && jcp.t_pad == jcp.b_pad && jcp.l_pad == jcp.r_pad + && one_of(jcp.t_pad, 0, 1) + && one_of(jcp.l_pad, 0, 1); + if (!ok) return status::unimplemented; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + jcp.typesize_acc = sizeof(int32_t); + jcp.typesize_bia = jcp.with_bias + ? types::data_type_size(bias_d.data_type()) + : 0; + + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.m = 2; + jcp.r = 3; + jcp.alpha = jcp.m + jcp.r - 1; + + int aa = jcp.alpha * jcp.alpha; + int L1_cap = get_cache_size(1, true); + int L2_cap = get_cache_size(2, true); + // need 1 extra reg for bcast, and 2 tmp regs for non-vnni + int free_regs = jcp.ver == ver_vnni ? 31 : 29; + + auto get_thr_eff = [&](int small_mb, int ix, int iy, int n2_b) { + float thr_eff; + float Z = (float)jcp.ic + jcp.oc; + float Y = (float)jcp.ic * jcp.oc; + if (small_mb == 0) { // outer par + int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix); + thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr); + } else { // inner par + int tranw = iy * ix / jcp.alpha; + int gemmw = aa * (jcp.nb_oc / n2_b); + int tranw_r = rnd_up(tranw, jcp.nthr); + int gemmw_r = rnd_up(gemmw, jcp.nthr); + thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y); + } + return thr_eff; + }; + + auto get_mem_eff = [&](int small_mb, int ix, int iy, int n2_b) { + float mem_eff, req_mem; + int M = ix * iy / jcp.alpha; + if (small_mb == 0) { // outer parallelization strategy + // memory for wino transforms (other memory has poor reuse) + req_mem = (float)aa * M * (jcp.ic + jcp.typesize_acc * jcp.oc); + mem_eff = req_mem < L1_cap ? 1.f : req_mem < L2_cap ? 0.5f : 0.f; + } else { // inner parallelization strategy + // memory used during gemm + int N = jcp.oc_block * n2_b; + req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N; + mem_eff = nstl::min(1.f, L2_cap / req_mem); + // memory used during wino transforms + int M_per_thr = div_up(M, jcp.nthr); + req_mem = (float)aa * M_per_thr + * (jcp.ic + jcp.typesize_acc * jcp.oc); + if (req_mem > L2_cap) + mem_eff = 0.1f; + } + return mem_eff; + }; + + auto get_tot_eff = [&](int small_mb, float thr_eff, float work_eff, + float mem_eff, float reg_eff) { + // these coefficients are chosen empirically + float mem_fac = 0.1f, reg_fac = 0.2f; + // normalized overhead relative to memory and register components + float tot_eff = 1.f + mem_fac * mem_eff + reg_fac * reg_eff; + // thread and work components affect all others + tot_eff *= thr_eff * work_eff; + return tot_eff; + }; + + auto find_m_n2_blocks = [&](bool small_mb, int ix, int iy, float work_eff, + int &m_block, int &n2_block, float &tot_eff) { + int M = (ix * iy) / jcp.alpha; + int max_m_block = nstl::min(M, free_regs); + int max_n2_block = nstl::min(jcp.nb_oc, free_regs); + tot_eff = 0.f; + for (int im = max_m_block; im > 0; im--) { + if (M % im) + continue; + for (int in2 = max_n2_block; in2 > 0; in2--) { + int used_regs = (im + 1) * in2; + float mem_eff = get_mem_eff(small_mb, ix, iy, in2); + float reg_eff = (float)(im * in2) / (im + in2); + float thr_eff = get_thr_eff(small_mb, ix, iy, in2); + float cur_tot_eff = get_tot_eff( + small_mb, thr_eff, work_eff, mem_eff, reg_eff); + if (jcp.nb_oc % in2 || used_regs > free_regs + || cur_tot_eff <= tot_eff) + continue; + tot_eff = cur_tot_eff; + m_block = im; + n2_block = in2; + } + } + }; + + /* Selecting xb and yb blocking */ + int min_yb = jcp.m; + int min_xb = jcp.m; + int max_yb = nstl::max(min_yb, rnd_up(jcp.oh, 2)); + int max_xb = nstl::max(min_xb, rnd_up(jcp.ow, 2)); + float best_eff = 0.f; + for (int ix = min_xb; ix <= max_xb; ix += 2) { + assert(rnd_up(jcp.ow, ix) >= jcp.iw - 2); + for (int iy = max_yb; iy >= min_yb; iy -= 2) { + assert(rnd_up(jcp.oh, iy) >= jcp.ih - 2); + + int m_b[2]; + int n2_b[2]; + bool small_mb; + float inner_eff, outer_eff, work_eff; + + int tiled_area = rnd_up(jcp.oh, iy) * rnd_up(jcp.ow, ix); + work_eff = (float)jcp.oh * jcp.ow / tiled_area; + if (best_eff > 0.f && work_eff < 4.f / 9.f) + continue; // no gain from Winograd transformation + + /* outer parallelization */ + find_m_n2_blocks(0, ix, iy, work_eff, m_b[0], n2_b[0], outer_eff); + + /* inner parallelization */ + find_m_n2_blocks(1, ix, iy, work_eff, m_b[1], n2_b[1], inner_eff); + + small_mb = inner_eff > outer_eff; + float eff = small_mb ? inner_eff : outer_eff; + if (eff > best_eff) { + best_eff = eff; + jcp.yb = iy; + jcp.xb = ix; + jcp.m_block = m_b[small_mb]; + jcp.n2_block = n2_b[small_mb]; + jcp.small_mb = small_mb; + } + } + } + + assert((jcp.m_block + 1) * jcp.n2_block <= free_regs); + assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0); + + jcp.mb_block = 1; + if (jcp.small_mb) { + // For small mb harness, set mb_block as large as possible subject to + // the constraint that winograd activations fit into available L3 cache + int L3_cap = get_cache_size(3, true); + int M = jcp.xb * jcp.yb / 4; + int wino_src_size = 16 * M * jcp.ic * jcp.typesize_in; + int wino_dst_size = 16 * M * jcp.oc * jcp.typesize_acc; + int max_mb_block = nstl::min( + jcp.mb, jcp.nthr * L3_cap / (wino_src_size + wino_dst_size)); + for (int i = max_mb_block; i > 1; i--) { + if (jcp.mb % i == 0) { + jcp.mb_block = i; + break; + } + } + } + jcp.nb_mb = jcp.mb / jcp.mb_block; + + jcp.M = jcp.mb_block * jcp.xb * jcp.yb / 4; + jcp.N = jcp.oc; + jcp.K = jcp.ic; + + jcp.inp_stride = jcp.M * jcp.ic; + jcp.out_stride = jcp.M * jcp.oc; + jcp.wei_stride = jcp.ic * jcp.oc; + jcp.bia_stride = jcp.oc; + + jcp.n_block = jcp.oc_block; + jcp.k_block = jcp.ic_block; + + jcp.n_chunks = (jcp.N / jcp.n_block) / jcp.n2_block; + + // We need jcp.k2_block to be a multiple of jcp.k_block = jcp.ic_block = 4 + // and jcp.K = jcp.ic to be a multiple of jcp.k2_block. Since jcp.ic is + // a multiple of load_block = 16, we just use that for now. + jcp.k2_block = load_block; + jcp.k_chunks = jcp.K / jcp.k2_block; + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + /* re-create weights primitive descriptor + and set weights wino_blocking */ + memory_desc_t expect_wei_md = wei_md; + + expect_wei_md.format_kind = format_kind::wino; + expect_wei_md.data_type = data_type::s8; + mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; + wd.wino_format = mkldnn_wino_wei_aaOIoi; + wd.r = jcp.r; + wd.alpha = jcp.alpha; + wd.ic = jcp.ic; + wd.oc = jcp.oc; + wd.ic_block = jcp.ic_block; + wd.oc_block = jcp.oc_block; + wd.oc2_block = jcp.n2_block; + wd.ic2_block = 1; + wd.adj_scale = adj_wei_scale; + + size_t max_size = types::data_type_size(data_type::s8) * + jcp.alpha * jcp.alpha * jcp.ic * jcp.oc; + max_size += types::data_type_size(data_type::s32) * + jcp.alpha * jcp.alpha * jcp.oc; + wd.size = max_size; + + if (wei_md.format_kind == format_kind::any) + wei_md = expect_wei_md; + if (wei_md != expect_wei_md) + return status::unimplemented; + + const int tilesize = jcp.alpha * jcp.alpha; + const int numtiles = jcp.M; + const int alltiles = numtiles * tilesize; + + jcp.size_wino_src + = utils::rnd_up(jcp.typesize_in * alltiles * jcp.ic, PAGE_4K) + / jcp.typesize_in; + jcp.size_wino_wei = tilesize * jcp.oc * jcp.ic; + jcp.size_wino_dst = alltiles * jcp.oc; + + return status::success; +} +//////////////////////////////////////////////////////////////////////////////// + +template +status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: + pd_t::jit_conf() { + return jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf( + jcp_, *this->desc(), this->src_md_, this->weights_md_, + this->dst_md_,this->bias_md_, *this->attr()); +} + +template +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::pd_t:: +init_scratchpad() { + auto scratchpad = this->scratchpad_registry().registrar(); + + int nthr_multiplier = jcp_.small_mb ? 1 : jcp_.nthr; + scratchpad.book(key_wino_V, + sizeof(src_data_t) * jcp_.size_wino_src * nthr_multiplier, PAGE_4K); + scratchpad.book(key_wino_M, + sizeof(acc_data_t) * jcp_.size_wino_dst * nthr_multiplier, PAGE_4K); + + dim_t scale_count = attr()->output_scales_.count_; + scratchpad.book(key_conv_adjusted_scales, + sizeof(float) * nstl::max(scale_count, 16)); +} + +template +jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: + jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) +{ + kernel_ = new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t( + pd()->jcp_, *pd()->attr()); + src_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_src_trans_t( + pd()->jcp_, *pd()->attr()); + dst_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t( + pd()->jcp_, *pd()->attr()); +} + +template +jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: + ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t() { + delete kernel_; + delete src_trans_; + delete dst_trans_; +} + +template +const float *jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: +adjust_oscales(const memory_tracking::grantor_t &scratchpad) const { + const float *oscales = pd()->attr()->output_scales_.scales_; + auto loc_scales = scratchpad.template get(key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / (adj_src_scale * adj_wei_scale); + if (count == 1) + utils::array_set(loc_scales, oscales[0] * factor, 16); + else + for (size_t c = 0; c < count; c++) loc_scales[c] = oscales[c] * factor; + return loc_scales; +} + +template +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const auto &jcp = kernel_->jcp; + if (jcp.small_mb) + execute_forward_small_mb(src, weights, bias, dst, this->scratchpad(ctx)); + else + execute_forward_mbN(src, weights, bias, dst, this->scratchpad(ctx)); +} + +template +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: +execute_forward_mbN(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const float *oscales = adjust_oscales(scratchpad); + + auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei); + auto wino_src_base = scratchpad.template get(key_wino_V); + auto wino_dst_base = scratchpad.template get(key_wino_M); + + parallel_nd(jcp.mb, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb), + [&](int mb, int tile_y_b, int tile_x_b) { + + int tile_y = tile_y_b * jcp.yb; + int tile_x = tile_x_b * jcp.xb; + + int ithr = mkldnn_get_thread_num(); + auto wino_src = wino_src_base + jcp.size_wino_src * ithr; + auto wino_dst = wino_dst_base + jcp.size_wino_dst * ithr; + + auto src_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t(); + auto dst_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t(); + auto gemm_p = + jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::call_params_t(); + + /* transformation of input tensor to winograd domain */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) { + uint16_t v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min(jcp.alpha, + nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min(jcp.alpha, + nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff); + v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff); + } + auto local_s = src + + mb * jcp.ih * jcp.iw * jcp.ic + + y * jcp.iw * jcp.ic + x * jcp.ic; + auto local_w = wino_src + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + } + } + /* gemms */ + for (int tile_ij = 0; tile_ij < 16; tile_ij++) { + // start threads at different GEMMs to help bring weights into LLC + int offset = (tile_ij + ithr) % 16; + gemm_p.src = wino_src + jcp.inp_stride * offset; + gemm_p.dst = wino_dst + jcp.out_stride * offset; + gemm_p.wei = wei + jcp.wei_stride * offset; + gemm_p.dst_b = dst_bias + jcp.bia_stride * offset; + + kernel_->ker_(&gemm_p); + } + + /* transformation from winograd domain to output tensor */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) { + uint16_t v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0); + v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0); + } + auto local_d = dst + + mb * jcp.oh * jcp.ow * jcp.oc + + y * jcp.ow * jcp.oc + x * jcp.oc; + auto local_w = wino_dst + m * jcp.oc; + + auto scales = oscales; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + } + } + }); +} + +template +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: +execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const float *oscales = adjust_oscales(scratchpad); + + auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei); + auto wino_src = scratchpad.template get(key_wino_V); + auto wino_dst = scratchpad.template get(key_wino_M); + + for (int mbb = 0; mbb < jcp.nb_mb; mbb++) { + for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) { + for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) { + /* transformation of input tensor to winograd domain */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block, + [&](int y_in_block_b, int x_in_block_b, int mb) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto src_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t(); + + uint16_t v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2) + + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min( + jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min( + jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff); + v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff); + } + auto local_s = src + + (mbb * jcp.mb_block + mb) * jcp.ih * jcp.iw * jcp.ic + + y * jcp.iw * jcp.ic + x * jcp.ic; + auto local_w = wino_src + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + }); + + /* gemms */ + parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) { + auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t:: + call_params_t(); + + gemm_p.src = wino_src + jcp.inp_stride * tile_ij; + gemm_p.dst = wino_dst + jcp.out_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + gemm_p.wei = wei + jcp.wei_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block * jcp.K; + gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + + kernel_->ker_(&gemm_p); + }); + + /* transformation from winograd domain to output tensor */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block, + [&](int y_in_block_b, int x_in_block_b, int mb) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto dst_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t(); + + uint16_t v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2) + + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0); + v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0); + } + auto local_d = dst + + (mbb * jcp.mb_block + mb) * jcp.oh * jcp.ow * jcp.oc + + y * jcp.ow * jcp.oc + x * jcp.oc; + auto local_w = wino_dst + m * jcp.oc; + + auto scales = oscales; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + }); + }}} +} + +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp new file mode 100644 index 000000000000..9e6e57b051c6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp @@ -0,0 +1,128 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP + +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t; +struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t; +struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t; + +template +struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_int8_wino:", avx512_core, ""), + jit_avx512_core_u8s8s32x_wino_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::u8, data_type::s8, + data_type::undef, dst_data_type, data_type::s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && !has_zero_dim_memory() + && set_default_formats(); + + if (!ok) return status::unimplemented; + + status_t status = jit_conf(); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + init_scratchpad(); + + return status; + } + + jit_conv_conf_2x3_wino_t jcp_; + + protected: + status_t jit_conf(); + void init_scratchpad(); + + bool set_default_formats() { + using namespace format_tag; + return set_default_formats_common(nhwc, any, nhwc); + } + }; + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type acc_data_t; + typedef typename prec_traits::type dst_data_t; + + jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd); + ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad) + const; + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + void execute_forward_mbN(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t *kernel_; + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t *src_trans_; + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t *dst_trans_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp new file mode 100644 index 000000000000..f4ec29ab00e3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp @@ -0,0 +1,820 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::maybe_eltwise(int position) +{ + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* eltwise before sum */ + return p.contain(eltwise, 0); + } else if (position == 1) { + /* eltwise after sum */ + return p.contain(sum, 0) && p.contain(eltwise, 1); + } + + return false; +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_bcast_data, reg_bcast_data); + + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_off)); + + Label bcast_loop; + Label bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + reduce_loop(load_loop_blk, jcp.ur, i, false); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } + else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + int output_offset = jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep; + + add(aux_reg_output_data, output_offset); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); + L(bcast_loop_tail_out); + } +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::cvt2ps(data_type_t type_in, + zmm_t zmm_in, const Xbyak::Operand &op, bool mask_flag) { + zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in; + switch (type_in) { + case data_type::f32: + case data_type::s32: vmovups(zmm, op); break; + case data_type::s8: vpmovsxbd(zmm, op); break; + case data_type::u8: vpmovzxbd(zmm, op); break; + default: assert(!"unsupported data type"); + } + if (type_in != data_type::f32) + vcvtdq2ps(zmm_in, zmm_in); +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk, + int ur, int substep, bool wraparound) +{ + auto vreg_load = [=](int i_load) { + return Zmm(ur * load_loop_blk + i_load); + }; + + auto vreg_accum = [=](int i_load, int i_ur) { + return Zmm(i_ur * load_loop_blk + i_load); + }; + + auto zmm_bias_alpha = [=]() { + return Zmm(ur * load_loop_blk); + }; + + auto xmm_bias_alpha = [=]() { + return Xmm(ur * load_loop_blk); + }; + auto bias_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_bias_data, + jcp.typesize_bia * jcp.oc_block * i_load); + }; + + auto comp_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_comp_data, + sizeof(int32_t) * jcp.oc_block * i_load); + }; + + auto scale_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_ptr_scales, + jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load)); + }; + + auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) { + assert(i_ur < jcp.ur); + assert(i_reduce <= jcp.reduce_loop_unroll); + assert(jcp.reduce_loop_unroll == jcp.reduce_block); + + int offt = (jcp.ic_without_padding * i_ur + i_reduce); + + return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt, + bcast); + }; + + auto load_ptr = [=](int i_reduce, int i_load) { + int u0 = i_reduce % jcp.reduce_loop_unroll; + int u1 = i_reduce / jcp.reduce_loop_unroll; + + int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block; + + return EVEX_compress_addr(aux_reg_load_data, + u1 * jcp.reduce_loop_load_step + + jcp.typesize_in * offt); + }; + + auto output_ptr = [=](int i_load, int i_ur) { + return EVEX_compress_addr(aux_reg_output_data, + jcp.typesize_out * (jcp.oc_without_padding * i_ur + + i_load * jcp.load_block)); + }; + + auto init = [=]() { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + vpxord(r, r, r); + } + if (jcp.signed_input) { + xor_(reg_scratch, reg_scratch); + Reg8 _t8 = reg_scratch.cvt8(); + mov(_t8, (int8_t)-128); + vpbroadcastb(zmm_shift, _t8); + } + }; + + auto store = [=](const bool mask_flag_in) { + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = (sum_idx != -1) + ? &p.entry_[sum_idx].sum.scale + : nullptr; + mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); + mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off)); + if (p_sum_scale && *p_sum_scale != 1.f) { + mov(EVEX_compress_addr(rsp, reg_load_data_off), reg_load_data); + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + } + if (jcp.signed_input && jcp.ver != ver_vnni) { + mov(reg_scratch, float2int(jcp.wei_adj_scale)); + vmovq(xmm_bias_alpha(), reg_scratch); + vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha()); + } + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1; + auto zmm_bias = zmm_tmp; + auto zmm_comp = zmm_bcast; + if (jcp.with_bias) { + if (jcp.signed_input) + mov(reg_bias_data, + EVEX_compress_addr(rsp,reg_bias_data_off)); + cvt2ps(jcp.bia_dt, zmm_bias, bias_ptr(i_load), mask_flag); + if (jcp.signed_input && jcp.ver != ver_vnni) + vmulps(zmm_bias, zmm_bias, zmm_bias_alpha()); + } + if (jcp.signed_input) { + mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); + cvt2ps(data_type::s32, zmm_comp, comp_ptr(i_load), mask_flag); + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + vcvtdq2ps(r, r); + if (jcp.signed_input) + vaddps(r, r, zmm_comp); + if (jcp.with_bias) + vaddps(r, r, zmm_bias); + + zmm_t mask_zmm = mask_flag ? r | ktail_mask | T_z : r; + vmulps(mask_zmm, r, scale_ptr(i_load)); + } + } + + if (maybe_eltwise(0)) + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + if (p_sum_scale) { // post_op: sum + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + const bool mask_flag = mask_flag_in && + i_load == load_loop_blk - 1; + for (int i_ur = 0; i_ur < ur; ++i_ur) { + vpxord(zmm_zero, zmm_zero, zmm_zero); + auto zmm_prev_dst = zmm_zero; + + auto r = vreg_accum(i_load, i_ur); + cvt2ps(jcp.dst_dt, zmm_prev_dst, output_ptr(i_load, i_ur), + mask_flag); + + if (*p_sum_scale == 1.f) + vaddps(r, zmm_prev_dst); + else + vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); + } + } + } + + if (maybe_eltwise(1)) + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + const bool mask_flag = mask_flag_in && + i_load == load_loop_blk - 1; + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + if (jcp.dst_dt == data_type::u8) { + vpxord(zmm_zero, zmm_zero, zmm_zero); + vmaxps(r, zmm_zero, r); + } + if (jcp.dst_dt != data_type::f32) + vcvtps2dq(r, r); + } + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + zmm_t r_zmm = mask_flag ? r | ktail_mask : r; + + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: + vmovups(output_ptr(i_load, i_ur), r_zmm); break; + case data_type::s8: + vpmovsdb(output_ptr(i_load, i_ur), r_zmm); break; + case data_type::u8: + vpmovusdb(output_ptr(i_load, i_ur), r_zmm); break; + default: assert(!"unknown dst_dt"); + } + } + } + mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_load_data, EVEX_compress_addr(rsp, reg_load_data_off)); + }; + + auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddubsw(zmm_tmp, vreg_src, vreg_wei); + vpmaddwd(zmm_tmp, zmm_tmp, zmm_one); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } + }; + + auto fma_block = [=](bool last_block) { + int reduce_step = 4; + int tail_size = jcp.ic_without_padding % reduce_step; + int loop_unroll = last_block && jcp.ic != jcp.ic_without_padding + ? rnd_up(jcp.ic_without_padding % jcp.ic_block, reduce_step) + : jcp.reduce_loop_unroll; + for (int i_reduce = 0; i_reduce < loop_unroll; + i_reduce += reduce_step) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load)); + for (int i_ur = 0; i_ur < ur; ++i_ur) { + if (last_block && tail_size != 0 + && i_reduce == loop_unroll - reduce_step) { + Xmm xmm_bcast = Xmm(zmm_bcast.getIdx()); + for (int r = 0; r < tail_size; ++r) + vpinsrb(xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data + + jcp.ic_without_padding * i_ur + i_reduce + r], r); + vpbroadcastd(zmm_bcast, xmm_bcast); + } else { + vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false)); + } + if (jcp.signed_input) + vpsubb(zmm_bcast, zmm_bcast, zmm_shift); + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + compute(vreg_accum(i_load, i_ur), + vreg_load(i_load), zmm_bcast); + } + } + } + }; + + Label reduce_loop; + Label reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + if (jcp.ic != jcp.ic_without_padding) { + fma_block(true); + } else { + fma_block(false); + } + + if (jcp.oc_without_padding != jcp.oc) { + Label end_store, common_store; + mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); + + /*Check if it is the last load_loop_blk*/ + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + cmp(reg_load_loop_work, 0); + jg(common_store, T_NEAR); + + /*Check if it is the last ocb*/ + test(reg_reduce_pos_flag, FLAG_OC_LAST); + jz(common_store, T_NEAR); + + store(true); + jmp(end_store, T_NEAR); + + L(common_store); + store(false); + + L(end_store); + + add(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + } else { + store(false); + } +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() +{ + preamble(); + + xor_(reg_scratch, reg_scratch); + Reg16 _t = reg_scratch.cvt16(); + mov(_t, 0x1); + vpbroadcastw(zmm_one, _t); + + sub(rsp, stack_space_needed); + + if (jcp.oc_without_padding != jcp.oc) { + int tail_size = jcp.oc_without_padding % jcp.oc_block; + int mask = (1 << tail_size) - 1; + Reg32 regw_tmp = reg_last_load.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + + if (jcp.with_bias) + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + if (jcp.signed_input) { + mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); + mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]); + mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data); + } + mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); + mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales); + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + + + auto load_loop_body = [=](int load_loop_blk) { + bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + if (jcp.with_bias) { + if (jcp.signed_input) + mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_off)); + add(reg_bias_data, + load_loop_blk * jcp.load_block * jcp.typesize_bia); + if (jcp.signed_input) + mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); + } + if (jcp.signed_input) { + mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); + add(reg_comp_data, + load_loop_blk * jcp.load_block * sizeof(int32_t)); + mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data); + } + mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); + mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off)); + add(reg_ptr_scales, + jcp.is_oc_scale * load_loop_blk * jcp.load_block * sizeof(float)); + mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales); + mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); + add(reg_output_data, + load_loop_blk * jcp.load_block * jcp.typesize_out); + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + const int simd_w = 16; + + Label load_loop_blk[7]; + + static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 }; + const int size_ur_cases_fma = sizeof(ur_cases_fma_expl_bcast); + const int *ur_cases_fma = ur_cases_fma_expl_bcast; + const int *ur_cases = ur_cases_fma; + const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases); + + for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { + int label_idx = num_ur_cases - ur_idx - 1; + if (jcp.ur <= ur_cases[ur_idx]) { + cmp(reg_load_loop_work, simd_w * (label_idx + 1)); + jle(load_loop_blk[label_idx], T_NEAR); + } + } + + for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { + if (jcp.ur <= ur_cases[ur_idx]) { + int label_idx = num_ur_cases - ur_idx - 1; + L(load_loop_blk[label_idx]); + { + if (label_idx == 0) { + cmp(reg_load_loop_work, 0); + je(load_loop_blk[num_ur_cases], T_NEAR); + } + + for (int _i = 1; _i <= label_idx + 1; _i++) { + prefetcht0(ptr [ reg_load_data + _i * jcp.ic * jcp.oc_block ]); + prefetcht1(ptr [ reg_output_data + _i * jcp.oc_block ]); + } + + load_loop_body(label_idx + 1); + if (label_idx - 1 > 0) { + cmp(reg_load_loop_work, 2 * label_idx * simd_w); + je(load_loop_blk[label_idx - 1], T_NEAR); + } + cmp(reg_load_loop_work, (label_idx + 1) * simd_w); + jge(load_loop_blk[label_idx]); + } + for (int idx = label_idx - 1; idx > 0; --idx) { + cmp(reg_load_loop_work, simd_w * (idx + 1)); + je(load_loop_blk[idx], T_NEAR); + } + if (ur_idx < num_ur_cases - 2) { + cmp(reg_load_loop_work, simd_w); + jle(load_loop_blk[0], T_NEAR); + } + } + } + L(load_loop_blk[num_ur_cases]); + + add(rsp, stack_space_needed); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_eltwise(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_eltwise(1)) + || (p.contain(sum, 1) && is_eltwise(0)); + default: return false; + } + + return false; +} + +status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf( + jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d, + const primitive_attr_t &attr, int nthreads, bool reduce_src) { + if (!mayiuse(avx512_core)) return status::unimplemented; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + if (!one_of(src_d.data_type(), data_type::u8, data_type::s8) + || weights_d.data_type() != data_type::s8 + || !one_of(dst_d.data_type(), + data_type::f32, data_type::s32, data_type::s8, data_type::u8)) + return status::unimplemented; + jcp.ver = ver_avx512_core; + if (mayiuse(avx512_core_vnni)) + jcp.ver = ver_vnni; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = jcp.ic; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + jcp.tr_is = rnd_up(jcp.is, 4); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + format_tag_t dat_tag = format_tag::nhwc; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + const int simd_w = 16; + + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + + args_ok = true + && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.ic_block = jcp.oc_block = simd_w; + + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + jcp.typesize_bia = jcp.with_bias + ? types::data_type_size(bias_d.data_type()) + : 0; + + const int SMALL_SPATIAL = 7 * 7; + const int BIG_REDUCE_DIM = 1024; + + int load_blocking = 0; + int load_blocking_max = 0; + int bcast_blocking = 0; + int bcast_blocking_max = 0; + int reduce_blocking = 0; + int reduce_blocking_max = 0; + jcp.load_grp_count = 1; + jcp.use_vmovntps = false; + + const int L2_size = get_cache_size(2, true) / sizeof(jcp.typesize_in); + const int L2_capacity = (L2_size * 3) / 4; + + int size_treshold = 28; + int max_regs = 0; + int min_regs = 6; + if (jcp.ver == ver_vnni) + max_regs = ((jcp.oh > size_treshold && jcp.ow > size_treshold) + && (jcp.oc < 128 || jcp.ic < 128)) ? min_regs : 9; + else + max_regs = 8; + jcp.expl_bcast = true; + + if (jcp.mb == 1 && jcp.ic > 128 + && (jcp.oh <= size_treshold && jcp.ow <= size_treshold)) { + if (jcp.os <= SMALL_SPATIAL && jcp.oc * jcp.ic < L2_size) + max_regs = min_regs; // mobilenet_v2 performance improvement + jcp.ur = nstl::min(max_regs, jcp.os); + } else { + const int spatial = jcp.oh; + jcp.ur = 1; + for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) { + if ((spatial >= size_treshold && spatial % ur_w == 0) + || (spatial < size_treshold && jcp.os % ur_w == 0)) { + jcp.ur = ur_w; + break; + } + } + if (jcp.ur == 1) { + jcp.ur = nstl::min(max_regs, jcp.os); + int os_tail = jcp.os % max_regs; + for (int i = max_regs; i >= min_regs; i--) { + int i_tail = jcp.os % i; + if (i_tail > os_tail || i_tail == 0) { + jcp.ur = i; + os_tail = i_tail; + if (i_tail == 0) + break; + } + } + } + } + + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.typesize_in; + + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; + + jcp.bcast_loop_output_step = jcp.ur * jcp.oc_without_padding * jcp.typesize_out; + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_without_padding * jcp.typesize_in; + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step + = jcp.reduce_dim * jcp.load_block * jcp.typesize_in; + + jcp.load_loop_iter_step = jcp.load_block; + + jcp.loop_order = reduce_src ? loop_blr : loop_lbr; + + int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + reduce_blocking = nb_reduce; + if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 64; + else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 16; + reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); + reduce_blocking *= jcp.reduce_block; + + bool cmp_reduce = reduce_blocking <= jcp.reduce_dim; + if (cmp_reduce) + jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; + load_blocking = jcp.load_dim; + + jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast); + jcp.load_grp_count = best_divider( + nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false); + + if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.load_dim * jcp.reduce_dim >= L2_size) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); + } else if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.mb <= nthreads + && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); // + load_blocking = jcp.load_block; + } + + bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, + div_up(nthreads, jcp.load_grp_count)) * jcp.bcast_block; + bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking); + bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); + + int space_for_bcast + = (L2_capacity - /* kernel_size - */ + 2 * jcp.load_block * reduce_blocking + - jcp.ur * reduce_blocking - 3 * 1024); + if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) + space_for_bcast /= 2; + + int bcast_in_cache + = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); + bcast_blocking = nstl::min( + bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); + + load_blocking_max = load_blocking; + bcast_blocking_max = bcast_blocking * 3 / 2; + reduce_blocking_max = reduce_blocking; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + assert(reduce_blocking_max); + assert(load_blocking % jcp.load_block == 0); + assert(reduce_blocking % jcp.reduce_block == 0); + assert(load_blocking_max % jcp.load_block == 0); + assert(reduce_blocking_max % jcp.reduce_block == 0); + + assert(jcp.reduce_loop_unroll % 4 == 0); + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + + assert(jcp.bcast_block % jcp.ur == 0); + assert(jcp.reduce_dim % jcp.reduce_block == 0); + + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + // miniumum size of load dim chunk for work distribution within threads + jcp.nb_load_chunk = 1; + // peformance improvements for googlenet_v3, mb=1; + // TODO: generalize this condition and rewrite it in appropriate manner + if (jcp.mb == 1 && jcp.nb_load % 4 == 0 && jcp.ic / jcp.oc >= 4 + && jcp.ic * jcp.oc <= L2_size) { + jcp.nb_load_chunk = 4; + jcp.load_grp_count = nstl::max(jcp.nb_load / 4, jcp.load_grp_count); + } + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + jcp.wei_adj_scale = + (weights_d.extra().flags | memory_extra_flags::scale_adjust) + ? weights_d.extra().scale_adjust : 1.f; + + return status::success; +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + using namespace mkldnn::impl::memory_tracking::names; + + if (jcp.signed_input && jcp.ver != ver_vnni) { + dim_t count = nstl::max(attr.output_scales_.count_, 16); + scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp new file mode 100644 index 000000000000..22e9732a1f3b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp @@ -0,0 +1,131 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP +#define JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_core_x8s8s32x_1x1_conv_kernel: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_1x1_conv_fwd_ker_t) + jit_avx512_core_x8s8s32x_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) : jcp(ajcp), attr_(attr), + eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode(); + } + + ~jit_avx512_core_x8s8s32x_1x1_conv_kernel() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const memory_desc_wrapper &bias_d, + const primitive_attr_t &attr, + int nthreads, bool reduce_src); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr); + + bool maybe_eltwise(int position); + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + + private: + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + using reg64_t = const Xbyak::Reg64; + using zmm_t = const Xbyak::Zmm; + using mask_t = const Xbyak::Opmask; + + reg64_t reg_bcast_data = r8; + reg64_t reg_ptr_scales = r8; + reg64_t reg_output_data = r9; + reg64_t reg_load_data = r10; + reg64_t reg_ptr_sum_scale = r10; + reg64_t reg_reduce_loop_work = r11; + reg64_t reg_bias_data = r12; + reg64_t reg_comp_data = r12; + reg64_t reg_scratch = r13; + reg64_t aux_reg_bcast_data = r14; + reg64_t aux_reg_load_data = r15; + reg64_t imm_addr64 = r15; + reg64_t reg_reduce_pos_flag = rax; + reg64_t aux1_reg_bcast_data = rbx; + reg64_t reg_bcast_loop_work = rbx; + reg64_t bcast_loop_iter = rdx; // Note: Fix me + reg64_t reg_load_loop_work = rsi; + reg64_t aux_reg_output_data = abi_not_param1; + reg64_t reduce_loop_iter = abi_param1; + + reg64_t reg_last_load = r8; + mask_t ktail_mask = k6; + + mask_t vmask = k7; + + Xbyak::Zmm zmm_tmp = Xbyak::Zmm(28); + Xbyak::Zmm zmm_one = Xbyak::Zmm(29); + Xbyak::Zmm zmm_zero = Xbyak::Zmm(30); + Xbyak::Zmm zmm_bcast = Xbyak::Zmm(31); + Xbyak::Zmm zmm_shift = Xbyak::Zmm(30); + + Xbyak::Zmm zmm_bias_alpha = Xbyak::Zmm(31); + Xbyak::Xmm xmm_bias_alpha = Xbyak::Xmm(31); + + int bcast_loop_work_off = 0; + int reg_bias_data_off = 8; + int reg_bcast_data_off = 16; + int reg_load_data_off = 24; + int reg_ptr_sum_scale_off = 32; + int reg_comp_data_off = 40; + int stack_space_needed = 48; + + void bcast_loop(int load_loop_blk); + void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); + + void generate(); + void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op, + bool mask_flag); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp new file mode 100644 index 000000000000..0bf09fc6776e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp @@ -0,0 +1,292 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +namespace { +template +void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, + T nx, T &nx_start, T &nx_end, T nx_divider) +{ + const T grp_size = utils::div_up(nthr, nx_divider); + const T grp_count = utils::div_up(nthr, grp_size); + + T grp = ithr / grp_size; + T grp_ithr = ithr % grp_size; + T grp_nthr = grp_size; + T first_grps = nthr % grp_count; + if (first_grps > 0 && grp >= first_grps) { + ithr -= first_grps * grp_size; + grp_nthr--; + grp = ithr / grp_nthr + first_grps; + grp_ithr = ithr % grp_nthr; + } + balance211(nx, grp_count, grp, nx_start, nx_end); + balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end); +} +} + +/* convolution forward */ +template +void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const +{ + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) { + auto local_scales = scratchpad.template get( + key_conv_adjusted_scales); + auto scales = pd()->attr()->output_scales_.scales_; + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, scales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = scales[c] * factor; + } + } + + parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad); + }); +} + +template +void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t +::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src, + const wei_data_t *weights, const char *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad.get(key_conv_rtus_space); + auto local_scales = scratchpad.get(key_conv_adjusted_scales); + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + const int stride_h = pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[1]; + const int pad_t = pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][1]; + + const auto &oscales = pd()->attr()->output_scales_; + + int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block) + * jcp.oc_block * jcp.ic_block; + wei_data_t *w = const_cast(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast(w + offset) : 0; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto p = jit_1x1_conv_call_s(); + + auto rp = rtus_driver_t::call_params_t(); + const int nb_oc = jcp.nb_load; + const int os_block = jcp.bcast_block; + + int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, + jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end, + jcp.load_grp_count); + if (jcp.nb_load_chunk > 1) { + ocb_start *= jcp.nb_load_chunk; + ocb_end *= jcp.nb_load_chunk; + } + + auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, + int &oh, int &ow, int &ih, int &iw) + { + int osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + oh = os / jcp.ow; + ow = os % jcp.ow; + + ih = nstl::max(oh * stride_h - pad_t, 0); + iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + }; + + auto init_load = [&](int ocb, int &load_step) + { + load_step = step(jcp.nb_load_blocking, ocb_end - ocb, + jcp.nb_load_blocking_max); + p.load_dim = this_block_size(ocb * jcp.oc_block, + ocb_end * jcp.oc_block, load_step * jcp.oc_block); + + if (ocb + load_step >= nb_oc) + p.first_last_flag |= FLAG_OC_LAST; + else + p.first_last_flag &= ~FLAG_OC_LAST; + + }; + + auto init_reduce = [&]() + { + p.reduce_dim = this_block_size(0, jcp.ic, jcp.ic); + rp.icb = p.reduce_dim / jcp.reduce_block; + }; + + auto inner_ker = [&](int ocb, int n, int g, int oh, int ow, + int ih, int iw) + { + const int icb = 0; // Start from the first IC block + const int _ocb = g * nb_oc + ocb; + const int _icb = g; + + const size_t dst_off = dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow); + + p.output_data = &dst[dst_off]; + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size]; + p.compensation = (jcp.signed_input) + ? &compensation[_ocb * jcp.oc_block] : 0; + p.scales = (jcp.signed_input && jcp.ver != ver_vnni) + ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block] + : &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block]; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ + + _icb * jcp.is * jcp.ic_block; + if (ocb == ocb_start) { + rp.src = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw); + rtus_driver_->ker_(&rp); + } + p.bcast_data = rp.ws; + } else + p.bcast_data = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw); + + kernel_->jit_ker(&p); + }; + + if (jcp.loop_order == loop_rlb) { + init_reduce(); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + inner_ker(ocb, n, g, oh, ow, ih, iw); + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == loop_lbr) { + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + init_reduce(); + inner_ker(ocb, n, g, oh, ow, ih, iw); + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == loop_rbl) { + init_reduce(); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + inner_ker(ocb, n, g, oh, ow, ih, iw); + ocb += load_step; + } + iwork += bcast_step; + } + } else if (jcp.loop_order == loop_blr) { + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + init_reduce(); + inner_ker(ocb, n, g, oh, ow, ih, iw); + ocb += load_step; + } + iwork += bcast_step; + } + } else { + assert(!"unsupported loop order"); + } +} + +using namespace data_type; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp new file mode 100644 index 000000000000..ad9027ac174d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp @@ -0,0 +1,159 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp" +#include "jit_uni_1x1_conv_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t : public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_int8_1x1:", avx512_core, ""), + jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t< + src_type, dst_type>); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, data_type::s8, data_type::undef, + dst_type, data_type::s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), format_tag::any, + dat_tag()) + && set_or_check_wei_format(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md()); + + status_t status = jit_avx512_core_x8s8s32x_1x1_conv_kernel:: + init_conf(jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), + *weights_md(1), *attr(), mkldnn_get_max_threads(), + rtus_.reduce_src_); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad( + scratchpad, jcp_, *attr()); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + format_tag_t dat_tag() const { return format_tag::nhwc; } + + bool set_or_check_wei_format() { + using namespace format_tag; + + const bool is_src_s8 = src_md_.data_type == data_type::s8; + format_tag_t wei_tag = with_groups() ? gOIhw4i16o4i : OIhw4i16o4i; + + memory_desc_t want_wei_md = weights_md_; + memory_desc_init_by_tag(want_wei_md, wei_tag); + if (is_src_s8) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups() ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md_.format_kind == format_kind::any) { + weights_md_ = want_wei_md; + return true; + } + + return weights_md_ == want_wei_md; + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = new jit_avx512_core_x8s8s32x_1x1_conv_kernel(pd()->jcp_, + *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + + private: + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src, const wei_data_t *weights, + const char *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_x8s8s32x_1x1_conv_kernel *kernel_; + rtus_driver_t *rtus_driver_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp new file mode 100644 index 000000000000..e89d0683028e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp @@ -0,0 +1,140 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "type_helpers.hpp" +#include "primitive_iterator.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_deconvolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t + : public cpu_primitive_t { + struct pd_t : public cpu_deconvolution_fwd_pd_t { + pd_t(engine_t *engine, const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) {} + + pd_t(const pd_t &other) + : cpu_deconvolution_fwd_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) + {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), + jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t); + + status_t init_convolution() { + convolution_desc_t cd; + status_t status; + + auto dd = desc(); + status = conv_desc_init(&cd, prop_kind::forward_training, + alg_kind::convolution_direct, &(dd->src_desc), + &(dd->weights_desc), &(dd->bias_desc), &(dd->dst_desc), + dd->strides, dd->dilates, dd->padding[0], dd->padding[1], + dd->padding_kind); + + if (status == status::success) { + status = mkldnn_primitive_desc::create( + &conv_pd_, (op_desc_t *)&cd, &attr_, engine_, nullptr); + } + + if (status == status::success) + status = set_default_params(); + + return status; + }; + + status_t init() { + bool ok = true + && is_fwd() + && desc()->alg_kind == alg_kind::deconvolution_direct + && !has_zero_dim_memory() + && desc()->src_desc.data_type == src_type + && desc()->dst_desc.data_type == dst_type + && desc()->weights_desc.data_type == data_type::s8 + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && desc()->accum_data_type == data_type::s32; + if (!ok) return status::unimplemented; + + CHECK(init_convolution()); + + return status::success; + } + + virtual void init_scratchpad_md() override { + const auto conv_1x1_pd = static_cast(conv_pd_); + scratchpad_md_ = *conv_1x1_pd->scratchpad_md(); + } + + protected: + status_t set_default_params() { + auto conv_1x1_pd_ = static_cast(conv_pd_); + src_md_ = *conv_1x1_pd_->src_md(); + dst_md_ = *conv_1x1_pd_->dst_md(); + weights_md_ = *conv_1x1_pd_->weights_md(); + if (with_bias()) + bias_md_ = *conv_1x1_pd_->weights_md(1); + return status::success; + } + + using conv_pd_t = typename jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t + ::pd_t; + friend jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t; + primitive_desc_t *conv_pd_; + }; + + jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + + ~jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t() + { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + return conv_p_->execute(ctx); + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + primitive_t *conv_p_; +}; + +} +} +} + +#endif /* CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp new file mode 100644 index 000000000000..10e98a00c45a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp @@ -0,0 +1,1182 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +namespace { +void pick_loop_order(jit_conv_conf_t &jcp, int nthr) +{ + jcp.loop_order = loop_cwgn; + if (jcp.ngroups > 1) { + jcp.loop_order = loop_ngcw; + if (jcp.mb < nthr) + jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg; + } +} +} + +template +bool _jit_avx512_core_x8s8s32x_fwd_kernel::maybe_eltwise(int position) +{ + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* eltwise before sum */ + return p.contain(eltwise, 0); + } else if (position == 1) { + /* eltwise after sum */ + return p.contain(sum, 0) && p.contain(eltwise, 1); + } + + return false; +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::prepare_output(int ur_w) +{ + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + for (int k = 0; k < nb_oc_block; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + vpxord(vmm, vmm, vmm); + } + if (jcp.signed_input) { + xor_(reg_scratch, reg_scratch); + if (jcp.is_depthwise && !jcp.is_fast_depthwise) { + Reg32 _t32 = reg_scratch.cvt32(); + mov(_t32, (uint32_t)128); + vpbroadcastd(vmm_shift, _t32); + } else { + Reg8 _t8 = reg_scratch.cvt8(); + mov(_t8, (int8_t)128); + vpbroadcastb(vmm_shift, _t8); + } + } +} + +template +const Vmm _jit_avx512_core_x8s8s32x_fwd_kernel:: + vmm_mask(const Vmm vmm_in, bool mask_flag, bool store) { + return vmm_in; +} + +template<> +const Zmm _jit_avx512_core_x8s8s32x_fwd_kernel:: + vmm_mask(const Zmm zmm_in, bool mask_flag, bool store) { + return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z) + : zmm_in; +} + + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::cvt2ps(data_type_t type_in, + const Vmm vmm_in, const Operand &op, bool mask_flag) { + //const Vmm vmm = mask_flag ? vmm_in | ktail_mask | T_z : vmm_in; + const Vmm vmm = vmm_mask(vmm_in, mask_flag); + switch (type_in) { + case data_type::f32: + case data_type::s32: vmovups(vmm, op); break; + case data_type::s8: vpmovsxbd(vmm, op); break; + case data_type::u8: vpmovzxbd(vmm, op); break; + default: assert(!"unsupported data type"); + } + if (type_in != data_type::f32) + vcvtdq2ps(vmm_in, vmm_in); +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_eltwise(int ur_w) { + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + if (ur_w == jcp.ur_w) + eltwise_injector_->compute_vector_range(0, nb_oc_block * jcp.ur_w); + else + for (int k = 0; k < nb_oc_block; k++) + eltwise_injector_->compute_vector_range(k * jcp.ur_w, + k * jcp.ur_w + ur_w); +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::store_output( + int ur_w, bool last_oc_block_flag) { + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block; + + mov(reg_bias, ptr[param1 + GET_OFF(bias)]); + mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); + if (jcp.signed_input) + mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); + + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = nullptr; + if (sum_idx != -1) { + const auto &p_entry = p.entry_[sum_idx]; + p_sum_scale = &p_entry.sum.scale; + } + + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + if (jcp.signed_input && jcp.ver != ver_vnni) { + /* put 'wei_adj_scale = 0.5' for bias calculation */ + mov(reg_bias_alpha, float2int(jcp.wei_adj_scale)); + vmovq(xmm_bias_alpha(), reg_bias_alpha); + vbroadcastss(vmm_bias_alpha(), xmm_bias_alpha()); + } + + for (int k = 0; k < nb_oc_block; k++) { + const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; + int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block); + if (jcp.with_bias) { + int bias_offset = jcp.typesize_bia * k * oc_block; + auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); + + cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag); + if (jcp.signed_input && jcp.ver != ver_vnni) + /* bias *= 0.5 */ + vmulps(vmm_bias, vmm_bias, vmm_bias_alpha()); + } + if (jcp.signed_input) { + int comp_offset = sizeof(int32_t) * k * oc_block; + auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset); + + cvt2ps(data_type::s32, vmm_comp, comp_addr, mask_flag); + } + /* add to zmm_accum: compensation, bias and permute */ + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + if (jcp.is_fast_depthwise) + vpermd(zmm_out(j, k), zmm_permute, zmm_out(j, k)); + vcvtdq2ps(vmm, vmm); + if (jcp.signed_input) + vaddps(vmm, vmm, vmm_comp); + if (jcp.with_bias) + vaddps(vmm, vmm, vmm_bias); + + const Vmm vmm_k = vmm_mask(vmm, mask_flag); + vmulps(vmm_k, vmm, + EVEX_compress_addr(reg_ptr_scales, scale_offset)); + } + } + + /* Do post-ops */ + if (maybe_eltwise(0)) compute_eltwise(ur_w); + if (p_sum_scale) { // post_op: sum + for (int k = 0; k < nb_oc_block; k++) { + const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; + for (int j = 0; j < ur_w; j++) { + int aux_output_offset + = jcp.typesize_out + * (k * oc_block + + j * jcp.oc_without_padding * jcp.ngroups); + auto addr = EVEX_compress_addr(reg_out, aux_output_offset); + Vmm vmm = vmm_out(j, k); + cvt2ps(jcp.dst_dt, vmm_prev_dst, addr, mask_flag); + if (*p_sum_scale == 1.f) + vaddps(vmm, vmm_prev_dst); + else + vfmadd231ps(vmm, vmm_prev_dst, zword_b[reg_ptr_sum_scale]); + } + } + } + if (maybe_eltwise(1)) compute_eltwise(ur_w); + + /* write out register to output_addr */ + for (int k = 0; k < nb_oc_block; k++) { + const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + if (jcp.dst_dt == data_type::u8) { + vpxord(vmm_zero, vmm_zero, vmm_zero); + vmaxps(vmm, vmm_zero, vmm); + } + + if (jcp.dst_dt != data_type::f32) { + /* Note: using Zmm for rounding in Xmm/Ymm kernel + because there is no instruction to do rounding + from Xmm/Ymm -> Xmm/Ymm. + Embedded rounding is not supported for Xmm. + TODO: maybe avoid Zmm if it helps performance.*/ + Zmm zmm = zmm_out(j, k); + vcvtps2dq(zmm, zmm); + } + } + + for (int j = 0; j < ur_w; j++) { + int aux_output_offset = jcp.typesize_out + * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups); + auto addr = EVEX_compress_addr(reg_out, aux_output_offset); + + Vmm vmm = vmm_out(j, k); + const Vmm r_vmm = vmm_mask(vmm, mask_flag, true); + + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: vmovups(addr, r_vmm); break; + case data_type::s8: vpmovsdb(addr, r_vmm); break; + case data_type::u8: vpmovusdb(addr, r_vmm); break; + default: assert(!"unknown dst_dt"); + } + } + } + +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { + assert(!"invalid group blocking for depthwise convolution"); +} + +template <> +void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { + + auto input_spatial_index = [=](int oi, int ki) { + return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l); + }; + + auto input_offset2 = [=](int ii, int ci) { + return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block); + }; + + auto input_offset3 = [=](int oi, int ci, int ki) { + return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci); + }; + + auto kernel_offset = [=](int ci, int ki) { + return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block); + }; + + auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { + // okay for depthwise since src is zero-extended + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddwd(zmm_tmp, vreg_src, vreg_wei); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } + }; + + int ii_start = 0; + int ii_end = -1; + if (jcp.is_resrc_depthwise && !h_padded) { + // find bounds of input spatial indices + bool first = true; + for (int ki = 0; ki < jcp.kw; ki++) { + int oi_start = get_ow_start(ki, pad_l); + int oi_end = get_ow_end(ur_w, ki, pad_r); + for (int oi = oi_start; oi < oi_end; oi++) { + int ii = input_spatial_index(oi, ki); + if (first || ii < ii_start) + ii_start = ii; + if (first || ii > ii_end) + ii_end = ii; + first = false; + } + } + } + + if (jcp.signed_input) { + vpxord(zmm_shifted_zero, zmm_shifted_zero, zmm_shifted_zero); + vpaddb(zmm_shifted_zero, zmm_shifted_zero, vmm_shift); + } + for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) { + const bool mask_flag = last_ic_block_flag != no_last_block + && ci == jcp.nb_ch_blocking - 1; + if (jcp.is_resrc_depthwise && !h_padded) { + // now we can load input once and reuse up to jcp.kw times + for (int ii = ii_start; ii <= ii_end; ii++) { + int aux_input_offset = input_offset2(ii, ci); + const Zmm zmm_inp_tmp = zmm_inp(ii, jcp.nb_ch_blocking); + const Zmm zmm_inp_msk = mask_flag + ? zmm_inp_tmp | ktail_mask | T_z + : zmm_inp_tmp; + if (jcp.is_fast_depthwise) { + assert(!mask_flag); + vbroadcasti32x4(zmm_inp_msk, + EVEX_compress_addr(aux_reg_inp, aux_input_offset)); + } else { + vpmovzxbd(zmm_inp_msk, + EVEX_compress_addr(aux_reg_inp, aux_input_offset)); + } + if (jcp.signed_input) + vpaddb(zmm_inp_tmp, zmm_inp_tmp, vmm_shift); + } + } + for (int ki = 0; ki < jcp.kw; ki++) { + int aux_kernel_offset = kernel_offset(ci, ki); + if (jcp.is_fast_depthwise) { + vbroadcasti32x4(zmm_wei, + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + vmovdqu8(zmm_wei | kblend_mask | T_z, zmm_wei); + } else { + vpmovsxbd(zmm_wei, + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + if (h_padded) { + assert(jcp.signed_input); + for (int oi = 0; oi < ur_w; oi++) + compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero); + } else { + const Zmm r_zmm_src = mask_flag ? zmm_src | ktail_mask : zmm_src; + int oi_start = get_ow_start(ki, pad_l); + int oi_end = get_ow_end(ur_w, ki, pad_r); + int start_ = jcp.signed_input ? 0 : oi_start; + int end_ = jcp.signed_input ? ur_w : oi_end; + for (int oi = start_; oi < end_; oi++) { + if (oi >= oi_start && oi < oi_end) { + if (jcp.is_resrc_depthwise) { + int ii = input_spatial_index(oi, ki); + zmm_src = zmm_inp(ii, jcp.nb_ch_blocking); + } else { + int aux_input_offset = input_offset3(oi, ci, ki); + if (jcp.is_fast_depthwise) { + assert(!mask_flag); + vbroadcasti32x4(r_zmm_src, + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + } else { + vpmovzxbd(r_zmm_src, + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + } + if (jcp.signed_input) + vpaddb(zmm_src, zmm_src, vmm_shift); + } + } else if (jcp.signed_input) { + zmm_src = zmm_shifted_zero; + } + compute(zmm_out(oi, ci), zmm_wei, zmm_src); + } + } + } + } +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, + int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { + if (jcp.is_depthwise) + return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded); + + int kw = jcp.kw; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int ch_block_all = jcp.ch_block * ic_block * oc_block; + + int nb_oc_block = jcp.nb_oc_blocking; + + auto input_offset = [=](int oi, int ic, int ki) { + return jcp.typesize_in + * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) + * jcp.ic_without_padding * jcp.ngroups + 4 * ic); + }; + auto kernel_offset = [=](int ii, int ic, int ki) { + return jcp.typesize_in + * ((ii * jcp.nb_ic * jcp.kh * jcp.kw + ki) * ch_block_all + + 4 * ic * oc_block); + }; + auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddubsw(vmm_tmp, vreg_src, vreg_wei); + vpmaddwd(vmm_tmp, vmm_tmp, vmm_one); + vpaddd(vreg_acc, vreg_acc, vmm_tmp); + } + }; + + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_ow_start(ki, pad_l); + int jj_end = get_ow_end(ur_w, ki, pad_r); + int tail_size = jcp.ic_without_padding % 4; + int _start = (jcp.signed_input) ? 0 : jj_start; + int _end = (jcp.signed_input) ? ur_w : jj_end; + /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */ + int icb = (last_ic_block_flag != no_last_block) + ? div_up((jcp.ic_without_padding % ic_block), 4) + : ic_block / 4; + for (int ic = 0; ic < icb; ic++) { + if (h_padded == true) { + /* fill padded area with shifted values */ + Vmm inp = vmm_inp(0,nb_oc_block); + vpxord(inp, inp, inp); + vpaddb(inp, inp, vmm_shift); + } else { + for (int jj = _start; jj < _end; jj++) { + int aux_input_offset = input_offset(jj, ic, ki); + if (jj >= jj_start && jj < jj_end) { + if (last_ic_block_flag == last_sp_block + && tail_size != 0 && ic == icb - 1) { + Xmm xmm_tmp = Xmm(vmm_inp(jj, nb_oc_block).getIdx()); + for (int r = 0; r < tail_size; ++r) + vpinsrb(xmm_tmp, xmm_tmp, + ptr[aux_reg_inp + aux_input_offset + r], r); + vpbroadcastd(vmm_inp(jj, nb_oc_block), xmm_tmp); + } else { + vpbroadcastd(vmm_inp(jj, nb_oc_block), + EVEX_compress_addr( + aux_reg_inp, aux_input_offset)); + } + if (jcp.signed_input) + vpaddb(vmm_inp(jj, nb_oc_block), + vmm_inp(jj, nb_oc_block), vmm_shift); + } else { + /* fill padded area with shifted values */ + if (jcp.signed_input) { + Vmm inp = vmm_inp(jj, nb_oc_block); + vpxord(inp, inp, inp); + vpaddb(inp, inp, vmm_shift); + } + } + } + } + for (int ii = 0; ii < nb_oc_block; ii++) { + int aux_kernel_offset = kernel_offset(ii, ic, ki); + vmovups(vmm_wei, + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + for (int jj = _start; jj < _end; jj++) { + Vmm inp = (h_padded == true) + ? vmm_inp(0,nb_oc_block) : vmm_inp(jj, nb_oc_block); + compute(vmm_out(jj, ii), vmm_wei, inp); + } + } + } + } +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::kh_loop( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) { + Label kh_label, skip_kh_loop; + Label t_overflow_label, no_t_overflow_label, + b_overflow_label, no_b_overflow_label; + + int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; + int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all; + int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw + * jcp.ic_without_padding * jcp.ngroups; + + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + + if (jcp.signed_input && jcp.ndims > 3) { + mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); + cmp(reg_overflow, 0); + je(no_t_overflow_label, T_NEAR); + L(t_overflow_label); { + compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); + + add(aux_reg_ker, shift_kernel_ptr); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(t_overflow_label, T_NEAR); + } + L(no_t_overflow_label); + } + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + if ((jcp.signed_input) || (!jcp.signed_input && + (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) { + cmp(reg_kj, 0); + je(skip_kh_loop, T_NEAR); + } + L(kh_label); { + compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false); + + add(aux_reg_ker, shift_kernel_ptr); + add(aux_reg_inp, shift_input_ptr); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + L(skip_kh_loop); + if (jcp.signed_input && jcp.ndims > 3) { + mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); + cmp(reg_overflow, 0); + je(no_b_overflow_label, T_NEAR); + L(b_overflow_label); { + compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); + + add(aux_reg_ker, shift_kernel_ptr); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(b_overflow_label, T_NEAR); + } + L(no_b_overflow_label); + } +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::icb_loop( + int ur_w, int pad_l, int pad_r, bool is_last_sp_block) +{ + prepare_output(ur_w); + + // IC loop + Label icb_label; + mov(reg_icb, jcp.nb_ic); + L(icb_label); + if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) { + Label common_ker, end_ker; + + cmp(reg_icb, 1); // The last IC block + jne(common_ker, T_NEAR); + + kh_loop(ur_w, pad_l, pad_r, + is_last_sp_block ? last_sp_block : last_ic_block); + jmp(end_ker, T_NEAR); + + L(common_ker); + kh_loop(ur_w, pad_l, pad_r, no_last_block); + + L(end_ker); + } else { + kh_loop(ur_w, pad_l, pad_r, no_last_block); + } + // End of IC Loop + int inp_step = jcp.ic_block; + int ker_step = jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block; + add(reg_inp, jcp.typesize_in * inp_step); + add(reg_ker, jcp.typesize_in * ker_step); + + dec(reg_icb); + cmp(reg_icb, 0); + jg(icb_label, T_NEAR); + + sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic); + sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + Label common_store, end_store; + + if (jcp.is_depthwise) + cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking); + else + cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); + + jne(common_store, T_NEAR); + + store_output(ur_w, true); // last oc block + jmp(end_store, T_NEAR); + + L(common_store); + store_output(ur_w, false); + + L(end_store); + } else { + store_output(ur_w, false); + } +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::generate() +{ + Label permute_index_table; + int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad) + * jcp.ic_without_padding * jcp.ngroups; + int inp_shift_pad_second_block = -1 * jcp.typesize_in * jcp.l_pad + * jcp.ic_without_padding * jcp.ngroups; + int inp_shift = jcp.typesize_in * + (jcp.ur_w * jcp.stride_w * jcp.ic_without_padding + * jcp.ngroups); + int out_shift = jcp.typesize_out * + (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups); + preamble(); + + if (jcp.is_depthwise) { + int idx = jcp.max_regs_ur - 1; + if (!jcp.is_resrc_depthwise) + zmm_src = Zmm(++idx); + if (jcp.ver != ver_vnni) + zmm_tmp = Zmm(++idx); + if (jcp.is_fast_depthwise) + zmm_permute = Zmm(++idx); + if (jcp.signed_input) { + zmm_shifted_zero = Zmm(++idx); + ++idx; // due to extra register used for shifts and compensations + } + assert(idx == ker_dw_reg_base_idx); + } + + if (!jcp.is_depthwise && jcp.ver != ver_vnni) { + xor_(reg_scratch, reg_scratch); + Reg16 _t16 = reg_scratch.cvt16(); + mov(_t16, 0x1); + vpbroadcastw(vmm_one, _t16); + } + + mov(reg_inp, ptr[param1 + GET_OFF(src)]); + mov(reg_out, ptr[param1 + GET_OFF(dst)]); + mov(reg_ker, ptr[param1 + GET_OFF(filt)]); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + int tail_size = jcp.is_depthwise + ? jcp.ngroups % jcp.ch_block + : jcp.oc_without_padding % jcp.oc_block; + int mask = (1 << tail_size) - 1; + mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); + Reg32 regw_tmp = reg_oi.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + if (jcp.is_fast_depthwise) { + // prepare mask register for blending weights + mov(reg_scratch, 0x8888444422221111); + kmovq(kblend_mask, reg_scratch); + // load permute indices from data section + mov(reg_scratch, permute_index_table); + vmovdqu32(zmm_permute, ptr[reg_scratch]); + } + + int r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + int n_oi = jcp.ow / jcp.ur_w; + int r_pad1 = (jcp.ur_w * n_oi - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1); + + if (jcp.nb_ow == 1) { + if (r_pad1 > 0 || jcp.ur_w_tail == 0) + n_oi--; + + xor_(reg_oi, reg_oi); + if (jcp.ow == jcp.ur_w) { + icb_loop(jcp.ur_w, jcp.l_pad, r_pad, true); + } else { + if (n_oi == 0) { + icb_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_pad, true); + } + } else { + if (jcp.l_pad > 0) { + icb_loop(jcp.ur_w, jcp.l_pad, 0, false); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + + inc(reg_oi); + } + if ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1)) + { + Label ow_loop_label; + L(ow_loop_label); { + icb_loop(jcp.ur_w, 0, 0, false); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + + inc(reg_oi); + cmp(reg_oi, n_oi); + jl(ow_loop_label, T_NEAR); + } + } + if (r_pad1 > 0 || jcp.ur_w_tail == 0) { + icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + } + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_pad, true); + } + } + } + } else { + // ow block is only processed. + // Number of block is passed as parameter owb, + // and padding processing depends on this number. + Label end_label, last_oi_label, middle_ow_blocks_label, tail_label, + oi_loop_label, oi_loop_end_label; + + assert(jcp.ow_block % jcp.ur_w == 0); + int n_oi_not_last_ow_block = jcp.ow_block / jcp.ur_w; + // to simplify code (and general regs usage), + // size of ow block must be >= 2 * ur_w + assert(n_oi_not_last_ow_block > 1); + int n_oi_next_last_ow_block = n_oi_not_last_ow_block; + int n_oi_first_ow_block = n_oi_not_last_ow_block; + int n_oi_last_ow_block + = (jcp.ow - jcp.ow_block * (jcp.nb_ow - 1)) / jcp.ur_w; + // prepare right padding + bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0; + bool first_ow_block_padded + = next_last_ow_block_padded && jcp.nb_ow == 2; + bool last_ow_block_padded + = (r_pad1 > 0 || jcp.ur_w_tail == 0) && n_oi_last_ow_block > 0; + + if (last_ow_block_padded) n_oi_last_ow_block--; + else if (first_ow_block_padded) n_oi_first_ow_block--; + else if (next_last_ow_block_padded) n_oi_next_last_ow_block--; + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, 0); // is that the first ow-block ? + jg(middle_ow_blocks_label, T_NEAR); + + // the first ow block, compute left padding + mov(reg_oi, n_oi_first_ow_block); + if (jcp.l_pad > 0) { + icb_loop(jcp.ur_w, jcp.l_pad, 0, false); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + + dec(reg_oi); + } + jmp(oi_loop_label, T_NEAR); + + // middle or last ow block entry + L(middle_ow_blocks_label); + + if (jcp.l_pad > 0) { + // just to consider left padding, not compute + add(reg_inp, inp_shift_pad_second_block); + } + + // set number of iteration for oi-loop + if (n_oi_last_ow_block != n_oi_not_last_ow_block) { + cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ? + mov(reg_oi, n_oi_last_ow_block); + je(oi_loop_label, T_NEAR); + } + + if (n_oi_next_last_ow_block != n_oi_not_last_ow_block) { + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + + mov(reg_oi, n_oi_next_last_ow_block); + je(oi_loop_label, T_NEAR); + } + mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks + + // oi loop w/o padding + L(oi_loop_label); { + cmp(reg_oi, 0); + jle(oi_loop_end_label, T_NEAR); + + icb_loop(jcp.ur_w, 0, 0, false); + + add(reg_inp, inp_shift); + add(reg_out, out_shift); + dec(reg_oi); + + jmp(oi_loop_label, T_NEAR); + } + L(oi_loop_end_label); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, 0); // first ow-block ? + if (first_ow_block_padded) + je(last_oi_label, T_NEAR); + else + je(end_label, T_NEAR); + + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + jl(end_label, T_NEAR); + if (next_last_ow_block_padded) + je(last_oi_label, T_NEAR); + else + je(end_label, T_NEAR); + + // that is last block + if (!last_ow_block_padded) + jmp(tail_label, T_NEAR); + + // last oi block with right padding + L(last_oi_label); + icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? + jl(end_label, T_NEAR); + + // ur_w tail + L(tail_label); + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_pad, true); + } + L(end_label); + } + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); + + if (jcp.is_fast_depthwise) { + align(64); + L(permute_index_table); + const uint32_t _idx[] + = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; + for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i) + dd(_idx[i]); + } +} + +bool jit_avx512_core_x8s8s32x_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) +{ + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_eltwise(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_eltwise(1)) || + (p.contain(sum, 1) && is_eltwise(0)); + default: return false; + } + + return false; +} + +status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const primitive_attr_t &attr, + int nthreads) +{ + using namespace prop_kind; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + bool is_1d = ndims == 3; + + if (!(mayiuse(avx512_core) + && one_of(src_d.data_type(), data_type::u8, data_type::s8) + && weights_d.data_type() == data_type::s8 + && one_of(dst_d.data_type(), data_type::f32, data_type::s32, + data_type::s8, data_type::u8))) + return status::unimplemented; + + jcp = zero(); + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = jcp.ic; + jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + + jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false; + jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc); + + if (jcp.is_depthwise) { + jcp.ch_block = 16; + jcp.ic_block = 1; + jcp.oc_block = 1; + } else { + jcp.ch_block = 1; + jcp.ic_block = 16; + jcp.oc_block = 16; + + if (jcp.ngroups == 1) { + /* For non grouped convolutions, pad channels by 16 if needed */ + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } else if (!is_1d && jcp.ngroups != 1 && jcp.ic % jcp.ic_block != 0) { + /* For grouped convolutions, MKL-DNN doesn't support padding. + Use Ymm when channels per group is multiple of 8, + Xmm when channels per group is multiple of 4 */ + jcp.ic_block = jcp.ic % 8 == 0 ? 8 : 4; + jcp.oc_block = jcp.ic_block; + } + if (jcp.ic % jcp.ic_block !=0 || jcp.oc % jcp.oc_block != 0) + return status::unimplemented; + } + + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + jcp.ver = mayiuse(avx512_core_vnni) ? ver_vnni : ver_avx512_core; + jcp.is_fast_depthwise = true && jcp.is_depthwise && jcp.ver == ver_vnni + && jcp.ngroups % jcp.ch_block == 0; // for groups not multiple of 16 + // would require byte masking + // for load from src + jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw + && jcp.kw < 4 && jcp.dilate_w == 0; + if (jcp.is_depthwise) { + jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise + - 2 * jcp.signed_input - (jcp.ver != ver_vnni); + } else { + jcp.max_regs_ur = jcp.ver == ver_vnni ? 31 : 28; + } + + auto set_or_check_wei_format = [&]() { + using namespace format_tag; + format_tag_t wei_tag; + if (jcp.ic_block == 16 || jcp.ch_block == 16) { + if (is_1d) { + wei_tag = with_groups + ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i + : OIw4i16o4i; + } else { + wei_tag = with_groups + ? jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i + : OIhw4i16o4i; + } + } else if (with_groups && jcp.ic_block == 8) { + wei_tag = gOIhw2i8o4i; + } else + wei_tag = gOIhw4o4i; + + memory_desc_t want_wei_md = weights_md; + memory_desc_init_by_tag(want_wei_md, wei_tag); + if (jcp.signed_input) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md.format_kind == format_kind::any) { + weights_md = want_wei_md; + return true; + } + + return weights_md == want_wei_md; + }; + + if (!set_or_check_wei_format()) + return status::unimplemented; + + format_tag_t dat_tag = utils::pick(ndims - 3, + format_tag::nwc, format_tag::nhwc); + + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, dat_tag)); + jcp.src_tag = dat_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + } + if (jcp.src_tag != dat_tag) + return status::unimplemented; + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); + jcp.dst_tag = dat_tag; + } else { + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + } + if (jcp.dst_tag != dat_tag) + return status::unimplemented; + + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); + } + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + jcp.typesize_bia = jcp.with_bias + ? types::data_type_size(bias_d.data_type()) + : 0; + + jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); + jcp.nb_ic = jcp.ic / jcp.ic_block; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + // Try to use 4 channel-groups at a time to avoid false sharing (depthwise) + int nb_ch_blocking = 4; + for ( /* init above */ ; nb_ch_blocking > 1; nb_ch_blocking--) + if (jcp.nb_ch % nb_ch_blocking == 0) + break; + jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1; + + // If OC blocking is incommensurate with the number of OC blocks (general + // requirement for all convolutions), or if it results in an unrolling + // factor smaller than the left padding (special requirement for SSD:fc6), + // then search for a smaller OC blocking that satisfies both constraints. + auto is_oc_blocking_ok = [&](int block) { + int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1)); + return jcp.nb_oc % block == 0 + && jcp.l_pad <= ur_w && jcp.ow % ur_w != 1; + }; + + // choose nb_oc work chunk size for distribution within threads + int max_threading_nb_oc_chunk = 4; + // Performance improvements for googlenet_v3 and resnet_50 with mb = 1; + // TODO: generalize this condition and rewrite it in appropriate manner + if (jcp.ver == ver_vnni && jcp.mb == 1 && jcp.kh == 3 && jcp.kw == 3 + && jcp.stride_w == 1 && jcp.ic % 64 == 0) + max_threading_nb_oc_chunk = 2; + jcp.nb_oc_blocking_thr_chunk = + nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc); + for (; jcp.nb_oc_blocking_thr_chunk > 1; jcp.nb_oc_blocking_thr_chunk--) { + if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk)) + break; + } + + // choose oc blocking for computational kernel + jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk; + // Performance improvements for googlenet_v3 with mb = 1; + // TODO: generalize this condition and rewrite it in appropriate manner + const int size_treshold_for_nb_oc_blocking_reduction = 17; + if (jcp.mb == 1 && jcp.ow <= size_treshold_for_nb_oc_blocking_reduction + && jcp.stride_w == 1 + && !(jcp.kh == 1 && jcp.kw == 3) + && !(jcp.kh >= 7 && jcp.oc % 64 == 0)) { + const int max_nb_oc_blocking = 2; + jcp.nb_oc_blocking = nstl::min(max_nb_oc_blocking, jcp.nb_oc); + for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) + if (jcp.nb_oc_blocking_thr_chunk % jcp.nb_oc_blocking == 0 + && is_oc_blocking_ok(jcp.nb_oc_blocking)) + break; + } + + if (jcp.is_resrc_depthwise) + jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w) + / (jcp.nb_ch_blocking + jcp.stride_w); + else + jcp.ur_w + = jcp.max_regs_ur / (jcp.is_depthwise ? jcp.nb_ch_blocking + : jcp.nb_oc_blocking + 1); + if (jcp.ow < jcp.ur_w) + jcp.ur_w = jcp.ow; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + jcp.ow_block = jcp.ow; + int base_work_amount = jcp.mb * jcp.nb_ch * jcp.oh + * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk); + float best_thr_eff + = (float)base_work_amount / rnd_up(base_work_amount, nthreads); + int max_nb_ow = div_up(jcp.ow, 2 * jcp.ur_w); + for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) { + int ow_block + = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), jcp.ur_w), jcp.ow); + if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block + && best_thr_eff > 0.8f) + break; + if (div_up(jcp.ow, ow_block) != nb_ow) + continue; + auto work_amount = base_work_amount * nb_ow; + float thr_eff = (float)work_amount / rnd_up(work_amount, nthreads); + if (ow_block >= 2 * jcp.ur_w && thr_eff > 1.1f * best_thr_eff) { + jcp.ow_block = ow_block; + best_thr_eff = thr_eff; + } + if (best_thr_eff > 0.9f) + break; + } + jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); + + bool args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.l_pad <= jcp.ur_w + && IMPLICATION(!jcp.is_1stconv, jcp.ic % jcp.ic_block == 0); + if (!args_ok) + return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + if (r_pad_no_tail > jcp.ur_w) + return status::unimplemented; + + pick_loop_order(jcp, nthreads); + + jcp.nb_ic_L2 = jcp.nb_ic; + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + jcp.wei_adj_scale = + (weights_d.extra().flags | memory_extra_flags::scale_adjust) + ? weights_d.extra().scale_adjust : 1.f; + + return status::success; +} + +void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, + const primitive_attr_t &attr) { + if (jcp.signed_input && jcp.ver != ver_vnni) { + dim_t count = nstl::max(attr.output_scales_.count_, (dim_t)jcp.ic_block); + scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); + } +} + +template struct _jit_avx512_core_x8s8s32x_fwd_kernel; +template struct _jit_avx512_core_x8s8s32x_fwd_kernel; +template struct _jit_avx512_core_x8s8s32x_fwd_kernel; +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp new file mode 100644 index 000000000000..d8a05ad53eec --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp @@ -0,0 +1,239 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_conv_fwd_ker_t) + + enum { STATE_FIRST_DST_LOAD = 0x1U }; + + _jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) : jcp(ajcp), attr_(attr), + eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + + generate(); + jit_ker_ = (void (*)(jit_conv_call_s *))getCode(); + } + + ~_jit_avx512_core_x8s8s32x_fwd_kernel() { + delete eltwise_injector_; + } + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker_)(jit_conv_call_s *); + +private: + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + enum { + typesize = sizeof(float), + ker_reg_base_idx = 28, + ker_dw_reg_base_idx = 30, + }; + typedef enum { + no_last_block, + last_ic_block, + last_sp_block, + } ic_block_t; + + /* data regs */ + const Xbyak::Reg64 reg_ptr_scales = rax; + const Xbyak::Reg64 reg_inp = r8; + const Xbyak::Reg64 reg_ker = r9; + const Xbyak::Reg64 reg_out = r10; + const Xbyak::Reg64 aux_reg_inp = r11; + const Xbyak::Reg64 reg_ptr_sum_scale = r11; + const Xbyak::Reg64 aux_reg_ker = r12; + const Xbyak::Reg64 reg_compensation = r14; + /* counter regs */ + const Xbyak::Reg64 reg_bias_alpha = abi_not_param1; + const Xbyak::Reg64 reg_oi = rbx; + const Xbyak::Reg64 reg_bias = rdx; + const Xbyak::Reg64 reg_oc_blocks = rsi; + const Xbyak::Reg64 reg_owb = aux_reg_ker; + const Xbyak::Reg64 reg_scratch = reg_compensation; + const Xbyak::Reg64 reg_kj = reg_ptr_scales; + const Xbyak::Reg64 reg_overflow = reg_ptr_scales; + const Xbyak::Reg64 reg_icb = reg_bias; + + const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); + const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3); + + const Vmm vmm_wei = Vmm(31); + /* used during bias section of store_output */ + const Vmm vmm_comp = Vmm(30); // only for signed input + const Vmm vmm_bias = Vmm(31); + /* used during post_op sum section of store_output */ + const Vmm vmm_prev_dst = Vmm(31); + /* used during write-out section of store_output */ + const Vmm vmm_zero = Vmm(31); + + /* used in compute_ker (but set during prepare_output) */ + const Vmm vmm_shift = vmm_comp; // only for signed input + /* used in compute_ker (but only for pre-VNNI machines) */ + const Vmm vmm_tmp = Vmm(28); // not used for depthwise + const Vmm vmm_one = Vmm(29); // set at start of kernel, not used for depthwise. + + /* registers use only for depthwise + groups are always blocked by 16(padded if needed), + hence use only Zmm registers */ + const Xbyak::Zmm zmm_wei = Xbyak::Zmm(31); + Xbyak::Zmm zmm_tmp; + Xbyak::Zmm zmm_src; + Xbyak::Zmm zmm_shifted_zero; + Xbyak::Zmm zmm_permute; + + Vmm vmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < (jcp.is_depthwise + ? ker_dw_reg_base_idx : ker_reg_base_idx)); + return Vmm(idx); + } + Xbyak::Zmm zmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < (jcp.is_depthwise + ? ker_dw_reg_base_idx : ker_reg_base_idx)); + return Xbyak::Zmm(idx); + } + Vmm vmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Vmm(idx); + } + Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Xbyak::Zmm(idx); + } + Vmm vmm_bias_alpha() { + int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + return Vmm(nb_c_block * jcp.ur_w); + } + Xbyak::Xmm xmm_bias_alpha() { + int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + return Xbyak::Xmm(nb_c_block * jcp.ur_w); + } + int get_ow_start(int ki, int pad_l) { + return nstl::max(0, + utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); + } + int get_ow_end(int ur_w, int ki, int pad_r) { + return ur_w - nstl::max(0, utils::div_up(pad_r + - (jcp.kw - 1 - ki) + * (jcp.dilate_w + 1), + jcp.stride_w)); + } + + bool maybe_eltwise(int position); + void prepare_output(int ur_w); + void store_output(int ur_w, bool last_oc_block_flag); + void compute_ker_dw( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded); + void compute_ker(int ur_w, int pad_l, int pad_r, + ic_block_t last_ic_block_flag, bool h_padded = false); + void compute_eltwise(int ur_w); + void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag); + void icb_loop( + int ur_w, int pad_l, int pad_r, bool is_last_spatial_block); + void generate(); + void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op, + bool mask_flag); + const Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store = false); +}; + +struct jit_avx512_core_x8s8s32x_fwd_kernel { + + jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) : + jit_ker(nullptr), + zmm_kernel_(nullptr), + ymm_kernel_(nullptr), + xmm_kernel_(nullptr) { + int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block; + switch (ch_block) { + case 16: + zmm_kernel_ = + new _jit_avx512_core_x8s8s32x_fwd_kernel( + ajcp, attr); + jit_ker = zmm_kernel_->jit_ker_; + return; + case 8: + ymm_kernel_ = + new _jit_avx512_core_x8s8s32x_fwd_kernel( + ajcp, attr); + jit_ker = ymm_kernel_->jit_ker_; + return; + case 4: + xmm_kernel_ = + new _jit_avx512_core_x8s8s32x_fwd_kernel( + ajcp, attr); + jit_ker = xmm_kernel_->jit_ker_; + return; + default: + assert(!"invalid channel blocking"); + } + } + + ~jit_avx512_core_x8s8s32x_fwd_kernel() { + delete xmm_kernel_; + delete ymm_kernel_; + delete zmm_kernel_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + memory_desc_t &src_pd, + memory_desc_t &weights_pd, + memory_desc_t &dst_pd, + memory_desc_t &bias_pd, + const primitive_attr_t &attr, + int nthreads); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + void (*jit_ker)(jit_conv_call_s *); + _jit_avx512_core_x8s8s32x_fwd_kernel *zmm_kernel_; + _jit_avx512_core_x8s8s32x_fwd_kernel *ymm_kernel_; + _jit_avx512_core_x8s8s32x_fwd_kernel *xmm_kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp new file mode 100644 index 000000000000..cdbf333d5eb0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp @@ -0,0 +1,423 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_x8s8s32x_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace nstl; + +using jit_conv_ker_t = void (*)(jit_conv_call_s *); + +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() \ + ? (d).blk_off((g), __VA_ARGS__) \ + : (d).blk_off(__VA_ARGS__)) + +template +void jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = scratchpad(ctx).template get( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast(&w[offset]) : 0; + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; + int group_block = jcp.ch_block; + int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow; + + parallel(0, [&](const int ithr, const int nthr) { + + int start{ 0 }, end{ 0 }; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_conv_call_s(); + + int n{ 0 }, gg{ 0 }, occ{ 0 }, owb{ 0 }; + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, + nb_groups, n, jcp.mb); + break; + case loop_gncw: + nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks, + owb, jcp.nb_ow); + break; + case loop_ngcw: + nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, oc_chunks, + owb, jcp.nb_ow); + break; + case loop_nwcg: + nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, + gg, nb_groups); + break; + default: assert(!"unsupported loop order"); + } + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int gb = gg * jcp.nb_ch_blocking; + int g = gb * group_block; + int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; + int g_ic = g * jcp.nb_ic * jcp.ic_block; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : 0; + p.compensation = (jcp.signed_input) ? compensation + g_oc : 0; + p.dst = dst + dst_d.blk_off(n, g_oc, ow_s); + p.src = src + src_d.blk_off(n, g_ic, iw_s); + p.filt = weights + wht_blk_off(weights_d, gb, ocb, 0); + p.scales = &oscales[jcp.is_oc_scale * g_oc]; + p.oc_blocks = jcp.is_depthwise ? gb : ocb; + p.kh_padding = jcp.kh; + p.t_overflow = 0; + p.b_overflow = 0; + p.owb = owb; + + kernel_->jit_ker(&p); + + ++start; + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg, nb_groups, + n, jcp.mb); + break; + case loop_gncw: + nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks, owb, + jcp.nb_ow); + break; + case loop_ngcw: + nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks, owb, + jcp.nb_ow); + break; + case loop_nwcg: + nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, gg, + nb_groups); + break; + default: assert(!"unsupported loop order"); + } + } + }); +} + +template +void jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = pd()->jcp_; + assert(jcp.ch_block == 1); + assert(jcp.nb_ch_blocking == 1); + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = scratchpad(ctx).template get( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast(&w[offset]) : 0; + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk; + int nb_groups = jcp.nb_ch; + int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; + + parallel(0, [&](const int ithr, const int nthr) { + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_conv_call_s(); + + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 }; + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g, + nb_groups, n, jcp.mb, oh_s, jcp.oh); + break; + case loop_ngcw: + nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, + owb, jcp.nb_ow, oh_s, jcp.oh); + break; + case loop_nhwcg: + nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, + occ, oc_chunks, g, nb_groups); + break; + default: assert(!"unsupported loop order"); + } + while (start < end) { + for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk; + occ1 += jcp.nb_oc_blocking) { + int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1; + int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; + + int g_ic = g * jcp.nb_ic * jcp.ic_block; + + int work_rem = end - start; + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + if (jcp.loop_order == loop_nhwcg) + oh_e = oh_s + 1; // step instead + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + auto bias_w = bias + ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) + : 0; + int32_t *compensation_w = (jcp.signed_input) + ? compensation + g_oc : 0; + + auto dst_w = dst + dst_d.blk_off(n, g_oc, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); + + auto scales = &oscales[jcp.is_oc_scale * g_oc]; + + for (int oj = oh_s, ij = ih_s; oj < oh_e; + ++oj, ij += jcp.stride_h) { + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = nstl::min(jcp.kh, + div_up(max(0, -ij), dilate_h)); + int i_b_overflow = nstl::min(jcp.kh, div_up( + max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1), + dilate_h)); + int kh_padding = nstl::max(0, + jcp.kh - i_t_overflow - i_b_overflow); + + size_t wei_stride = (!jcp.signed_input) + ? i_t_overflow * wht_h_stride : 0; + p.src = src_w + i_t_overflow * dilate_h * src_h_stride; + p.dst = dst_w; + p.filt = wht_w + wei_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.oc_blocks = ocb; + p.kh_padding = kh_padding; + p.scales = scales; + p.t_overflow = i_t_overflow; + p.b_overflow = i_b_overflow; + p.owb = owb; + + kernel_->jit_ker(&p); + src_w += src_h_stride * jcp.stride_h; + dst_w += dst_h_stride; + } + } + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, g, + nb_groups, n, jcp.mb, oh_s, jcp.oh); + break; + case loop_ngcw: + nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, + oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); + break; + case loop_nhwcg: + ++start; + nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ, + oc_chunks, g, nb_groups); + break; + default: assert(!"unsupported loop order"); + } + } + }); +} + +template +void jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = pd()->jcp_; + assert(jcp.ic_block == 1); + assert(jcp.oc_block == 1); + assert(jcp.nb_ic == 1); + assert(jcp.nb_oc == 1); + assert(jcp.nb_oc_blocking == 1); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = scratchpad(ctx).template get( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast(&w[offset]) : 0; + int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; + int group_block = jcp.ch_block; + + parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups, + [&](int n, int oh_s, int owb, int gg) { + + auto p = jit_conv_call_s(); + + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + int gb = gg * jcp.nb_ch_blocking; + int g = gb * group_block; + + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : 0; + int32_t *compensation_w = jcp.signed_input ? compensation + g : 0; + + auto dst_w = dst + dst_d.blk_off(n, g, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g, ih_s, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, gb, 0); + + auto scales = &oscales[jcp.is_oc_scale * g]; + + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h)); + int i_b_overflow = nstl::min(jcp.kh, + div_up(max(0, ih_s - jcp.ih + (jcp.kh - 1) * dilate_h + 1), + dilate_h)); + int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow); + + size_t wei_stride = jcp.signed_input ? 0 : i_t_overflow * wht_h_stride; + p.src = src_w + i_t_overflow * dilate_h * src_h_stride; + p.dst = dst_w; + p.filt = wht_w + wei_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.oc_blocks = gb; + p.kh_padding = kh_padding; + p.scales = scales; + p.t_overflow = i_t_overflow; + p.b_overflow = i_b_overflow; + p.owb = owb; + + kernel_->jit_ker(&p); + }); +} + +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::u8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::u8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::s8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::s8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::s32>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::s32>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::f32>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::f32>; +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp new file mode 100644 index 000000000000..203ebdf94207 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_int8:", avx512_core, ""), + jit_avx512_core_x8s8s32x_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, data_type::s8, data_type::undef, + dst_type, data_type::s32) + && IMPLICATION(with_bias(), utils::one_of(bias_md_.data_type, + data_type::f32, data_type::s32, data_type::s8, + data_type::u8)) + && !has_zero_dim_memory(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_x8s8s32x_fwd_kernel::init_conf( + jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, + *attr(), mkldnn_get_max_threads()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(scratchpad, + jcp_, *attr()); + + return status; + } + + jit_conv_conf_t jcp_; + }; + + jit_avx512_core_x8s8s32x_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { + kernel_ = new jit_avx512_core_x8s8s32x_fwd_kernel(pd()->jcp_, + *pd()->attr()); + } + + ~jit_avx512_core_x8s8s32x_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override + { + const auto &_pd = pd(); + if (_pd->ndims() == 3) + execute_forward_1d(ctx); + else if (_pd->jcp_.is_depthwise) + execute_forward_2d_dw(ctx); + else + execute_forward_2d(ctx); + return status::success; + } + +private: + void execute_forward_1d(const exec_ctx_t &ctx) const; + void execute_forward_2d(const exec_ctx_t &ctx) const; + void execute_forward_2d_dw(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_x8s8s32x_fwd_kernel *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp new file mode 100644 index 000000000000..142af1f54128 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp @@ -0,0 +1,1034 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_avx512_core_x8s8s32x_deconvolution.hpp" + +#define GET_OFF(field) offsetof(jit_deconv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +using namespace nstl; + +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) : \ + (d).blk_off(__VA_ARGS__)) + +status_t jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( + jit_conv_conf_t &jcp, const deconvolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, const bool with_bias, + memory_desc_t &bias_md, const primitive_attr_t &attr) { + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper bias_d(&bias_md); + + if (!(mayiuse(avx512_core) + && one_of(src_d.data_type(), data_type::u8, data_type::s8) + && weights_d.data_type() == data_type::s8 + && one_of(dst_d.data_type(), data_type::f32, data_type::s32, + data_type::s8, data_type::u8))) + return status::unimplemented; + + jcp = zero(); + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + jcp.signed_input = src_d.data_type() == data_type::s8; + const int ndims = jcp.ndims = dst_d.ndims(); + const bool is_1d = ndims == 3; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups; + jcp.is_depthwise = true && with_groups + && utils::everyone_is(1, jcp.ic_without_padding, + jcp.oc_without_padding); + + /* TODO: future work, on hold until depthwise specialized kernel is + * implemented. */ + if (jcp.is_depthwise && jcp.signed_input) + return status::unimplemented; + + format_tag_t dat_tag = utils::pick(ndims - 3, + format_tag::nwc, format_tag::nhwc); + + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, dat_tag)); + jcp.src_tag = dat_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + } + if (jcp.src_tag != dat_tag) + return status::unimplemented; + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); + jcp.dst_tag = dat_tag; + } else { + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + } + if (jcp.dst_tag != dat_tag) + return status::unimplemented; + + auto set_or_check_wei_format = [&]() { + using namespace format_tag; + + format_tag_t wei_tag = is_1d + ? (jcp.is_depthwise + ? Goiw16g : (with_groups ? gOIw4i16o4i : OIw4i16o4i)) + : (jcp.is_depthwise + ? Goihw16g : (with_groups ? gOIhw4i16o4i : OIhw4i16o4i)); + + memory_desc_t want_wei_md = weights_md; + memory_desc_init_by_tag(want_wei_md, wei_tag); + if (jcp.signed_input && !jcp.is_depthwise) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md.format_kind == format_kind::any) { + weights_md = want_wei_md; + return true; + } + + return weights_md == want_wei_md; + }; + + if (!set_or_check_wei_format()) + return status::unimplemented; + + jcp.with_bias = with_bias; + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); + } + + jcp.prop_kind = cd.prop_kind; + jcp.mb = src_d.dims()[0]; + jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + + if (jcp.is_depthwise) { + jcp.ch_block = 16; + jcp.oc_block = 1; + jcp.ic_block = 1; + } else { + jcp.ch_block = 1; + jcp.oc_block = 16; + jcp.ic_block = 16; + + if (jcp.ngroups == 1) { + jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block); + jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block); + } + if (jcp.ic % jcp.ic_block != 0) + return status::unimplemented; + } + + jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + if (!IMPLICATION(jcp.dilate_h, jcp.stride_h == 1) + || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1)) + return status::unimplemented; + + /* padding: bottom and right */ + jcp.b_pad = (jcp.ih - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.oh + jcp.t_pad - 1); + jcp.r_pad = (jcp.iw - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.ow + jcp.l_pad - 1); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + jcp.ver = ver_avx512_core; + if (mayiuse(avx512_core_vnni)) + jcp.ver = ver_vnni; + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + jcp.dst_dt = dst_d.data_type(); + jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef; + jcp.typesize_bia + = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + + jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + /* kernel blocking params */ + const int regs = jcp.ver == ver_vnni ? 30 : 28; + jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc); + for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) + if (jcp.nb_oc % jcp.nb_oc_blocking == 0 + && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1)) + break; + + jcp.ur_w = regs / (jcp.nb_oc_blocking + 1); + int l_overflow = max( + 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w); + + if (jcp.ow < jcp.ur_w) { + jcp.ur_w = jcp.ow; + jcp.ur_w_tail = 0; + } else { + for (; jcp.ur_w >= 1; jcp.ur_w--) { + /* ur_w should be multiple of stride_w in order + to simplify logic for get_ow_start and get_ow_end */ + bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0; + + /* boundary conditions: + These conditions ensure all elements close to boundary + are computed in a single call of compute loop */ + bool left_boundary_covered = jcp.ur_w >= l_overflow * jcp.stride_w; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + int r_overflow_no_tail + = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - max(0, jcp.r_pad) - jcp.ur_w_tail) + / jcp.stride_w); + bool right_boundary_covered + = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w; + + if (is_multiple_of_stride && left_boundary_covered + && right_boundary_covered) + break; + else if (jcp.ur_w == 1) + /* The boundary conditions above are also important + to maintain simplicity of calls to icb_loop, + if those conditions are not satisfied, + then special cases will need to be added + to use correct l_overflow/r_overflow values + when different iterations of compute loop + work on the locations close to boundary. + So to keep code simple, return unimplemented + for extreme case when a good ur_w cannot be found. + */ + return status::unimplemented; + } + } + + jcp.wei_adj_scale = + (weights_d.extra().flags | memory_extra_flags::scale_adjust) + ? weights_d.extra().scale_adjust : 1.f; + + jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn; + return status::success; +} + +bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::maybe_eltwise(int position) { + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* eltwise before sum */ + return p.contain(eltwise, 0); + } else if (position == 1) { + /* eltwise after sum */ + return p.contain(sum, 0) && p.contain(eltwise, 1); + } + return false; +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_eltwise(int ur_w) { + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + eltwise_injector_->compute_vector_range(0, nb_oc_block * ur_w); +} + +bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_eltwise(0) || p.contain(sum, 0); + case 2: + return (p.contain(sum, 0) && is_eltwise(1)) + || (p.contain(sum, 1) && is_eltwise(0)); + default: return false; + } + + return false; +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, + const primitive_attr_t &attr) { + if (jcp.signed_input && jcp.ver != ver_vnni) { + dim_t count = nstl::max(attr.output_scales_.count_, 16); + scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w, + int l_overflow, int r_overflow, ker_block_t last_ic_block_flag, + bool h_padded) { + + const int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; + const int ur_w_stride = jcp.signed_input ? 1 : jcp.stride_w; + + auto src_offset = [=](int oj, int icb, int ki) { + return jcp.typesize_in + * (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w) + * jcp.ngroups * jcp.ic_without_padding + + icb * 4); + }; + + auto kernel_offset = [=](int ocb, int icb, int ki) { + return jcp.typesize_in + * (ocb * jcp.nb_ic * jcp.kh * jcp.kw * ch_block_all + + icb * jcp.oc_block * jcp.ic_block / 4 + + ki * ch_block_all); + }; + + auto compute = [=](zmm_t vreg_acc, zmm_t vreg_wei, zmm_t vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else if (jcp.is_depthwise) { + vpmulld(zmm_tmp, vreg_src, vreg_wei); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } else { + vpmaddubsw(zmm_tmp, vreg_src, vreg_wei); + vpmaddwd(zmm_tmp, zmm_tmp, zmm_one); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } + }; + + for (int ki = 0; ki < jcp.kw; ki++) { + + int jj_start = get_ow_start(ki, l_overflow); + int jj_end = get_ow_end(ur_w, ki, r_overflow); + + int _start = (jcp.signed_input) ? 0 : jj_start; + int _end = (jcp.signed_input) ? ur_w : jj_end; + + int tail_size = jcp.ic_without_padding % 4; + int n_ic_blocks = jcp.is_depthwise ? + 1 : + (last_ic_block_flag & ~no_last_block ? + div_up(jcp.ic_without_padding % jcp.ic_block, + 4) : + jcp.ic_block / 4); + + for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) { + if (h_padded == true) { + /* fill padded area with shifted values */ + Zmm inp = zmm_inp(0, jcp.nb_oc_blocking); + vpxord(inp, inp, inp); + vpsubb(inp, inp, zmm_shift); + } else { + + for (int jj = _start; jj < _end; jj += ur_w_stride) { + + int aux_src_off = src_offset(jj, icb1, ki); + + if (jj >= jj_start && jj < jj_end + && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) { + if (jcp.is_depthwise) { + vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking), + EVEX_compress_addr( + aux_reg_src, aux_src_off)); + } else if ((last_ic_block_flag & last_sp_block) + && tail_size != 0 && icb1 == n_ic_blocks - 1) { + xmm_t xmm_tmp = xmm_t( + zmm_inp(jj, jcp.nb_oc_blocking).getIdx()); + for (int r = 0; r < tail_size; ++r) + vpinsrb(xmm_tmp, xmm_tmp, + ptr[aux_reg_src + aux_src_off + r], r); + vpbroadcastd( + zmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp); + } else { + vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking), + EVEX_compress_addr( + aux_reg_src, aux_src_off)); + } + if (jcp.signed_input) + vpsubb(zmm_inp(jj, jcp.nb_oc_blocking), + zmm_inp(jj, jcp.nb_oc_blocking), zmm_shift); + } else { + /* fill padded area with shifted values */ + if (jcp.signed_input) { + Zmm inp = zmm_inp(jj, jcp.nb_oc_blocking); + vpxord(inp, inp, inp); + vpsubb(inp, inp, zmm_shift); + } + } + } + } + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + int aux_filt_off = kernel_offset(ocb, icb1, ki); + + if (_end - _start > 0) { + if (jcp.is_depthwise) + vpmovsxbd(zmm_wei, + EVEX_compress_addr(aux_reg_filt, aux_filt_off)); + else + vmovups(zmm_wei, + EVEX_compress_addr(aux_reg_filt, aux_filt_off)); + } + for (int jj = _start; jj < _end; jj += ur_w_stride) { + Zmm inp = (h_padded == true) ? + zmm_inp(0, jcp.nb_oc_blocking) : + zmm_inp(jj, jcp.nb_oc_blocking); + compute(zmm_out(jj, ocb), zmm_wei, inp); + } + } + } + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::kh_loop(int ur_w, + int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) { + + int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; + int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw + * jcp.ngroups * jcp.ic_without_padding; + const int stride_h = jcp.signed_input ? 1 : jcp.stride_h; + int shift_filt_kh = jcp.typesize_in * jcp.kw * ch_block_all * stride_h; + + Label kh_loop_label, skip_kh_loop; + Label t_overflow_label, no_t_overflow_label, b_overflow_label, + no_b_overflow_label; + + mov(aux_reg_src, reg_src); + mov(aux_reg_filt, reg_filt); + + if (jcp.signed_input && jcp.ndims > 3) { + /* Weights are transposed, so first compute 'bottom' padding. */ + mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); + cmp(reg_overflow, 0); + je(no_b_overflow_label, T_NEAR); + L(b_overflow_label); { + compute_ker(ur_w, 0, 0, last_ic_block_flag, true); + + add(aux_reg_filt, shift_filt_kh); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(b_overflow_label, T_NEAR); + } + L(no_b_overflow_label); + } + + mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); + + if (jcp.signed_input || ((!jcp.signed_input) + && ((min(jcp.t_pad, jcp.b_pad) < 0) + || ((jcp.kh - 1) * (jcp.dilate_h + 1) + < nstl::max(jcp.t_pad, jcp.b_pad))))) { + cmp(reg_kh, 0); + je(skip_kh_loop, T_NEAR); + } + + L(kh_loop_label); { + compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false); + sub(aux_reg_src, shift_src_ih); + add(aux_reg_filt, shift_filt_kh); + dec(reg_kh); + + /* Insert weight compensation in stride 'holes' */ + if (jcp.signed_input && jcp.stride_h > 1) { + Label kh_comp_loop; + + cmp(reg_kh, 0); + je(skip_kh_loop, T_NEAR); + mov(reg_comp_strides, jcp.stride_h - 1); + L(kh_comp_loop); + { + compute_ker( + ur_w, 0, 0, last_ic_block_flag, true); + add(aux_reg_filt, shift_filt_kh); + dec(reg_comp_strides); + cmp(reg_comp_strides, 0); + jg(kh_comp_loop, T_NEAR); + } + } + cmp(reg_kh, 0); + jg(kh_loop_label, T_NEAR); + } + L(skip_kh_loop); + if (jcp.signed_input && jcp.ndims > 3) { + mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); + cmp(reg_overflow, 0); + je(no_t_overflow_label, T_NEAR); + L(t_overflow_label); { + compute_ker(ur_w, 0, 0, last_ic_block_flag, true); + + add(aux_reg_filt, shift_filt_kh); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(t_overflow_label, T_NEAR); + } + L(no_t_overflow_label); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::prepare_output(int ur_w) { + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + for (int ur = 0; ur < ur_w; ur++) { + zmm_t zmm = zmm_out(ur, ocb); + vpxord(zmm, zmm, zmm); + } + } + if (jcp.signed_input) { + xor_(reg_scratch, reg_scratch); + Reg8 _t8 = reg_scratch.cvt8(); + mov(_t8, (int8_t)-128); + vpbroadcastb(zmm_shift, _t8); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::cvt2ps( + data_type_t type_in, zmm_t zmm_in, const Operand &op, bool mask_flag) { + zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in; + switch (type_in) { + case data_type::f32: + case data_type::s32: vmovups(zmm, op); break; + case data_type::s8: vpmovsxbd(zmm, op); break; + case data_type::u8: vpmovzxbd(zmm, op); break; + default: assert(!"unsupported data type"); + } + if (type_in != data_type::f32) + vcvtdq2ps(zmm_in, zmm_in); +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output( + int ur_w, bool last_oc_block) { + mov(reg_bias, ptr[param1 + GET_OFF(bias)]); + mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); + + if (jcp.signed_input) + mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); + + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale + = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr; + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + if (jcp.with_bias && jcp.signed_input && jcp.ver != ver_vnni) { + mov(reg_bias_alpha, float2int(jcp.wei_adj_scale)); + vmovq(xmm_bias_alpha(), reg_bias_alpha); + vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha()); + } + + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1; + int scale_offset + = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block); + + auto zmm_bias = zmm_tmp; + if (jcp.with_bias) { + int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block; + auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); + cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag); + if (jcp.signed_input && jcp.ver != ver_vnni) + vmulps(zmm_bias, zmm_bias, zmm_bias_alpha()); + } + if (jcp.signed_input) { + int comp_offset = sizeof(int32_t) * ocb * jcp.oc_block; + auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset); + cvt2ps(data_type::s32, zmm_comp, comp_addr, mask_flag); + } + + for (int ur = 0; ur < ur_w; ur++) { + zmm_t zmm = zmm_out(ur, ocb); + vcvtdq2ps(zmm, zmm); + if (jcp.signed_input) + vaddps(zmm, zmm, zmm_comp); + if (jcp.with_bias) + vaddps(zmm, zmm, zmm_bias); + zmm_t mask_zmm = mask_flag ? zmm | ktail_mask | T_z : zmm; + vmulps(mask_zmm, zmm, + EVEX_compress_addr(reg_ptr_scales, scale_offset)); + } + } + if (maybe_eltwise(0)) + compute_eltwise(ur_w); + if (p_sum_scale) { // post_op: sum + for (int k = 0; k < jcp.nb_oc_blocking; k++) { + const bool mask_flag + = last_oc_block == 1 && k == jcp.nb_oc_blocking - 1; + for (int j = 0; j < ur_w; j++) { + int aux_output_offset + = jcp.typesize_out + * (k * jcp.oc_block + + j * jcp.oc_without_padding * jcp.ngroups); + auto addr = EVEX_compress_addr(reg_dst, aux_output_offset); + Zmm zmm = zmm_out(j, k); + cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag); + if (*p_sum_scale == 1.f) + vaddps(zmm, zmm_prev_dst); + else + vfmadd231ps(zmm, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); + } + } + } + if (maybe_eltwise(1)) + compute_eltwise(ur_w); + + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1; + for (int ur = 0; ur < ur_w; ur++) { + zmm_t zmm = zmm_out(ur, ocb); + if (jcp.dst_dt == data_type::u8) { + vpxord(zmm_zero, zmm_zero, zmm_zero); + vmaxps(zmm, zmm_zero, zmm); + } + if (jcp.dst_dt != data_type::f32) + vcvtps2dq(zmm, zmm); + } + for (int ur = 0; ur < ur_w; ur++) { + int aux_dst_off = jcp.typesize_out + * (ur * jcp.ngroups * jcp.oc_without_padding + + ocb * jcp.oc_block); + auto addr = EVEX_compress_addr(reg_dst, aux_dst_off); + + zmm_t zmm = zmm_out(ur, ocb); + zmm_t r_zmm = mask_flag ? zmm | ktail_mask : zmm; + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: vmovups(addr, r_zmm); break; + case data_type::s8: vpmovsdb(addr, r_zmm); break; + case data_type::u8: vpmovusdb(addr, r_zmm); break; + default: assert(!"unknown dst_dt"); + } + } + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::icb_loop( + int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) { + + int shift_src_icb = jcp.typesize_in * jcp.ic_block; + int shift_filt_icb + = jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block; + + prepare_output(ur_w); + + Label skip_icb_loop, icb_loop_label; + + mov(reg_icb, jcp.nb_ic); + L(icb_loop_label); { + + if (jcp.ic_without_padding != jcp.ic) { + Label common_ker, end_ker; + cmp(reg_icb, 1); + jg(common_ker, T_NEAR); + + kh_loop(ur_w, l_overflow, r_overflow, + is_last_sp_block ? last_sp_block : last_ic_block); + jmp(end_ker, T_NEAR); + + L(common_ker); + kh_loop(ur_w, l_overflow, r_overflow, no_last_block); + + L(end_ker); + } else { + kh_loop(ur_w, l_overflow, r_overflow, no_last_block); + } + + add(reg_src, shift_src_icb); + add(reg_filt, shift_filt_icb); + dec(reg_icb); + cmp(reg_icb, 0); + jg(icb_loop_label, T_NEAR); + } + + /* come-back pointers */ + sub(reg_src, jcp.nb_ic * shift_src_icb); + sub(reg_filt, jcp.nb_ic * shift_filt_icb); + L(skip_icb_loop); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + Label common_store, end_store; + mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); + if (jcp.is_depthwise) + cmp(reg_oc_blocks, jcp.nb_ch - 1); + else + cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); + jne(common_store, T_NEAR); + + store_output(ur_w, true); + jmp(end_store, T_NEAR); + + L(common_store); + store_output(ur_w, false); + + L(end_store); + + } else { + store_output(ur_w, false); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() { + preamble(); + + xor_(reg_scratch, reg_scratch); + Reg16 _t = reg_scratch.cvt16(); + mov(_t, 0x1); + vpbroadcastw(zmm_one, _t); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + int tail_size = jcp.is_depthwise ? + jcp.ngroups % jcp.ch_block : + jcp.oc_without_padding % jcp.oc_block; + int mask = (1 << tail_size) - 1; + Reg32 regw_tmp = reg_nur_w.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + + mov(reg_src, ptr[param1 + GET_OFF(src)]); + mov(reg_filt, ptr[param1 + GET_OFF(filt)]); + mov(reg_dst, ptr[param1 + GET_OFF(dst)]); + + int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups + * jcp.oc_without_padding; + int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups + * jcp.ic_without_padding; + + int l_overflow = max( + 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w); + int r_overflow + = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad)) + / jcp.stride_w); + + int r_overflow1 + = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) + / jcp.stride_w); + int nur_w = jcp.ow / jcp.ur_w; + if (r_overflow1 > 0) + nur_w--; + + if (jcp.ur_w == jcp.ow) { + icb_loop(jcp.ur_w, l_overflow, r_overflow, true); + } else if (nur_w == 0) { + icb_loop(jcp.ur_w, l_overflow, r_overflow1, jcp.ur_w_tail == 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + if (jcp.ur_w_tail != 0) + icb_loop(jcp.ur_w_tail, 0, r_overflow, true); + } else { + xor_(reg_nur_w, reg_nur_w); + if (l_overflow > 0) { + icb_loop(jcp.ur_w, l_overflow, 0, false); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + inc(reg_nur_w); + } + if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) { + Label ow_loop_label; + L(ow_loop_label); + { + icb_loop(jcp.ur_w, 0, 0, false); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + inc(reg_nur_w); + cmp(reg_nur_w, nur_w); + jl(ow_loop_label, T_NEAR); + } + } + if (r_overflow1 > 0) { + icb_loop(jcp.ur_w, 0, r_overflow1, jcp.ur_w_tail == 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + } + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_overflow, true); + } + } + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +template +void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_1d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + auto &jcp = kernel_->jcp; + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int nb_groups = jcp.nb_ch; + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales + = scratchpad(ctx).template get(key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; + auto w = const_cast(weights); + int32_t *compensation + = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : 0; + + parallel(0, [&](const int ithr, const int nthr) { + int start{ 0 }, end{ 0 }; + int work_amount = jcp.mb * nb_groups * oc_chunks; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_deconv_call_s(); + + int n{ 0 }, g{ 0 }, occ{ 0 }; + if (jcp.loop_order == loop_ngc) + nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks); + else if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb); + else + assert(!"unsupported loop order"); + while (start < end) { + + int ocb = occ * jcp.nb_oc_blocking; + int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; + int g_ic = g * jcp.ch_block * jcp.ic; + + p.dst = dst + dst_d.blk_off(n, g_oc); + p.src = src + src_d.blk_off(n, g_ic); + p.filt = weights + wht_blk_off(weights_d, g, ocb, 0); + p.bias = jcp.with_bias ? + bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) : + 0; + p.compensation = (jcp.signed_input) ? compensation + g_oc : 0; + p.scales = &oscales[jcp.is_oc_scale * g_oc]; + p.t_overflow = 0; + p.b_overflow = 0; + p.kh_padding = jcp.kh; + p.oc_blocks = jcp.is_depthwise ? g : ocb; + + kernel_->jit_ker(&p); + + ++start; + if (jcp.loop_order == loop_ngc) + nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks); + else if (jcp.loop_order == loop_cgn) + nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb); + else + assert(!"unsupported loop order"); + } + }); +} + +template +void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_2d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + auto &jcp = kernel_->jcp; + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int nb_groups = jcp.nb_ch; + + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 1); + size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales + = scratchpad(ctx).template get(key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; + auto w = const_cast(weights); + int32_t *compensation + = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : 0; + + parallel(0, [&](const int ithr, const int nthr) { + int start{ 0 }, end{ 0 }; + int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_deconv_call_s(); + + /*loop order = cgn*/ + int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }; + if (jcp.loop_order == loop_ngc) + nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, + oh_s, jcp.oh); + else if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb, + oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + while (start < end) { + + int ocb = occ * jcp.nb_oc_blocking; + int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; + int g_ic = g * jcp.ch_block * jcp.ic; + int work_rem = end - start; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + + auto dst_w = dst + dst_d.blk_off(n, g_oc); + auto src_w = src + src_d.blk_off(n, g_ic); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); + auto bias_w = jcp.with_bias ? + bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) : + 0; + int32_t *compensation_w + = (jcp.signed_input) ? compensation + g_oc : 0; + + auto scales = &oscales[jcp.is_oc_scale * g_oc]; + for (int oj = oh_s; oj < oh_e; oj++) { + int ih_max = 0, kh_lo = 0, kh_len = 0; + if (jcp.dilate_h != 0 && jcp.stride_h == 1) { + /* dilation */ + int dilate_h = jcp.dilate_h + 1; + // Note: use div_up to account for "holes" in filter + int o_t_overflow = div_up( + max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad), + dilate_h); + int o_b_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 - jcp.oh + + oj - jcp.b_pad), + dilate_h); + kh_len = jcp.kh - o_t_overflow - o_b_overflow; + kh_lo = o_b_overflow; + ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h; + } else { + int o_t_overflow = max( + 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h); + int o_b_overflow + = max(0, ((oj + jcp.kh) - (jcp.oh + jcp.b_pad)) + / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 + - abs(jcp.oh + jcp.b_pad - (oj + 1)) % jcp.stride_h; + int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h; + + kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h + + 1 - o_t_overflow - o_b_overflow; + kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h; + ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h; + } + + int wei_stride + = (!jcp.signed_input) ? kh_lo * wht_kh_stride : 0; + p.src = src_w + ih_max * src_h_stride; + p.dst = dst_w + oj * dst_h_stride; + p.filt = wht_w + wei_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.t_overflow = max( + 0, jcp.kh - (kh_lo + max(0, kh_len - 1) * jcp.stride_h + + 1)); + p.b_overflow = kh_lo; + p.kh_padding = kh_len; + p.scales = scales; + p.oc_blocks = jcp.is_depthwise ? g : ocb; + kernel_->jit_ker(&p); + } + if (jcp.loop_order == loop_ngc) + nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, + oc_chunks, oh_s, jcp.oh); + else if (jcp.loop_order == loop_cgn) + nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n, + jcp.mb, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + } + }); +} + +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp new file mode 100644 index 000000000000..901038fa4841 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp @@ -0,0 +1,237 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "cpu_primitive.hpp" +#include "cpu_memory.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "nstl.hpp" + +#include "cpu_deconvolution_pd.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +typedef enum { + no_last_block = 0x1U, + last_ic_block = 0x2U, + last_sp_block = 0x4U, +} ker_block_t; + +struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_ker_t); + + jit_avx512_core_x8s8s32x_deconv_fwd_kernel( + const jit_conv_conf_t &ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + generate(); + jit_ker = (void (*)(jit_deconv_call_s *))getCode(); + } + + ~jit_avx512_core_x8s8s32x_deconv_fwd_kernel() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const deconvolution_desc_t &cd, + memory_desc_t &src_md, + memory_desc_t &weights_md, + memory_desc_t &dst_md, + const bool with_bias, + memory_desc_t &bias_md, + const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + const jit_conv_conf_t &jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_deconv_call_s *); +private: + jit_uni_eltwise_injector_f32 *eltwise_injector_; + using reg64_t = const Xbyak::Reg64; + using zmm_t = const Xbyak::Zmm; + using xmm_t = const Xbyak::Xmm; + + reg64_t reg_src = r8; + reg64_t reg_filt = r9; + reg64_t reg_dst = r10; + reg64_t param1 = abi_param1; + reg64_t reg_kh = abi_not_param1; + reg64_t reg_nur_w = rbx; + reg64_t reg_bias = rdx; + reg64_t reg_icb = reg_bias; + reg64_t reg_ptr_scales = rax; + reg64_t reg_oc_blocks = rsi; + + reg64_t aux_reg_src = r11; + reg64_t aux_reg_filt = r12; + + reg64_t reg_compensation = r14; + reg64_t reg_scratch = r14; + reg64_t reg_ptr_sum_scale = r11; + reg64_t reg_bias_alpha = abi_not_param1; + reg64_t reg_overflow = rax; + reg64_t reg_comp_strides = reg_overflow; + + Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); + zmm_t zmm_tmp = zmm_t(28); + zmm_t zmm_one = zmm_t(29); + /* used during write-out section of store_output */ + zmm_t zmm_zero = zmm_t(31); + zmm_t zmm_wei = zmm_t(31); + + /* signed input */ + zmm_t zmm_shift = zmm_t(30); + zmm_t zmm_comp = zmm_t(30); + zmm_t zmm_bias = zmm_t(31); + zmm_t zmm_prev_dst = zmm_t(31); + + zmm_t zmm_out(int i_ur, int i_oc) { + int idx = i_ur * jcp.nb_oc_blocking + i_oc; + assert(idx < 31); + return zmm_t(idx); + } + zmm_t zmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return zmm_t(idx); + } + zmm_t zmm_bias_alpha() { + return zmm_t(jcp.nb_oc_blocking * jcp.ur_w); + } + xmm_t xmm_bias_alpha() { + return xmm_t(jcp.nb_oc_blocking * jcp.ur_w); + } + + int get_ow_start(int ki, int l_overflow) { + int res = (jcp.ow - 1 + jcp.r_pad) % jcp.stride_w + + l_overflow * jcp.stride_w + - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + return res; + } + + int get_ow_end(int ur_w, int ki, int r_overflow) { + if (utils::one_of(ur_w, jcp.ow, jcp.ur_w_tail)) + ur_w += nstl::min(0, jcp.r_pad); // remove negative padding + int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w + + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + return ur_w - res; + } + bool maybe_eltwise(int position); + void compute_eltwise(int ur_w); + void prepare_output(int ur_w); + void store_output(int ur_w, bool last_oc_block); + void compute_ker(int ur_w, int l_overflow, int r_overflow, + ker_block_t last_ic_block_flag, bool h_padded = false); + void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block); + void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block); + void generate(); + void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op, + bool mask_flag); +}; + +template +struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_deconvolution_fwd_pd_t { + using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_deconvolution:", avx512_core, ""), + _jit_avx512_core_x8s8s32x_deconvolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && (desc()->alg_kind & alg_kind::deconvolution_direct) + && desc()->src_desc.data_type == src_type + && desc()->dst_desc.data_type == dst_type + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && desc()->accum_data_type == data_type::s32; + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_x8s8s32x_deconv_fwd_kernel:: + init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, + with_bias(), bias_md_, *attr()); + + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(scratchpad, + jcp_, *attr()); + + return status::success; + } + + jit_conv_conf_t jcp_; + }; + + _jit_avx512_core_x8s8s32x_deconvolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { + kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel(pd()->jcp_, + *pd()->attr()); + } + + ~_jit_avx512_core_x8s8s32x_deconvolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if(pd()->ndims() == 3) + execute_forward_1d(ctx); + else + execute_forward_2d(ctx); + return status::success; + } + +private: + void execute_forward_1d(const exec_ctx_t &ctx) const; + void execute_forward_2d(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_avx512_core_x8s8s32x_deconv_fwd_kernel *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp new file mode 100644 index 000000000000..c09592d5c950 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp @@ -0,0 +1,773 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX2_GENERATOR_HPP +#define CPU_JIT_AVX2_GENERATOR_HPP + +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_isa_traits.hpp" +#include "jit_utils/jit_utils.hpp" + +#if defined(_WIN32) && !defined(__GNUC__) +# define STRUCT_ALIGN(al, ...) __declspec(align(al)) __VA_ARGS__ +#else +# define STRUCT_ALIGN(al, ...) __VA_ARGS__ __attribute__((__aligned__(al))) +#endif + +#if defined(_WIN32) +# define OFFSET_SHADOWSPACE 0x28 +#endif + +#define DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_name) \ + const char *name() const override { return STRINGIFY(jit_name); } \ + const char *source_file() const override { return __FILE__; } + +namespace mkldnn { +namespace impl { +namespace cpu { + +// TODO: move this to jit_generator class? +namespace { + +typedef enum { + PAGE_4K = 4096, + PAGE_2M = 2097152, +} cpu_page_size_t; + +// TODO: move this somewhere else? Although this is only used by jit kernels +// (Roma) +static inline int float2int(float x) { + union { + float vfloat; + int vint; + } cvt; + cvt.vfloat = x; + return cvt.vint; +} + +// TODO: A GPR class that hides ABI details from the JIT kernels and allows +// numbering registers from 0 to 14 (x86_64) / 6 (x32) (gpr0, gpr1, ...) and +// stack register (sr). +// +// This will allow using syntax like this: +// +// param = gpr0; +// reg_input = gpr0; +// reg_output = gpr1; +// ... +// +// #ifndef XBYAK64 +// mov(param, ptr[sr]) +// #endif +// +// (Roma) + +#ifdef XBYAK64 +constexpr Xbyak::Operand::Code abi_save_gpr_regs[] = { + Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12, + Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15, +#ifdef _WIN32 + Xbyak::Operand::RDI, Xbyak::Operand::RSI, +#endif +}; + +#ifdef _WIN32 +static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RCX), + abi_param2(Xbyak::Operand::RDX), + abi_param3(Xbyak::Operand::R8), + abi_param4(Xbyak::Operand::R9), + abi_not_param1(Xbyak::Operand::RDI); +#else +static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RDI), + abi_param2(Xbyak::Operand::RSI), + abi_param3(Xbyak::Operand::RDX), + abi_param4(Xbyak::Operand::RCX), + abi_param5(Xbyak::Operand::R8), + abi_param6(Xbyak::Operand::R9), + abi_not_param1(Xbyak::Operand::RCX); +#endif +#endif + +inline unsigned int get_cache_size(int level, bool per_core = true){ + unsigned int l = level - 1; + // Currently, if XByak is not able to fetch the cache topology + // we default to 32KB of L1, 512KB of L2 and 1MB of L3 per core. + if (cpu.getDataCacheLevels() == 0){ + const int L1_cache_per_core = 32000; + const int L2_cache_per_core = 512000; + const int L3_cache_per_core = 1024000; + int num_cores = per_core ? 1 : mkldnn_get_max_threads(); + switch(l){ + case(0): return L1_cache_per_core * num_cores; + case(1): return L2_cache_per_core * num_cores; + case(2): return L3_cache_per_core * num_cores; + default: return 0; + } + } + if (l < cpu.getDataCacheLevels()) { + return cpu.getDataCacheSize(l) + / (per_core ? cpu.getCoresSharingDataCache(l) : 1); + } else + return 0; +} + +} + +class jit_generator : public Xbyak::CodeGenerator +{ +private: + const size_t xmm_len = 16; +#ifdef _WIN32 + const size_t xmm_to_preserve_start = 6; + const size_t xmm_to_preserve = 10; +#else + const size_t xmm_to_preserve_start = 0; + const size_t xmm_to_preserve = 0; +#endif + + const size_t num_abi_save_gpr_regs + = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]); + + const size_t size_of_abi_save_regs + = num_abi_save_gpr_regs * rax.getBit() / 8 + + xmm_to_preserve * xmm_len; + +public: + enum { + _cmp_eq_oq = 0u, + _cmp_lt_os = 1u, + _cmp_le_os = 2u, + _cmp_neq_uq = 4u, + _cmp_nlt_us = 5u, + _cmp_nle_us = 6u, + + _op_floor = 1u, + _op_mxcsr = 4u, + }; + + Xbyak::Reg64 param1 = abi_param1; + const int EVEX_max_8b_offt = 0x200; + const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp; + + inline size_t get_size_of_abi_save_regs() { + return size_of_abi_save_regs; + } + + void preamble() { + if (xmm_to_preserve) { + sub(rsp, xmm_to_preserve * xmm_len); + for (size_t i = 0; i < xmm_to_preserve; ++i) + movdqu(ptr[rsp + i * xmm_len], Xbyak::Xmm(xmm_to_preserve_start + i)); + } + for (size_t i = 0; i < num_abi_save_gpr_regs; ++i) + push(Xbyak::Reg64(abi_save_gpr_regs[i])); + if (mayiuse(avx512_common)) { + mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); + } + } + + void mic_prefetcht0(Xbyak::Address a) { + if (mayiuse(avx512_mic)) + prefetcht0(a); + } + + void mic_prefetcht1(Xbyak::Address a) { + if (mayiuse(avx512_mic)) + prefetcht1(a); + } + + void mic_prefetcht2(Xbyak::Address a) { + if (mayiuse(avx512_mic)) + prefetcht2(a); + } + + void uni_vzeroupper() { + if (mayiuse(avx) && !mayiuse(avx512_mic)) + vzeroupper(); + } + + void postamble() { + for (size_t i = 0; i < num_abi_save_gpr_regs; ++i) + pop(Xbyak::Reg64(abi_save_gpr_regs[num_abi_save_gpr_regs - 1 - i])); + if (xmm_to_preserve) { + for (size_t i = 0; i < xmm_to_preserve; ++i) + movdqu(Xbyak::Xmm(xmm_to_preserve_start + i), ptr[rsp + i * xmm_len]); + add(rsp, xmm_to_preserve * xmm_len); + } + uni_vzeroupper(); + ret(); + } + + template + Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, + T raw_offt, bool bcast = false) + { + using Xbyak::Zmm; + using Xbyak::Reg64; + using Xbyak::Address; + using Xbyak::RegExp; + + assert(raw_offt <= INT_MAX); + auto offt = static_cast(raw_offt); + + int scale = 0; + + if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) { + offt = offt - 2 * EVEX_max_8b_offt; + scale = 1; + } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) { + offt = offt - 4 * EVEX_max_8b_offt; + scale = 2; + } + + auto re = RegExp() + base + offt; + if (scale) + re = re + reg_EVEX_max_8b_offt * scale; + + if (bcast) + return zword_b [re]; + else + return zword [re]; + } + + Xbyak::Address make_safe_addr(const Xbyak::Reg64 ®_out, size_t offt, + const Xbyak::Reg64 &tmp_reg, bool bcast = false) { + if (offt > INT_MAX) { + mov(tmp_reg, offt); + return bcast ? ptr_b[reg_out + tmp_reg] : ptr[reg_out + tmp_reg]; + } else { + return bcast ? ptr_b[reg_out + offt] : ptr[reg_out + offt]; + } + } + + Xbyak::Address EVEX_compress_addr_safe(const Xbyak::Reg64 &base, + size_t raw_offt, const Xbyak::Reg64 ®_offt, bool bcast = false) { + if (raw_offt > INT_MAX) { + return make_safe_addr(base, raw_offt, reg_offt, bcast); + } else { + return EVEX_compress_addr(base, raw_offt, bcast); + } + } + + void safe_add(const Xbyak::Reg64 &base, size_t raw_offt, + const Xbyak::Reg64 ®_offt) { + if (raw_offt > INT_MAX) { + mov(reg_offt, raw_offt); + add(base, reg_offt); + } else { + add(base, raw_offt); + } + } + + void safe_sub(const Xbyak::Reg64 &base, size_t raw_offt, + const Xbyak::Reg64 ®_offt) { + if (raw_offt > INT_MAX) { + mov(reg_offt, raw_offt); + sub(base, reg_offt); + } else { + sub(base, raw_offt); + } + } + + // Disallow char-based labels completely + void L(const char *label) = delete; + void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); } + + void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + pxor(x2, op); + } + void uni_vpxor(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + if (mayiuse(avx2)) { + vpxor(x1, x2, op); + } else { + vxorps(x1, x2, op); + } + } + void uni_vpxor(const Xbyak::Zmm &x1, const Xbyak::Zmm &x2, + const Xbyak::Operand &op) { + vpxord(x1, x2, op); + } + + void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Xmm &x) { + movss(addr, x); + } + void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Ymm &x) { + vmovss(addr, x); + } + void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address& addr) { + movss(x, addr); + } + void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address& addr) { + vmovss(x, addr); + } + + void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Xmm &x) { + movsd(addr, x); + } + void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Ymm &x) { + vmovsd(addr, x); + } + void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address& addr) { + movsd(x, addr); + } + void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address& addr) { + vmovsd(x, addr); + } + + void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) { + movdqu(addr, x); + } + void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Ymm &x) { + vmovdqu(addr, x); + } + void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Zmm &x) { + vmovdqu32(addr, x); + } + + void uni_vmovdqu(const Xbyak::Xmm &x, const Xbyak::Address &addr) { + movdqu(x, addr); + } + void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) { + vmovdqu(x, addr); + } + void uni_vmovdqu(const Xbyak::Zmm &x, const Xbyak::Address &addr) { + vmovdqu32(x, addr); + } + + void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) { + movups(addr, x); + } + void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Ymm &x) { + vmovups(addr, x); + } + + void uni_vmovups(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + movups(x, op); + } + void uni_vmovups(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vmovups(x, op); + } + + void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Xmm &x) { + movntps(addr, x); + } + void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Ymm &x) { + vmovntps(addr, x); + } + + void uni_vbroadcastss(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + movss(x, op); + shufps(x, x, 0x0); + } + void uni_vbroadcastss(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + if (op.isMEM() || mayiuse(avx2)) { + vbroadcastss(x, op); + } else { + Xbyak::Xmm t(x.getIdx()); + if (t.getIdx() != op.getIdx()) movss(t, op); + vinsertf128(x, x, t, 1); + vshufps(x, x, x, 0); + } + } + + void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + movsd(x, op); + pshufd(x, x, 0x0); + } + void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + if (mayiuse(avx2)) { + vpbroadcastd(x, op); + } else { + Xbyak::Xmm t(x.getIdx()); + if (t.getIdx() != op.getIdx()) movsd(t, op); + vinsertf128(x, x, t, 1); + vshufps(x, x, x, 0); + } + } + + void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + rcpss(x, op); + } + void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) { + Xbyak::Xmm x1_(x1.getIdx()); + Xbyak::Xmm x2_(x2.getIdx()); + vrcpss(x1_, x1_, x2_); + } + void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) { + Xbyak::Xmm x_(x.getIdx()); + vrcpss(x_, x_, op); + } + + void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + rcpps(x, op); + } + void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vrcpps(x, op); + } + void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) { + vrcp14ps(x, op); + } + + void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + divps(x, op2); + } + void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vdivps(x, op1, op2); + } + + void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { + movups(buf, op1); + divps(buf, op2); + if (x.getIdx() != buf.getIdx()) { + movups(x, buf); + } + } + + void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Ymm &buf) { + vdivps(x, op1, op2); + } + + void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + addps(x, op2); + } + void uni_vaddps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vaddps(x, op1, op2); + } + + void uni_vpsignd(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, + const Xbyak::Operand& op) { + assert(x1.getIdx() == x2.getIdx()); + psignd(x1, op); + } + void uni_vpsignd(const Xbyak::Ymm& x1, const Xbyak::Ymm& x2, + const Xbyak::Operand& op) { + vpsignd(x1, x2, op); + } + + void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + subps(x, op2); + } + void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vsubps(x, op1, op2); + } + + void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { + movups(buf, op1); + subps(buf, op2); + if (x.getIdx() != buf.getIdx()) { + movups(x, buf); + } + } + + void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Ymm &buf) { + vsubps(x, op1, op2); + } + + void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + mulps(x, op2); + } + void uni_vmulps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vmulps(x, op1, op2); + } + + void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + mulps(x1, x2); + addps(x1, op); + } + void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vfmadd213ps(x1, x2, op); + } + + void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + mulps(x2, op); + addps(x1, x2); + } + void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vfmadd231ps(x1, x2, op); + } + + void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + mulps(x2, op); + subps(x1, x2); + } + + void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vfnmadd231ps(x1, x2, op); + } + + void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + sqrtps(x, op); + } + void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vsqrtps(x, op); + } + + void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + paddd(x2, op); + } + void uni_vpaddd(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + vpaddd(x1, x2, op); + } + + void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + assert(x1.getIdx() == x2.getIdx()); + andps(x1, op); + } + void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + if (!mayiuse(avx512_common) || x1.getBit() < 512) + vandps(x1, x2, op); + else + vpandd(x1, x2, op); + } + + void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + assert(x1.getIdx() == x2.getIdx()); + orps(x1, op); + } + void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + if (!mayiuse(avx512_common) || x1.getBit() < 512) + vorps(x1, x2, op); + else + vpord(x1, x2, op); + } + + void uni_vpslld(const Xbyak::Xmm &x, const Xbyak::Operand &op, + const int imm) { + assert(x.getIdx() == op.getIdx()); + pslld(x, imm); + } + void uni_vpslld(const Xbyak::Ymm &x, const Xbyak::Operand &op, + const int imm) { + vpslld(x, op, imm); + } + + void uni_vpsrld(const Xbyak::Xmm &x, const Xbyak::Operand &op, + const int imm) { + assert(x.getIdx() == op.getIdx()); + psrld(x, imm); + } + void uni_vpsrld(const Xbyak::Ymm &x, const Xbyak::Operand &op, + const int imm) { + vpsrld(x, op, imm); + } + + void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + maxps(x, op2); + } + void uni_vmaxps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vmaxps(x, op1, op2); + } + + void uni_vminps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + minps(x, op2); + } + void uni_vminps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vminps(x, op1, op2); + } + + void uni_vcmpgtps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + cmpps(x1, op, _cmp_nle_us); + } + + void uni_vcmpgtps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vcmpgtps(x1, x2, op); + } + + void uni_vcmpgeps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + cmpps(x1, op, _cmp_nlt_us); + } + + void uni_vcmpgeps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vcmpps(x1, x2, op, _cmp_nlt_us); + } + + void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) { + ptest(x1, op); + } + + void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) { + assert(!(x1.isZMM() || op.isZMM())); + vtestps(x1, op); + } + + void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op, const Xbyak::Xmm &msk) { + assert(x1.getIdx() == x2.getIdx()); + assert(msk.getIdx() == 0); + blendvps(x1, op); + } + void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op, const Xbyak::Ymm &msk) { + vblendvps(x1, x2, op, msk); + } + + void uni_vroundps(const Xbyak::Xmm &x, const Xbyak::Operand &op, + const int imm) { + roundps(x, op, imm); + } + void uni_vroundps(const Xbyak::Ymm &x, const Xbyak::Operand &op, + const int imm) { + vroundps(x, op, imm); + } + + void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + cvtps2dq(x, op); + } + void uni_vcvtps2dq(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vcvtps2dq(x, op); + } + + void uni_vcvtdq2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + cvtdq2ps(x, op); + } + void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vcvtdq2ps(x, op); + } + + void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Xmm &x2) { + movmskps(x1.cvt64(), x2); + } + void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Ymm &x2) { + vmovmskps(x1, x2); + } + + void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){ + assert(x1.getIdx() == x1.getIdx()); + packssdw(x1, op); + } + void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){ + vpackssdw(x1, x2, op); + } + + void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){ + assert(x1.getIdx() == x1.getIdx()); + packuswb(x1, op); + } + void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){ + vpackuswb(x1, x2, op); + } + + + void mul_by_const(const Xbyak::Reg &out, + const Xbyak::Reg64 &tmp, int value) { + // Generates a shift + add sequence for multiplicating contents of the + // out register by a known JIT-time value. Clobbers the tmp register. + // + // Pros compared to mul/imul: + // - does not require using known registers + // - not microcoded on Intel(R) Xeon Phi(TM) processors + // Still, there are probably a lot of cases when mul/imul is faster on + // Intel(R) Core(TM) processors. Not intended for critical path. + + // TODO: detect when overflow is emminent (Roma) + // TODO: detect when using mul/imul is a better option (Roma) + + int p = 0; // the current power of 2 + int old_p = 0; // the last seen power of 2 such that value[old_p] != 0 + + xor_(tmp, tmp); + while (value) { + if (value & 1) { + int shift = p - old_p; + if (shift) { + shl(out, shift); + old_p = p; + } + add(tmp, out); + } + value >>= 1; + p++; + } + mov(out, tmp); + } + +public: + jit_generator( + void *code_ptr = nullptr, + size_t code_size = 256 * 1024 + ) : Xbyak::CodeGenerator(code_size, code_ptr) + { + } + virtual ~jit_generator() {} + + virtual const char *name() const = 0; + virtual const char *source_file() const = 0; + + const Xbyak::uint8 *getCode() { + const Xbyak::uint8 *code = CodeGenerator::getCode(); + size_t code_size = getSize(); + jit_utils::register_jit_code(code, code_size, name(), source_file()); + return code; + } + + template const F getCode() { + return (const F)getCode(); + } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp new file mode 100644 index 000000000000..56d7f592e2a2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp @@ -0,0 +1,481 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_PRIMITIVE_CONF_HPP +#define JIT_PRIMITIVE_CONF_HPP + +#include + +#include "common/primitive_attr.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +/* convolution */ +enum conv_version_t {ver_unused, ver_fma, ver_avx512_core, ver_4fma, ver_vnni}; +enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn, + loop_ngcw, loop_nhwcg, loop_nwcg}; +enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr, + loop_brl}; +enum conv_kernel_kind_t {embd_bcast, expl_bcast}; + +enum { + FLAG_MB_FIRST = 1 << 0, FLAG_MB_LAST = 1 << 1, + FLAG_OC_FIRST = 1 << 2, FLAG_OC_LAST = 1 << 3, + FLAG_IC_FIRST = 1 << 4, FLAG_IC_LAST = 1 << 5, + FLAG_SP_FIRST = 1 << 6, FLAG_SP_LAST = 1 << 7, + FLAG_REDUCE_FIRST = 1<<8, FLAG_REDUCE_LAST = 1<<9, + FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips + loading weights-data from memory; this + needs to happen on the first Group/16 + iteration. */ + FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip + loading bias data from memory */ + FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution + pass */ +}; + +struct jit_conv_conf_t { + prop_kind_t prop_kind; + conv_version_t ver; + conv_loop_order_t loop_order; + + int simd_w; + int ndims; + int mb; + int ngroups, ic, oc, oc_without_padding, ic_without_padding; + int id, ih, iw, od, oh, ow; + int f_pad, l_pad, t_pad; + int back_pad, r_pad, b_pad; + int kd, kh, kw; + int stride_d, stride_h, stride_w; + int dilate_d, dilate_h, dilate_w; + format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround + bool with_bias; + bool with_sum; + bool with_eltwise; + + post_ops_t::entry_t::eltwise_t eltwise; + + int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; + + int idp, ihp, iwp, ohp, owp; + int nb_ic, ic_block; + int nb_oc, oc_block; + int nb_ow, ow_block; + int nb_oc_blocking; /* used in jit kernels for nb_oc work bloking taking + into account vector registers distribution */ + int nb_oc_blocking_thr_chunk; /* used for distibution of nb_oc work + within threads */ + int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work + int nb_ic_L2; + int h_blocking; + int nb_oc_L2; + int ur_h, ur_w; + int ur_w_tail; + bool is_1stconv; + int nonblk_group_off; + /* fma avx512_core */ + conv_kernel_kind_t kernel_kind; + /* 4fma */ + int tr_iw; + int tr_src_num_guard_elems; + /* 1st conv: 4fma */ + int tr_ld; + int kh_step; + /* 4vnni */ + int typesize_in; + int typesize_out; + int typesize_bia; + int typesize_acc; + /* avx512_u8s8u8 */ + int ic_nb1, ic_nb2; + int oc_nb1; + int ur_ow_max, ur_ow, ur_ow_tail; + int ur_ow_nsteps; + data_type_t bia_dt; + data_type_t dst_dt; + /* avx512: max possible value is nregs(32) - aux_regs(4) */ + int src_offsets[28]; + int src_count; + bool expl_bcast; + bool large_spatial; + int is_oc_scale; + int max_regs_ur; // maximum accumulation registers + // dw conv + int nb_ch, ch_block, nb_ch_blocking; + bool is_depthwise, is_fast_depthwise, is_resrc_depthwise; + int aligned_threads; + // large spatial + int oh_blk_size; + // s8s8 convolution + bool signed_input; + float wei_adj_scale; +}; + +struct jit_conv_conf_2x3_wino_t { + conv_version_t ver; + + int m; + int r; + int alpha; + int tile_h, tile_w; + + int mb; + int ngroups, ic, oc, oc_without_padding; + int ih, iw, oh, ow; + int l_pad, t_pad; + int r_pad, b_pad; + int kh, kw; + int stride_h, stride_w; + int dilate_h, dilate_w; + + int nb_ic, ic_block; + int nb_oc, oc_block; + + int w_block_size, h_block_size; + + data_type_t bia_dt; + data_type_t dst_dt; + + int is_oc_scale; + int typesize_in; + int typesize_out; + int typesize_bia; + int typesize_acc; + + format_tag_t src_tag, dst_tag; // temporary workaround + bool with_bias; + bool small_mb; + + int xb, yb; + int inp_stride; + int out_stride; + int wei_stride; + int bia_stride; + + int M, N, K; + int m_block, n_block, k_block; + int n2_block, n_chunks; + int k2_block, k_chunks; + + int mb_block, nb_mb; + + size_t size_wino_src, size_wino_wei, size_wino_dst; + + int nthr; +}; + +/* + Winograd sched policy: + + Computation Unit: + W: weights transform + S: src transform + D: dst transform + G: gemm + + Thread grouping by: + i: nb_ic + o: nb_oc + t: tile_block + e: element in tile + + Note: 'i' and 'o' are omited if + i. not comblined with t or + ii. with discrete transforms + + Current policies supported: +*/ +enum winograd_sched_t { + WSCHED_INVALID = 0, + + /* Forward & backward-data */ + /* W_S_G_D implements discrete transforms */ + WSCHED_DATA_W_S_G_D, + /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/ + WSCHED_DATA_W_SGD, + + /* Backward-weights */ + WSCHED_WEI_S_D_G_W, + WSCHED_WEI_SDGtWo, + WSCHED_WEI_S_D_Giot_W, + WSCHED_WEI_SDGt_W, +}; + +struct jit_conv_winograd_conf_t : public jit_conv_conf_t { + int itiles; + int jtiles; + int ntiles; + int ic_simd_block=16; + int tile_4fma_padding; + int tile_4fma; + int oc_simd_block=16; + int oc_reg_block; + int ic_reg_block; + int tile_block; + int tile_block_ur; + int nb_tile_block_ur; + + bool double_buffering; + bool with_relu_postsum; + int zmm_start; + int nb_reg; + + int dimK; + int dimK_4fma; + int dimK_reg_block; + int dimK_block; + int dimK_nb_block; + + int dimM; + int dimM_reg_block; + int dimM_simd_block; + int dimM_block; + int dimM_nb_block; + + int dimN; + int dimN_reg_block; + int dimN_bcast_ur; + int dimN_block; + int dimN_nb_block; + + winograd_sched_t sched_policy; +}; + +struct jit_conv_call_s { + const void *src; /* hack, non-const for backward_data */ + const void *dst; /* hack, non-const for forward */ + const void *filt; /* hack, non-const for backward_weights */ + const void *bias; /* hack, non-const for backward_bias */ + const void *src_prf; + const void *dst_prf; + const void *filt_prf; + const void *bias_prf; + const void *scales; + const void *acc_s32; + const void *compensation; + size_t kd_offset; + size_t kd_offset_prf; + size_t d_index; + size_t d_index_prf; + size_t d_worksize; + size_t d_worksize_prf; + size_t kd_padding; + size_t kd_padding_prf; + size_t kh_padding; + size_t kh_padding_prf; + size_t owb; + size_t owb_prf; + size_t kw_padding; + size_t channel; + size_t channel_prf; + size_t oc_blocks; + size_t ur_w; + size_t ur_str_w; + size_t ch_blocks; + size_t t_overflow; + size_t b_overflow; + int flags; +}; + +struct jit_deconv_call_s { + const void *src; /* hack, non-const for backward_data */ + const void *dst; /* hack, non-const for forward */ + const void *filt; /* hack, non-const for backward_weights */ + const void *bias; /* hack, non-const for backward_bias */ + const void *scales; + const void *compensation; + size_t t_overflow; + size_t b_overflow; + size_t kh_padding; + size_t oc_blocks; +}; + +struct jit_dw_conv_call_s { + const void *input; + const void *output; + const void *filter; + const void *bias; + size_t kh_count; + size_t oh_count; + size_t oh_index; + size_t filter_pad_off; + unsigned char + exec_flags; /* Flags passed by driver execution to inner kernel */ +}; + +struct jit_wino_transform_call_s { + size_t tile_block; + size_t tile_block_ur; + size_t nb_tile_block_ur; + size_t tile_count; + size_t tj; + size_t ti; + void *src; + void *dst; + void *Mw; + void *M; + void *T; + void *G; + void *bias; +}; + +struct jit_1x1_conv_conf_t { + prop_kind_t prop_kind; + conv_version_t ver; + + int mb; + int ngroups, ic, oc, oc_without_padding, ic_without_padding; + int iw, ih, ow, oh; + int l_pad, t_pad; + int kh, kw; + int stride_h, stride_w; + format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround + bool with_bias; + bool with_sum; + bool with_eltwise; + + post_ops_t::entry_t::eltwise_t eltwise; + + int is, os; + int ic_block, oc_block; + + int ur, ur_tail; + + int reduce_dim, reduce_block, nb_reduce, + nb_reduce_blocking, nb_reduce_blocking_max; + int load_dim, load_block, nb_load, + nb_load_blocking, nb_load_blocking_max, nb_load_chunk; + int bcast_dim, bcast_block, nb_bcast, + nb_bcast_blocking, nb_bcast_blocking_max; + + int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step; + int load_loop_load_step, load_loop_iter_step; + int bcast_loop_output_step, bcast_loop_output_substep; + int bcast_loop_bcast_step, bcast_loop_bcast_substep; + int fma_step; + int load_grp_count; + conv_1x1_loop_order_t loop_order; + bool use_vmovntps; + /* avx512 core */ + bool expl_bcast; + /* 4vnni */ + int typesize_in; + int typesize_out; + int typesize_bia; + int typesize_acc; + /* 4fma */ + bool transpose_src; + int tr_is; + int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; + int is_oc_scale; + data_type_t bia_dt; + data_type_t dst_dt; + bool signed_input; + float wei_adj_scale; +}; + +struct jit_gemm_conv_conf_t { + prop_kind_t prop_kind; + + int mb; + int ngroups, ic, oc; + int iw, ih, id, ow, oh, od; + int l_pad, t_pad, f_pad; + int kh, kw, kd; + int stride_h, stride_w, stride_d; + int dilate_h, dilate_w, dilate_d; + bool with_bias; + + int is, os, ks; + int ic_block, oc_block; + + int nthr; + ptrdiff_t im2col_sz; + bool need_wei_reduction; + bool signed_input; + int oh_block; + int ow_block; + bool outer_threading; +}; + +struct jit_1x1_conv_call_s { + const void *bcast_data; + const void *load_data; + const void *output_data; + const void *bias_data; // used in forward and backward_weights only + const void *acc_s32; + const void *scales; + const void *compensation; + + size_t load_dim; + size_t bcast_dim; + size_t reduce_dim; + + size_t output_stride; // used in backward_weights only + + size_t first_last_flag; +}; + +/* pooling */ +struct jit_pool_conf_t { + int ndims; + int mb, c; + int id, ih, iw, od, oh, ow; + int stride_d, stride_h, stride_w; + int kd, kh, kw; + int f_pad, t_pad, l_pad; + alg_kind_t alg; + bool is_training; + bool pad_w_is_null; + bool is_backward; + bool simple_alg; + data_type_t ind_dt; + + int c_block, c_tail, nb_c; + int ur_c, ur_c_tail; + int ur_w; + int ur_w_tail; + size_t tail[4]; + data_type_t src_dt; + data_type_t dst_dt; +}; + +struct jit_pool_call_s { + const float *src; + const float *dst; + const void *indices; + const float *src_prf; + const float *dst_prf; + const void *indices_prf; + size_t oh; + size_t kd_padding; + size_t kh_padding; + size_t kh_padding_shift; + size_t kd_padding_shift; + size_t kw_padding; + const float* init_value; + float ker_area_h; +}; + + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp new file mode 100644 index 000000000000..94d2101d6eab --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp @@ -0,0 +1,677 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include "jit_sse42_1x1_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_sse42_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, reg_bcast_loop_work); + + Label bcast_loop; + Label bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + generate_reduce_loop(load_loop_blk, jcp.ur); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + generate_reduce_loop(load_loop_blk, jcp.ur_tail); + L(bcast_loop_tail_out); + } +} + +void jit_sse42_1x1_conv_kernel_f32::generate_reduce_loop( + int load_loop_blk, int ur) +{ + auto reg_load = [=](int i, int n) { + return Xmm(2*ur * load_loop_blk + 2*i + n + 1); + }; + + auto reg_accum = [=](int i, int j, int n) { + return Xmm(2*j * load_loop_blk + 2*i + n + 1); + }; + + auto bias_ptr = [=](int i, int n) { + return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i + n*4*sizeof(float)]; + }; + + auto bcast_ptr = [=](int u, int j) { + assert(j < jcp.ur); + assert(u <= jcp.reduce_loop_unroll); + size_t offt; + if (one_of(jcp.prop_kind, + forward_training, forward_inference, backward_data)) { + assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data) + ? jcp.oc_block : jcp.ic_block); + auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is; + offt = (u == jcp.reduce_loop_unroll) + ? (height + j) * jcp.reduce_loop_unroll + : j * jcp.reduce_loop_unroll + u; + } else + offt = u * jcp.ic_block + j; + return ptr[aux_reg_bcast_data + sizeof(float) * offt]; + }; + + auto load_ptr = [=](int u, int i, int n) { + size_t offt; + size_t u0 = u % jcp.reduce_loop_unroll; + size_t u1 = u / jcp.reduce_loop_unroll; + switch (jcp.prop_kind) { + case backward_data: + offt = (i * jcp.oc_block + u0) * jcp.ic_block; + break; + case backward_weights: + offt = (i * jcp.os + u0) * jcp.oc_block; + break; + default: + offt = (i * jcp.ic + u0) * jcp.oc_block; + } + return ptr[aux_reg_load_data + + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt + n * 4 * sizeof(float)]; + }; + + auto output_ptr = [=](int i, int j, int n) { + switch (jcp.prop_kind) { + case backward_data: + return ptr[aux_reg_output_data + + (i * jcp.is + j) * jcp.ic_block * sizeof(float) + n * 4 * sizeof(float)]; + case backward_weights: + return ptr[aux_reg_output_data + + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale + + sizeof(float) * jcp.oc_block * j + n * 4 * sizeof(float)]; + default: + return ptr[aux_reg_output_data + + (i * jcp.os + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)]; + } + }; + + auto init = [=]() { + Label init_done; + Label init_zero; + + if (jcp.with_bias && one_of(jcp.prop_kind, forward_training, + forward_inference)) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(init_zero); + + for (int i = 0; i < load_loop_blk; i++) + for (int j = 0; j < ur; ++j) { + movups(reg_accum(i, j, 0), bias_ptr(i, 0)); + movups(reg_accum(i, j, 1), bias_ptr(i, 1)); + } + jmp(init_done); + } + + L(init_zero); + for (int i = 0; i < load_loop_blk; ++i) + for (int j = 0; j < ur; ++j) { + auto r0 = reg_accum(i, j, 0); + auto r1 = reg_accum(i, j, 1); + xorps(r0, r0); + xorps(r1, r1); + } + + L(init_done); + + // load weights + for (int i = 0; i < load_loop_blk; ++i) { + movups(reg_load(i, 0), load_ptr(0, i, 0)); + movups(reg_load(i, 1), load_ptr(0, i, 1)); + } + + movss(reg_bcast, bcast_ptr(0, 0)); + shufps(reg_bcast, reg_bcast, 0); + }; // init() + + auto store = [=]() { + Label store_noadd; + + if (!jcp.with_sum) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jnz(store_noadd, T_NEAR); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + auto r0 = reg_accum(i, j, 0); + auto r1 = reg_accum(i, j, 1); + addps(r0, output_ptr(i, j, 0)); + addps(r1, output_ptr(i, j, 1)); + } + + L(store_noadd); + + if (jcp.with_eltwise) { + assert(ur * load_loop_blk < 14); + + Label store_norelu; + test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + jz(store_norelu, T_NEAR); + + eltwise_injector_->compute_vector_range(1, + 2 * ur * load_loop_blk + 1); + + L(store_norelu); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + movups(output_ptr(i, j, 0), reg_accum(i, j, 0)); + movups(output_ptr(i, j, 1), reg_accum(i, j, 1)); + } + }; + + auto fma_block = [=](bool last_block) { + for (int u = 0; u < jcp.reduce_loop_unroll; ++u) { + for (int j = 0; j < ur; ++j) { + for (int i = 0; i < load_loop_blk; ++i) { + mulps(reg_load(i, 0), reg_bcast); + mulps(reg_load(i, 1), reg_bcast); + addps(reg_accum(i, j, 0), reg_load(i, 0)); + addps(reg_accum(i, j, 1), reg_load(i, 1)); + + if (j == ur - 1 && !(last_block && u == jcp.reduce_loop_unroll - 1)) { + movups(reg_load(i, 0), load_ptr(u + 1, i, 0)); + movups(reg_load(i, 1), load_ptr(u + 1, i, 1)); + } + } + if (j < ur - 1) { + movss(reg_bcast, bcast_ptr(u, j + 1)); + shufps(reg_bcast, reg_bcast, 0); + } + } // for ur + if (!last_block || u < jcp.reduce_loop_unroll - 1) { + movss(reg_bcast, bcast_ptr(u + 1, 0)); + shufps(reg_bcast, reg_bcast, 0); + } + } // for reduce_loop_unroll + }; + + Label reduce_loop; + Label reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} // reduce_loop() + +void jit_sse42_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) +{ + if (!jcp.with_bias || jcp.prop_kind != backward_weights) + return; + + Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out; + Label diff_bias_load; + + auto diff_bias_ptr = [=](int i, int n) { + return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)+ 4*n*sizeof(float)]; + }; + + auto load_ptr = [=](int u, int i, int n) { + return ptr[aux_reg_load_data + + (i * jcp.os + u) * jcp.oc_block * sizeof(float) + 4*n*sizeof(float)]; + }; + + auto diff_bias_reg = [=](int i, int n) { return Xmm(2*i + n + 1); }; + + mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]); + cmp(reg_diff_bias_data, 0); + je(diff_bias_loop_out, T_NEAR); + + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(diff_bias_load, T_NEAR); + + for (int i = 0; i < load_loop_blk; ++i) { + auto r0 = diff_bias_reg(i, 0); + auto r1 = diff_bias_reg(i, 1); + xorps(r0, r0); + xorps(r1, r1); + } + jmp(diff_bias_init_out, T_NEAR); + + L(diff_bias_load); + for (int i = 0; i < load_loop_blk; ++i) { + movups(diff_bias_reg(i, 0), diff_bias_ptr(i, 0)); + movups(diff_bias_reg(i, 1), diff_bias_ptr(i, 1)); + } + + L(diff_bias_init_out); + mov(aux_reg_load_data, reg_load_data); + mov(reduce_loop_iter, reg_reduce_loop_work); + L(diff_bias_loop); { + for(int u = 0; u < jcp.reduce_loop_unroll; ++u) + for (int i = 0; i < load_loop_blk; ++i) { + addps(diff_bias_reg(i, 0), load_ptr(u, i, 0)); + addps(diff_bias_reg(i, 1), load_ptr(u, i, 1)); + } + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jnz(diff_bias_loop, T_NEAR); + } + + for (int i = 0; i < load_loop_blk; i++) { + movups(diff_bias_ptr(i, 0), diff_bias_reg(i, 0)); + movups(diff_bias_ptr(i, 1), diff_bias_reg(i, 1)); + } + + add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + + L(diff_bias_loop_out); +} + +void jit_sse42_1x1_conv_kernel_f32::generate() +{ + preamble(); + + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + if (jcp.with_bias) { + if (jcp.prop_kind == backward_weights) { + sub(rsp, stack_space_needed); + mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + } else + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + } + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + if (jcp.prop_kind == backward_weights) + mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + + auto generate_load_loop_body = [=] (int load_loop_blk) { + generate_bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + add(reg_output_data, + load_loop_blk * jcp.os * jcp.oc_block * sizeof(float)); + break; + case backward_data: + add(reg_output_data, + load_loop_blk * jcp.is * jcp.ic_block * sizeof(float)); + break; + case backward_weights: + for (int i = 0; i < load_loop_blk; i++) + add(reg_output_data, reg_output_stride); + break; + default: + assert(!"invalid prop_kind"); + } + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + Label load_loop_blk_8; + Label load_loop_blk_16; + Label load_loop_blk_24; + Label load_loop_blk_end; + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16, T_NEAR); + + cmp(reg_load_loop_work, 16); + jle(load_loop_blk_16, T_NEAR); + + L(load_loop_blk_24); { + generate_diff_bias_loop(3); + generate_load_loop_body(3); + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16); + cmp(reg_load_loop_work, 24); + jge(load_loop_blk_24); + } + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + L(load_loop_blk_16); { + generate_diff_bias_loop(2); + generate_load_loop_body(2); + cmp(reg_load_loop_work, 16); + jge(load_loop_blk_16); + } + + L(load_loop_blk_8); { + cmp(reg_load_loop_work, 0); + je(load_loop_blk_end, T_NEAR); + generate_diff_bias_loop(1); + generate_load_loop_body(1); + } + + L(load_loop_blk_end); + + if (jcp.with_bias && jcp.prop_kind == backward_weights) + add(rsp, stack_space_needed); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_sse42_1x1_conv_kernel_f32::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(sse42)) + return status::unimplemented; + + // TODO (Roma): this code is duplicated from the generic kernel; maybe the + // configuration struct could do some stuff below + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + const int is_bwd_d = jcp.prop_kind == backward_data; + + format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c; + format_tag_t wei_tag = with_groups + ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o, + gOIhw8o8i) + : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o, + OIhw8o8i); + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + const int simd_w = 4; + jcp.ic_block = jcp.oc_block = simd_w*2; + + args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.ic % jcp.ic_block == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + jcp.ur = 1; + + int load_blocking{ 0 }; + int load_blocking_max{ 0 }; + int bcast_blocking{ 0 }; + int bcast_blocking_max{ 0 }; + int reduce_blocking{ 0 }; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.is * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + load_blocking = 120; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 192; + reduce_blocking = 128; // affects L1$ utilization + } else if (jcp.prop_kind == backward_data) { + jcp.reduce_dim = jcp.oc; + jcp.reduce_block = jcp.oc_block; + + jcp.load_dim = jcp.ic; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.os; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.os * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.ic * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.load_loop_iter_step = jcp.ic_block; + + load_blocking = 96; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 196; + reduce_blocking = 64; // affects L1$ utilization + } else if (jcp.prop_kind == backward_weights) { + jcp.reduce_dim = jcp.os; + jcp.reduce_block = 1; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float); + jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float); + jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float); + + jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + while (true) { + if (load_blocking <= 32) break; + else if (load_blocking % 2 == 0) load_blocking /= 2; + else if (load_blocking % 3 == 0) load_blocking /= 3; + else break; + } + load_blocking *= jcp.load_block; + load_blocking_max = load_blocking; + assert(jcp.load_dim % load_blocking == 0); + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + while (true) { + if (bcast_blocking <= 9) break; + else if (bcast_blocking % 2 == 0) bcast_blocking /= 2; + else if (bcast_blocking % 3 == 0) bcast_blocking /= 3; + else break; + } + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(jcp.bcast_dim % bcast_blocking == 0); + + reduce_blocking = 128; // affects L1$ utilization + } else + return status::unimplemented; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + + assert(jcp.bcast_block % jcp.ur == 0); + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + return status::success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp new file mode 100644 index 000000000000..b314a5098c3c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp @@ -0,0 +1,104 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_SSE42_1x1_CONV_KERNEL_F32_HPP +#define JIT_SSE42_1x1_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_1x1_conv_kernel_f32: public jit_generator { + jit_sse42_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode(); + } + + ~jit_sse42_1x1_conv_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr); + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_1x1_conv_kernel_f32) + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + using xmm_t = const Xbyak::Xmm; + + reg64_t reg_bcast_data = rax; + reg64_t reg_load_data = rsi; + reg64_t reg_output_data = rbx; + reg64_t aux_reg_bcast_data = rdx; + reg64_t aux1_reg_bcast_data = abi_not_param1; + reg64_t aux_reg_load_data = abi_param1; + reg64_t aux_reg_output_data = rbp; + reg64_t reg_load_loop_work = r9; + reg64_t reg_bcast_loop_work = r10; + reg64_t reg_reduce_loop_work = r11; + reg64_t load_loop_iter = r13; + reg64_t imm_addr64 = load_loop_iter; + reg64_t bcast_loop_iter = r14; + reg64_t reduce_loop_iter = r15; + reg64_t reg_reduce_pos_flag = r8; + reg64_t reg_output_stride = r12; + reg64_t reg_bias_data = r12; + reg64_t reg_diff_bias_data = bcast_loop_iter; + + int reg_diff_bias_data_stack_offt = 0; + int stack_space_needed = 8; + + xmm_t reg_bcast = xmm_t(15); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + void generate_bcast_loop(int load_loop_blk); + void generate_reduce_loop(int load_loop_blk, int ur); + void generate_diff_bias_loop(int load_loop_blk); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp new file mode 100644 index 000000000000..30c137641e3c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp @@ -0,0 +1,134 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "jit_sse42_1x1_convolution.hpp" +#include "utils.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define data_blk_off(f, n, c, h, w) \ + ((ndims == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w)) + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +void jit_sse42_1x1_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int ndims = src_d.ndims(); + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + parallel(0, [&](const int ithr, const int nthr) { + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + auto par_conv = jit_1x1_conv_call_s(); + + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + const int os_block = jcp.bcast_block; + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int iwork = start; + while (iwork < end) { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + const int bcast_step_rem = jcp.nb_bcast - osb; + int bcast_step = bcast_step_rem <= jcp.nb_bcast_blocking_max + ? bcast_step_rem : jcp.nb_bcast_blocking; + bcast_step = nstl::min(bcast_step, end - iwork); + + const int os = osb * os_block; + const int ow = os % jcp.ow; + const int oh = os / jcp.ow; + const int iw = nstl::max(ow * jcp.stride_w - jcp.l_pad, 0); + const int ih = nstl::max(oh * jcp.stride_h - jcp.t_pad, 0); + + par_conv.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + + int ocb = 0; + while (ocb < jcp.nb_load) { + const int load_step_rem = jcp.nb_load - ocb; + const int load_step = load_step_rem < jcp.nb_load_blocking_max + ? load_step_rem : jcp.nb_load_blocking; + + const size_t _ocb = g * nb_oc + ocb; + par_conv.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, + load_step * jcp.oc_block); + + const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); + par_conv.output_data = &dst[dst_off]; + + par_conv.bias_data = &bias[_ocb * jcp.oc_block]; + + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + par_conv.first_last_flag = 0 + | (icb == 0) * FLAG_REDUCE_FIRST + | (icb + nb_ic_blocking >= nb_ic) * FLAG_REDUCE_LAST; + + par_conv.reduce_dim = this_block_size(icb * jcp.ic_block, + jcp.ic, nb_ic_blocking * jcp.ic_block); + + const size_t _icb = g * nb_ic + icb; + const size_t src_off = data_blk_off(src_d, n, _icb, ih, iw); + par_conv.bcast_data = &src[src_off]; + + par_conv.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + kernel_->jit_ker(&par_conv); + } + + ocb += load_step; + } + + iwork += bcast_step; + } + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp new file mode 100644 index 000000000000..b32b1e4784da --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp @@ -0,0 +1,96 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_SSE42_1x1_CONVOLUTION_HPP +#define CPU_JIT_SSE42_1x1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "jit_sse42_1x1_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_1x1_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", sse42, ""), + jit_sse42_1x1_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + return jit_sse42_1x1_conv_kernel_f32::init_conf(jcp_, *desc(), + *src_md(), *weights_md(), *dst_md(), *attr()); + } + + jit_1x1_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) + : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + jit_sse42_1x1_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) { + kernel_ = new jit_sse42_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + } + ~jit_sse42_1x1_convolution_fwd_t() { delete kernel_; }; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_sse42_1x1_conv_kernel_f32 *kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp new file mode 100644 index 000000000000..17cabc11866c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp @@ -0,0 +1,497 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "cpu_memory.hpp" + +#include "jit_sse42_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_sse42_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + int iw = jcp.iw; + int ih = jcp.ih; + int kw = jcp.kw; + int kh = jcp.kh; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w)); + int jj_end = ur_w + - nstl::max(0, div_up(ki*dilate_w + pad_r - (kw-1)*dilate_w, stride_w)); + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + int inp_off; + if (one_of(jcp.src_tag, ncw, nchw)) + inp_off = ifm2*ih*iw + (ki*dilate_w + jj*stride_w - pad_l); + else + inp_off = (ki*dilate_w + jj*stride_w - pad_l)*ic_blk + ifm2; + + movss(Xmm(oc_blocks * ur_w + jj + 1), + ptr[aux_reg_input + sizeof(float) * inp_off]); + shufps(Xmm(oc_blocks * ur_w + jj + 1), + Xmm(oc_blocks * ur_w + jj + 1), 0x0); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + int ker_off = ii * nb_ic * kh * kw * ic_blk * oc_blk + + ki * ic_blk * oc_blk + ifm2 * oc_blk; + + for (int jj = jj_start; jj < jj_end; jj++) + { + movups(xmm0, + ptr[aux_reg_kernel + sizeof(float) * ker_off]); + mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1)); + addps(Xmm(ur_w * ii + jj + 1), xmm0); + } + } + } + } +} + +void jit_sse42_conv_fwd_kernel_f32::oh_step_nopad(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + Label kw_loop; + + int iw = jcp.iw; + int ih = jcp.ih; + int kw = jcp.kw; + int kh = jcp.kh; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + xor_(ki_iter, ki_iter); + L(kw_loop); + { + int jj_start = 0; + int jj_end = ur_w; + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + int inp_off; + if (one_of(jcp.src_tag, ncw, nchw)) + inp_off = ifm2 * ih * iw + (jj * stride_w - pad_l); + else + inp_off = (jj * stride_w - pad_l) * ic_blk + ifm2; + + movss(Xmm(oc_blocks * ur_w + jj + 1), + ptr[aux_reg_input + sizeof(float) * inp_off]); + shufps(Xmm(oc_blocks * ur_w + jj + 1), + Xmm(oc_blocks * ur_w + jj + 1), 0x0); + } + for (int ii = 0; ii < oc_blocks; ii++) { + int aux_kernel_offset = ii * nb_ic * kh * kw * ic_blk * oc_blk + + ifm2 * oc_blk; + for (int jj = jj_start; jj < jj_end; jj++) { + movups(xmm0, + ptr[aux_reg_kernel + sizeof(float) * aux_kernel_offset]); + mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1)); + addps(Xmm(ur_w * ii + jj + 1), xmm0); + } + } + } + add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw) ? + dilate_w : ic_blk * dilate_w)); + + inc(ki_iter); + cmp(ki_iter, kw); + jl(kw_loop, T_NEAR); + } +} + +void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + int iw = jcp.iw; + int kw = jcp.kw; + int ow = jcp.ow; + int oh = jcp.oh; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw) + ? dilate_h : ic_blk * dilate_h; + const int inp_off = one_of(jcp.src_tag, ncw, nchw) + ? dilate_w : ic_blk * dilate_w; + + xor_(simd_iter, simd_iter); + + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + Label init_simd_iter_loop; + Label init_done; + Label init_first; + + L(init_simd_iter_loop); + + if (!jcp.with_sum) { + test(reg_ci_flag, FLAG_IC_FIRST); + jne(init_first, T_NEAR); + } + + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + movups(Xmm(ur_w * ii + jj + 1), xword[reg_output + + sizeof(float) * (ii * oh * ow + jj) * oc_blk]); + + if (jcp.with_sum && jcp.with_bias) { + test(reg_ci_flag, FLAG_IC_FIRST); + je(init_done, T_NEAR); + + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + addps(Xmm(ur_w * ii + jj + 1), + xword[reg_bias + sizeof(float) * ii * oc_blk]); + } + + jmp(init_done); + + L(init_first); + if (this->jcp.with_bias) { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + movups(Xmm(ur_w * ii + jj + 1), + xword[reg_bias + sizeof(float) * ii * oc_blk]); + } else { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + pxor(Xmm(ur_w * ii + jj + 1), Xmm(ur_w * ii + jj + 1)); + } + + L(init_done); + + Label skip_kh_loop; + mov(kj, reg_kh); + if ((jcp.dilate_h >= jcp.ih) + || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { + cmp(kj, 0); + je(skip_kh_loop, T_NEAR); + } + Label kh_loop; + L(kh_loop); + { + if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) { + oh_step_nopad(ur_w, pad_l, pad_r, oc_blocks); + sub(aux_reg_input, sizeof(float) * kw * inp_off); + add(aux_reg_input, sizeof(float) * iw * inp_mult); + } else { + oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks); + add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * iw * inp_mult); + } + + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + L(skip_kh_loop); + + if (jcp.with_eltwise) { + Label regular_store; + test(reg_ci_flag, FLAG_IC_LAST); + je(regular_store, T_NEAR); + + eltwise_injector_->compute_vector_range(1, oc_blocks * ur_w + 1); + + L(regular_store); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + const size_t o_off = (ii * oh * ow + jj) * oc_blk; + + Xmm reg_out = Xmm(ur_w * ii + jj + 1); + movups(xword[reg_output + sizeof(float) * o_off], reg_out); + } + } + + mov(aux_reg_kernel, reg_kernel); + mov(aux_reg_input, reg_input); + add(aux_reg_kernel, sizeof(float) * 4); + add(reg_output, sizeof(float) * 4); + add(reg_bias, sizeof(float) * 4); + + inc(simd_iter); + cmp(simd_iter, 2); + jl(init_simd_iter_loop, T_NEAR); + + sub(reg_output, sizeof(float) * 8); + sub(reg_bias, sizeof(float) * 8); +} + +inline void jit_sse42_conv_fwd_kernel_f32::solve_common(int oc_blocks) +{ + int ur_w = jcp.ur_w; + int ur_w_tail = jcp.ur_w_tail; + int n_oi = jcp.ow / ur_w; + int iw = jcp.iw; + int kw = jcp.kw; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + int dilate_w = jcp.dilate_w + 1; + int str_w = jcp.stride_w; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw) ? 1 : ic_blk; + + int l_pad = jcp.l_pad; + int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1)); + int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1); + if (r_pad1 > 0) n_oi--; + + if (l_pad > 0) { + n_oi--; + if (n_oi < 0 && r_pad1 > 0) + width_blk_step(ur_w, l_pad, r_pad1, oc_blocks); // "lrpad" + else + width_blk_step(ur_w, l_pad, 0, oc_blocks); // "lpad" + add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + Label ow_loop; + xor_(oi_iter, oi_iter); + + if (n_oi > 0) { + L(ow_loop); + + width_blk_step(ur_w, 0, 0, oc_blocks); // "middle" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + + inc(oi_iter); + cmp(oi_iter, n_oi); + jl(ow_loop, T_NEAR); + } + + if (r_pad1 > 0 && n_oi >=0) { + width_blk_step(ur_w, 0, r_pad1, oc_blocks); // "rpad" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + if (ur_w_tail != 0) + width_blk_step(ur_w_tail, 0, r_pad, oc_blocks); // "tail" +} + +void jit_sse42_conv_fwd_kernel_f32::generate() +{ + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); + mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); + + int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; + Label tail, exit; + + cmp(reg_oc_blocks, jcp.nb_oc_blocking); + jne(nb_oc_tail ? tail : exit, T_NEAR); + + solve_common(jcp.nb_oc_blocking); + jmp(exit, T_NEAR); + + if (nb_oc_tail) { + L(tail); + cmp(reg_oc_blocks, nb_oc_tail); + jne(exit, T_NEAR); + solve_common(nb_oc_tail); + } + + L(exit); + + this->postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_sse42_conv_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(sse42)) return status::unimplemented; + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[0]; + jcp.dilate_w = cd.dilates[ndims - 3]; + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + + if (ndims == 3) { + jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c); + } + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + const bool flat = jcp.ic == 3; + const bool mimo = !flat; + + bool args_ok = true + && IMPLICATION(flat, one_of(jcp.src_tag, ncw, nwc, nchw, nhwc) + && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o)) + && IMPLICATION(mimo, one_of(jcp.src_tag, nCw8c, nChw8c) + && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o)) + && one_of(jcp.dst_tag, nCw8c, nChw8c); + if (!args_ok) return status::unimplemented; + + const int simd_w = 8; // 2 SSE vectors processing at once + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + jcp.ur_w = 3; + if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */ + + args_ok = true + && jcp.oc % simd_w == 0 + && jcp.l_pad <= jcp.ur_w + && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0) + || (jcp.stride_w == 1 && jcp.stride_h == 1)) + && IMPLICATION(mimo, jcp.ic % simd_w == 0); + if (!args_ok) return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + + // kernel needs 1 temporary YMM register + const int num_avail_regs = 15; + if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) { + /* recalculate ur_w, nb_oc_blocking and ur_w_tail */ + jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail, + nstl::min(jcp.ow, num_avail_regs / 2)); + jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + /* check again ... */ + r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail)) + return status::unimplemented; + } + assert(jcp.nb_oc_blocking > 0); + assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs); + + jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.oc_block = simd_w; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.nb_ic_blocking = 12; + jcp.nb_ic_blocking_max = 16; + } else { + jcp.nb_ic_blocking = 1; + jcp.nb_ic_blocking_max = jcp.nb_ic_blocking; + } + + return status::success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp new file mode 100644 index 000000000000..33c26ef0811d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp @@ -0,0 +1,93 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_SSE42_CONV_KERNEL_F32_HPP +#define JIT_SSE42_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_conv_fwd_kernel_f32: public jit_generator { + jit_sse42_conv_fwd_kernel_f32(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + ~jit_sse42_conv_fwd_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_conv_fwd_kernel_f32) + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + reg64_t reg_input = rax; + reg64_t aux_reg_input = r8; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r9; + reg64_t reg_output = rsi; + reg64_t reg_bias = rbx; + + reg64_t kj = r10; + reg64_t oi_iter = r11; + reg64_t ki_iter = r12; + reg64_t reg_kh = abi_not_param1; + reg64_t simd_iter = r15; + reg64_t reg_oc_blocks = r14; + reg64_t imm_addr64 = reg_oc_blocks; + Xbyak::Reg32 reg_ci_flag = r13d; + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, + int oc_blocks); + inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks); + inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks); + inline void solve_common(int oc_blocks); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp new file mode 100644 index 000000000000..5f77d692f51f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp @@ -0,0 +1,136 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "jit_sse42_convolution.hpp" +#include "mkldnn_thread.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +#define src_blk_off(f, n, c, h, w) \ + (pd()->ndims() == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w) + +#define wht_blk_off_(f, g, ...) \ + pd()->with_groups() \ + ? (f).blk_off(g, __VA_ARGS__) \ + : (f).blk_off(__VA_ARGS__) +#define wht_blk_off(f, g, oc, ic, kh, kw) \ + pd()->ndims() == 3 \ + ? wht_blk_off_(f, g, oc, ic, kw) \ + : wht_blk_off_(f, g, oc, ic, kh, kw) + +void jit_sse42_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = kernel_->jcp; + + int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); + const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.oh; + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{ 0 }, end{ 0 }; + balance211(work_amount, nthr, ithr, start, end); + + int icbb = 0; + while (icbb < jcp.nb_ic) { + int icb_step = jcp.nb_ic_blocking; + int icb_step_rem = jcp.nb_ic - icbb; + if (icb_step_rem < jcp.nb_ic_blocking_max) + icb_step = icb_step_rem; + + size_t n{0}, g{0}, ocbb{0}, oh{0}; + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + oh, jcp.oh); + for (size_t iwork = start; iwork < end; ++iwork) { + int ocb = ocbb * jcp.nb_oc_blocking; + int ocb_num = jcp.nb_oc_blocking; + + for (int icb = icbb; icb < icbb + icb_step; ++icb) { + auto par_conv = jit_conv_call_s(); + + const int ij = oh * jcp.stride_h; + const int i_t_overflow = nstl::max(0, jcp.t_pad - ij); + const int i_b_overflow = nstl::max(jcp.ih, ij + + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih; + + const size_t _oc = g * jcp.nb_oc + ocb; + const size_t _ic = g * jcp.nb_ic + icb; + + const int ih = nstl::max(ij - jcp.t_pad + + div_up(i_t_overflow, + (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0); + par_conv.src = &src[src_blk_off(src_d, n, + jcp.ic == 3 ? 0 : _ic, ih, 0)]; + + par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, oh, 0)]; + + const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); + par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb, + jcp.ic == 3 ? 0 : icb, wh, 0)]; + + if (icb == 0) { + if (bias) + par_conv.bias = + &bias[bias_d.blk_off(_oc * jcp.oc_block)]; + par_conv.flags |= FLAG_IC_FIRST; + } + + if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) { + par_conv.flags |= FLAG_IC_LAST; + } + + par_conv.oc_blocks = + nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb; + + par_conv.kw_padding = 0; + const int kh_padding = jcp.kh + - div_up(i_t_overflow, (jcp.dilate_h + 1)) + - div_up(i_b_overflow, (jcp.dilate_h + 1)); + par_conv.kh_padding = nstl::max(0, kh_padding); + kernel_->jit_ker(&par_conv); + } + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + oh, jcp.oh); + } + icbb += icb_step; + } + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp new file mode 100644 index 000000000000..d2f0a38c5cef --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_SSE42_CONVOLUTION_HPP +#define CPU_JIT_SSE42_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_sse42_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", sse42, ""), + jit_sse42_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + return jit_sse42_conv_fwd_kernel_f32::init_conf(jcp_, *desc(), + *src_md(), *weights_md(), *dst_md(), *attr()); + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + const bool flat = IC() == 3; + auto src_tag = flat + ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) + : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto dst_tag = + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, + gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) + : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, + OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); + + return set_default_formats_common(src_tag, wei_tag, dst_tag); + } + }; + + jit_sse42_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_sse42_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); } + ~jit_sse42_convolution_fwd_t() { delete kernel_; }; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_sse42_conv_fwd_kernel_f32 *kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp new file mode 100644 index 000000000000..0e734f72650f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp @@ -0,0 +1,1192 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "jit_generator.hpp" +#include "cpu_barrier.hpp" + +#include "jit_transpose_src_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +#define GET_OFF(x) offsetof(ctx_t, x) + +struct jit_trans_iw_ic_t: public jit_trans_src_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_t) + + jit_trans_iw_ic_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + using opmask_t = const Xbyak::Opmask; + + enum { typesize = sizeof(float), transpose_size = 16, small_spatial = 14 }; + int src_stride, tr_src_stride; + int tail; + bool enable_prefetch; + + opmask_t k3333 = k1; + opmask_t k5555 = k2; + opmask_t kAAAA = k3; + opmask_t kCCCC = k4; + opmask_t k0F0F = k5; + opmask_t kF0F0 = k6; + opmask_t kTail = k7; + + reg64_t reg_src = r8; + reg64_t reg_tr_src = r9; + reg64_t reg_src_prf = r10; + reg64_t reg_tr_src_prf = r11; + reg64_t reg_loop = r12; + reg64_t reg_tr_src_tmp = r13; + reg32_t regw_tmp = r14d; + + void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); + void generate(); +}; + +void jit_trans_iw_ic_t::transpose(int nrows, int l_pad, int r_pad, + bool nontemporal_stores) { + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 16, "Unsupported transpose size"); + if (!nrows) + return; + + auto pf_src_t0 = [=](int i) { + if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_src, + (transpose_size + i) * src_stride)); + }; + + auto pf_tr_src_t0 = [=](int i) { + int offset = (transpose_size) * typesize + i * tr_src_stride; + if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src, offset)); + if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src, + offset + 64)); + }; + + auto pf_src_t1 = [=](int i) { + if(enable_prefetch) prefetcht1(EVEX_compress_addr(reg_src_prf, + i * src_stride)); + }; + + auto pf_tr_src_t1 = [=](int i) { + if(enable_prefetch) prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, + i * tr_src_stride)); + }; + + auto src_zmm = [=](int i) { + assert(i >= 0 && i < 16); + return Zmm(i); + }; + + auto tmp_zmm = [=](int i) { + assert(i >= 0 && i < 16); + return Zmm(16 + i); + }; + + auto load = [=](int i) { + vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + auto store = [=](Zmm r, int i) { + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + auto padding = [=] (Reg64 reg, int pad) { + kmovw(kTail, (1 << pad) - 1); + auto k = kTail; + auto base = reg; + base.setOpmaskIdx(k.getIdx(), true); + + auto zmm_zero = r; + vpxord(zmm_zero, zmm_zero, zmm_zero); + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + vmovups(addr, zmm_zero); + }; + + mov(reg_tr_src_tmp, reg_tr_src); + if (l_pad > 0) + add(reg_tr_src_tmp, l_pad * typesize); + + if (tail != transpose_size) + kmovw(kTail, (1 << tail) - 1); + + // Xbyak does not allow k0 to be specified explicitly via the '|' + // operator, so we have to do this via a method call (implicitly + // EVEX encoding uses k0 to mean 'no mask') + bool partial_store = nrows < 16; + auto k = partial_store ? kTail : k0; + auto base = reg_tr_src_tmp; + base.setOpmaskIdx(k.getIdx(), true); + + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + if (nontemporal_stores && !partial_store) + vmovntps(addr, r); + else + vmovups(addr, r); + + if (r_pad > 0) { + add(reg_tr_src_tmp, tail * typesize); + padding(reg_tr_src_tmp, r_pad); + } + + if (l_pad > 0) { + padding(reg_tr_src, l_pad); + } + }; + + auto transpose16x8 = [=](int base_idx) { + assert(base_idx == 0 || base_idx == 8); + + // swap 1 + for (int i = 0; i < 4; i++) { + int src_idx0 = base_idx + i * 2; + int src_idx1 = src_idx0 + 1; + + int next_src_idx0 = src_idx0 + 2; + int next_src_idx1 = src_idx1 + 2; + bool load_next = base_idx == 0 || i < 3; + + if (base_idx == 0 && i == 0) { + load(src_idx0); + load(src_idx1); + } + + auto tmp0 = tmp_zmm(src_idx0); + auto tmp1 = tmp_zmm(src_idx1); + auto src0 = src_zmm(src_idx0); + auto src1 = src_zmm(src_idx1); + + if (next_src_idx0 < nrows && load_next) + load(next_src_idx0); + valignd(tmp0, src0, src0, 0x1); + pf_src_t1(base_idx + i); + + if (next_src_idx1 < nrows && load_next) + load(next_src_idx1); + valignd(tmp1, src1, src1, 0xf); + pf_src_t0(base_idx + i); + + vmovaps(src0 | kAAAA, tmp1); + vmovaps(src1 | k5555, tmp0); + } + // swap 2 + for (int i = 0; i < 4; i++) { + int select_half = (i < 2) ? 0 : 2; + int src_idx0 = base_idx + i + select_half + 0; + int src_idx2 = src_idx0 + 2; + + auto tmp0 = tmp_zmm(src_idx0); + auto tmp1 = tmp_zmm(src_idx2); + auto src0 = src_zmm(src_idx0); + auto src2 = src_zmm(src_idx2); + + valignd(tmp0, src0, src0, 0x2); + pf_src_t1(base_idx + 4 + i); + valignd(tmp1, src2, src2, 0xe); + pf_src_t0(base_idx + 4 + i); + vmovaps(src2 | k3333, tmp0); + vmovaps(src0 | kCCCC, tmp1); + } + + // swap 4 + for (int i = 0; i < 4; i++) { + int src_idx0 = base_idx + i; + int src_idx4 = src_idx0 + 4; + + auto tmp0 = tmp_zmm(src_idx0); + auto src0 = src_zmm(src_idx0); + auto src4 = src_zmm(src_idx4); + + vmovaps(tmp0, src0); + vshuff32x4(src0 | kF0F0, src4, src4, 0xb1); + pf_tr_src_t1(base_idx / 2 + i); + vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1); + pf_tr_src_t0(base_idx / 2 + i); + } + }; + + auto fixup16x16 = [=]() { + // swap 8 + for (int i = 0; i < 8; i++) { + auto tmp = tmp_zmm(i); + auto src0 = src_zmm(i); + auto src8 = src_zmm(8 + i); + vshuff64x2(tmp, src0, src8, 0x44); + store(tmp, i); + if (i % 2 == 0) { + pf_tr_src_t1(8 + i / 2); + pf_tr_src_t0(8 + i / 2); + } + } + + for (int i = 0; i < 8; i++) { + auto tmp = tmp_zmm(8 + i); + auto src0 = src_zmm(i); + auto src8 = src_zmm(8 + i); + vshuff64x2(tmp, src0, src8, 0xee); + store(tmp, 8 + i); + if (i % 2 == 0) { + pf_tr_src_t1(12 + i / 2); + pf_tr_src_t0(12 + i / 2); + } + } + }; + + transpose16x8(0); + transpose16x8(8); + fixup16x16(); +} + +void jit_trans_iw_ic_t::generate() { + preamble(); + + const int ic_block = conf_->ic_block; + const int iw = conf_->iw; + const int tr_iw = conf_->tr_iw; + const int transposes = utils::div_up(iw, transpose_size); + int loop_iters = nstl::max(0, transposes - 1); + tail = iw - loop_iters * transpose_size; + + src_stride = ic_block * typesize; + assert(src_stride == 64); + tr_src_stride = tr_iw * typesize; + + bool nontemporal_stores = false; + enable_prefetch = iw > small_spatial ? 1 : 0; + + assert(transpose_size == ic_block); + const int src_step = ic_block * transpose_size * typesize; + const int tr_src_step = ic_block * typesize; + + const int left_pad = conf_->l_pad; + const int right_pad = tr_iw - iw - left_pad; + + mov(reg_src, ptr [param1 + GET_OFF(src)]); + mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); + mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + kmovw(k3333, 0x3333); // 0011001100110011 + kmovw(k5555, 0x5555); // 0101010101010101 + kmovw(kAAAA, 0xaaaa); // 1010101010101010 + kmovw(kCCCC, 0xcccc); // 1100110011001100 + kmovw(k0F0F, 0x0f0f); // 0000111100001111 + kmovw(kF0F0, 0xf0f0); // 1111000011110000 + + if (left_pad > 0 && loop_iters > 0) { + loop_iters--; + transpose(transpose_size, left_pad, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step + left_pad * typesize); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step + left_pad * typesize); + } + + if (loop_iters) { + mov(reg_loop, loop_iters); + Label loop; + L(loop); { + transpose(transpose_size, 0, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, 1); + jnz(loop); + } + } + if (transposes > 1) + transpose(tail, 0, right_pad, nontemporal_stores); + else + transpose(tail, left_pad, right_pad, nontemporal_stores); + + postamble(); +} + +struct jit_trans_iw_ic_int16_t: public jit_trans_src_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_int16_t) + jit_trans_iw_ic_int16_t(const jit_conv_conf_t *conf): + jit_trans_src_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + using opmask_t = const Xbyak::Opmask; + + enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 }; + int src_stride, tr_src_stride; + int tail; + bool enable_prefetch; + + opmask_t kFFFF = k1; + opmask_t k5555 = k2; + opmask_t kAAAA = k3; + opmask_t kAA = k4; + opmask_t k55 = k5; + opmask_t kCC = k6; + opmask_t k33 = k7; + opmask_t kTail = k1; + + reg64_t reg_src = r8; + reg64_t reg_tr_src = r9; + reg64_t reg_src_prf = r10; + reg64_t reg_tr_src_prf = r11; + reg64_t reg_loop = r12; + reg64_t reg_tr_src_tmp = r13; + reg32_t regw_tmp = r14d; + reg64_t imm_addr64 = rbx; + + Xbyak::Zmm vidx1 = zmm31; + Xbyak::Zmm vidx2 = zmm30; + Xbyak::Zmm vidx3 = zmm29; + Xbyak::Zmm vidx4 = zmm28; + Xbyak::Zmm vidx5 = zmm27; + Xbyak::Zmm zmm_tmp = zmm26; + + + void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); + void generate(); +}; + +void jit_trans_iw_ic_int16_t::transpose(int nrows, int l_pad, int r_pad, + bool nontemporal_stores) { + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 16, "Unsupported transpose size"); + if (!nrows) + return; + + auto src_zmm = [=](int i) { + return Zmm(i); + }; + + auto src_ymm = [=](int i) { + assert(i >= 0 && i < 16); + return Ymm(i); + }; + + auto load_ymm = [=](int i) { + vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + auto store = [=](Zmm r, int i) { + + auto padding = [=] (Reg64 reg, int pad) { + kmovw(kTail, (1 << pad) - 1); + auto k = kTail; + auto base = reg; + base.setOpmaskIdx(k.getIdx(), true); + + auto zmm_zero = zmm_tmp; + vpxord(zmm_zero, zmm_zero, zmm_zero); + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + vmovups(addr, zmm_zero); + }; + + int store_tail = (nrows%2) ? nrows+1 : nrows; + + int store_pad = (l_pad%2) ? l_pad/2 + 1 : l_pad/2; + mov(reg_tr_src_tmp, reg_tr_src); + if (l_pad > 0) { + padding(reg_tr_src, store_pad); + add(reg_tr_src_tmp, l_pad * typesize); + } + if (r_pad > 0) { + store_pad = (r_pad%2) ? r_pad/2 + 1 : r_pad/2; + int addr_shift = (r_pad%2) ? 1 : 0; + add(reg_tr_src_tmp, (nrows - addr_shift) * typesize); + padding(reg_tr_src_tmp, store_pad); + } + + mov(reg_tr_src_tmp, reg_tr_src); + add(reg_tr_src_tmp, l_pad * typesize); + + kmovw(kTail, (1 << store_tail/2) - 1); + auto k = kTail; + auto base = reg_tr_src_tmp; + base.setOpmaskIdx(k.getIdx(), true); + + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + vmovups(addr, r); + + }; + + kmovw(kFFFF, 0xffff); + //all loads + for (int i=0; i<16; i++){ + vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); + } + + for (int i = 0; i < nrows/2; i++) { + auto src0 = src_ymm(2*i); + auto src1 = src_ymm(2*i+1); + auto zmm_src0 = src_zmm(2*i); + load_ymm(2*i); + + vpunpcklwd(src1, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vpunpckhwd(src0, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vinserti64x4(zmm_src0, zmm_src0, src1, 1); + vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0); + } + + // for odd numbers we need to mix row with zeroes + if (nrows%2) { + int i = nrows-1; + auto src0 = src_ymm(i); + auto src1 = src_ymm(i+1); //zero + + auto zmm_src0 = src_zmm(i); + vpxor(src1, src1, src1); + + load_ymm(i); + vpunpckhwd(src0, src0, src1); + vinserti64x4(zmm_tmp, zmm_tmp, src0, 0); + vpxor(src0, src0, src0); + load_ymm(i); + vpunpcklwd(src1, src0, src1); + vinserti64x4(zmm_tmp, zmm_tmp, src1, 1); + vpxord(zmm_src0, zmm_src0, zmm_src0); + vmovups(zmm_src0, zmm_tmp); + vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0); + } + + // swap 1 + for (int i=0; i<4; i++) { + auto zmm0 = src_zmm(4*i); + auto zmm1 = src_zmm(4*i+2); + auto tmp0 = src_zmm(4*i+1); + auto tmp1 = src_zmm(4*i+3); + + vmovups(tmp0, zmm0); + vmovups(tmp1, zmm1); + + vpermps(tmp0 | kAAAA, vidx3, zmm1); + vpermps(tmp1 | k5555, vidx3, zmm0); + } + // swap 2 + int base_idx; + base_idx=0; + for (int i=0; i<2; i++) { + auto zmm0 = src_zmm(base_idx+2*i+1); + auto zmm1 = src_zmm(base_idx+2*i+5); + + auto tmp0 = src_zmm(base_idx+2*i); + auto tmp1 = src_zmm(base_idx+2*i+4); + + vmovupd(tmp0, zmm0); + vmovupd(tmp1, zmm1); + + vpermpd(tmp0 | kAA, vidx2, zmm1); + vpermpd(tmp1 | k55, vidx2, zmm0); + } + base_idx=8; + for (int i=0; i<2; i++) { + auto zmm0 = src_zmm(base_idx+2*i+1); + auto zmm1 = src_zmm(base_idx+2*i+5); + + auto tmp0 = src_zmm(base_idx+2*i); + auto tmp1 = src_zmm(base_idx+2*i+4); + + vmovupd(tmp0, zmm0); + vmovupd(tmp1, zmm1); + + vpermpd(tmp0 | kAA, vidx2, zmm1); + vpermpd(tmp1 | k55, vidx2, zmm0); + } + + // swap 3 + for (int i=0; i<4; i++) { + auto zmm0 = src_zmm(2*i); + auto zmm1 = src_zmm(2*i+8); + + auto tmp0 = src_zmm(2*i+1); + auto tmp1 = src_zmm(2*i+9); + + vmovupd(tmp0, zmm0); + vmovupd(tmp1, zmm1); + + vpermpd(tmp0 | kCC, vidx1, zmm1); + vpermpd(tmp1 | k33, vidx1, zmm0); + } + + // all stores + for (int i=0; i<8; i++) + vextracti64x4(src_ymm(2*i), src_zmm(2*i+1), 1); + + store(src_zmm(1), 0); + store(src_zmm(0), 1); + store(src_zmm(3), 2); + store(src_zmm(2), 3); + store(src_zmm(9), 4); + store(src_zmm(8), 5); + store(src_zmm(11), 6); + store(src_zmm(10), 7); + store(src_zmm(5), 8); + store(src_zmm(4), 9); + store(src_zmm(7), 10); + store(src_zmm(6), 11); + store(src_zmm(13), 12); + store(src_zmm(12), 13); + store(src_zmm(15), 14); + store(src_zmm(14), 15); + +} + +void jit_trans_iw_ic_int16_t::generate() { + preamble(); + + alignas(64) static constexpr const int64_t idx1[8] + = { 2, 3, 0, 1, 6, 7, 4, 5 }; + alignas(64) static constexpr const int64_t idx2[8] + = { 1, 0, 3, 2, 5, 4, 7, 6 }; + alignas(64) static constexpr const int32_t idx3[16] + = { 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14 }; + alignas(64) static constexpr const int32_t idx4[16] + = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 }; + alignas(64) static constexpr const int32_t idx5[16] + = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 }; + + const int ic_block = conf_->ic_block; + const int iw = conf_->iw; + const int tr_iw = conf_->tr_iw; + const int transposes = utils::div_up(iw, transpose_size); + int loop_iters = nstl::max(0, transposes - 1); + tail = iw - loop_iters * transpose_size; + + src_stride = ic_block * typesize; + tr_src_stride = tr_iw * typesize; + + bool nontemporal_stores = false; + enable_prefetch = iw > small_spatial ? 1 : 0; + + assert(transpose_size == ic_block); + const int src_step = ic_block * transpose_size * typesize; + const int tr_src_step = ic_block * typesize; + + const int left_pad = conf_->l_pad; + const int right_pad = tr_iw - iw - left_pad; + + mov(reg_src, ptr [param1 + GET_OFF(src)]); + mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); + mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + kmovw(kFFFF, 0xffff); + kmovw(k5555, 0x5555); + kmovw(kAAAA, 0xaaaa); + kmovw(kAA, 0xaa); + kmovw(k55, 0x55); + kmovw(kCC, 0xcc); + kmovw(k33, 0x33); + + auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa64(z, ptr[imm_addr64]); + }; + + auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa32(z, ptr[imm_addr64]); + }; + + vmovdqa64(vidx1, idx1); + vmovdqa64(vidx2, idx2); + vmovdqa32(vidx3, idx3); + vmovdqa32(vidx4, idx4); + vmovdqa32(vidx5, idx5); + + if (left_pad > 0 && loop_iters > 0) { + loop_iters--; + transpose(transpose_size, left_pad, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step + left_pad * typesize); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step + left_pad * typesize); + } + + if (loop_iters) { + mov(reg_loop, loop_iters); + Label loop; + L(loop); { + transpose(transpose_size, 0, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, 1); + jnz(loop); + } + } + if (transposes > 1) + transpose(tail, 0, right_pad, nontemporal_stores); + else + transpose(tail, left_pad, right_pad, nontemporal_stores); + + postamble(); + +} + +struct jit_trans_ow_oc_t: public jit_trans_dst_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_ow_oc_t) + jit_trans_ow_oc_t(const jit_conv_conf_t *conf): jit_trans_dst_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + using opmask_t = const Xbyak::Opmask; + using zmm = const Xbyak::Zmm; + + enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 }; + int src_stride, tr_src_stride; + int tail; + bool enable_prefetch; + + opmask_t kFF = k1; + + zmm vidx1 = zmm31; + + reg64_t reg_src = r8; + reg64_t reg_tr_src = r9; + reg64_t reg_src_prf = r10; + reg64_t reg_tr_src_prf = r11; + reg64_t reg_loop = r12; + reg64_t reg_tr_src_tmp = r13; + reg32_t regw_tmp = r14d; + reg64_t imm_addr64 = rbx; + + void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); + void generate(); +}; + +void jit_trans_ow_oc_t::transpose(int nrows, int l_pad, int r_pad, + bool nontemporal_stores) { + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 16, "Unsupported transpose size"); + if (!nrows) + return; + + auto src_zmm = [=](int i) { + return Zmm(i); + }; + + auto src_ymm = [=](int i) { + assert(i >= 0 && i < 16); + return Ymm(i); + }; + + auto load_ymm = [=](int i) { + vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + + auto store = [=](Zmm r, int i) { + auto addr = EVEX_compress_addr(reg_tr_src, i * tr_src_stride); + if (nontemporal_stores) + vmovntps(addr, r); + else + vmovups(addr, r); + }; + + for (int i = 0; i < nrows/2; i++) { + auto src0 = src_ymm(2*i); + auto src1 = src_ymm(2*i+1); + auto zmm_src0 = src_zmm(2*i); + load_ymm(2*i); + vpunpcklwd(src1, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vpunpckhwd(src0, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vinserti64x4(zmm_src0, zmm_src0, src1, 1); + vpermpd(zmm_src0 | kFF, vidx1, zmm_src0); + store(zmm_src0, 2*i); + } + if (r_pad > 0) { + auto src0 = src_ymm(nrows-1); + auto src1 = src_ymm(nrows); + auto zmm_src0 = src_zmm(30); + load_ymm(nrows-1); + + vpxor(src1, src1, src1); + vpunpckhwd(src1, src0, src1); + vinserti64x4(zmm_src0, zmm_src0, src1, 0); + vpxor(src1, src1, src1); + vpunpcklwd(src0, src0, src1); + vinserti64x4(zmm_src0, zmm_src0, src0, 1); + vpermpd(zmm_src0 | kFF, vidx1, zmm_src0); + store(zmm_src0, nrows-1); + } +} + +void jit_trans_ow_oc_t::generate() { + preamble(); + + alignas(64) static constexpr const int64_t idx1[8] + = { 4, 5, 0, 1, 6, 7, 2, 3 }; + + const int oc_block = conf_->oc_block; + const int ow = conf_->ow; + const int transposes = utils::div_up(ow, transpose_size); + int loop_iters = nstl::max(0, transposes - 1); + tail = ow - loop_iters * transpose_size; + + src_stride = oc_block * typesize; + tr_src_stride = oc_block * typesize; + + bool nontemporal_stores = false; + enable_prefetch = ow > small_spatial ? 1 : 0; + + const int src_step = oc_block * transpose_size * typesize; + const int tr_src_step = oc_block * transpose_size * typesize; + const int right_pad = ow % 2; + + mov(reg_src, ptr [param1 + GET_OFF(src)]); + mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); + mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + kmovw(kFF, 0xFF); + + auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa64(z, ptr[imm_addr64]); + }; + + vmovdqa64(vidx1, idx1); + if (loop_iters) { + mov(reg_loop, loop_iters); + Label loop; + L(loop); { + transpose(transpose_size, 0, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, 1); + jnz(loop); + } + } + transpose(tail, 0, right_pad, nontemporal_stores); + + postamble(); +} + +struct jit_trans_iw_x4_4x_t: public jit_trans_src_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_x4_4x_t) + + jit_trans_iw_x4_4x_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + + void generate(); + enum { typesize = (int)sizeof(float) }; +}; + +/** @brief transposition of the form [:][iw/4][4] -> [:][4][iw/4] + * required for 1st 4fma backward by weights convolution */ +void jit_trans_iw_x4_4x_t::generate() { + using namespace utils; + + /* TODO: put into code */ + static int mask[16] = { + 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, }; + + const auto &c = *conf_; + const int simd_w = cpu_isa_traits::vlen / typesize; + const int niters = c.tr_ld / simd_w; + + assert(niters <= 4); /* [bwd_w:tr_src:r1] */ + + Reg64 reg_ptr_src = r8; + Reg64 reg_ptr_tr_src = r9; + + Reg64 reg_ih = rax; + Reg64 reg_ih_end = rbx; + + Reg64 reg_nthr_oc_b = rsi; + Reg64 reg_ptr_tr_src_bctx = abi_not_param1; + + Reg64 reg_tmp = rdx; + + Zmm vmsk = Zmm(31); + Opmask kmsk = k7; + + auto emit_tr_sync = [&]() { + simple_barrier::generate(*this, reg_ptr_tr_src_bctx, reg_nthr_oc_b); + }; + + auto emit_tr_iw = [&]() { + auto vreg = [](int iter, int i) { + assert(4 * iter + i < 24); + return Zmm(4 * iter + i); + }; + auto vtmp = [](int i) { return Zmm(24 + i); }; + + auto emit_load = [&](int iter) { + for (int i = 0; i < 4; ++i) { + auto v = vreg(iter, i); + const int off = (iter * 4 + i) * simd_w; + + if (off + simd_w <= c.iw) + vmovups(v, ptr[reg_ptr_src + off * typesize]); + else if (off < c.iw) + vmovups(v | kmsk | T_z, ptr[reg_ptr_src + off * typesize]); + else + vpxord(v, v, v); + } + }; + + auto emit_tr = [&](int iter) { + for (int i = 0; i < 4; ++i) + vpermps(vreg(iter, i), vmsk, vreg(iter, i)); + + vshuff32x4(vtmp(0), vreg(iter, 0), vreg(iter, 1), 0x88); + vshuff32x4(vtmp(1), vreg(iter, 0), vreg(iter, 1), 0xdd); + vshuff32x4(vtmp(2), vreg(iter, 2), vreg(iter, 3), 0x88); + vshuff32x4(vtmp(3), vreg(iter, 2), vreg(iter, 3), 0xdd); + + vshuff32x4(vreg(iter, 0), vtmp(0), vtmp(2), 0x88); + vshuff32x4(vreg(iter, 2), vtmp(0), vtmp(2), 0xdd); + vshuff32x4(vreg(iter, 1), vtmp(1), vtmp(3), 0x88); + vshuff32x4(vreg(iter, 3), vtmp(1), vtmp(3), 0xdd); + }; + + auto emit_store = [&]() { + for (int i = 0; i < 4; ++i) { + for (int iter = 0; iter < niters; ++iter) { + const size_t off = i * c.tr_ld + iter * simd_w; + vmovups(ptr[reg_ptr_tr_src + off * typesize], vreg(iter, i)); + } + } + }; + + for (int iter = 0; iter < niters; ++iter) + emit_load(iter); + + for (int iter = 0; iter < niters; ++iter) + emit_tr(iter); + + emit_store(); + }; + + preamble(); + + mov(reg_ptr_src, ptr[abi_param1 + GET_OFF(src)]); + mov(reg_ptr_tr_src, ptr[abi_param1 + GET_OFF(tr_src)]); + + mov(reg_nthr_oc_b.cvt32(), ptr[abi_param1 + GET_OFF(nthr_oc_b)]); + mov(reg_ih.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_start)]); + mov(reg_ih_end.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_end)]); + mov(reg_ptr_tr_src_bctx, ptr[abi_param1 + GET_OFF(tr_src_bctx)]); + + emit_tr_sync(); + + Label l_ih_loop, l_tr_done; + cmp(reg_ih, reg_ih_end); + je(l_tr_done, T_NEAR); + + mov(reg_tmp, (size_t)&mask[0]); + vmovups(vmsk, ptr[reg_tmp]); + + if (c.iw % simd_w) { + const char load_mask = (1 << (c.iw % simd_w)) - 1; + mov(reg_tmp, load_mask); + kmovw(kmsk, reg_tmp.cvt32()); + } + + /* src += ih_start * c.iw; */ + imul(reg_tmp, reg_ih, c.iw * typesize); + add(reg_ptr_src, reg_tmp); + /* tr_src += ih_start * c.stride_w * c.tr_ld; */ + imul(reg_tmp, reg_ih, c.stride_w * c.tr_ld * typesize); + add(reg_ptr_tr_src, reg_tmp); + + L(l_ih_loop); { + emit_tr_iw(); + + add(reg_ptr_src, c.iw * typesize); + add(reg_ptr_tr_src, c.stride_w * c.tr_ld * typesize); + + inc(reg_ih); + cmp(reg_ih, reg_ih_end); + jl(l_ih_loop, T_NEAR); + } + + L(l_tr_done); + + emit_tr_sync(); + + postamble(); +} + +/* +// ------------------------------------------------- +// jit_transpose4x16_src +// ------------------------------------------------- +*/ + +void jit_transpose4x16_src::transpose(int nrows) +{ + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 4, "Unsupported transpose size"); + if (!nrows) + return; + + auto pf_src_t0 = [=](int i) { + if (tparams->src_pf0_distance) + prefetcht0(EVEX_compress_addr( + reg_src, (tparams->src_pf0_distance + i) * src_stride)); + }; + + auto pf_tr_src_t0 = [=](int i) { + if (tparams->tr_src_pf0_distance) + prefetcht0(EVEX_compress_addr(reg_tr_src, + (tparams->tr_src_pf0_distance + i) * src_stride)); + }; + + auto pf_src_t1 = [=](int i) { + if (tparams->src_pf1) + prefetcht1(EVEX_compress_addr(reg_src_prf, i * src_stride)); + }; + + auto pf_tr_src_t1 = [=](int i) { + if (tparams->tr_src_pf1) + prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, i * tr_src_stride)); + }; + + auto src_zmm = [=](int i) { + assert(i >= 0 && i < 4); + return Zmm(i); + }; + + auto tmp_zmm = [=](int i) { + assert(i >= 0 && i < 4); + return Zmm(4 + i); + }; + + auto load = [=](int i) { + vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + auto store = [=](Zmm r, int i) { + vmovups(EVEX_compress_addr(reg_tr_src, i * tr_src_stride), r); + }; + + auto tmp0 = tmp_zmm(0); + auto tmp1 = tmp_zmm(1); + auto tmp2 = tmp_zmm(2); + auto tmp3 = tmp_zmm(3); + + auto src0 = src_zmm(0); + auto src1 = src_zmm(1); + auto src2 = src_zmm(2); + auto src3 = src_zmm(3); + for (int i = 0; i < nrows; i++) { + load(i); + } + + for (size_t i = nrows; i < 4; i++) { + vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); + } + + vmovupd(tmp0, src0); + vmovupd(tmp1, src1); + pf_src_t0(0); + vpermpd(tmp0 | kF0, vidx01, src2); + vpermpd(tmp1 | kF0, vidx01, src3); + + valignd(src0, src0, src0, 8); + valignd(src1, src1, src1, 8); + pf_src_t0(1); + vmovupd(tmp2, src0); + vmovupd(tmp3, src1); + pf_src_t0(2); + vpermpd(tmp2 | kF0, vidx10, src2); + vpermpd(tmp3 | kF0, vidx10, src3); + pf_src_t0(3); + + vmovupd(src0, tmp0); + pf_src_t1(0); + vmovupd(src1, tmp2); + pf_src_t1(1); + vmovupd(src2, tmp1); + pf_src_t1(2); + vmovupd(src3, tmp3); + pf_src_t1(3); + vpermpd(src0 | kCC, vidx1, tmp1); + vpermpd(src1 | kCC, vidx1, tmp3); + pf_tr_src_t0(0); + vpermpd(src2 | k33, vidx1, tmp0); + vpermpd(src3 | k33, vidx1, tmp2); + pf_tr_src_t0(1); + + vmovupd(tmp0, src0); + vmovupd(tmp1, src2); + pf_tr_src_t0(2); + vmovupd(tmp2, src1); + vmovupd(tmp3, src3); + pf_tr_src_t0(3); + vpermps(tmp0 | kFFFF, vidxP, src0); + pf_tr_src_t1(0); + vpermps(tmp1 | kFFFF, vidxP, src2); + pf_tr_src_t1(1); + vpermps(tmp2 | kFFFF, vidxP, src1); + pf_tr_src_t1(3); + vpermps(tmp3 | kFFFF, vidxP, src3); + pf_tr_src_t1(4); + + store(tmp0, 0); + store(tmp1, 1); + store(tmp2, 2); + store(tmp3, 3); +} + +alignas(64) static constexpr const int64_t idx01[8] + = { 0, 0, 0, 0, 0, 1, 2, 3 }; +alignas(64) static constexpr const int64_t idx10[8] + = { 0, 0, 0, 0, 4, 5, 6, 7 }; +alignas(64) static constexpr const int64_t idx1[8] = { 2, 3, 0, 1, 6, 7, 4, 5 }; +alignas(64) static constexpr const int32_t idxP[16] + = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; + +void jit_transpose4x16_src::generate() +{ + preamble(); + + const int ic_block = params->ic_block; + const int is = params->is; + int tail = is % transpose_size; + + src_stride = ic_block * typesize; + assert(src_stride == 64); + tr_src_stride = ic_block * typesize; + + const int src_step = ic_block * transpose_size * typesize; + const int tr_src_step = ic_block * transpose_size * typesize; + +#define GET_TR_OFF(x) offsetof(jit_src_transpose_s, x) + mov(reg_loop, ptr[param1 + GET_TR_OFF(size)]); + mov(reg_src, ptr[param1 + GET_TR_OFF(src)]); + mov(reg_tr_src, ptr[param1 + GET_TR_OFF(tr_src)]); + mov(reg_src_prf, ptr[param1 + GET_TR_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr[param1 + GET_TR_OFF(tr_src_prf)]); +#undef GET_TR_OFF + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa64(z, ptr[imm_addr64]); + }; + + auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa32(z, ptr[imm_addr64]); + }; + + kmovw(kF0, 0xf0); // 11110000 + kmovw(kCC, 0xcc); // 11001100 + kmovw(k33, 0x33); // 00110011 + kmovw(kFFFF, 0xffff); // 1111111111111111 + + vmovdqa64(vidx01, idx01); + vmovdqa64(vidx10, idx10); + vmovdqa64(vidx1, idx1); + vmovdqa32(vidxP, idxP); + + Label loop_label; + Label tail_label; + + cmp(reg_loop, transpose_size); + jl(tail_label, T_NEAR); + + L(loop_label); + { + transpose(transpose_size); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, transpose_size); + cmp(reg_loop, transpose_size); + jge(loop_label, T_NEAR); + } + L(tail_label); + transpose(tail); + + postamble(); +} + +jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf) { + if (conf->ver == ver_4fma && !conf->is_1stconv) + return new jit_trans_iw_ic_t(conf); + if (conf->ver == ver_4fma && conf->is_1stconv) + return new jit_trans_iw_x4_4x_t(conf); + assert(!"unsupported configuration"); + return nullptr; +} + +jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf) { + assert(!"unsupported configuration"); + return nullptr; +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp new file mode 100644 index 000000000000..565e97e4fc5e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp @@ -0,0 +1,145 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_TRANSPOSE_SRC_HPP +#define CPU_JIT_TRANSPOSE_SRC_HPP + +#include "cpu_barrier.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_trans_src_t { + struct ctx_t { + const void *src; + const void *tr_src; + const void *src_prf; + const void *tr_src_prf; + + /* 1st conv 4fma: backward by weights */ + int nthr_oc_b; /* number of threads process given src image */ + int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ + simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ + }; + + jit_trans_src_t(const jit_conv_conf_t *conf) + : conf_(conf), ker_(nullptr) {} + virtual ~jit_trans_src_t() {} + + void operator()(const ctx_t *ctx) + { assert(ker_); ker_(ctx); } + + const jit_conv_conf_t *conf_; + void (*ker_)(const ctx_t *); +}; + +struct jit_src_transpose_s { + size_t size; + const void *src; + const void *tr_src; + const void *src_prf; + const void *tr_src_prf; +}; + +struct jit_trans_dst_t { + struct ctx_t { + const void *src; + const void *tr_src; + const void *src_prf; + const void *tr_src_prf; + + /* 1st conv 4fma: backward by weights */ + int nthr_oc_b; /* number of threads process given src image */ + int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ + simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ + }; + + jit_trans_dst_t(const jit_conv_conf_t *conf) + : conf_(conf), ker_(nullptr) {} + virtual ~jit_trans_dst_t() {} + + void operator()(const ctx_t *ctx) + { assert(ker_); ker_(ctx); } + + const jit_conv_conf_t *conf_; + void (*ker_)(const ctx_t *); +}; + +struct jit_transpose4x16_src_t { + int src_pf0_distance; + int tr_src_pf0_distance; + bool src_pf1; + bool tr_src_pf1; +}; + +struct jit_transpose4x16_src : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_transpose4x16_src) + + jit_transpose4x16_src(const jit_1x1_conv_conf_t *aparams, + jit_transpose4x16_src_t *tparams_) + : params(aparams), tparams(tparams_) + { + this->generate(); + jit_ker = (decltype(jit_ker))this->getCode(); + } + + const jit_1x1_conv_conf_t *params; + const jit_transpose4x16_src_t *tparams; + void (*jit_ker)(jit_src_transpose_s *); + + void operator()(jit_src_transpose_s *arg) { jit_ker(arg); } + + static const int transpose_size = 4; +private: + static const int typesize = sizeof(float); + + int src_stride, tr_src_stride; + + Xbyak::Reg64 imm_addr64 = rbx; + + Xbyak::Opmask kF0 = k1; + Xbyak::Opmask kCC = k2; + Xbyak::Opmask k33 = k3; + Xbyak::Opmask kFFFF = k4; + + Xbyak::Zmm vidx01 = zmm31; + Xbyak::Zmm vidx10 = zmm30; + Xbyak::Zmm vidx1 = zmm29; + Xbyak::Zmm vidxP = zmm28; + + Xbyak::Reg64 reg_src = r8; + Xbyak::Reg64 reg_tr_src = r9; + Xbyak::Reg64 reg_src_prf = r10; + Xbyak::Reg64 reg_tr_src_prf = r11; + Xbyak::Reg64 reg_loop = r12; + Xbyak::Reg64 reg_tr_src_tmp = r13; + Xbyak::Reg32 regw_tmp = r14d; + + void transpose_block(int ur, int nrows); + void transpose(int nrows); + void generate(); +}; + +jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf); +jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp new file mode 100644 index 000000000000..53313f9f01b8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp @@ -0,0 +1,327 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_1x1_CONV_UTILS_HPP +#define JIT_UNI_1x1_CONV_UTILS_HPP + +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; + +struct reduce_to_unit_stride_t { + convolution_desc_t conv_d_; + bool reduce_src_; + size_t space_per_thread_; +}; + +/* 1x1-kernel does not support non-unit strides so far, so the idea is: + * - for fwd or bwd_weights: to copy src to a scratch memory (with strides + * equal to 1) and then call the kernel + * - for bwd_data: reduce the problem to the one with unit stride by + * performing computations in a scratch memory (with strides equal to 1) + * and then copy the result to diff_src */ +template +inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d, + const memory_desc_t *&src_d, const memory_desc_t *dst_d) { + const bool is_bwd_data = self->desc()->prop_kind + == prop_kind::backward_data; + + const int ndims = src_d->ndims; + const auto dat_tag = ndims == 3 + ? memory_desc_wrapper(dst_d).matches_one_of_tag( + format_tag::nCw8c, format_tag::nCw16c) + : memory_desc_wrapper(dst_d).matches_one_of_tag( + format_tag::nChw8c, format_tag::nChw16c); + + bool rtus_applicable = true + && utils::pick(ndims - 3, + (conv_d->strides[0] != 1 && !one_of(conv_d->src_desc.data_type, + data_type::s32)), + (conv_d->strides[0] != 1 || conv_d->strides[1] != 1)) + && dat_tag != format_tag::undef; + for (int d = 2; d < ndims; ++d) { + /* TODO: relax these conditions (by improving reducer) */ + rtus_applicable = rtus_applicable + && conv_d->padding[0][d - 2] == 0 + && dst_d->dims[d] * conv_d->strides[d - 2] == src_d->dims[d]; + } + + if (rtus_applicable) { + self->rtus_.reduce_src_ = true; + conv_d = &(self->rtus_.conv_d_ = *conv_d); + self->rtus_.conv_d_.strides[0] = 1; + if (ndims == 4) + self->rtus_.conv_d_.strides[1] = 1; + utils::array_set(self->rtus_.conv_d_.padding[0], 0, 2); + if (ndims == 4) + utils::array_set(self->rtus_.conv_d_.padding[1], 0, 2); + const int ic = src_d->dims[1]; + if (is_bwd_data) { + src_d = &(self->rtus_.conv_d_.diff_src_desc = *dst_d); + self->rtus_.conv_d_.diff_src_desc.dims[1] = ic; + memory_desc_wrapper::compute_blocking( + self->rtus_.conv_d_.diff_src_desc, dat_tag); + } else { + data_type_t data_type = self->rtus_.conv_d_.src_desc.data_type; + src_d = &(self->rtus_.conv_d_.src_desc = *dst_d); + self->rtus_.conv_d_.src_desc.dims[1] = ic; + self->rtus_.conv_d_.src_desc.data_type = data_type; + memory_desc_wrapper::compute_blocking( + self->rtus_.conv_d_.src_desc, dat_tag); + } + } +} + +template +inline void rtus_prepare_space_info(conv_pd_t *self, + memory_tracking::registrar_t &scratchpad) { + const auto &jcp = self->jcp_; + + const int max_threads = mkldnn_get_max_threads(); + const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind, + jcp.nb_reduce, jcp.nb_load_blocking_max, jcp.nb_bcast_blocking); + size_t typesize = types::data_type_size( + conv_prop_invariant_src_d(self->desc())->data_type); + + self->rtus_.space_per_thread_ = factor * jcp.is * jcp.ic_block; + scratchpad.book(memory_tracking::names::key_conv_rtus_space, + typesize * max_threads * self->rtus_.space_per_thread_); +} + +template +struct rtus_driver_t: public jit_generator { + + struct call_params_t { + const void *ws; /* reduced image (w/ strides = 1) */ + const void *src; /* source image (w/ non-unit strides) */ + size_t icb; + size_t os; + size_t iw_start; + }; + + void (*ker_)(const call_params_t *p); + + DECLARE_CPU_JIT_AUX_FUNCTIONS(rtus_driver_t) + + /* cpu specific part */ + using Vmm = typename utils::conditional::type; + + Xbyak::Reg64 reg_ws = abi_param1; + Xbyak::Reg64 reg_src = abi_not_param1; + Xbyak::Reg64 reg_icb = rdx; + Xbyak::Reg64 reg_os = r11; + Xbyak::Reg64 reg_iw_start = r8; + + Xbyak::Reg64 reg_cur_os = rax; + Xbyak::Reg64 reg_cur_iw = r9; + Xbyak::Reg64 reg_cur_src = r10; + + int iw_, stride_w_; + int src_step_h_, src_step_icb_, ws_step_icb_, vlen_, vlen_shift_; + bool src_to_ws_; + size_t typesize_; + Vmm reg_zero; + Vmm reg_v; + + rtus_driver_t(int iw, int stride_w, int src_step_h, + int src_step_icb, int ws_step_icb, bool src_to_ws, size_t typesize) + : iw_(iw), stride_w_(stride_w), src_step_h_(src_step_h) + , src_step_icb_(src_step_icb), ws_step_icb_(ws_step_icb) + , src_to_ws_(src_to_ws), typesize_(typesize) + { + using namespace Xbyak; + vlen_ = cpu_isa_traits::vlen; + vlen_shift_ = cpu_isa_traits::vlen_shift; + if (typesize_ == 2) { + vlen_ /= 2; + vlen_shift_--; + } + + reg_zero = Vmm(0); + reg_v = Vmm(1); + + generate(); + } + + void loop_is() { + using namespace Xbyak; + + mov(reg_cur_src, reg_src); + mov(reg_cur_iw, reg_iw_start); + mov(reg_cur_os, reg_os); + + Label is_loop, skip_h_step; + L(is_loop); + + if (src_to_ws_) { + vmovups(reg_v, ptr[reg_cur_src]); + vmovups(ptr[reg_ws], reg_v); + } else { + vmovups(reg_v, ptr[reg_ws]); + vmovups(ptr[reg_cur_src], reg_v); + for (int w = 1; w < stride_w_; ++w) + vmovups(ptr[reg_cur_src + w * vlen_], reg_zero); + } + + add(reg_ws, vlen_); + + add(reg_cur_iw, stride_w_); + add(reg_cur_src, stride_w_ * vlen_); + + cmp(reg_cur_iw, iw_); + jl(skip_h_step); + /* for 1d convolution the loop over h should be skipped */ + if (src_step_icb_ == iw_) jmp(skip_h_step); + + if (src_to_ws_) { + add(reg_cur_src, (src_step_h_ - iw_) * vlen_); + } else { + Xbyak::Reg64 reg_cur_src_fin = reg_cur_iw; /* just reuse */ + mov(reg_cur_src_fin, reg_cur_src); + add(reg_cur_src_fin, (src_step_h_ - iw_) * vlen_); + Label ih_loop; + L(ih_loop); + + for (int w = 0; w < stride_w_; ++w) + vmovups(ptr[reg_cur_src + w * vlen_], reg_zero); + + add(reg_cur_src, stride_w_ * vlen_); + cmp(reg_cur_src, reg_cur_src_fin); + jl(ih_loop); + } + xor_(reg_cur_iw, reg_cur_iw); + + L(skip_h_step); + + sub(reg_cur_os, vlen_); + jnz(is_loop); + + /* restore dst */ + sub(reg_ws, reg_os); + } + + void generate() { + using namespace Xbyak; + assert(isa == avx2 || isa == avx512_common + || isa == avx512_core || isa == avx512_mic); + +#if defined(_WIN32) + assert(reg_src == abi_not_param1 && abi_not_param1 == rdi); + push(rdi); +#endif + +#define READ_PARAM(what) \ + mov(reg_ ## what, ptr[abi_param1 + offsetof(call_params_t, what)]) + READ_PARAM(src); + READ_PARAM(icb); + READ_PARAM(os); + READ_PARAM(iw_start); + + assert(reg_ws == abi_param1); + READ_PARAM(ws); /* reg_ws should always be read the last */ +#undef READ_PARAM + + shl(reg_os, vlen_shift_); + + if (!src_to_ws_) + uni_vpxor(reg_zero, reg_zero, reg_zero); + + Label icb_loop; + L(icb_loop); + + loop_is(); + + add(reg_ws, ws_step_icb_ * vlen_); + add(reg_src, src_step_icb_ * vlen_); + + dec(reg_icb); + jnz(icb_loop, T_NEAR); + +#if defined(_WIN32) + pop(rdi); +#endif + + uni_vzeroupper(); + ret(); + this->ker_ = reinterpret_cast(const_cast( + this->getCode())); + } +}; + +template +inline void init_rtus_driver(conv_t *self) { + const auto &conf = *self->pd(); + if (!conf.rtus_.reduce_src_) return; + + const auto &cd = *conf.desc(); + const int ndims = conf.ndims(); + const int stride_h = (conf.ndims() == 3) ? 1 : cd.strides[0]; + const int stride_w = cd.strides[ndims - 3]; + + const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data; + const auto &src_d = is_bwd_data ? *conf.diff_src_md() : *conf.src_md(); + + const int ih = ndims == 3 ? 1 : src_d.dims[2]; + const int iw = src_d.dims[ndims - 1]; + + const int src_step_h = stride_h * iw; + const int src_step_icb = ih * iw; + const int ws_step_icb = conf.jcp_.is; + const bool src_to_ws = !is_bwd_data; + const size_t typesize = types::data_type_size( + conv_prop_invariant_src_d(self->pd()->desc())->data_type); + + self->rtus_driver_ = new rtus_driver_t(iw, stride_w, src_step_h, + src_step_icb, ws_step_icb, src_to_ws, typesize); +} + +inline int best_divider(int value, int min_divider, int max_divider, + bool find_max, int step = 1) +{ + max_divider = nstl::max(1, nstl::min(max_divider, value)); + min_divider = nstl::max(1, nstl::min(min_divider, max_divider)); + + auto loss_ratio = [](int total, int chunk) + { return float(rnd_up(total, chunk) - total) / rnd_up(total, chunk); }; + + float min_loss = FLT_MAX; + int x_divider = max_divider; + for (int divider = max_divider; divider >= min_divider; divider -= step) { + const float loss = loss_ratio(value, divider); + if ((find_max && loss < min_loss) || (!find_max && loss <= min_loss)) { + min_loss = loss; + x_divider = divider; + } + } + return x_divider; +} + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp new file mode 100644 index 000000000000..72fe3a810971 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp @@ -0,0 +1,1407 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_barrier.hpp" +#include "cpu_batch_normalization_utils.hpp" +#include "jit_generator.hpp" + +#include "jit_uni_batch_normalization.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { + +using namespace memory_tracking::names; + +using namespace Xbyak; +namespace barrier = simple_barrier; + +typedef float data_t; + +template +struct jit_bnorm_t: public jit_generator { + struct call_params_t { + // keep all sizes at 8 bytes -- jit code expects this + size_t N_ithr, N_nthr; + size_t coff_max, soff_max; + size_t mb_stride_Bc, spat_size, spat_size_loc; + size_t S_s, S_tail; + size_t is_cblk_tail; + data_t chan_size, eps, one; + const data_t *scale_shift; + const data_t *mean, *var; + const data_t *diff_scale_shift; + const data_t *src, *dst; + const data_t *diff_src, *diff_dst; + const data_t *rbuf1, *rbuf2; + const uint8_t *ws; + barrier::ctx_t *barrier; + }; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t) + + /* cpu specific part */ + using Vmm = typename utils::conditional3::type; + const AddressFrame &vmmword = (isa == sse42) ? xword : + (isa == avx2) ? yword : zword; + + const int vlen = isa == sse42 ? 32 : cpu_isa_traits::vlen; + + const batch_normalization_pd_t *bdesc_; + bool is_spatial_thr_; + + void (*ker)(const call_params_t *); + void operator()(const call_params_t *p) { (*ker)(p); } + + Reg64 reg_param = abi_param1; + + Reg64 reg_scale_shift = rbx; + Reg64 reg_rbuf1 = abi_not_param1; + Reg64 reg_rbuf2 = rdx; + + Reg64 reg_mean = rbp; + Reg64 reg_var = reg_param; + Reg64 reg_diff_scale_shift = rax; + + Reg64 reg_coff = r8; + Reg64 reg_coff_max = r9; + Reg64 reg_soff = r10; + Reg64 reg_soff_max = r11; + Reg64 reg_ctr = r12; + Reg64 reg_roff = r13; + + Reg64 reg_mb_stride_Bc = r14; + + Reg64 reg_src = r15; + Reg64 reg_diff_src = reg_rbuf1; + Reg64 reg_dst = rsi; + Reg64 reg_diff_dst = reg_dst; + + Reg64 reg_tmp_off = reg_roff; + + // Reuse loop counters + Reg64 reg_bar = reg_coff; + Reg64 reg_nnthr = reg_soff; // must be usable w/ loops over coff + Reg64 reg_tmp = reg_ctr; + + // Relu section + bool with_relu, with_relu_inf_only; + Vmm vzero; // is_fwd() ? vdiff_beta : vbeta + Reg64 reg_ws = reg_roff; + Label l_relu_mask_avx2; + Opmask kstore_mask = Opmask(1); + + // channel tail processing + Opmask ktail_mask = Opmask(2); + + size_t unroll_blocks; + size_t unroll_regs; + Vmm vbuf = Vmm(isa == avx512_common ? 20 : 5); + Vmm vdiff_beta = Vmm(isa == avx512_common ? 21 : 6); + Vmm vdiff_gamma = Vmm(isa == avx512_common ? 22 : 7); + Vmm vsqrtvar = Vmm(isa == avx512_common ? 23 : 8); + Vmm vone = Vmm(isa == avx512_common ? 24 : 9); + Vmm vmean = Vmm(isa == avx512_common ? 25 : 10); + Vmm vgamma = Vmm(isa == avx512_common ? 26 : 11); + Vmm vbeta = Vmm(isa == avx512_common ? 27 : 12); + Vmm veps = Vmm(isa == avx512_common ? 28 : 13); + Vmm vchan_size = Vmm(isa == avx512_common ? 29 : 14); + Vmm vtail_mask = Vmm(isa == avx512_common ? 30 : 15); + + size_t t0_pf_offt; + size_t t1_pf_offt; + size_t spat_size; + size_t chan_data_offt; + + enum { + stack_off_N_nthr = 0, + stack_off_N_ithr = 8, + stack_off_src = 16, + stack_off_dst = 24, + stack_off_diff_src = 32, + stack_off_diff_dst = 40, + stack_off_diff_scale_shift = 48, + stack_off_ws = 56, + stack_off_barrier = 64, + stack_off_spat_size_loc = 72, + stack_off_s_s = 80, + stack_off_s_tail = 88, + stack_off_is_cblk_tail = 96, + stack_size_required = 104, + }; + + bool is_c_padded() const { + const memory_desc_wrapper data_d(bdesc_->src_md()); + return bdesc_->C() != data_d.padded_dims()[1]; + } + + void compute_static_strides() { + spat_size = bdesc_->D() * bdesc_->W() * bdesc_->H(); + chan_data_offt = bdesc_->C() * sizeof(data_t); + + if (isa == avx512_mic) { + t0_pf_offt = 4096; + t1_pf_offt = 0; + } else { + t0_pf_offt = 0; + t1_pf_offt = 0; + } + } + + void load_common_params() { +# define PARAM_OFF(x) offsetof(call_params_t, x) + mov(reg_rbuf1, ptr[reg_param + PARAM_OFF(rbuf1)]); + if (bdesc_->is_bwd()) + mov(reg_rbuf2, ptr[reg_param + PARAM_OFF(rbuf2)]); + mov(reg_coff_max, ptr[reg_param + PARAM_OFF(coff_max)]); + mov(reg_soff_max, ptr[reg_param + PARAM_OFF(soff_max)]); + mov(reg_mb_stride_Bc, ptr[reg_param + PARAM_OFF(mb_stride_Bc)]); + shl(reg_coff_max, 2); + shl(reg_soff_max, 2); + shl(reg_mb_stride_Bc, 2); + + mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]); + mov(reg_scale_shift, ptr[reg_param + PARAM_OFF(scale_shift)]); + + uni_vbroadcastss(vchan_size, vmmword[reg_param + PARAM_OFF(chan_size)]); + uni_vbroadcastss(vone, vmmword[reg_param + PARAM_OFF(one)]); + uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]); + + mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_nthr)]); + mov(ptr[rsp + stack_off_N_nthr], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_ithr)]); + mov(ptr[rsp + stack_off_N_ithr], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(src)]); + mov(ptr[rsp + stack_off_src], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(dst)]); + mov(ptr[rsp + stack_off_dst], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_src)]); + mov(ptr[rsp + stack_off_diff_src], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_dst)]); + mov(ptr[rsp + stack_off_diff_dst], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(ws)]); + mov(ptr[rsp + stack_off_ws], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(barrier)]); + mov(ptr[rsp + stack_off_barrier], reg_tmp); + if (is_spatial_thr_) { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(spat_size_loc)]); + mov(ptr[rsp + stack_off_spat_size_loc], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_s)]); + mov(ptr[rsp + stack_off_s_s], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_tail)]); + mov(ptr[rsp + stack_off_s_tail], reg_tmp); + } + if (is_c_padded()) { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(is_cblk_tail)]); + mov(ptr[rsp + stack_off_is_cblk_tail], reg_tmp); + } + + if (bdesc_->is_fwd()) { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]); + mov(reg_var, reg_tmp); + } else { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_scale_shift)]); + mov(ptr[rsp + stack_off_diff_scale_shift], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]); + mov(reg_var, reg_tmp); + } +# undef PARAM_OFF + } + + void prepare_tail_mask_avx512_common() { + if (!is_c_padded()) return; + + const int tail = bdesc_->C() % (int)(vlen / sizeof(float)); + const int mask = (1 << tail) - 1; + + Reg32 regw_tmp = reg_tmp.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + + void prepare_tail_mask_avx2_common() { + if (!is_c_padded()) return; + + const int tail = bdesc_->C() % (int)(vlen / sizeof(float)); + static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff, + 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0, 0, 0, 0, 0, 0, 0, 0}; + + mov(reg_tmp, reinterpret_cast(&mask[8 - tail])); + vmovups(vtail_mask, ptr[reg_tmp]); + } + + void prepare_relu() { + with_relu = bdesc_->is_fwd() + ? bdesc_->with_relu_post_op() || bdesc_->fuse_bn_relu() + : bdesc_->fuse_bn_relu(); + with_relu_inf_only = with_relu && bdesc_->is_fwd() + && !(bdesc_->fuse_bn_relu() && bdesc_->is_training()); + + vzero = bdesc_->is_fwd() ? vdiff_beta : vbeta; + if (with_relu) { + uni_vpxor(vzero, vzero, vzero); + if (!bdesc_->is_fwd() && isa == avx2) + prepare_l_relu_mask_avx2(); + } + } + + void prepare_l_relu_mask_avx2() { + Label l_mask_after; + jmp(l_mask_after); + align(32); + L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */ + for (int i = 0; i < 8; ++i) dd(1< + void spat_loop(size_t len, size_t blocks, size_t regs, + init_t init, body_t body, fini_t fini) { + size_t factor = regs * blocks; + size_t loop_unroll = len / factor * factor; + size_t loop_tail = len - loop_unroll; + size_t num_active_regs = (len < regs) ? len : regs; + for (size_t i = 0; i < num_active_regs; i++) + init(i); + if (loop_unroll) { + if (is_spatial_thr_) { + mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]); + add(reg_soff, ptr[rsp + stack_off_s_s]); + } else { + mov(reg_ctr, loop_unroll); + } + Label label; + L(label); { + for (size_t i = 0; i < factor; i++) { + size_t base_reg = i % regs; + body(base_reg, i); + } + add(reg_soff, factor * vlen); + sub(reg_ctr, factor); + jnz(label); + } + if (is_spatial_thr_) { + add(reg_soff, ptr[rsp + stack_off_s_tail]); + } + } + + for (size_t i = 0; i < loop_tail; i++) { + size_t base_reg = i % regs; + body(base_reg, i); + } + if (loop_tail) + add(reg_soff, loop_tail * vlen); + + for (size_t i = 0; i < num_active_regs; i++) + fini(i); + } + + void mean_channels() { + Label ch_label; + L(ch_label); { + uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); + spat_loop(spat_size, unroll_blocks, + unroll_regs, + [=](size_t base_reg) { + Vmm v = Vmm(base_reg * 2); + if (base_reg) + uni_vpxor(v, v, v); + }, + [=](size_t base_reg, size_t i) { + Vmm v0 = Vmm(base_reg * 2 + 0); + Vmm v1 = Vmm(base_reg * 2 + 1); + size_t offt = i * vlen; + uni_vmovups(v1, + vmmword[reg_src + reg_soff + offt]); + uni_vaddps(v0, v0, v1); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) { + Vmm b = Vmm(0); + Vmm v = Vmm(base_reg * 2); + if (base_reg) + uni_vaddps(b, b, v); + }); + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(ch_label); + } + } + + void var_channels() { + Label ch_label; + L(ch_label); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); + spat_loop(spat_size, unroll_blocks, unroll_regs, + [=](size_t base_reg) { + Vmm v = Vmm(base_reg * 3); + if (base_reg > 0) + uni_vpxor(v, v, v); + }, + [=](size_t base_reg, size_t i) { + Vmm v = Vmm(3 * base_reg); + Vmm vtmp0 = Vmm(3 * base_reg + 1); + Vmm vtmp1 = Vmm(3 * base_reg + 2); + size_t offt = i * vlen; + uni_vmovups(vtmp0, + vmmword[reg_src + reg_soff + offt]); + if (isa == sse42) { + movups(vtmp1, vmean); + subps(vtmp1, vtmp0); + } else { + vsubps(vtmp1, vmean, vtmp0); + } + uni_vfmadd231ps(v, vtmp1, vtmp1); + + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) { + Vmm b = Vmm(0); + Vmm v = Vmm(base_reg * 3); + if (base_reg) + uni_vaddps(b, b, v); + }); + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(ch_label); + } + } + + void compute_mean_variance() { + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + xor_(reg_coff, reg_coff); + Label zero_rbuf; + L(zero_rbuf); { + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + cmp(reg_coff, reg_coff_max); + jne(zero_rbuf); + } + + mov(reg_src, ptr[rsp + stack_off_src]); + + xor_(reg_soff, reg_soff); + Label mean_spatial; + L(mean_spatial); { + xor_(reg_coff, reg_coff); + + if (isa == sse42) + mov(reg_tmp_off, reg_soff); + + mean_channels(); + + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + + mean_channels(); + + sub(reg_src, vlen / 2); + } + + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(mean_spatial); + } + + Label no_mean_reduction; + barrier(); { + mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); + cmp(reg_tmp, 0); + jne(no_mean_reduction); + mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); + xor_(reg_coff, reg_coff); + Label mean_reduction_channels; + L(mean_reduction_channels); { + mov(reg_roff, reg_coff); + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); + mov(reg_ctr, reg_nnthr); + Label mean_reduction_thrs; + L(mean_reduction_thrs); { + uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]); + uni_vmovups(vmmword[reg_rbuf1 + reg_roff], Vmm(0)); + add(reg_roff, reg_coff_max); + sub(reg_ctr, 1); + jnz(mean_reduction_thrs); + } + uni_vdivps(Vmm(1), Vmm(1), vchan_size); + uni_vmovups_maybe_tail(mean_ptr(), Vmm(1)); + + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + + cmp(reg_coff, reg_coff_max); + jne(mean_reduction_channels); + } + } + L(no_mean_reduction); + barrier(); + + xor_(reg_soff, reg_soff); + Label var_spatial; + L(var_spatial); { + xor_(reg_coff, reg_coff); + + if (isa == sse42) + mov(reg_tmp_off, reg_soff); + + var_channels(); + + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + + var_channels(); + + sub(reg_src, vlen / 2); + } + + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(var_spatial); + } + + Label no_var_reduction; + barrier(); { + mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); + cmp(reg_tmp, 0); + jne(no_var_reduction); + + mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); + xor_(reg_coff, reg_coff); + Label var_reduction_channels; + L(var_reduction_channels); { + mov(reg_roff, reg_coff); + uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); + mov(reg_ctr, reg_nnthr); + Label var_reduction_thrs; + L(var_reduction_thrs); { // TODO: unroll (?) + uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]); + add(reg_roff, reg_coff_max); + sub(reg_ctr, 1); + jnz(var_reduction_thrs); + } + uni_vdivps(Vmm(1), Vmm(1), vchan_size); + uni_vmovups_maybe_tail(var_ptr(), Vmm(1)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + + cmp(reg_coff, reg_coff_max); + jne(var_reduction_channels); + } + } + L(no_var_reduction); + barrier(); + } + + void forward_channels() { + Label ch_label; + L(ch_label); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); + uni_vaddps(vsqrtvar, vsqrtvar, veps); + uni_vsqrtps(vsqrtvar, vsqrtvar); + + if (bdesc_->use_scaleshift()) { + uni_vmovups_maybe_tail(vgamma, gamma_ptr()); + uni_vmovups_maybe_tail(vbeta, beta_ptr()); + } + + Vmm vscale = bdesc_->use_scaleshift() ? vgamma : vone; + Vmm vdiv = bdesc_->use_scaleshift() ? vgamma : vsqrtvar; + + if (isa == sse42) { + movups(vbuf, vscale); + divps(vbuf, vsqrtvar); + movups(vdiv, vbuf); + } else { + vdivps(vdiv, vscale, vsqrtvar); + } + + auto compute = [=](bool output_is_aligned) { + spat_loop(spat_size, unroll_blocks, unroll_regs, + [](size_t base_reg) {UNUSED(base_reg);}, + [=](size_t base_reg, size_t i) { + Vmm v = Vmm(base_reg); + size_t offt = i * vlen; + uni_vmovups(v, + vmmword[reg_src + reg_soff + offt]); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + uni_vsubps(v, v, vmean); + if (bdesc_->use_scaleshift()) { + uni_vfmadd213ps(v, vgamma, vbeta); + } else { + uni_vmulps(v, v, vsqrtvar); + } + if (with_relu_inf_only) { + uni_vmaxps(v, v, vzero); + } else if (with_relu) { + if (isa == avx512_common) + fwd_process_relu_avx512_common(v, offt); + else + fwd_process_relu_avx2(v, offt, Vmm(3)); + } + if (output_is_aligned) { + uni_vmovntps( + vmmword[reg_dst + reg_soff + offt], v); + } else { + uni_vmovups( + vmmword[reg_dst + reg_soff + offt], v); + } + }, + [](size_t base_reg) {UNUSED(base_reg);}); + }; + + Label unaligned_store, end_store; + test(reg_dst, vlen - 1); + jnz(unaligned_store, T_NEAR); + compute(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + compute(false); + } + L(end_store); + + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(ch_label); + } + } + + void forward() { + mov(reg_src, ptr[rsp + stack_off_src]); + mov(reg_dst, ptr[rsp + stack_off_dst]); + mov(reg_ws, ptr[rsp + stack_off_ws]); + + xor_(reg_soff, reg_soff); + Label dst_spatial; + L(dst_spatial); { + xor_(reg_coff, reg_coff); + if (isa == sse42) + mov(reg_tmp_off, reg_soff); + + forward_channels(); + + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_src, vlen / 2); + add(reg_dst, vlen / 2); + mov(reg_coff, vlen / 2); + + forward_channels(); + + sub(reg_src, vlen / 2); + sub(reg_dst, vlen / 2); + } + + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jnz(dst_spatial); + } + } + + void backward_sh_channels() { + Label sh_channels; + L(sh_channels); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); + uni_vmovups(Vmm(1), vmmword[reg_rbuf2 + reg_coff]); + spat_loop(spat_size, 1, 1, + [=](size_t base_reg) { + if (base_reg > 0) { + for (int i = 0; i < 2; i++) { + Vmm v(base_reg * 5 + i); + uni_vpxor(v, v, v); + } + } + }, + [=](size_t base_reg, size_t i) { + Vmm o0 = Vmm(base_reg * 5 + 0); + Vmm o1 = Vmm(base_reg * 5 + 1); + Vmm t1 = Vmm(base_reg * 5 + 2); + Vmm t2 = Vmm(base_reg * 5 + 3); + Vmm t3 = Vmm(base_reg * 5 + 4); + size_t offt = i * vlen; + uni_vmovups(t1, vmmword[reg_src + reg_soff + offt]); + uni_vmovups(t2, vmmword[reg_diff_dst + reg_soff + + offt]); + if (with_relu) { + if (isa == avx512_common) + bwd_process_relu_avx512_common(t2, offt); + else if (isa == avx2) + bwd_process_relu_avx2(t2, offt, t3); + else + assert(false); + } + uni_vsubps(t3, vmean, t1, t3); + if (isa == sse42) { + mulps(t3, t2); + subps(o0, t3); + } else { + vfnmadd231ps(o0, t3, t2); + } + uni_vaddps(o1, o1, t2); + mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_diff_dst + reg_soff + offt + + t1_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) { + Vmm b0 = Vmm(0); + Vmm b1 = Vmm(1); + if (base_reg) { + uni_vaddps(b0, b0, Vmm(base_reg * 5 + 0)); + uni_vaddps(b1, b1, Vmm(base_reg * 5 + 1)); + } + }); + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1)); + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(sh_channels); + } + } + + void backward_diff_channels() { + Label diff_channels; + L(diff_channels); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); + uni_vaddps(vsqrtvar, vsqrtvar, veps); + uni_vsqrtps(vsqrtvar, vsqrtvar); + uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf); + if (bdesc_->use_scaleshift()) + uni_vmovups_maybe_tail(vgamma, gamma_ptr()); + uni_vmovups_maybe_tail(vdiff_gamma, diff_gamma_ptr()); + uni_vmovups_maybe_tail(vdiff_beta, diff_beta_ptr()); + uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar); + uni_vdivps(vdiff_beta, vdiff_beta, vchan_size); + uni_vdivps(vdiff_gamma, vdiff_gamma, vchan_size); + + auto compute = [=](bool output_is_aligned) { + spat_loop(spat_size, unroll_blocks, unroll_regs, + [=](size_t base_reg) {UNUSED(base_reg);}, + [=](size_t base_reg, size_t i) { + Vmm v(base_reg * 2 + 0); + Vmm t(base_reg * 2 + 1); + Vmm t1(base_reg * 2 + 2); + size_t offt = i * vlen; + uni_vmovups(v, vmmword[reg_diff_dst + reg_soff + + offt]); + if (with_relu) { + if (isa == avx512_common) + bwd_process_relu_avx512_common(v, offt); + else if (isa == avx2) + bwd_process_relu_avx2(v, offt, t); + else + assert(false); + } + if (!bdesc_->use_global_stats()) { + uni_vsubps(v, v, vdiff_beta); + uni_vmovups(t, vmmword[reg_src + reg_soff + + offt]); + uni_vsubps(t, vmean, t, t1); + uni_vmulps(t, t, vdiff_gamma); + uni_vaddps(v, v, t); + } + uni_vmulps(v, v, vsqrtvar); + if (bdesc_->use_scaleshift()) { + uni_vmulps(v, v, vgamma); + } + if (output_is_aligned) { + uni_vmovntps( + vmmword[reg_diff_src + reg_soff + offt], + v); + } else { + uni_vmovups( + vmmword[reg_diff_src + reg_soff + offt], + v); + } + mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_diff_dst + reg_soff + + offt + t1_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) {UNUSED(base_reg);}); + }; + + Label unaligned_store, end_store; + test(reg_diff_src, vlen - 1); + jnz(unaligned_store, T_NEAR); + compute(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + compute(false); + } + L(end_store); + + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(diff_channels); + } + } + + void backward() { + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + xor_(reg_coff, reg_coff); + Label zero_rbuf, sh_spatial; + + L(zero_rbuf); { + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + cmp(reg_coff, reg_coff_max); + jne(zero_rbuf); + } + + mov(reg_src, ptr[rsp + stack_off_src]); + mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]); + if (with_relu) { + assert(isa == avx2 || isa == avx512_common); + mov(reg_ws, ptr[rsp + stack_off_ws]); + } + + xor_(reg_soff, reg_soff); + L(sh_spatial); { + xor_(reg_coff, reg_coff); + if (isa == sse42) { + mov(reg_tmp_off, reg_soff); + } + backward_sh_channels(); + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_diff_dst, vlen / 2); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + backward_sh_channels(); + sub(reg_diff_dst, vlen / 2); + sub(reg_src, vlen / 2); + } + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(sh_spatial); + } + + mov(reg_diff_scale_shift, ptr[rsp + stack_off_diff_scale_shift]); + + Label no_sh_reduction; + barrier(); { + mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); + cmp(reg_tmp, 0); + Label sh_reduction_channels; + jne(no_sh_reduction, T_NEAR); + + mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); + xor_(reg_coff, reg_coff); + L(sh_reduction_channels); { + mov(reg_roff, reg_coff); + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); + uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); + uni_vaddps(vsqrtvar, vsqrtvar, veps); + uni_vsqrtps(vsqrtvar, vsqrtvar); + uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf); + mov(reg_ctr, reg_nnthr); + Label sh_reduction_thrs; + L(sh_reduction_thrs); { // TODO: unroll (?) + uni_vaddps(Vmm(0), Vmm(0), vmmword[reg_rbuf1 + reg_roff]); + uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf2 + reg_roff]); + add(reg_roff, reg_coff_max); + sub(reg_ctr, 1); + jnz(sh_reduction_thrs); + } + uni_vmulps(Vmm(0), Vmm(0), vsqrtvar); + uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0)); + uni_vmovups_maybe_tail(diff_beta_ptr(), Vmm(1)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + cmp(reg_coff, reg_coff_max); + jne(sh_reduction_channels); + } + } + L(no_sh_reduction); + barrier(); + + mov(reg_diff_src, ptr[rsp + stack_off_diff_src]); + if (with_relu) { + assert(isa == avx2 || isa == avx512_common); + mov(reg_ws, ptr[rsp + stack_off_ws]); + } + + xor_(reg_soff, reg_soff); + Label diff_spatial; + L(diff_spatial); { + xor_(reg_coff, reg_coff); + if (isa == sse42) { + mov(reg_tmp_off, reg_soff); + } + backward_diff_channels(); + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_diff_dst, vlen / 2); + add(reg_diff_src, vlen / 2); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + backward_diff_channels(); + sub(reg_diff_dst, vlen / 2); + sub(reg_diff_src, vlen / 2); + sub(reg_src, vlen / 2); + } + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(diff_spatial); + } + } + + jit_bnorm_t(const batch_normalization_pd_t *bdesc): bdesc_(bdesc) { + static_assert(isa == sse42 || isa == avx2 || isa == avx512_common + || isa == avx512_mic, "unsupported isa"); + + const int simd_w = isa == sse42 ? 8 : + cpu_isa_traits::vlen / sizeof(data_t); + is_spatial_thr_ = + bnorm_utils::is_spatial_thr(bdesc_, simd_w, sizeof(data_t)); + + unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1; + unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1; + + preamble(); + + if (isa == avx512_common) + prepare_tail_mask_avx512_common(); + else if (isa == avx2) + prepare_tail_mask_avx2_common(); + + compute_static_strides(); + sub(rsp, stack_size_required); + load_common_params(); + prepare_relu(); + + if (bdesc_->is_fwd()) { + if (!bdesc_->stats_is_src()) { + compute_mean_variance(); + } + forward(); + } else { + backward(); + } + add(rsp, stack_size_required); + postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); + } +}; + +template +struct uni_bnorm_driver_t: public c_compatible { + uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc) + : bdesc_(bdesc), ker_(bdesc_) + { + const int nthrs = mkldnn_get_max_threads(); + const dim_t C_PADDED = get_c_padded(bdesc_); + + size_t data_size = sizeof(data_t) * bdesc_->MB() * C_PADDED + * bdesc_->D() * bdesc_->H() * bdesc_->W(); + l3_size_ = get_cache_size(3, true) * nthrs / 2; + do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0); + } + + ~uni_bnorm_driver_t() {} + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const batch_normalization_pd_t *bdesc) { + int nthrs = mkldnn_get_max_threads(); + dim_t C_PADDED = get_c_padded(bdesc); + + int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED; + int pbuf_sz = use_tmp_diff_scale_shift(bdesc) * 2 * C_PADDED; + int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs; + + scratchpad.book(key_bnorm_tmp_stats, sizeof(data_t) * sbuf_sz); + scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * pbuf_sz); + scratchpad.book(key_bnorm_reduction, sizeof(data_t) * rbuf_sz); + + if (mkldnn_thr_syncable()) { + int n_barriers = C_PADDED / simd_w; + scratchpad.book(key_barrier, sizeof(barrier::ctx_t) * n_barriers); + } + } + + void exec(int ithr, int nthr, const data_t *src, data_t *diff_src, + data_t *dst, const data_t *diff_dst, const data_t *scale_shift, + data_t *diff_scale_shift, const data_t *mean, const data_t *var, + const uint8_t *ws, const memory_tracking::grantor_t &scratchpad) { + auto sbuf = scratchpad.get(key_bnorm_tmp_stats); + auto pbuf = scratchpad.get(key_bnorm_tmp_diff_ss); + auto rbuf = scratchpad.get(key_bnorm_reduction); + auto barriers = scratchpad.get(key_barrier); + + dim_t N = bdesc_->MB(); + dim_t C = bdesc_->C(); + dim_t C_PADDED = get_c_padded(bdesc_); + dim_t D = bdesc_->D(); + dim_t H = bdesc_->H(); + dim_t W = bdesc_->W(); + dim_t SP = D * H * W; + dim_t img_size = C_PADDED * D * H * W; + const int vlen = isa == sse42 ? 32 : cpu_isa_traits::vlen; + + typename jit_bnorm_t::call_params_t p; + + p.eps = bdesc_->desc()->batch_norm_epsilon; + p.one = 1.0f; + p.spat_size = D * H * W; + p.chan_size = 1.0f * N * p.spat_size; + + dim_t C_blks = C_PADDED / simd_w; + + int C_ithr{0}, C_nthr{0}, N_ithr{0}, N_nthr{0}, S_ithr{0}, S_nthr{0}; + dim_t C_blk_s{0}, C_blk_e{0}, N_s{0}, N_e{0}, S_s{0}, S_e{0}; + + dim_t C_blks_per_iter{ 1 }; + int64_t iters{ 1 }; + if (do_blocking_) { + int num_tensors = bdesc_->is_fwd() ? 1 : 2; + size_t working_set_size + = (N * D * H * W * simd_w * sizeof(data_t)) * num_tensors; + bnorm_utils::cache_balance(working_set_size, C_blks, + C_blks_per_iter, iters); + } + + bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_, + true, ithr, nthr, N, do_blocking_ ? C_blks_per_iter : C_blks, + SP, C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e, + S_ithr, S_nthr, S_s, S_e); + + int SP_N_ithr = N_ithr * S_nthr + S_ithr; + int SP_N_nthr = N_nthr * S_nthr; + assert(IMPLICATION(!mkldnn_thr_syncable(), SP_N_nthr == 1)); + + p.N_ithr = SP_N_ithr; + p.N_nthr = SP_N_nthr; + + int last_iter_blks = C_blks - (iters - 1) * C_blks_per_iter; + int global_C_blk_s; + int global_barriers_per_iter = C_nthr; + + for (int64_t it = 0; it < iters; it++) { + if (it == iters - 1 && iters > 1) { + C_blk_s = C_blk_e = N_s = N_e = 0; + spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_, + spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, + C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, + N_e, S_ithr, S_nthr, S_s, S_e); + + // Update call parameters for JIT, last iteration + p.N_ithr = N_ithr * S_nthr + S_ithr; + p.N_nthr = N_nthr * S_nthr; + } + + global_C_blk_s = do_blocking_ ? + (C_blk_s == -1) ? -1 : it * C_blks_per_iter + C_blk_s : + C_blk_s; + + int C_blks_thr = C_blk_e - C_blk_s; + int N_thr = N_e - N_s; + + size_t coff_base = global_C_blk_s * simd_w; + size_t soff_base + = global_C_blk_s * p.spat_size * simd_w + N_s * img_size; + + p.spat_size_loc = S_e - S_s; + p.S_s = S_s * vlen; + p.S_tail = (p.spat_size - S_e) * vlen; + p.coff_max = C_blks_thr * simd_w; + p.mean = (use_tmp_stats(bdesc_) ? sbuf : mean) + coff_base; + p.var = (use_tmp_stats(bdesc_) ? sbuf + C_PADDED : var) + coff_base; + p.scale_shift = scale_shift + coff_base; + p.diff_scale_shift = (use_tmp_diff_scale_shift(bdesc_) + ? pbuf : diff_scale_shift) + coff_base; + + p.soff_max = N_thr * img_size; + p.src = src + soff_base; + p.dst = dst + soff_base; + p.diff_src = diff_src + soff_base; + p.diff_dst = diff_dst + soff_base; + p.ws = ws + soff_base / 8; + + p.mb_stride_Bc = img_size - p.coff_max * p.spat_size; + + // use SP_N_nthr which is the same as p.N_nthr except maybe for + // the last iteration. + p.rbuf1 = rbuf + ((it * C_blks_per_iter) * SP_N_nthr + + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) * simd_w; + // rbuf1 and rbuf2 have to be disjoint + p.rbuf2 = p.rbuf1 + C_PADDED * nthr; + p.is_cblk_tail = (it * C_blks_per_iter + C_blk_e) * simd_w > C; + + size_t iter_bariers + = do_blocking_ ? it * global_barriers_per_iter : 0; + p.barrier = barriers + C_ithr + iter_bariers; + if (p.soff_max != 0 && p.coff_max != 0) + ker_(&p); + } + } + + void init_barriers(const memory_tracking::grantor_t &scratchpad) { + auto barriers = scratchpad.get(key_barrier); + if (barriers) { + const int n_barriers = get_c_padded(bdesc_) / simd_w; + for (int i = 0; i < n_barriers; ++i) + barrier::ctx_init(&barriers[i]); + } + } + +private: + enum { + simd_w = isa == sse42 ? 8 : cpu_isa_traits::vlen / sizeof(data_t) + }; + + static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) { + return true + && !bdesc->stats_is_src() + && bdesc->desc()->prop_kind == prop_kind::forward_inference; + } + + static bool use_tmp_diff_scale_shift(const batch_normalization_pd_t *bdesc) + { + return false + || (bdesc->is_bwd() && !bdesc->use_scaleshift()) + || bdesc->desc()->prop_kind == prop_kind::backward_data; + } + + static dim_t get_c_padded(const batch_normalization_pd_t *bdesc) + { return bdesc->src_md()->padded_dims[1]; } + + const batch_normalization_pd_t *bdesc_; + bool do_blocking_; + size_t l3_size_; + + jit_bnorm_t ker_; +}; + +} + +using namespace data_type; +using namespace format_tag; +using namespace utils; + +/* fwd */ + +template +status_t jit_uni_batch_normalization_fwd_t::pd_t::init() { + auto desired_fmt_tag = (ndims() == 4) + ? isa == avx512_common ? nChw16c : nChw8c + : isa == avx512_common ? nCdhw16c : nCdhw8c; + + bool ok = true + && mayiuse(isa) + && is_fwd() + && !has_zero_dim_memory() + && one_of(ndims(), 4, 5) + && src_md()->data_type == f32 + && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && (attr()->has_default_values() || this->with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (is_training() && fuse_bn_relu()) { + if (isa < avx2) return status::unimplemented; + init_default_ws(1); + } + + if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() + && isa < avx2) + return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + uni_bnorm_driver_t::init_scratchpad(scratchpad, this); + + return status::success; +} + +template +jit_uni_batch_normalization_fwd_t::jit_uni_batch_normalization_fwd_t( + const pd_t *apd): cpu_primitive_t(apd) +{ bnorm_driver_ = new uni_bnorm_driver_t(pd()); } + +template +status_t jit_uni_batch_normalization_fwd_t::execute( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + + auto mean = pd()->stats_is_src() + ? const_cast(CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)) + : CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); + auto var = pd()->stats_is_src() + ? const_cast(CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)) + : CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto scratchpad = this->scratchpad(ctx); + + bnorm_driver_->init_barriers(scratchpad); + + parallel(0, [&](const int ithr, const int nthr) { + bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr, + scale_shift, nullptr, mean, var, ws, scratchpad); + }); + + return status::success; +} + +template +jit_uni_batch_normalization_fwd_t::~jit_uni_batch_normalization_fwd_t() +{ delete bnorm_driver_; } + +/* bwd */ + +template +status_t jit_uni_batch_normalization_bwd_t::pd_t::init() { + auto desired_fmt_tag = (ndims() == 4) + ? one_of(isa, sse42, avx2) ? nChw8c : nChw16c + : one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c; + + bool ok = true + && mayiuse(isa) + && is_bwd() + && !has_zero_dim_memory() + && one_of(ndims(), 4, 5) + && everyone_is(f32, src_md()->data_type, diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), + utils::everyone_is(f32, + weights_md()->data_type, + diff_weights_md()->data_type)) + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() + && isa < avx2) + return status::unimplemented; + + if (fuse_bn_relu()) { + if (isa < avx2) return status::unimplemented; + init_default_ws(1); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + /* TODO: extra checks required */ + + auto scratchpad = scratchpad_registry().registrar(); + uni_bnorm_driver_t::init_scratchpad(scratchpad, this); + + return status::success; +} + +template +jit_uni_batch_normalization_bwd_t::jit_uni_batch_normalization_bwd_t( + const pd_t *apd): cpu_primitive_t(apd) +{ bnorm_driver_ = new uni_bnorm_driver_t(pd()); } + +template +status_t jit_uni_batch_normalization_bwd_t::execute( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto var = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scale_shift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + + bnorm_driver_->init_barriers(scratchpad); + + parallel(0, [&](const int ithr, const int nthr) { + bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst, + scale_shift, diff_scale_shift, mean, var, ws, scratchpad); + }); + + return status::success; +} + +template +jit_uni_batch_normalization_bwd_t::~jit_uni_batch_normalization_bwd_t() +{ delete bnorm_driver_; } + +/* struct instantiation */ +template struct jit_uni_batch_normalization_fwd_t; +template struct jit_uni_batch_normalization_bwd_t; +template struct jit_uni_batch_normalization_fwd_t; +template struct jit_uni_batch_normalization_bwd_t; +template struct jit_uni_batch_normalization_fwd_t; +template struct jit_uni_batch_normalization_bwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp new file mode 100644 index 000000000000..96410ec84ea7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp @@ -0,0 +1,100 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_BATCH_NORMALIZATION_HPP +#define JIT_UNI_BATCH_NORMALIZATION_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_isa_traits.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { template struct uni_bnorm_driver_t; } + +template +struct jit_uni_batch_normalization_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_fwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_batch_normalization_fwd_t); + + status_t init(); + }; + + typedef typename prec_traits::type data_t; + + jit_uni_batch_normalization_fwd_t(const pd_t *apd); + ~jit_uni_batch_normalization_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + uni_bnorm_driver_t *bnorm_driver_; +}; + +template +struct jit_uni_batch_normalization_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_bwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_batch_normalization_bwd_t); + + status_t init(); + }; + + typedef typename prec_traits::type data_t; + + jit_uni_batch_normalization_bwd_t(const pd_t *apd); + ~jit_uni_batch_normalization_bwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + uni_bnorm_driver_t *bnorm_driver_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp new file mode 100644 index 000000000000..b7dc5f85c5ad --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp @@ -0,0 +1,1302 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include "jit_uni_dw_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +template +void jit_uni_dw_conv_fwd_kernel_f32::load_src(int ur_ch_blocks, int ur_w) { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int ow = 0; ow < ur_w; ow++) { + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow); + + int b_off = ch*jcp.ch_block + i*4; + if (this->jcp.with_bias) + uni_vmovups(vmm_acc, + vmmword[reg_bias + b_off*sizeof(float)]); + else + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + + int o_off = ch*jcp.oh*jcp.ow*jcp.ch_block + + ow*jcp.ch_block + i*4; + if (this->jcp.with_sum) + uni_vaddps(vmm_acc, vmm_acc, + vmmword[reg_output + o_off*sizeof(float)]); + } + } + } +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::apply_filter( + int ur_ch_blocks, int ur_w) { + int ch_blk = jcp.ch_block; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + mov(iter_kw, reg_kw); + mov(aux1_reg_input, aux_reg_input); + mov(aux1_reg_kernel, aux_reg_kernel); + + Label kw_label; + L(kw_label); { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int ker_off = ch*jcp.kh*jcp.kw*ch_blk + i*4; + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux1_reg_kernel + + ker_off*sizeof(float)]); + + for (int ow = 0; ow < ur_w; ow++) { + int inp_off = ch*jcp.ih*jcp.iw*ch_blk + + ow*stride_w*ch_blk + i*4; + Vmm vmm_src = get_src_reg(0); + uni_vmovups(vmm_src, ptr[aux1_reg_input + + inp_off*sizeof(float)]); + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + + ch*ur_w + ow); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + add(aux1_reg_kernel, ch_blk*sizeof(float)); + add(aux1_reg_input, ch_blk*dilate_w*sizeof(float)); + + dec(iter_kw); + cmp(iter_kw, 0); + jg(kw_label, T_NEAR); + } + add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float)); + add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float)); + + dec(iter_kh); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::apply_filter_unrolled( + int ur_ch_blocks, int ur_w) { + int ch_blk = jcp.ch_block; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int kw = 0; kw < jcp.kw; kw++) { + int ker_off = ch*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*4; + + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux_reg_kernel + + ker_off*sizeof(float)]); + + for (int ow = 0; ow < ur_w; ow++) { + int inp_off = ch*jcp.ih*jcp.iw*ch_blk + + ow*stride_w*ch_blk + kw*ch_blk*dilate_w + i*4; + + Vmm vmm_src = get_src_reg(0); + uni_vmovups(vmm_src, ptr[aux_reg_input + + inp_off*sizeof(float)]); + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + + ch*ur_w + ow); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + } + + add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float)); + add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float)); + + dec(iter_kh); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::apply_activation( + int ur_ch_blocks, int ur_w) { + if (this->jcp.with_eltwise) { + int repeats = isa == sse42 ? 2 : 1; + eltwise_injector_->compute_vector_range(4, repeats * ur_w * ur_ch_blocks + 4); + } +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::store_dst( + int ur_ch_blocks, int ur_w) { + int ch_blk = jcp.ch_block; + + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int ow = 0; ow < ur_w; ow++) { + int o_off = ch*jcp.oh*jcp.ow*ch_blk + ow*ch_blk + i*4; + Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow); + + uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst); + } + } + } +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::loop_body(int ur_ch_blocks) { + Label unrolled_w_label; + Label tail_w_label; + Label exit_label; + + L(unrolled_w_label); { + int ur_w = jcp.ur_w; + + cmp(reg_ur_w, ur_w); + jl(tail_w_label, T_NEAR); + + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + load_src(ur_ch_blocks, ur_w); + apply_filter_unrolled(ur_ch_blocks, ur_w); + apply_activation(ur_ch_blocks, ur_w); + store_dst(ur_ch_blocks, ur_w); + + add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_output, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_w, ur_w); + jmp(unrolled_w_label); + } + + L(tail_w_label); { + int ur_w = 1; + + cmp(reg_ur_w, ur_w); + jl(exit_label, T_NEAR); + + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + load_src(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + apply_activation(ur_ch_blocks, ur_w); + store_dst(ur_ch_blocks, ur_w); + + add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_output, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_w, ur_w); + jmp(tail_w_label); + } + + L(exit_label); +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::generate() { + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); + mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); + mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]); + + Label ch_blocks_tail_label; + Label exit_label; + + int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; + + cmp(reg_ch_blocks, jcp.nb_ch_blocking); + jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); + + loop_body(jcp.nb_ch_blocking); // channel main loop + + if (ch_blocks_tail) { + L(ch_blocks_tail_label); + + cmp(reg_ch_blocks, ch_blocks_tail); + jne(exit_label, T_NEAR); + + loop_body(ch_blocks_tail); // channel tail loop + } + + L(exit_label); + + this->postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +template +bool jit_uni_dw_conv_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +template +status_t jit_uni_dw_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(isa)) return status::unimplemented; + + const int simd_w = isa == avx512_common ? 16 : 8; + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + if (!with_groups) return status::unimplemented; + + jcp.ngroups = weights_d.dims()[0]; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1]; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1]; + + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + + jcp.kh = weights_d.dims()[3]; + jcp.kw = weights_d.dims()[4]; + + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.b_pad = cd.padding[1][0]; + jcp.r_pad = cd.padding[1][1]; + + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + bool ok_to_pad_channels = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && one_of(isa, avx512_common, avx2); + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.oc, simd_w); + jcp.ngroups = rnd_up(jcp.ngroups, simd_w); + } + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && jcp.ngroups % simd_w == 0 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ngroups <= weights_d.padded_dims()[0]; + if (!args_ok) return status::unimplemented; + + jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3; + + jcp.ch_block = simd_w; + jcp.nb_ch = jcp.oc / jcp.ch_block; + jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2; + if (jcp.nb_ch < jcp.nb_ch_blocking) + jcp.nb_ch_blocking = jcp.nb_ch; + + return status::success; +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc_without_padding != jcp.oc) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +template struct jit_uni_dw_conv_fwd_kernel_f32; +template struct jit_uni_dw_conv_fwd_kernel_f32; +template struct jit_uni_dw_conv_fwd_kernel_f32; + +template +inline void jit_uni_dw_conv_bwd_data_kernel_f32::load_ddst( + int ur_ch_blocks, int ur_str_w) { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int w = 0; w < ur_str_w; w++) { + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + } + } + } +} + +template +inline void jit_uni_dw_conv_bwd_data_kernel_f32::apply_filter( + int ur_ch_blocks, int ur_str_w) { + int kw = jcp.kw; + int kh = jcp.kh; + int ow = jcp.ow; + int oh = jcp.oh; + + int ch_blk = jcp.ch_block; + int stride_h = jcp.stride_h; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + mov(aux1_reg_ddst, aux_reg_ddst); + mov(aux1_reg_kernel, aux_reg_kernel); + + mov(iter_kw, reg_kw); + Label kw_label; + L(kw_label); { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int ker_off = ch*kh*kw*ch_blk + i*4; + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux1_reg_kernel + + ker_off*sizeof(float)]); + + for (int w = 0; w < ur_str_w; w++) { + int ddst_off = (ch*oh*ow + w)*ch_blk + i*4; + + Vmm vmm_src = get_src_reg(0); + uni_vmovups(vmm_src, ptr[aux1_reg_ddst + + ddst_off*sizeof(float)]); + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + + add(aux1_reg_kernel, ch_blk*stride_w*sizeof(float)); + sub(aux1_reg_ddst, ch_blk*sizeof(float)); + + sub(iter_kw, stride_w); + cmp(iter_kw, 0); + jg(kw_label, T_NEAR); + } + + add(aux_reg_kernel, kw*ch_blk*stride_h*sizeof(float)); + sub(aux_reg_ddst, ow*ch_blk*sizeof(float)); + + sub(iter_kh, stride_h); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +template +inline void jit_uni_dw_conv_bwd_data_kernel_f32::store_dsrc( + int ur_ch_blocks, int ur_str_w) { + int ch_blk = jcp.ch_block; + int iw = jcp.iw; + int ih = jcp.ih; + int stride_w = jcp.stride_w; + + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int w = 0; w < ur_str_w; w++) { + int dsrc_off = (ch*ih*iw + w*stride_w)*ch_blk + i*4; + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + + uni_vmovups(ptr[reg_dsrc + dsrc_off*sizeof(float)], vmm_acc); + } + } + } +} + +template +inline void jit_uni_dw_conv_bwd_data_kernel_f32::loop_body( + int ur_ch_blocks) { + Label unrolled_w_label; + Label tail_w_label; + Label exit_label; + + L(unrolled_w_label); { + int ur_w = jcp.ur_w; + + cmp(reg_ur_str_w, ur_w); + jl(tail_w_label, T_NEAR); + + mov(aux_reg_ddst, reg_ddst); + mov(aux_reg_kernel, reg_kernel); + + load_ddst(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + store_dsrc(ur_ch_blocks, ur_w); + + add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_str_w, ur_w); + jmp(unrolled_w_label); + } + + L(tail_w_label); { + int ur_w = 1; + + cmp(reg_ur_str_w, ur_w); + jl(exit_label, T_NEAR); + + mov(aux_reg_ddst, reg_ddst); + mov(aux_reg_kernel, reg_kernel); + + load_ddst(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + store_dsrc(ur_ch_blocks, ur_w); + + add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_str_w, ur_w); + jmp(tail_w_label); + } + + L(exit_label); +} + +template +void jit_uni_dw_conv_bwd_data_kernel_f32::generate() { + preamble(); + + mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); + mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); + mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); + mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]); + + Label ch_blocks_tail_label; + Label exit_label; + + int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; + + cmp(reg_ch_blocks, jcp.nb_ch_blocking); + jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); + + loop_body(jcp.nb_ch_blocking); // channel main loop + + if (ch_blocks_tail) { + L(ch_blocks_tail_label); + + cmp(reg_ch_blocks, ch_blocks_tail); + jne(exit_label, T_NEAR); + + loop_body(ch_blocks_tail); // channel tail loop + } + + L(exit_label); + + this->postamble(); +} + +template +status_t jit_uni_dw_conv_bwd_data_kernel_f32::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) { + if (!mayiuse(isa)) return status::unimplemented; + + const int simd_w = isa == avx512_common ? 16 : 8; + + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; + if (!with_groups) return status::unimplemented; + + jcp.ngroups = weights_d.dims()[0]; + jcp.mb = diff_src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1]; + jcp.oc_without_padding = jcp.oc; + jcp.ic = diff_src_d.dims()[1]; + + jcp.ih = diff_src_d.dims()[2]; + jcp.iw = diff_src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + + jcp.kh = weights_d.dims()[3]; + jcp.kw = weights_d.dims()[4]; + + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.b_pad = cd.padding[1][0]; + jcp.r_pad = cd.padding[1][1]; + + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + + bool ok_to_pad_channels = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && one_of(isa, avx512_common, avx2); + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.oc, simd_w); + jcp.ngroups = rnd_up(jcp.ngroups, simd_w); + } + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && jcp.ngroups % simd_w == 0 + && jcp.dilate_h == 0 + && jcp.dilate_w == 0 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag + && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 + && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1 + && jcp.ic <= diff_src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ngroups <= weights_d.padded_dims()[0]; + if (!args_ok) return status::unimplemented; + + jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3; + + jcp.ch_block = simd_w; + jcp.nb_ch = jcp.ic / jcp.ch_block; + jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2; + if (jcp.nb_ch < jcp.nb_ch_blocking) + jcp.nb_ch_blocking = jcp.nb_ch; + + return status::success; +} + +template +void jit_uni_dw_conv_bwd_data_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + UNUSED(scratchpad); + UNUSED(jcp); +} + +template struct jit_uni_dw_conv_bwd_data_kernel_f32; +template struct jit_uni_dw_conv_bwd_data_kernel_f32; +template struct jit_uni_dw_conv_bwd_data_kernel_f32; + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::zero_filter() { + for (int r = 0; r < reg_repeats; ++r) { + for (int i = 0; i < jcp.kw; ++i) { + Vmm vmm_acc = get_acc_reg(r * jcp.kw + i); + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + } + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::load_filter() { + for (int r = 0; r < reg_repeats; ++r) { + const int reg_set = r * jcp.kw; + for (int i = 0; i < jcp.kw; ++i) { + int off_filter = (reg_set + i) * simd_w; + Vmm vmm_acc = get_acc_reg(reg_set + i); + uni_vmovups(vmm_acc, + vmmword[reg_tmp_filter + off_filter * sizeof(float)]); + } + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::zero_bias() { + for (int r = 0; r < reg_repeats; ++r) { + Vmm vmm_bias = get_bias_reg(r); + uni_vpxor(vmm_bias, vmm_bias, vmm_bias); + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::load_bias() { + for (int r = 0; r < reg_repeats; ++r) { + Vmm vmm_bias = get_bias_reg(r); + uni_vmovups( + vmm_bias, vmmword[reg_bias_baddr + r * simd_w * sizeof(float)]); + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_ow_step_unroll( + int unroll_w, int l_pad, int pad_offset, int ow_block) { + + const int iw_block = ow_block * jcp.stride_w; + const int right_border = jcp.iw - iw_block; + + const int cascade_input = nstl::min(jcp.stride_w, jcp.kw); + + /* preamble count for number of cascaded LOAD + FMA operation */ + const int input_overlap = nstl::max(jcp.kw - l_pad, 0); + + /* LOAD initial input registers, then cascade LOADs and FMAs*/ + for (int r = 0; r < reg_repeats; ++r) { + for (int i_ur = 0; i_ur < unroll_w; ++i_ur) { + int off_output = (i_ur * reg_repeats + r) * simd_w; + Vmm vmm_output = get_output_reg(r); + uni_vmovups(vmm_output, + ptr[reg_tmp_output + off_output * sizeof(float)]); + if (i_ur == 0) { + for (int c = 0; c < input_overlap; ++c) { + int off_input + = ((c - pad_offset) * reg_repeats + r) * simd_w; + Vmm vmm_input + = get_input_reg((c % jcp.kw) * reg_repeats + r); + uni_vmovups(vmm_input, + ptr[reg_tmp_input + off_input * sizeof(float)]); + } + } else { + for (int c = 0; c < cascade_input; ++c) { + int overlap = (i_ur - 1) * jcp.stride_w + input_overlap; + int off_input + = ((overlap + c - pad_offset) * reg_repeats + r) + * simd_w; + Vmm vmm_input = get_input_reg( + ((overlap + c) % jcp.kw) * reg_repeats + r); + uni_vmovups(vmm_input, + ptr[reg_tmp_input + off_input * sizeof(float)]); + } + } + + for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) { + int io_overlap = i_kw + (i_ur * jcp.stride_w); + + /* Don't apply FMAs that fall into the padded region */ + if (io_overlap - l_pad < 0 + || io_overlap - jcp.l_pad >= right_border) + continue; + + Vmm vmm_input = get_input_reg( + ((io_overlap - l_pad) % jcp.kw) * reg_repeats + r); + Vmm vmm_acc = get_acc_reg(i_kw * reg_repeats + r); + Vmm vmm_aux = isa == sse42 ? get_aux_reg() : vmm_input; + if (isa == sse42) + uni_vmovups(vmm_aux, vmm_input); + uni_vfmadd231ps(vmm_acc, vmm_aux, vmm_output); + } + } + } +} + +template +inline void +jit_uni_dw_conv_bwd_weights_kernel_f32::compute_bias_step_unroll( + const int unroll_w) { + for (int r = 0; r < reg_repeats; ++r) { + for (int i = 0; i < unroll_w; ++i) { + Vmm vmm_bias = get_bias_reg(r); + int off_output = (i * reg_repeats + r) * simd_w; + if (isa == sse42) { + /* Need to support unaligned address loads for SSE42*/ + Vmm vmm_output = get_output_reg(1 + r); + uni_vmovups(vmm_output, + ptr[reg_tmp_output + off_output * sizeof(float)]); + uni_vaddps(vmm_bias, vmm_bias, vmm_output); + } else { + uni_vaddps(vmm_bias, vmm_bias, + vmmword[reg_tmp_output + off_output * sizeof(float)]); + } + } + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::store_filter() { + for (int r = 0; r < reg_repeats; ++r) { + const int reg_set = r * jcp.kw; + for (int i = 0; i < jcp.kw; ++i) { + int off_filter = (i + reg_set) * simd_w; + Vmm vmm_acc = get_acc_reg(i + reg_set); + uni_vmovups(vmmword[reg_tmp_filter + off_filter * sizeof(float)], + vmm_acc); + } + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::store_bias() { + for (int r = 0; r < reg_repeats; ++r) { + Vmm vmm_bias = get_bias_reg(r); + uni_vmovups( + vmmword[reg_bias_baddr + r * simd_w * sizeof(float)], vmm_bias); + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_bias_loop( + const int block_size) { + Label oh_label; + Label ow_blk_label; + + const int unroll_w = nstl::min(block_size, jcp.ow); + const int unroll_w_trips = jcp.ow / unroll_w; + const int tail_w = jcp.ow > block_size ? jcp.ow % block_size : 0; + + const int ch_offset = jcp.ch_block; + + mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]); + mov(reg_oh_worksize, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]); + + mov(reg_tmp_output, reg_output_baddr); + L(oh_label); + { + + mov(iter_ow_blk, unroll_w_trips); + L(ow_blk_label); + { + + compute_bias_step_unroll(unroll_w); + add(reg_tmp_output, unroll_w * ch_offset * sizeof(float)); + + dec(iter_ow_blk); + cmp(iter_ow_blk, 0); + jg(ow_blk_label, T_NEAR); + } + + if (tail_w > 0) { + compute_bias_step_unroll(tail_w); + add(reg_tmp_output, tail_w * ch_offset * sizeof(float)); + } + + inc(reg_oh); + cmp(reg_oh, reg_oh_worksize); + jl(oh_label, T_NEAR); + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_zero_filter() { + + const int ch_offset = jcp.ch_block; + + Label kh_loop_label, skip_zeroing_label; + + mov(reg_exec_flags, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]); + and_(reg_exec_flags, FLAG_ZERO_FILTER); + test(reg_exec_flags, reg_exec_flags); + je(skip_zeroing_label); + + zero_filter(); + + mov(reg_tmp_filter, reg_filter_baddr); + mov(reg_kh, jcp.kh); + L(kh_loop_label); + { + store_filter(); + + add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); + dec(reg_kh); + cmp(reg_kh, 0); + jg(kh_loop_label); + } + + /* Comeback pointers */ + sub(reg_tmp_filter, jcp.kh * jcp.kw * ch_offset * sizeof(float)); + + L(skip_zeroing_label); +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_h_step( + int unroll_w, int l_pad, int pad_offset, int ow_block) { + + const int ch_offset = jcp.ch_block; + + Label kh_loop_label, skip_loop_label; + + cmp(reg_kh_count, 0); + je(skip_loop_label, T_NEAR); + + mov(reg_kh, reg_kh_count); + L(kh_loop_label); + { + load_filter(); + compute_ow_step_unroll(unroll_w, l_pad, pad_offset, ow_block); + store_filter(); + + add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); + add(reg_tmp_input, jcp.iw * ch_offset * sizeof(float)); + dec(reg_kh); + cmp(reg_kh, 0); + jg(kh_loop_label); + } + + /* Comeback pointers */ + Label kh_comeback_label; + mov(reg_kh, reg_kh_count); + L(kh_comeback_label); + { + sub(reg_tmp_input, jcp.iw * ch_offset * sizeof(float)); + sub(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); + dec(reg_kh); + cmp(reg_kh, 0); + jg(kh_comeback_label, T_NEAR); + } + + L(skip_loop_label); +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_h_loop( + int unroll_w, int l_pad, int pad_offset, int ow_block) { + + const size_t io_overlap = jcp.ih / jcp.stride_h < jcp.oh ? + jcp.ih / jcp.stride_h - 1 : + jcp.oh - jcp.b_pad - 1; + const int ch_offset = jcp.ch_block; + const int t_overlap_off = jcp.t_pad % jcp.stride_h == 0 ? jcp.stride_h : 1; + const int b_overlap_off = jcp.b_pad % jcp.stride_h == 0 ? jcp.stride_h : 1; + + Label tpad_loop_label, h_loop_label, skip_tpad_label, skip_bpad_label, + end_h_loop_label; + + mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]); + mov(reg_oh_worksize, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]); + mov(reg_kh_count, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, kh_count)]); + + mov(reg_tmp_output, reg_output_baddr); + mov(reg_tmp_input, reg_input_baddr); + mov(reg_tmp_filter, reg_filter_baddr); + + L(h_loop_label); + { + + compute_h_step(unroll_w, l_pad, pad_offset, ow_block); + + add(reg_tmp_output, jcp.ow * ch_offset * sizeof(float)); + + /* If within the top_pad region */ + if (jcp.t_pad > 0) { + /* Skip t_pad area if no longer in initial h_block */ + cmp(reg_oh, jcp.t_pad); + jg(skip_tpad_label, T_NEAR); + + cmp(reg_kh_count, jcp.kh); + jge(skip_tpad_label, T_NEAR); + + add(reg_kh_count, t_overlap_off); + sub(reg_tmp_filter, + t_overlap_off * jcp.kw * ch_offset * sizeof(float)); + + /* kernel has moved beyond padding (adjust for stride effects) */ + if (jcp.t_pad % jcp.stride_h != 0) { + int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h; + add(reg_tmp_input, + inp_corr * jcp.iw * ch_offset * sizeof(float)); + } + jmp(tpad_loop_label, T_NEAR); + } + + L(skip_tpad_label); + + cmp(reg_oh, io_overlap); + jl(skip_bpad_label, T_NEAR); + sub(reg_kh_count, b_overlap_off); + + L(skip_bpad_label); + add(reg_tmp_input, jcp.stride_h * jcp.iw * ch_offset * sizeof(float)); + + L(tpad_loop_label); + + cmp(reg_oh, jcp.ih / jcp.stride_h); + jge(end_h_loop_label, T_NEAR); + + inc(reg_oh); + + cmp(reg_oh, reg_oh_worksize); + jl(h_loop_label, T_NEAR); + } + L(end_h_loop_label); +} + +template +inline void +jit_uni_dw_conv_bwd_weights_kernel_f32::compute_ow_block_unroll() { + + const int ch_offset = jcp.ch_block; + int ow = jcp.ow; + int pad_offset = 0; + int l_pad = jcp.l_pad; + + /* Calculate effective padding */ + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + + /* Is this strictly defined by: + * -code-size (?) + * -address size (?) */ + const int max_unroll_w = 30; + const int block_size = 15; + + int unroll_w_tail = 0; + int unroll_w = 0; + int unroll_w_trips = 0; + + if (jcp.ow > max_unroll_w) { + unroll_w = nstl::min(block_size, jcp.ow); + unroll_w_trips = ow / unroll_w; + /* calculate tail */ + unroll_w_tail = ow % unroll_w; + /* Perform some rebalancing if tail too small*/ + if ((unroll_w_tail == 0 && r_pad != 0) + || (r_pad > 0 && r_pad >= unroll_w_tail)) { + if (unroll_w_trips > 1) { + unroll_w_tail += unroll_w; + unroll_w_trips--; + } else { + /* Idealy, this case shouldn't happen */ + unroll_w_tail += (unroll_w - unroll_w / 2); + unroll_w = unroll_w / 2; + } + } + } else { + unroll_w = jcp.ow; + unroll_w_trips = nstl::max(1, ow / unroll_w); + } + if (jcp.with_bias) { + Label skip_load_bias; + mov(reg_bias_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, bias)]); + + zero_bias(); + + mov(reg_exec_flags, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]); + and_(reg_exec_flags, FLAG_ZERO_BIAS); + test(reg_exec_flags, reg_exec_flags); + jne(skip_load_bias); + + load_bias(); + + L(skip_load_bias); + compute_bias_loop(block_size); + + store_bias(); + } + + /* Pass filter address, then offset for h_padding. */ + compute_zero_filter(); + mov(reg_kh_offset, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter_pad_off)]); + add(reg_filter_baddr, reg_kh_offset); + + /* compute left padded block */ + if (l_pad) { + compute_h_loop(unroll_w, l_pad, 0, 0); + add(reg_output_baddr, unroll_w * ch_offset * sizeof(float)); + add(reg_input_baddr, + unroll_w * jcp.stride_w * ch_offset * sizeof(float)); + unroll_w_trips--; + pad_offset = l_pad; + l_pad = 0; + } + + /* compute middle block */ + Label ow_blk_label; + + /* Insert loop for 'ow' block when middle block needs to execute more + * than once */ + bool do_ow_blk_loop = unroll_w_trips > 1; + if (do_ow_blk_loop) { + mov(iter_ow_blk, unroll_w_trips); + L(ow_blk_label); + } + if (unroll_w_trips > 0) { + compute_h_loop(unroll_w, l_pad, pad_offset, 0); + add(reg_output_baddr, unroll_w * ch_offset * sizeof(float)); + add(reg_input_baddr, + unroll_w * jcp.stride_w * ch_offset * sizeof(float)); + } + if (do_ow_blk_loop) { + dec(iter_ow_blk); + cmp(iter_ow_blk, 0); + jg(ow_blk_label, T_NEAR); + } + + /* compute right padded block */ + if (unroll_w_tail) { + compute_h_loop(unroll_w_tail, 0, pad_offset, jcp.ow - unroll_w_tail); + } +} + +template +void jit_uni_dw_conv_bwd_weights_kernel_f32::generate() { + preamble(); + + mov(reg_input_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, input)]); + mov(reg_output_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, output)]); + mov(reg_filter_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter)]); + + compute_ow_block_unroll(); + + this->postamble(); +} + +template +status_t jit_uni_dw_conv_bwd_weights_kernel_f32::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d, int nthreads) { + if (!mayiuse(isa)) + return status::unimplemented; + + jcp.ngroups = diff_weights_d.dims()[0]; + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + + jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.oc, jcp.ic); + + if (!jcp.is_depthwise) + return status::unimplemented; + + jcp.ch_block = isa == avx512_common ? 16 : 8; + + jcp.mb = src_d.dims()[0]; + + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + + jcp.kh = diff_weights_d.dims()[3]; + jcp.kw = diff_weights_d.dims()[4]; + + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + + jcp.t_pad = cd.padding[0][0]; + jcp.b_pad = cd.padding[1][0]; + + jcp.l_pad = cd.padding[0][1]; + jcp.r_pad = cd.padding[1][1]; + + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + + jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag + && jcp.ngroups % jcp.ch_block == 0 && jcp.dilate_h == 0 + && jcp.dilate_w == 0 && jcp.kw <= 3 + && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 + && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1; + if (!args_ok) + return status::unimplemented; + + jcp.nb_ch = jcp.ngroups / jcp.ch_block; + + /* kernel applicability check wrt boundaries + * the conditions are quite general across the kernels we have, + * but ideally the check should belong to a specific kernel... */ + const int max_hpad = (jcp.kh - 1 + 1) / 2; + const int max_wpad = (jcp.kw - 1 + 1) / 2; + const bool boundaries_ok = true && jcp.t_pad <= max_hpad + && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad + && jcp.r_pad <= max_wpad; + if (!boundaries_ok) + return status::unimplemented; + + balance(jcp, nthreads); + + return status::success; +} + +template +void jit_uni_dw_conv_bwd_weights_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + /* Notes: if splitting thread work on 'mb', then a reduction has to take + * place. Hence, book a per-thread, local weights-buffer for the + * reduction */ + if (jcp.nthr_mb > 1) { + const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw; + scratchpad.book(key_conv_wei_reduction, + sizeof(float) * wei_size * (jcp.nthr_mb - 1)); + + if (jcp.with_bias) + scratchpad.book(key_conv_bia_reduction, + sizeof(float) * jcp.ngroups * (jcp.nthr_mb - 1)); + } +} + +template +void jit_uni_dw_conv_bwd_weights_kernel_f32::balance(jit_conv_conf_t &jcp, + int nthreads) { + jcp.nthr = nthreads; + jcp.nthr_g = jcp.nthr_mb = 1; + + /* Basic-Heuristics for parallel strategy: + * 1) Tries to parallel on the number of Groups (g) where tasks are + * independent. Otherwise, + * 2) Tries to split the work across g and MiniBatch (mb). + * Parallelizing on mb requires computing a reduction for weights. + * + * NOTE: because of 'task partitioning' scheme, there will be unbalanced + * per-thread load when the number of threads is high (e.g. > 16). + */ + jcp.nthr_g = nstl::min(jcp.nb_ch, jcp.nthr); + jcp.nthr_mb = nstl::min(nstl::max(1, jcp.nthr / jcp.nthr_g), jcp.mb); + + jcp.nthr = jcp.nthr_g * jcp.nthr_mb; +} + +template struct jit_uni_dw_conv_bwd_weights_kernel_f32; +template struct jit_uni_dw_conv_bwd_weights_kernel_f32; +template struct jit_uni_dw_conv_bwd_weights_kernel_f32; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp new file mode 100644 index 000000000000..9c08fc4a0980 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp @@ -0,0 +1,253 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_DW_CONV_KERNEL_F32_HPP +#define JIT_UNI_DW_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_uni_dw_conv_fwd_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32) + + jit_uni_dw_conv_fwd_kernel_f32(jit_conv_conf_t ajcp) + : jcp(ajcp), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + ~jit_uni_dw_conv_fwd_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using Vmm = typename utils::conditional3::type; + using reg64_t = const Xbyak::Reg64; + const Xbyak::AddressFrame &vmmword = (isa == sse42) + ? xword : (isa == avx2) ? yword : zword; + const int vlen = cpu_isa_traits::vlen; + + // dw convolution + reg64_t reg_input = r8; + reg64_t aux_reg_input = r9; + reg64_t aux1_reg_input = r10; + reg64_t reg_kernel = r11; + reg64_t aux_reg_kernel = r12; + reg64_t aux1_reg_kernel = r13; + reg64_t reg_output = r14; + reg64_t reg_bias = r15; + reg64_t reg_kh = rax; + reg64_t reg_kw = rbx; + reg64_t iter_kh = rdx; + reg64_t iter_kw = rsi; + reg64_t reg_ur_w = rbp; + reg64_t reg_ch_blocks = aux1_reg_input; + reg64_t imm_addr64 = aux1_reg_input; + + inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } + inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } + + inline void load_src(int ur_ch_blocks, int ur_w); + inline void apply_filter(int ur_ch_blocks, int ur_w); + inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w); + inline void apply_activation(int ur_ch_blocks, int ur_w); + inline void store_dst(int ur_ch_blocks, int ur_w); + inline void loop_body(int ur_ch_blocks); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + void generate(); +}; + +template +struct jit_uni_dw_conv_bwd_data_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32) + + jit_uni_dw_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) { + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using Vmm = typename utils::conditional3::type; + using reg64_t = const Xbyak::Reg64; + + inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } + inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } + + reg64_t reg_ddst = rax; + reg64_t aux_reg_ddst = r8; + reg64_t aux1_reg_ddst = abi_not_param1; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r10; + reg64_t aux1_reg_kernel = rbp; + reg64_t reg_dsrc = rsi; + + reg64_t reg_ur_str_w = r9; + reg64_t reg_ch_blocks = rbx; + + reg64_t iter_kh = r11; + reg64_t iter_kw = r12; + reg64_t reg_kh = r13; + reg64_t reg_kw = r14; + + inline void loop_body(int ur_ch_blocks); + inline void load_ddst(int ur_ch_blocks, int ur_str_w); + inline void apply_filter(int ur_ch_blocks, int ur_str_w); + inline void store_dsrc(int ur_ch_blocks, int ur_str_w); + + void generate(); +}; + +template +struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator { + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_weights_kernel_f32) + + jit_uni_dw_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) : jcp(ajcp) { + this->generate(); + jit_ker = (void (*)(jit_dw_conv_call_s *)) this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d, int nthreads); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + static void balance(jit_conv_conf_t &jcp, int nthreads); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_dw_conv_call_s *); + +private: + using Vmm = typename utils::conditional3::type; + using reg64_t = const Xbyak::Reg64; + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int reg_repeats = (isa == sse42) ? 2 : 1; + + const Xbyak::AddressFrame &vmmword + = (isa == sse42) ? xword : (isa == avx2) ? yword : zword; + + /* XXX: offset between input and accummulators is 3, therefore, assume 'kw' + * is no larger than 3*/ + inline Vmm get_bias_reg(int idx = 0) { return Vmm(idx); } + inline Vmm get_output_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_input_reg(int idx) { return Vmm(idx + 4 * reg_repeats + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 1 * reg_repeats + 1); } + inline Vmm get_aux_reg() { return Vmm(0); } + + reg64_t reg_tmp_input = r9; + reg64_t reg_tmp_output = r10; + reg64_t reg_tmp_filter = r13; + reg64_t reg_kh_offset = rax; + + /* parameter passed by driver into kernel */ + Xbyak::Reg8 reg_exec_flags = bl; + + reg64_t reg_oh_worksize = r14; + reg64_t reg_oh = rax; + + reg64_t iter_ow_blk = r11; + + reg64_t reg_kh = rsi; + reg64_t reg_kh_count = rdx; + + /* Base addresses for convolution parameters. */ + reg64_t reg_input_baddr = r15; + reg64_t reg_output_baddr = r12; + reg64_t reg_filter_baddr = abi_not_param1; + reg64_t reg_bias_baddr = r13; + + /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs + */ + inline void compute_ow_step_unroll( + int unroll_w, int l_pad, int pad_offset, int ow_block); + + /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */ + inline void compute_h_step( + int unroll_w, int l_pad, int pad_offset, int ow_block); + inline void compute_h_loop( + int unroll_w, int l_pad, int pad_offset, int ow_block); + + /* Write 'width' micro-kernel JITs; depending on the padding and convolution + * size, write a micro-kernel for the left ow-block, middle ow-block(s), and + * right ow-block.*/ + inline void compute_ow_block_unroll(); + + inline void compute_zero_filter(); + inline void load_filter(); + inline void zero_filter(); + inline void load_bias(); + inline void zero_bias(); + inline void compute_bias_step_unroll(const int unroll_w); + inline void compute_bias_loop(const int block_size); + inline void store_filter(); + inline void store_bias(); + + void generate(); +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp new file mode 100644 index 000000000000..58449601a343 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp @@ -0,0 +1,427 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" + +#include "jit_uni_dw_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +template +void _jit_uni_dw_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = kernel_->jcp; + + if (pd()->wants_padded_bias()) { + auto padded_bias = this->scratchpad(ctx).template get( + key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + int dil_h = jcp.dilate_h + 1; + int dil_w = jcp.dilate_w + 1; + int str_h = jcp.stride_h; + int str_w = jcp.stride_w; + + auto kernel_params = [&](int ur_w_step, int ow, int oh, int ih, int kh, + int kh_padding, int ch, int ch_num, int n) { + auto par_conv = jit_conv_call_s(); + + const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w)); + const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w + + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw; + + const int iw = nstl::max((ow*str_w - jcp.l_pad + + div_up(i_l_overflow, dil_w)*dil_w), 0); + const int kw = div_up(i_l_overflow, dil_w); + + const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w) + - div_up(i_r_overflow, dil_w); + + par_conv.src = &src[src_d.blk_off(n, ch, ih, iw)]; + par_conv.dst = &dst[dst_d.blk_off(n, ch, oh, ow)]; + + par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, kh, kw)]; + if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block)]; + + par_conv.kh_padding = (size_t)nstl::max(0, kh_padding); + par_conv.kw_padding = (size_t)nstl::max(0, kw_padding); + + par_conv.ur_w = (size_t)ur_w_step; + + par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch; + + return par_conv; + }; + + const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); + parallel_nd(jcp.mb, chb_work, jcp.oh, + [&](int n, int chb, int oh) { + int ch = chb * jcp.nb_ch_blocking; + int ch_num = jcp.nb_ch_blocking; + + const int i_t_overflow = nstl::max(0, (int)(jcp.t_pad - oh*str_h)); + const int i_b_overflow = nstl::max(jcp.ih, + (int)(oh*str_h + (jcp.kh - 1)*dil_h - jcp.t_pad + 1)) - jcp.ih; + + const int ih = nstl::max((int)(oh*str_h - jcp.t_pad + + div_up(i_t_overflow, dil_h)*dil_h), 0); + const int kh = div_up(i_t_overflow, dil_h); + const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h) + - div_up(i_b_overflow, dil_h); + + // left border + int ow = 0; + int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow); + int ur_w_step = 1; + for (; ow < l_border; ow++) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, + kh, kh_padding, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + + // main loop + ur_w_step = (jcp.iw - (jcp.kw - 1)*dil_w + jcp.l_pad - 1) + / jcp.stride_w - ow + 1; + if (ur_w_step > 0) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, + kh, kh_padding, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + + ow += ur_w_step; + } + + // right border + ur_w_step = 1; + for (; ow < jcp.ow; ow++) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, + kh, kh_padding, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +template struct _jit_uni_dw_convolution_fwd_t; +template struct _jit_uni_dw_convolution_fwd_t; +template struct _jit_uni_dw_convolution_fwd_t; + +template +void _jit_uni_dw_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + auto kernel_params = [&](int ur_str_w, int iw, int oh, int ih, + int i_t_overflow, int i_b_overflow, int stride_off_h, + int ch, int ch_num, int n) { + auto par_conv = jit_conv_call_s(); + + const int i_l_overflow = nstl::max(0, (jcp.kw - 1 - iw - jcp.l_pad)); + const int i_r_overflow = nstl::max(0, (jcp.kw - 1 - (jcp.iw - 1 - iw) + - jcp.r_pad)); + + int ow = iw + jcp.l_pad - i_r_overflow; + int stride_off_w = ow % jcp.stride_w; + ow /= jcp.stride_w; + + par_conv.src = &diff_src[diff_src_d.blk_off(n, ch, ih, iw)]; + par_conv.dst = &diff_dst[diff_dst_d.blk_off(n, ch, oh, ow)]; + par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, i_b_overflow + + stride_off_h, i_r_overflow + stride_off_w)]; + + par_conv.kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow + - stride_off_h); + par_conv.kw_padding = nstl::max(0, jcp.kw - i_l_overflow - i_r_overflow + - stride_off_w); + + par_conv.ur_str_w = ur_str_w; + + par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch; + + return par_conv; + }; + + const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); + parallel_nd(jcp.mb, chb_work, jcp.ih, + [&](int n, int chb, int ih) { + int ch = chb * jcp.nb_ch_blocking; + int ch_num = jcp.nb_ch_blocking; + + const int i_t_overflow = nstl::max(0, (int)(jcp.kh - 1 - ih + - jcp.t_pad)); + const int i_b_overflow = nstl::max(0, (int)(jcp.kh - 1 + - (jcp.ih - 1 - ih) - jcp.b_pad)); + + int oh = ih + jcp.t_pad - i_b_overflow; + int stride_off_h = oh % jcp.stride_h; + oh /= jcp.stride_h; + + for (int i_str_w = 0; i_str_w < jcp.stride_w; i_str_w++) { + // left border + int iw = i_str_w; + int l_border = nstl::min(jcp.kw - 1 - jcp.l_pad, jcp.iw); + int ur_str_w = 1; + for (; iw < l_border; iw += jcp.stride_w) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + + // main loop + ur_str_w = nstl::min((jcp.iw - jcp.kw + jcp.r_pad - iw) + / jcp.stride_w, jcp.iw); + if (ur_str_w > 0) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + + iw += ur_str_w * jcp.stride_w; + } + + // right border + ur_str_w = 1; + for (; iw < jcp.iw; iw += jcp.stride_w) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + } + }); +} + +template struct _jit_uni_dw_convolution_bwd_data_t; +template struct _jit_uni_dw_convolution_bwd_data_t; +template struct _jit_uni_dw_convolution_bwd_data_t; + +template +_jit_uni_dw_convolution_bwd_weights_t:: +_jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), acc_ker_(nullptr) +{ + kernel_ = new jit_uni_dw_conv_bwd_weights_kernel_f32(pd()->jcp_); + if (pd()->jcp_.nthr_mb > 1 && do_parallel_reduction()) + acc_ker_ = new cpu_accumulator_1d_t(); +} + +template +void _jit_uni_dw_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto diff_wei_reduction_buf = + scratchpad(ctx).template get(key_conv_wei_reduction); + auto diff_bia_reduction_buf = + scratchpad(ctx).template get(key_conv_bia_reduction); + + const auto &jcp = kernel_->jcp; + + /* Used when executing a parallel reduction */ + simple_barrier::ctx_t reduction_bctx; + simple_barrier::ctx_init(&reduction_bctx); + + const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw; + const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0; + + const int ch_block = jcp.ch_block; + + auto set_kernel_params = [&](jit_dw_conv_call_s *conv_params, + const int batch, const int group, const int oh_start, + const int work_size, const unsigned char exec_flag, + const size_t kh_padding, const size_t filter_off) { + const int tpad_underflow_off = jcp.t_pad - filter_off; + + conv_params->exec_flags = exec_flag; + conv_params->kh_count = jcp.kh - kh_padding; + + const int oh_s = oh_start; + const int oh_e = oh_start + work_size; + const int ih_s = oh_s * jcp.stride_h; + + conv_params->filter_pad_off + = filter_off * jcp.kw * ch_block * sizeof(float); + conv_params->oh_index = oh_s; + conv_params->oh_count = oh_e; + + size_t diff_dst_off + = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh + + oh_start) + * jcp.ow; + + size_t src_off = ((batch * (jcp.ngroups / ch_block) + group) * jcp.ih + + ih_s - tpad_underflow_off) * jcp.iw; + + conv_params->output = &diff_dst[diff_dst_off * ch_block]; + conv_params->input = &src[src_off * ch_block]; + }; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + assert(nthr == jcp.nthr); + + auto conv_params = jit_dw_conv_call_s(); + const int h_block_size = 15; + + /* assign iteration space to thread */ + const int ithr_g = ithr % jcp.nthr_g; + const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb; + + /* split dimensions */ + int g_start{ 0 }, g_end{ 0 }; + balance211(jcp.nb_ch, jcp.nthr_g, ithr_g, g_start, g_end); + + int mb_start{ 0 }, mb_end{ 0 }; + balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end); + + auto diff_wei = ithr_mb == 0 + ? diff_weights : diff_wei_reduction_buf + (ithr_mb - 1) * wei_size; + auto diff_bia = ithr_mb == 0 + ? diff_bias : diff_bia_reduction_buf + (ithr_mb - 1) * bias_size; + + for (int g = g_start; g < g_end; ++g) { + unsigned char zero_filter_flag = FLAG_ZERO_FILTER; + unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0; + + size_t diff_wei_off = g * jcp.kh * jcp.kw; + conv_params.filter = &diff_wei[diff_wei_off * ch_block]; + + if (jcp.with_bias) + conv_params.bias = &diff_bia[g * ch_block]; + + for (int mb = mb_start; mb < mb_end; ++mb) { + int oh = 0; + while (oh < jcp.oh) { + const int h_work = nstl::min(h_block_size, jcp.oh - oh); + auto kh_t_padding = nstl::max(0, jcp.t_pad - oh); + auto kh_b_padding + = (oh * jcp.stride_h + jcp.kh - 1 > jcp.ih) ? + jcp.b_pad - (h_work - 1) : + 0; + + set_kernel_params(&conv_params, mb, g, oh, h_work, + zero_filter_flag | zero_bias_flag, + kh_t_padding + kh_b_padding, kh_t_padding); + kernel_->jit_ker(&conv_params); + + zero_bias_flag &= ~FLAG_ZERO_BIAS; + zero_filter_flag &= ~FLAG_ZERO_FILTER; + oh += h_work; + } + } + } + + if (do_parallel_reduction() && jcp.nthr_mb > 1) { + size_t reduct_start{ 0 }, reduct_end{ 0 }; + balance211(wei_size, nthr, ithr, reduct_start, reduct_end); + + const int acc_size = reduct_end - reduct_start; + const size_t reduct_off = reduct_start; + auto *acc_data = diff_weights + reduct_off; + + simple_barrier::barrier(&reduction_bctx, nthr); + + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + auto *src_data = diff_wei_reduction_buf + + (thr_mb - 1) * wei_size + reduct_off; + acc_ker_->accumulate(acc_data, src_data, acc_size); + } + } + }); + + if (jcp.nthr_mb <= 1) return; + + /* Apply single-threaded 'mb' reduction */ + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + size_t mb_accum_offset = (thr_mb - 1) * wei_size; + size_t b_accum_offset = (thr_mb - 1) * bias_size; + + for (int g = 0; g < jcp.nb_ch; ++g) { + /* Reduction on Bias */ + if (jcp.with_bias) { + PRAGMA_OMP_SIMD() + for (int g_block = 0; g_block < ch_block; ++g_block) { + size_t bias_offset = g * ch_block + g_block; + diff_bias[bias_offset] += diff_bia_reduction_buf[ + b_accum_offset + bias_offset]; + } + } + + if (do_parallel_reduction()) continue; + + for (int kh = 0; kh < jcp.kh; ++kh) + for (int kw = 0; kw < jcp.kw; ++kw) + { + size_t wei_offset = (g * jcp.kh + kh) * jcp.kw + kw; + PRAGMA_OMP_SIMD() + for (int g_block = 0; g_block < ch_block; ++g_block) { + const size_t off = wei_offset * ch_block + g_block; + diff_weights[off] += + diff_wei_reduction_buf[mb_accum_offset + off]; + } + } + } + } +} + +template struct _jit_uni_dw_convolution_bwd_weights_t; +template struct _jit_uni_dw_convolution_bwd_weights_t; +template struct _jit_uni_dw_convolution_bwd_weights_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp new file mode 100644 index 000000000000..ca53749ec2e7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp @@ -0,0 +1,266 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_DW_CONVOLUTION_HPP +#define CPU_JIT_UNI_DW_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_barrier.hpp" +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_uni_dw_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct _jit_uni_dw_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), + _jit_uni_dw_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_uni_dw_conv_fwd_kernel_f32::init_conf( + jcp_, *desc(), src_md(), *weights_md(), *dst_md(), *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_dw_conv_fwd_kernel_f32::init_scratchpad(scratchpad, + jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + _jit_uni_dw_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_dw_conv_fwd_kernel_f32(pd()->jcp_); } + + ~_jit_uni_dw_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_dw_conv_fwd_kernel_f32 *kernel_; +}; + +using jit_avx512_common_dw_convolution_fwd_t = + _jit_uni_dw_convolution_fwd_t; +using jit_avx2_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t; +using jit_sse42_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t; + +template +struct _jit_uni_dw_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), + _jit_uni_dw_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + + if (!ok) return status::unimplemented; + + status_t status = jit_uni_dw_conv_bwd_data_kernel_f32:: + init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_dw_conv_bwd_data_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + _jit_uni_dw_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_dw_conv_bwd_data_kernel_f32(pd()->jcp_); } + ~_jit_uni_dw_convolution_bwd_data_t() { delete kernel_; }; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_dw_conv_bwd_data_kernel_f32 *kernel_; +}; + +using jit_avx512_common_dw_convolution_bwd_data_t = + _jit_uni_dw_convolution_bwd_data_t; +using jit_avx2_dw_convolution_bwd_data_t = + _jit_uni_dw_convolution_bwd_data_t; +using jit_sse42_dw_convolution_bwd_data_t = + _jit_uni_dw_convolution_bwd_data_t; + +template +struct _jit_uni_dw_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), + _jit_uni_dw_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const int max_threads = mkldnn_in_parallel() + ? 1 : mkldnn_get_max_threads(); + + status_t status = jit_uni_dw_conv_bwd_weights_kernel_f32:: + init_conf(jcp_, *desc(), *src_md(), *diff_weights_md(), + *diff_dst_md(), max_threads); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_dw_conv_bwd_weights_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + _jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd); + ~_jit_uni_dw_convolution_bwd_weights_t() { + delete kernel_; + delete acc_ker_; + }; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + bool do_parallel_reduction() const { return false; } + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_dw_conv_bwd_weights_kernel_f32 *kernel_; + cpu_accumulator_1d_t *acc_ker_; +}; + +using jit_avx512_common_dw_convolution_bwd_weights_t = + _jit_uni_dw_convolution_bwd_weights_t; +using jit_avx2_dw_convolution_bwd_weights_t = + _jit_uni_dw_convolution_bwd_weights_t; +using jit_sse42_dw_convolution_bwd_weights_t = + _jit_uni_dw_convolution_bwd_weights_t; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp new file mode 100644 index 000000000000..2af643587183 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp @@ -0,0 +1,1142 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_uni_eltwise.hpp" + +#define GET_OFF(field) offsetof(jit_args, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +template +void jit_uni_eltwise_injector_f32::injector_preamble(size_t start_idx, + size_t end_idx) { + preserved_vecs_count = 0; + vecs_to_preserve = (size_t)aux_vecs_count(alg_); + start_idx_tail = start_idx; + + // For sse42 mask register has to be Xmm(0) + if (isa == sse42 && vecs_to_preserve > 0) { + size_t idx = 0; + assert(idx < start_idx); + preserved_vec_idxs[preserved_vecs_count++] = idx; + } + + for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) { + if (preserved_vecs_count >= vecs_to_preserve) break; + if (start_idx <= idx && idx < end_idx) continue; + + preserved_vec_idxs[preserved_vecs_count++] = idx; + } + + size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count; + for (size_t i = 0; i < preserved_vecs_count_tail; i++) { + preserved_vec_idxs[preserved_vecs_count++] = start_idx_tail++; + } + + assert(preserved_vecs_count == vecs_to_preserve); + + if (save_state_) { + h->push(p_table); + + if (preserved_vecs_count) + h->sub(h->rsp, preserved_vecs_count * vlen); + + for (size_t i = 0; i < preserved_vecs_count; ++i) + h->uni_vmovups(h->ptr[h->rsp + i * vlen], + Vmm(preserved_vec_idxs[i])); + + load_table_addr(); + } + + assign_regs(); +} + +template +void jit_uni_eltwise_injector_f32::injector_preamble_tail(size_t start_idx) +{ + size_t tail_vecs_to_preserve = start_idx_tail - start_idx; + if (tail_vecs_to_preserve == 0) return; + + const int idx_off = vecs_to_preserve - tail_vecs_to_preserve; + + if (save_state_) { + if (idx_off) + h->add(h->rsp, idx_off * vlen); + + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), + h->ptr[h->rsp + i * vlen]); + } + + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve; + + if (save_state_) { + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + h->uni_vmovups(h->ptr[h->rsp + i * vlen], + Vmm(preserved_vec_idxs[idx_off + i])); + + if (idx_off) + h->sub(h->rsp, idx_off * vlen); + } + + assign_regs(); +} + +template +void jit_uni_eltwise_injector_f32::injector_postamble() { + if (!save_state_) return; + + for (size_t i = 0; i < preserved_vecs_count; ++i) + h->uni_vmovups(Vmm(preserved_vec_idxs[i]), + h->ptr[h->rsp + i * vlen]); + + if (preserved_vecs_count) + h->add(h->rsp, preserved_vecs_count * vlen); + + h->pop(p_table); +} + +template +void jit_uni_eltwise_injector_f32::assign_regs() { + vmm_mask = Vmm(preserved_vec_idxs[0]); + vmm_aux0 = Vmm(preserved_vec_idxs[0]); + vmm_aux1 = Vmm(preserved_vec_idxs[1]); + vmm_aux2 = Vmm(preserved_vec_idxs[2]); + vmm_aux3 = Vmm(preserved_vec_idxs[3]); + vmm_aux4 = Vmm(preserved_vec_idxs[4]); +} + +template +void jit_uni_eltwise_injector_f32::exp_compute_vector(const Vmm &vmm_src) { + h->uni_vminps(vmm_src, vmm_src, table_val(10)); + h->uni_vmaxps(vmm_src, vmm_src, table_val(11)); + h->uni_vmovups(vmm_aux0, vmm_src); + //calculate exp(x) + // fx = x * log2ef + 0.5 + h->uni_vmulps(vmm_src, vmm_src, table_val(2)); + h->uni_vaddps(vmm_src, vmm_src, table_val(1)); + + // tmp = floorf(fx) + if (isa == avx512_common) { + h->vcvtps2dq(vmm_aux1 | h->T_rd_sae, vmm_src); + h->vcvtdq2ps(vmm_aux1, vmm_aux1); + + h->vcmpps(k_mask, vmm_aux1, vmm_src, _cmp_nle_us); + h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0)); + + h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux3); + } else { + h->uni_vroundps(vmm_aux1, vmm_src, _op_floor); + } + + //keep fx for further computations + h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx + + //x = x - fx * ln2 + h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val(3)); + + // compute 2^n + h->uni_vcvtps2dq(vmm_aux1, vmm_src); + h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4)); + h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx + + // y = p5 + h->uni_vmovups(vmm_src, table_val(9)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(8)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(7)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(6)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(0)); + // y = y * x + p0 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(5)); //exp(q) + // y = y * 2^n + h->uni_vmulps(vmm_src, vmm_src, vmm_aux1); +} + +template +void jit_uni_eltwise_injector_f32::relu_compute_vector(const Vmm &vmm_src) +{ + const int alpha_off = 0, zero_off = 1; + + h->uni_vmovups(vmm_aux1, vmm_src); + if (isa == sse42) { + h->movups(vmm_mask, vmm_src); + h->mulps(vmm_src, table_val(alpha_off)); + h->cmpps(vmm_mask, table_val(zero_off), _cmp_nle_us); + h->blendvps(vmm_src, vmm_aux1); + } else if (isa == avx2) { + h->vmulps(vmm_src, vmm_src, table_val(alpha_off)); + h->vcmpgtps(vmm_mask, vmm_aux1, table_val(zero_off)); + h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask); + } else if (isa == avx512_common) { + h->vmulps(vmm_src, vmm_src, table_val(alpha_off)); + h->vcmpps(k_mask, vmm_aux1, table_val(zero_off), _cmp_nle_us); + h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1); + } +} + +template +void jit_uni_eltwise_injector_f32::relu_zero_ns_compute_vector( + const Vmm &vmm_src) { + const int zero_off = 1; + h->uni_vmaxps(vmm_src, vmm_src, table_val(zero_off)); +} + +template +void jit_uni_eltwise_injector_f32::elu_compute_vector(const Vmm &vmm_src) { + const int alpha_off = 23, zero_off = 24; + + // compute exponent + h->uni_vmovups(vmm_aux2, vmm_src); + exp_compute_vector(vmm_src); + + // alpha * (exp(x) - 1) + h->uni_vsubps(vmm_src, vmm_src, table_val(0)); + h->uni_vmulps(vmm_src, vmm_src, table_val(alpha_off)); + + // combine with mask + if (isa == sse42) { + h->pxor(vmm_mask, vmm_mask); + h->cmpps(vmm_mask, vmm_aux2, _cmp_le_os); + h->blendvps(vmm_src, vmm_aux2); + } else if (isa == avx2) { + h->uni_vcmpgtps(vmm_mask, vmm_aux2, table_val(zero_off)); + h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask); + } else if (isa == avx512_common) { + h->vcmpps(k_mask, vmm_aux2, table_val(zero_off), _cmp_nle_us); + h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2); + } +} + +template +void jit_uni_eltwise_injector_f32::tanh_compute_vector(const Vmm &vmm_src) +{ + // # comes from Taylor expansion error bound + // > linear_sat_point = single(sqrt(3) * 1b-12); + // # comes from the exp formula cancellation + // > exp_bound_point = (single(log(3)/2)); + // # comes from rounding accuracy in float + // > one_sat_point = round(atanh(1 - 1b-25), single, RU); + // > P = fpminimax(f, [|1, 3, 5, 7, 9|], [|24... |], + // [linear_sat_point, exp_bound_point], relative, floating); + // > err_bound = D(sup(supnorm(P, tanh(x), + // [linear_sat_point, exp_bound_point], relative, theta))); + // 0x1.fffd6f00b9539p-25 + // > P; + // x * (0x1.fffffep-1 + x^0x1p1 * (-0x1.55539ep-2 + x^0x1p1 * + // (0x1.10be3ep-3 + x^0x1p1 * (-0x1.ae57b4p-5 + // + x^0x1p1 * 0x1.09fa1p-6)))) + + // register mapping + // vmm_src contains input + // vmm_aux0 contains mask of currently valid results. + // 1 is need computation, 0 is already computed + // vmm_aux1 contains current output + // vmm_aux2, vmm_aux3 contains auxiliary values + // vmm_aux4 contains the original sign of inputs + + Label end_tanh_label; + + auto test_exit =[&](Xbyak::Address threshold){ + // is not necessary for >AVX, but should not matter on perf + h->uni_vmovups(vmm_aux0, vmm_src); + if (isa == avx512_common){ + h->vcmpps(k_mask, vmm_aux0, threshold, 0x5); + h->kortestw(k_mask, k_mask); + } else { + h->uni_vcmpgeps(vmm_aux0, vmm_aux0, threshold); + h->uni_vtestps(vmm_aux0, vmm_aux0); + } + h->jz(end_tanh_label, Xbyak::CodeGenerator::T_NEAR); + }; + + auto blend_results=[&](Vmm vmm_partial_res){ + if (isa == avx512_common) + h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_partial_res); + else + h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_partial_res, vmm_aux0); + }; + + // because tanh(x) = -tanh(-x), we extract sign to make x postive + // and reapply sign at the end + // mov is not necessary for >AVX, but should not matter for performance + h->uni_vmovups(vmm_aux4, vmm_src); + h->uni_vandps(vmm_aux4, vmm_aux4, table_val(12)); + h->uni_vandps(vmm_src, vmm_src, table_val(17)); + + // if x < linear_sat_point for all inputs, we just return the input + h->uni_vmovups(vmm_aux1, vmm_src); + test_exit(table_val(13)); + + // if one of the mask is one, we have to compute an better approx + h->uni_vmovups(vmm_aux2, vmm_src); + h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux2); + h->uni_vmovups(vmm_aux3, table_val(22)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(21)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(20)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(19)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(18)); + h->uni_vmulps(vmm_aux3, vmm_aux3, vmm_src); + + // we blend only the result that need update + blend_results(vmm_aux3); + + // if x < exp_bound_point, we go to return point + test_exit(table_val(14)); + + // if not we use a better approx 1 - 2 / (1 + exp(2x)) + // compute 2x + h->uni_vmovups(vmm_aux3, vmm_src); + h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux3); + + // Compute exp(2x) + // We need to save kmask, vmm_aux0, vmm_aux1 and vmm_src as exp can use them + // vmm_src is not more read afterwards, so we do not have to save it + auto stack_size = 3 * vlen + (isa == avx512_common) * 4; + h->sub(h->rsp, stack_size); + h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_aux0); + h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_aux1); + h->uni_vmovups(h->ptr[h->rsp + 2 * vlen], vmm_src); + if (isa == avx512_common) + h->kmovw(h->ptr[h->rsp + 3 * vlen], k_mask); + + exp_compute_vector(vmm_aux3); + + h->uni_vmovups(vmm_aux0, h->ptr[h->rsp + 0 * vlen]); + h->uni_vmovups(vmm_aux1, h->ptr[h->rsp + 1 * vlen]); + h->uni_vmovups(vmm_src, h->ptr[h->rsp + 2 * vlen]); + if (isa == avx512_common) + h->kmovw(k_mask, h->ptr[h->rsp + 3 * vlen]); + h->add(h->rsp, stack_size); + + // 1 + exp(2x) + h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(0)); + + // 1 - 2 / (1 + exp(2x)) + h->uni_vmovups(vmm_aux2, table_val(16)); + h->uni_vdivps(vmm_aux2, vmm_aux2, vmm_aux3); + h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(0)); + + // we blend only the result that need update + blend_results(vmm_aux2); + + // finally, we saturate to 1 if needed + // TODO: maybe move that up if most inputs saturate in practice + if (isa == avx512_common) + h->vcmpps(k_mask, vmm_aux0, table_val(15), 0x5); + else { + h->uni_vmovups(vmm_aux0, vmm_src); + h->uni_vcmpgeps(vmm_aux0, vmm_aux0, table_val(15)); + } + h->uni_vmovups(vmm_aux2, table_val(0)); + blend_results(vmm_aux2); + + h->L(end_tanh_label); + { + // we apply the sign of x to the result and we are done + h->uni_vmovups(vmm_src, vmm_aux1); + h->uni_vpxor(vmm_src, vmm_src, vmm_aux4); + } +} + +template +void jit_uni_eltwise_injector_f32::square_compute_vector( + const Vmm &vmm_src) { + h->uni_vmulps(vmm_src, vmm_src, vmm_src); +} + +template +void jit_uni_eltwise_injector_f32::abs_compute_vector(const Vmm &vmm_src) { + // compute abs(x) = _mm_and_ps(x, 01111..111)); + h->uni_vandps(vmm_src, vmm_src, table_val(0)); +} + +template +void jit_uni_eltwise_injector_f32::sqrt_compute_vector(const Vmm &vmm_src) +{ + if (isa == avx512_common) { + h->vcmpps(k_mask, vmm_src, table_val(0), _cmp_nle_us); + h->uni_vsqrtps(vmm_aux1, vmm_src); + h->uni_vmovups(vmm_src, table_val(0)); + h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1); + } else { + h->uni_vmovups(vmm_mask, vmm_src); + h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(0)); + h->uni_vsqrtps(vmm_aux1, vmm_src); + h->uni_vmovups(vmm_src, table_val(0)); + h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask); + } +} + +template +void jit_uni_eltwise_injector_f32::linear_compute_vector( + const Vmm &vmm_src) { + // compute x = alpha * x + beta; + h->uni_vmovups(vmm_aux0, table_val(0)); + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(1)); +} + +template +void jit_uni_eltwise_injector_f32::bounded_relu_compute_vector( + const Vmm &vmm_src) { + // compute bounded relu */ + h->uni_vmaxps(vmm_src, vmm_src, table_val(1)); + h->uni_vminps(vmm_src, vmm_src, table_val(0)); +} + +template +void jit_uni_eltwise_injector_f32::soft_relu_compute_vector( + const Vmm &vmm_src) { + // duplicate src + h->uni_vmovups(vmm_aux2, vmm_src); + + h->uni_vminps(vmm_src, vmm_src, table_val(24)); + h->uni_vmaxps(vmm_src, vmm_src, table_val(25)); + h->uni_vmovups(vmm_aux1, vmm_src); + // calculate exp(x) + // fx = x * log2ef + 0.5 + h->uni_vmulps(vmm_src, vmm_src, table_val(2)); + h->uni_vaddps(vmm_src, vmm_src, table_val(1)); + + // tmp = floorf(fx) + if (isa == avx512_common) { + h->vcvtps2dq(vmm_aux0 | h->T_rd_sae, vmm_src); + h->vcvtdq2ps(vmm_aux0, vmm_aux0); + + h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_nle_us); + h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0)); + + h->vsubps(vmm_aux0, vmm_aux0, vmm_aux3); + } else { + h->uni_vroundps(vmm_aux0, vmm_src, _op_floor); + } + + // keep fx for further computations + h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx + // calculation fx * ln2 + h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3)); + // x = x - fx * ln2 + h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0); + // y = p5 + h->uni_vmovups(vmm_aux3, table_val(22)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(21)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(20)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(19)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0)); + // y = y * x + p0 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(17)); + + // compute 2^(-n) + if (isa == avx512_common) { + h->vmulps(vmm_aux1, vmm_src, table_val(23)); + h->vcvtps2dq(vmm_aux1, vmm_aux1); + } else { + h->uni_vcvtps2dq(vmm_aux1, vmm_src); + h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(23)); + } + + h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4)); + h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx + // calculate ln(1 + y) + h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1); + // x = y; y is free; keep x for further computations + h->uni_vmovups(vmm_src, vmm_aux3); + // frexp() + h->uni_vpsrld(vmm_src, vmm_src, 23); + h->uni_vcvtdq2ps(vmm_src, vmm_src); + // got n. where n is x = 2^n * y. y = 0.5 .. 1 + h->uni_vsubps(vmm_src, vmm_src, table_val(5)); + + h->uni_vandps(vmm_aux3, vmm_aux3, table_val(6)); + // got y. (mantisa) 0.5 < y < 1 + h->uni_vorps(vmm_aux3, vmm_aux3, table_val(7)); + // y = y - 1 + h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0)); + // y = p8 + h->uni_vmovups(vmm_aux1, table_val(16)); + // y = y * x + p7 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(15)); + // y = y * x + p6 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(14)); + // y = y * x + p5 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(13)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(12)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(11)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(10)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(9)); + // y = y * x + p0 ; p0 = 0 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(8)); + //calculate ln(2) * n + h->uni_vmulps(vmm_src, vmm_src, table_val(3)); + h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src); + h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0); + + // get vmm_mask = src > max logf + h->uni_vmovups(vmm_mask, vmm_aux2); + if (isa == avx512_common) { + // y = (x < max log f) ? soft_relu(x) : x + h->vcmpps(k_mask, vmm_mask, table_val(24), _cmp_nle_us); + h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2); + } else { + // y = (x < max log f) ? soft_relu(x) : x + h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(24)); + h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask); + } + + h->uni_vmovups(vmm_src, vmm_aux1); +} + +template +void jit_uni_eltwise_injector_f32::logistic_compute_vector( + const Vmm &vmm_src) { + // we store the original sign and make x negative + // IMPORTANT: we assume vmm_aux0 to be xmm0, as for sse4.2 path it is required + // IMPORTANT: we use vmm_aux2 for the mask as exp_compute does not use it. + h->uni_vmovups(vmm_aux2, vmm_src); + h->uni_vandps(vmm_aux2, vmm_aux2, table_val(12)); + h->uni_vorps(vmm_src, vmm_src, table_val(12)); + + exp_compute_vector(vmm_src); + // dup exp(x) + h->uni_vmovups(vmm_aux1, vmm_src); + // (exp(x) + 1) + h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(0)); + // y = exp(x) / (exp(x) + 1) + h->uni_vdivps(vmm_src, vmm_src, vmm_aux1); + + // Now we have to apply the "symmetry" based on original sign + h->uni_vmovups(vmm_aux3, table_val(0)); + h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src); + if (isa == avx512_common) { + h->vptestmd(k_mask, vmm_aux2, vmm_aux2); + h->vblendmps(vmm_aux3 | k_mask, vmm_aux3, vmm_src); + } else { + h->uni_vmovups(vmm_aux0, vmm_aux2);// The mask should be xmm0 for sse4.2 + h->uni_vblendvps(vmm_aux3, vmm_aux3, vmm_src, vmm_aux0); + } + h->uni_vmovups(vmm_src, vmm_aux3); +} + +template +void jit_uni_eltwise_injector_f32::relu_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template +void jit_uni_eltwise_injector_f32::elu_prepare_table() { + const unsigned int cvals[] = { + 0x3f800000, // [0] 1.0f + 0x3f000000, // [1] 0.5f + 0x3fb8aa3b, // [2] log2ef = 1.44269502f + 0x3f317218, // [3] ln2f = 0.69314718f + 0x0000007f, // [4] 0x7f + // exp(x) polynom + 0x3f800001, // [5] p0 = 1.0000001f + 0x3efffe85, // [6] p2 = 0.4999887f + 0x3e2aaa3e, // [7] p3 = 0.16666505f + 0x3d2bb1b1, // [8] p4 = 0.041917507f + 0x3c091ec1, // [9] p5 = 0.008369149f + 0x42b0c0a5, //[10] max logf = 88.3762589f + 0xc1766666, //[11] min logf = -14.5f + // tanh(x) constants, + 0x80000000, //[12] mask to extract sign + 0x39ddb3d7, //[13] arg below which tanh(x) = x + 0x3f0c9f54, //[14] arg below which pol approx is valid + 0x41102cb4, //[15] arg after which tanh(x) = 1 + 0xc0000000, //[16] -2.0f + 0x7fffffff, //[17] mask to make positive + // tanh pol approx + 0x3f7fffff, //[18] p0 + 0xbeaaa9cf, //[19] p1 + 0x3e085f1f, //[20] p2 + 0xbd572bda, //[21] p3 + 0x3c84fd08, //[22] p4 + }; + + for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]); + } + + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template +void jit_uni_eltwise_injector_f32::soft_relu_prepare_table() { + const unsigned int cvals[] = { + 0x3f800000, // [0] 1.0f + 0x3f000000, // [1] 0.5f + 0x3fb8aa3b, // [2] log2ef = 1.44269502f + 0x3f317218, // [3] ln2f = 0.69314718f + 0x0000007f, // [4] 0x7f + 0x42fc0000, // [5] 126 + 0x807fffff, // [6] and with (to get 0.5 * mantissa) + 0x3f000000, // [7] or with (to get 0.5 * mantissa) + // ln(1 + x) polynomial + 0xb2b4637d, // [8] p0 = 0.0000000244f + 0x3f7fff8e, // [9] p1 = 0.9999976971f + 0xbf001759, //[10] p2 = -0.5002478215f + 0x3ea70608, //[11] p3 = 0.3272714505f + 0xbea3d7bf, //[12] p4 = -0.3153830071f + 0xbe361d04, //[13] p5 = -0.1701777461f + 0xbfa8f1e6, //[14] p6 = -1.3254635147f + 0xbfe1e812, //[15] p7 = -1.7971917960f + 0xbfc4d30e, //[16] p8 = -1.5652673123f + // exp(x) polynomial + 0x3f800001, //[17] p0 = 1.0000001f + 0x3f800000, //[18] p1 = 1.0f + 0x3efffe85, //[19] p2 = 0.4999887f + 0x3e2aaa3e, //[20] p3 = 0.16666505f + 0x3d2bb1b1, //[21] p4 = 0.041917507f + 0x3c091ec1, //[22] p5 = 0.008369149f + 0xbf800000, //[23] is required for sign changing + 0x42b0c0a5, //[24] max logf = 88.3762589f + 0xc1766666 //[25] min logf = -14.5f + }; + + for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) { + for (size_t d = 0; d < vlen / sizeof(float); ++d) { + h->dd(cvals[i]); + } + } +} + +template +void jit_uni_eltwise_injector_f32::abs_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0x7fffffff); +} + +template +void jit_uni_eltwise_injector_f32::sqrt_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template +void jit_uni_eltwise_injector_f32::linear_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_)); +} + +template +void jit_uni_eltwise_injector_f32::bounded_relu_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template +int jit_uni_eltwise_injector_f32::aux_vecs_count(alg_kind_t alg_) { + switch (alg_) { + case alg_kind::eltwise_relu: return (alpha_ == 0.f) ? 0 : 2; + case alg_kind::eltwise_elu: return 4; + case alg_kind::eltwise_tanh: return 5; + case alg_kind::eltwise_square: return 0; + case alg_kind::eltwise_abs: return 0; + case alg_kind::eltwise_sqrt: return 2; + case alg_kind::eltwise_linear: return 1; + case alg_kind::eltwise_bounded_relu: return 0; + case alg_kind::eltwise_soft_relu: return 4; + case alg_kind::eltwise_logistic: return 4; + default: assert(!"unsupported eltwise algorithm"); + } + + return 0; +} + +template +void jit_uni_eltwise_injector_f32::compute_body(size_t start_idx, + size_t end_idx) { + using namespace alg_kind; + for (size_t idx = start_idx; idx < end_idx; idx++) { + switch (alg_) { + case eltwise_relu: + if (alpha_ == 0.f) relu_zero_ns_compute_vector(Vmm(idx)); + else relu_compute_vector(Vmm(idx)); + break; + case eltwise_elu: elu_compute_vector(Vmm(idx)); break; + case eltwise_tanh: tanh_compute_vector(Vmm(idx)); break; + case eltwise_square: square_compute_vector(Vmm(idx)); break; + case eltwise_abs: abs_compute_vector(Vmm(idx)); break; + case eltwise_sqrt: sqrt_compute_vector(Vmm(idx)); break; + case eltwise_linear: linear_compute_vector(Vmm(idx)); break; + case eltwise_bounded_relu: bounded_relu_compute_vector(Vmm(idx)); break; + case eltwise_soft_relu: soft_relu_compute_vector(Vmm(idx)); break; + case eltwise_logistic: logistic_compute_vector(Vmm(idx)); break; + default: assert(!"unsupported eltwise algorithm"); + } + } +} + +template +void jit_uni_eltwise_injector_f32::compute_vector_range(size_t start_idx, + size_t end_idx) { + assert(start_idx < end_idx && end_idx <= vecs_count); + + injector_preamble(start_idx, end_idx); + compute_body(start_idx_tail, end_idx); + injector_preamble_tail(start_idx); + compute_body(start_idx, start_idx_tail); + injector_postamble(); +} + +template +void jit_uni_eltwise_injector_f32::prepare_table(bool gen_table) { + using namespace alg_kind; + + h->align(64); + h->L(l_table); + + if (gen_table) { + switch (alg_) { + case eltwise_relu: relu_prepare_table(); break; + case eltwise_elu: + case eltwise_tanh: + case eltwise_logistic: + elu_prepare_table(); break; + case eltwise_soft_relu: soft_relu_prepare_table(); break; + case eltwise_abs: abs_prepare_table(); break; + case eltwise_sqrt: sqrt_prepare_table(); break; + case eltwise_linear: linear_prepare_table(); break; + case eltwise_bounded_relu: bounded_relu_prepare_table(); break; + case eltwise_square: break; + default: assert(!"unsupported eltwise algorithm"); + } + } +} + +template struct jit_uni_eltwise_injector_f32; +template struct jit_uni_eltwise_injector_f32; +template struct jit_uni_eltwise_injector_f32; + + +struct jit_args { + const float *from; + const float *for_comparison; + const float *to; + size_t work_amount; +}; + +struct jit_uni_eltwise_kernel_f32 : public c_compatible { + const eltwise_desc_t &desc_; + + void (*ker_)(const jit_args *); + void operator()(const jit_args *args) { assert(ker_); ker_(args); } + + jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc) + : desc_(desc), ker_(nullptr) {} + virtual ~jit_uni_eltwise_kernel_f32() {} + +protected: + bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; } +}; + +/* jit kernels */ +namespace { + +template +struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32, + public jit_generator +{ + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32) + + void compute_step(bool vectorize, const int uf, const int shift) { + for (int i = 0; i < uf; i++) { + if (vectorize) { + uni_vmovups(Vmm(i + 1), ptr[reg_from + i * shift]); + if (is_bwd()) + uni_vmovups(Vmm(uf + i + 1), + ptr[reg_for_comparison + i * shift]); + } else { + movss(Xmm(i + 1), ptr[reg_from + i * shift]); + if (is_bwd()) + movss(Xmm(uf + i + 1), + ptr[reg_for_comparison + i * shift]); + } + } + + if (isa == sse42) { + for (int i = 0; i < uf; i++) { + movups(Vmm(2 * uf + i + 1), Vmm(i + 1)); + mulps(Vmm(2 * uf + i + 1), vmm_ns); + + Vmm mask = Vmm(0); + if (is_bwd()) { + movups(mask, Vmm(uf + i + 1)); + cmpps(mask, vmm_zero, _cmp_nle_us); + } else { + movups(mask, Vmm(i + 1)); + cmpps(mask, vmm_zero, _cmp_nle_us); + } + blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1)); + } + } else { + for (int i = 0; i < uf; i++) { + vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns); + if (isa == avx2) { + if (is_bwd()) + vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero); + else + vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero); + + vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1), + Vmm(i + 1), vmm_mask); + + } else { + if (is_bwd()) + vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us); + else + vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us); + vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1), + Vmm(i + 1)); + } + } + } + + for (int i = 0; i < uf; i++) { + if (vectorize) { + uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1)); + } else { + movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1)); + } + } + } + + jit_uni_relu_kernel_f32(const eltwise_desc_t &desc) + : jit_uni_eltwise_kernel_f32(desc), jit_generator() { + assert(desc.alg_kind == alg_kind::eltwise_relu); + assert(isa == sse42 || isa == avx2 || isa == avx512_common); + + Reg64 param = abi_param1; + + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int loop_dec[] = {simd_w, 1}; + const int uf[] = {1, 1}; + const int shift[] = {cpu_isa_traits::vlen, sizeof(float)}; + const bool loop_vectorize[] = {true, false}; + + this->preamble(); + + mov(reg_from, ptr[param + GET_OFF(from)]); + if (is_bwd()) + mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]); + mov(reg_to, ptr[param + GET_OFF(to)]); + mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); + + mov(imm_addr64, float2int(desc.alpha)); + movq(xmm_ns, imm_addr64); + uni_vbroadcastss(vmm_ns, xmm_ns); + + uni_vpxor(vmm_zero, vmm_zero, vmm_zero); + + Label loop_label[3]; + + for (int id = 0; id < 2; id++) { + L(loop_label[id]); + cmp(reg_work_amount, uf[id] * loop_dec[id] - 1); + jle(loop_label[id + 1], T_NEAR); + + compute_step(loop_vectorize[id], uf[id], shift[id]); + + add(reg_from, uf[id] * shift[id]); + add(reg_to, uf[id] * shift[id]); + if (is_bwd()) + add(reg_for_comparison, uf[id] * shift[id]); + + sub(reg_work_amount, uf[id] * loop_dec[id]); + jmp(loop_label[id]); + } + + L(loop_label[2]); + this->postamble(); + + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using Vmm = typename utils::conditional3::type; + + Reg64 reg_from = rax; + Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from; + Reg64 reg_to = r8; + Reg64 reg_work_amount = rsi; + Reg64 imm_addr64 = rbx; + + Xmm xmm_ns = Xmm(14); + + Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14); + Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15); + + Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12); + Opmask k_mask = Opmask(1); +}; + +template +struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32, + public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32) + + jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc) + : jit_uni_eltwise_kernel_f32(desc), jit_generator() { + + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + desc.alg_kind, desc.alpha, desc.beta, false, r9, Opmask(1)); + + using namespace alg_kind; + + assert(is_bwd() == false); + assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); + + preamble(); + + Reg64 param = abi_param1; + mov(reg_from, ptr[param + GET_OFF(from)]); + mov(reg_to, ptr[param + GET_OFF(to)]); + mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); + eltwise_injector_->load_table_addr(); + + Label reminder_loop_start, reminder_loop_end; + Label vectorized_loop_start, vectorized_loop_end; + + cmp(reg_work_amount, simd_w); + jl(reminder_loop_start, T_NEAR); + + L(vectorized_loop_start); + + uni_vmovups(vmm_src, ptr[reg_from]); + eltwise_injector_->compute_vector(vmm_src.getIdx()); + uni_vmovups(ptr[reg_to], vmm_src); + + add(reg_from, vlen); + add(reg_to, vlen); + + sub(reg_work_amount, simd_w); + cmp(reg_work_amount, simd_w); + jge(vectorized_loop_start, T_NEAR); + + L(vectorized_loop_end); + + L(reminder_loop_start); + + cmp(reg_work_amount, 0); + jle(reminder_loop_end, T_NEAR); + + movss(xmm_src, ptr[reg_from]); + eltwise_injector_->compute_vector(xmm_src.getIdx()); + movss(ptr[reg_to], xmm_src); + + add(reg_from, sizeof(float)); + add(reg_to, sizeof(float)); + + dec(reg_work_amount); + jmp(reminder_loop_start, T_NEAR); + + L(reminder_loop_end); + + postamble(); + + eltwise_injector_->prepare_table(); + + ker_ = (decltype(ker_))this->getCode(); + } + + ~jit_uni_kernel_fwd_f32() { delete eltwise_injector_; } + +private: + using Vmm = typename utils::conditional3::type; + + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int vlen = cpu_isa_traits::vlen; + + Reg64 reg_from = rax; + Reg64 reg_to = r8; + Reg64 reg_work_amount = rsi; + Reg64 imm_addr64 = rbx; + + Xmm xmm_src = Xmm(1); + Vmm vmm_src = Vmm(1); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; +}; + +} /* namespace */ + +template +status_t jit_uni_eltwise_fwd_t::pd_t::init() { + using namespace alg_kind; + + bool ok = true + && mayiuse(isa) + && is_fwd() + && utils::everyone_is(data_type::f32, desc()->data_desc.data_type) + && !has_zero_dim_memory() + && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh, + eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, + eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu, + eltwise_logistic) + && memory_desc_wrapper(src_md()).is_dense(true) + && IMPLICATION(!memory_desc_wrapper(src_md()).is_dense(false), + math::eltwise_fwd_preserves_zero(desc()->alg_kind, true)) + && attr()->has_default_values(); + + return ok ? status::success : status::unimplemented; +} + +template +jit_uni_eltwise_fwd_t::jit_uni_eltwise_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd), kernel_(nullptr) { + const auto &desc = *pd()->desc(); + switch (desc.alg_kind) { + case alg_kind::eltwise_relu: + kernel_ = new jit_uni_relu_kernel_f32(desc); break; + default: + kernel_ = new jit_uni_kernel_fwd_f32(desc); + } +} + +template +jit_uni_eltwise_fwd_t::~jit_uni_eltwise_fwd_t() +{ delete kernel_; } + +template +void jit_uni_eltwise_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const size_t nelems = data_d.nelems(true); + + src += data_d.offset0(); + dst += data_d.offset0(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + + const int cache_line = 16; + + balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end); + start = nstl::min(nelems, start * cache_line); + end = nstl::min(nelems, end * cache_line); + + auto arg = jit_args(); + arg.from = &src[start]; + arg.for_comparison = &src[start]; + arg.to = &dst[start]; + arg.work_amount = end - start; + if (arg.work_amount) + (*kernel_)(&arg); + }); +} + +template +status_t jit_uni_eltwise_bwd_t::pd_t::init() { + bool ok = true + && !is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu) + && src_md()->data_type == data_type::f32 + && !has_zero_dim_memory() + && mayiuse(isa) + && memory_desc_wrapper(src_md()).is_dense() + && memory_desc_wrapper(diff_dst_md()) == memory_desc_wrapper(src_md()) + && attr()->has_default_values(); + + return ok ? status::success : status::unimplemented; +} + +template +jit_uni_eltwise_bwd_t::jit_uni_eltwise_bwd_t(const pd_t *apd) + : cpu_primitive_t(apd), kernel_(nullptr) { + const auto &desc = *pd()->desc(); + switch (desc.alg_kind) { + case alg_kind::eltwise_relu: + kernel_ = new jit_uni_relu_kernel_f32(desc); break; + default: assert(!"unknown eltwise alg_kind"); + } +} + +template +jit_uni_eltwise_bwd_t::~jit_uni_eltwise_bwd_t() +{ delete kernel_; } + +template +void jit_uni_eltwise_bwd_t::execute_backward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + + const size_t nelems = data_d.nelems(); + + src += data_d.offset0(); + diff_dst += diff_data_d.offset0(); + diff_src += diff_data_d.offset0(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + + const int cache_line = 16; + + balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end); + start = nstl::min(nelems, start * cache_line); + end = nstl::min(nelems, end * cache_line); + + auto arg = jit_args(); + arg.from = &diff_dst[start]; + arg.to = &diff_src[start]; + arg.for_comparison = &src[start]; + arg.work_amount = end - start; + if (arg.work_amount) + (*kernel_)(&arg); + }); +} + +template struct jit_uni_eltwise_fwd_t; +template struct jit_uni_eltwise_bwd_t; +template struct jit_uni_eltwise_fwd_t; +template struct jit_uni_eltwise_bwd_t; +template struct jit_uni_eltwise_fwd_t; +template struct jit_uni_eltwise_bwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp new file mode 100644 index 000000000000..45436b9f46e4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp @@ -0,0 +1,193 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_ELTWISE_HPP +#define CPU_JIT_UNI_ELTWISE_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_eltwise_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_uni_eltwise_injector_f32 { + using Vmm = typename utils::conditional3::type; + + jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg, + float alpha, float beta, bool save_state = true, + Xbyak::Reg64 p_table = Xbyak::util::rax, + Xbyak::Opmask k_mask = Xbyak::Opmask(1)) + : alg_(alg), alpha_(alpha), beta_(beta), h(host) + , save_state_(save_state), p_table(p_table), k_mask(k_mask) + { + using namespace alg_kind; + assert(utils::one_of(isa, sse42, avx2, avx512_common)); + assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); + } + + // note that eltwise.scale is ignored + jit_uni_eltwise_injector_f32(jit_generator *host, + const post_ops_t::entry_t::eltwise_t &eltwise, + bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax, + Xbyak::Opmask k_mask = Xbyak::Opmask(1)) + : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha, + eltwise.beta, save_state, p_table, k_mask) {} + + void compute_vector_range(size_t start_idx, size_t end_idx); + void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); } + void prepare_table(bool gen_table=true); + void load_table_addr() { h->mov(p_table, l_table); } + + const alg_kind_t alg_; + const float alpha_; + const float beta_; + + jit_generator * const h; + + const bool save_state_; + const Xbyak::Reg64 p_table; + const Xbyak::Opmask k_mask; + Xbyak::Label l_table; + +private: + // if only the injector was inherited from jit_generator... + enum { + _cmp_le_os = jit_generator::_cmp_le_os, + _cmp_nle_us = jit_generator::_cmp_nle_us, + _op_floor = jit_generator::_op_floor, + }; + + size_t vlen = cpu_isa_traits::vlen; + + const static size_t preserved_vecs_max = 5; + + size_t vecs_to_preserve = 0; + size_t vecs_count = isa == avx512_common ? 32 : 16; + size_t preserved_vecs_count = 0; + size_t preserved_vec_idxs[preserved_vecs_max] = {0}; + size_t start_idx_tail = 0; + + Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4; + + Xbyak::Address table_val(int index) + { return h->ptr[p_table + index * vlen]; } + + int aux_vecs_count(alg_kind_t alg); + + void compute_body(size_t start_idx, size_t end_idx); + void injector_preamble(size_t start_idx, size_t end_idx); + void injector_preamble_tail(size_t start_idx); + void injector_postamble(); + void assign_regs(); + + void exp_compute_vector(const Vmm &vmm_src); + void relu_compute_vector(const Vmm &vmm_src); + void relu_zero_ns_compute_vector(const Vmm &vmm_src); + void elu_compute_vector(const Vmm &vmm_src); + void tanh_compute_vector(const Vmm &vmm_src); + void square_compute_vector(const Vmm &vmm_src); + void abs_compute_vector(const Vmm &vmm_src); + void sqrt_compute_vector(const Vmm &vmm_src); + void linear_compute_vector(const Vmm &vmm_src); + void bounded_relu_compute_vector(const Vmm &vmm_src); + void soft_relu_compute_vector(const Vmm &vmm_src); + void logistic_compute_vector(const Vmm &vmm_src); + + void relu_prepare_table(); + void elu_prepare_table(); + void soft_relu_prepare_table(); + void abs_prepare_table(); + void sqrt_prepare_table(); + void linear_prepare_table(); + void bounded_relu_prepare_table(); +}; + +struct jit_uni_eltwise_kernel_f32; + +template +struct jit_uni_eltwise_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_eltwise_fwd_pd_t { + using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_eltwise_fwd_t); + + status_t init(); + }; + + jit_uni_eltwise_fwd_t(const pd_t *apd); + ~jit_uni_eltwise_fwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_eltwise_kernel_f32 *kernel_; +}; + +template +struct jit_uni_eltwise_bwd_t : public cpu_primitive_t { + struct pd_t : public cpu_eltwise_bwd_pd_t { + using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_eltwise_bwd_t); + + status_t init(); + }; + + jit_uni_eltwise_bwd_t(const pd_t *apd); + ~jit_uni_eltwise_bwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_eltwise_kernel_f32 *kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp new file mode 100644 index 000000000000..a3ca6273a0b1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp @@ -0,0 +1,949 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_uni_i8i8_pooling.hpp" + +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::types; +using namespace alg_kind; + +template +struct jit_uni_i8i8_pooling_fwd_ker_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t) + + struct call_params_t { + const char *src_i8; + const char *dst_i8; + size_t kw_range; + size_t kh_range; + float idivider; + }; + + using Vmm = typename cpu_isa_traits::Vmm; + Xmm xreg(int idx) const { return Xmm(idx); } + Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); } + Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); } + + // In case of avx2 with data type i8 we need to use + // maskmovdqu instruction which has its destination hardcoded in rdi. + // Windows ABI: abi_param1 is rcx - nothing to do else + // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1 + Reg64 reg_param = rcx; // Our "unified abi_param1" + Reg64 reg_ptr_src_i8 = r8; + Reg64 reg_ptr_dst_i8 = r9; + Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi + + Reg64 ki = r10; + Reg64 kj = r11; + Reg64 reg_kw = r12; + Reg64 reg_kh = r13; + Reg64 c_iter = r14; + + Reg64 aux_reg_src_h = rax; + Reg64 aux_reg_src_w = rbx; + + Reg64 reg_tmp = rdx; + + Reg64 reg_mask = r15; + + Opmask k_cmp_mask = Opmask(7); + + Opmask mask(int idx) { + return Opmask(6 - idx); + } + + // ref to any of XYZ-regs via xreg/yreg/vreg functions + Xmm xmm_tmp = xreg(0); // temp to init vreg_tmp + Vmm vreg_tmp = vreg(0); // max pooling : holds minimum values for data_type + Vmm vreg_zeros = vreg(1); + + // only in case of == avx2 + Vmm vreg_mask = vreg(2); // full byte-mask + Xmm xreg_mask_lo = xreg(2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask) + Xmm xreg_mask_hi = xreg(3); // "max" - high 128-bits part of byte-mask (stored separately) + Xmm xreg_mask_q = xreg(3); // "avg" - 1/4 part of the mask for s8/u8 operations + Vmm vreg_mask_q = vreg(3); // "avg" - 1/4 part for non-zero tails + + enum:int {vidx_base = isa == avx2 ? 4 : 2}; + Vmm base_vr(int idx) const { return vreg(vidx_base + idx); } + + size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); } + size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); } + + /* max pooling */ + Vmm vreg_src(int idx) const { return base_vr(idx); } // [0 .. ur_c-1] + Vmm vreg_dst(int idx) const { return base_vr(jpp.ur_c + idx); } // [ur_c .. 2*ur_c-1] + + /* avg pooling */ + // s32 used for processing of s8/u8 data + // thus we need to take into account ratio of sizes s32/i8 = 4 + static constexpr data_type_t avg_proc_dt = data_type::s32; + enum:int { + s32_to_i8_ratio = sizeof(typename prec_traits::type) + / sizeof(typename prec_traits::type), + max_num_ll = s32_to_i8_ratio + }; + Vmm vreg_src_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 0*max_num_ll); } // ll: 0..4 [0..3] + Vmm vreg_dst_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 1*max_num_ll); } // ll: 0..4 [4..7] + Vmm vreg_dst_f32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 2*max_num_ll); } // ll: 0..4 [8..11] + + void (*ker_)(const call_params_t *); + jit_pool_conf_t jpp; + + void init_tmp_reg(); + void init_mask(); + + void load_vreg_mask_q(int ll) {}; + + void load_src_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void load_src_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void load_src(int jj, int ll, int c_tail); + + void store_dst_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void store_dst_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void store_dst(int jj, int ll, int c_tail); + + void compute_avg_step(int ur_c, int c_tail); + void compute_max_op(const int jj); + void compute_max_step(int ur_c, int c_tail); + void compute_step(int ur_c, int c_tail); + + void compute_c_block(); + void generate(); + + static status_t init_conf(jit_pool_conf_t &jpp, const pooling_pd_t *ppd); + + jit_uni_i8i8_pooling_fwd_ker_t(const jit_pool_conf_t &jpp_) + : jpp(jpp_) { + generate(); + ker_ = reinterpret_cast(const_cast( + getCode())); + } +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_vreg_mask_q(int ll) { + + // extract ll-th part of mask (ll-th QWORD) + vpblendd(vreg_mask_q, vreg_zeros, vreg_mask, 0x3 << ll); // 0x3 - mask for 2 x DWORD + + // Move mask from ll-th pos to 0-th pos + if (ll>0) + vpermq(vreg_mask_q, vreg_mask_q, ll); +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_src_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + if (masked) { + if (jpp.src_dt == s32) { + vpblendd(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], static_cast(msk)); + } else { + vpblendvb(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], vreg_mask); + } + } else + vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]); +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_src_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + if (masked) { + if (jpp.src_dt == s32) + vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]); + else + vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]); + } else + vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]); +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_src_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + auto load_i8 = [&](bool is_signed, const Vmm& vr_src) { + + // Need to use mask of tail? + if (masked) { + + // load ll-th part of mask into vreg_mask_q + load_vreg_mask_q(ll); + + // Load by mask from mem into register vr_src + vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w + offset], vreg_mask_q); + + // Conversion s8/u8 -> s32 + if (is_signed) + vpmovsxbd(vr_src, vr_src); + else + vpmovzxbd(vr_src, vr_src); + } else { + + // Load from mem into vr_src with conversion + if (is_signed) + vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); + else + vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); + } + }; + + switch (jpp.src_dt) { + case s32: + if (masked) + vpblendd(vreg_src_s32(jj, ll), vreg_zeros, ptr[aux_reg_src_w + offset], + static_cast(msk)); + else + vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]); + break; + case s8: + load_i8(true, vreg_src_s32(jj, ll)); + break; + case u8: + load_i8(false, vreg_src_s32(jj, ll)); + break; + default: assert(!"unsupported src data type"); + } +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_src_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + const Vmm& vr_src = masked ? + vreg_src_s32(jj, ll) | mask(ll) : + vreg_src_s32(jj, ll); + + switch (jpp.src_dt) { + case s32: + vmovups(vr_src, ptr[aux_reg_src_w + offset]); + break; + case s8: + vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); + break; + case u8: + vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); + break; + default: assert(!"unsupported src data type"); + } +}; + +template +void jit_uni_i8i8_pooling_fwd_ker_t::load_src(int jj, int ll, int c_tail) { + using namespace data_type; + + int c_block = jpp.c_block; + int ur_c = jpp.ur_c; + + switch (jpp.alg) { + case pooling_max: { + auto offset = jj*c_block*sizeof_src_dt(); + bool masked = jj == ur_c - 1 && c_tail; + load_src_max_op(jj, ll, offset, masked, jpp.tail[0]); + break; + } + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: { + auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_src_dt(); + bool masked = jj == ur_c - 1 && c_tail; + load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]); + break; + } + default: assert(!"unsupported algorithm"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + int c_block = jpp.c_block; + + if (masked) { + switch (jpp.src_dt) { + case s32: + vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj)); + break; + case s8: + case u8: { + // Store low half by mask (bytes 0...15) + lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]); + maskmovdqu(vreg_dst(jj), xreg_mask_lo); + + // Do we need to store high half (bytes 16...31) ? + const uint64_t low_mask = (1ULL << (c_block/2))-1; + if (msk & ~low_mask) { + vextracti128(Xmm(vreg_dst(jj).getIdx()), vreg_dst(jj), 1); + add(reg_ptr_maskmovdqu_dst, c_block / 2); + maskmovdqu(vreg_dst(jj), xreg_mask_hi); + } + } break; + default: assert(!"unsupported src data type"); + } + } else + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj)); +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + if (masked) { + switch (jpp.src_dt) { + case s32: + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0)); + break; + case s8: + case u8: + vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0)); + break; + default: assert(!"unsupported src data type"); + } + } else + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj)); +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk){ + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + auto s32_to_i8 = [&](bool is_signed, const Vmm& vr_dst) { + + // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16} + // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0} + if (is_signed) + vpackssdw(vr_dst, vr_dst, vreg_zeros); + else + vpackusdw(vr_dst, vr_dst, vreg_zeros); + + // Permute qwords to restore original order + // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0} + vpermq(vr_dst, vr_dst, 0x58); + + // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8} + // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx} + if (is_signed) + vpacksswb(vr_dst, vr_dst, vreg_zeros); + else + vpackuswb(vr_dst, vr_dst, vreg_zeros); + + }; + + auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm& vr_dst) { + + // Conversion s32 -> s8/u8 + s32_to_i8(is_signed, vr_dst); + + // Need to use mask of tail? + if (is_masked) { + // load ll-th part of mask into vreg_mask_q + load_vreg_mask_q(ll); + } + + // store 8 bytes + lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]); + maskmovdqu(vr_dst, xreg_mask_q); + }; + + switch (jpp.dst_dt) { + case s32: + if (masked) { + vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst_s32(jj, ll)); + } else + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll)); + break; + case s8: + store_i8(true, masked, vreg_dst_s32(jj, ll)); + break; + case u8: + store_i8(false, masked, vreg_dst_s32(jj, ll)); + break; + default: assert(!"unsuppotred dst data_type"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + const Vmm& vr_dst = masked ? + vreg_dst_s32(jj, ll) | mask(ll) : + vreg_dst_s32(jj, ll); + + switch (jpp.dst_dt) { + case s32: + vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst); + break; + case s8: + vpmovdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); + break; + case u8: + vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); + break; + default: assert(!"unsupported dst data_type"); + } +} + + +template +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst(int jj, int ll, + int c_tail) { + using namespace data_type; + + int c_block = jpp.c_block; + int ur_c = jpp.ur_c; + + switch(jpp.alg) { + case pooling_max: { + auto offset = jj*c_block*sizeof_dst_dt(); + bool masked = jj == ur_c - 1 && c_tail; + store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]); + break; + } + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: { + auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_dst_dt(); + bool masked = jj == ur_c - 1 && c_tail; + store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]); + break; + } + default: assert(!"unsupported pooling algorithm"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::compute_max_op(const int jj) +{ + using namespace data_type; + switch (jpp.src_dt) { + case s32: + vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); + break; + case s8: + vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); + break; + case u8: + vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); + break; + default: assert(!"unsupported src data type"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::compute_max_op(const int jj) +{ + using namespace data_type; + + // Compare + switch (jpp.src_dt) { + case s32: + vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); + break; + case s8: + vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); + break; + case u8: + vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); + break; + default: assert(!"unsupported src data type"); + } + + // move max values into vreg_dst + if (jpp.src_dt == s32) + vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj)); + else + vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj)); +} + + +template +void jit_uni_i8i8_pooling_fwd_ker_t::compute_max_step(int ur_c, int c_tail) +{ + Label l_kw, l_kh; + + int iw = jpp.iw; + int c = jpp.c; + + for (int jj = 0; jj < ur_c; jj++) + vmovups(vreg_dst(jj), vreg_tmp); + + mov(aux_reg_src_h, reg_ptr_src_i8); + + xor_(kj, kj); + L(l_kh); + { + mov(aux_reg_src_w, aux_reg_src_h); + xor_(ki, ki); + L(l_kw); + { + for (int jj = 0; jj < ur_c; jj++) { + load_src(jj, 0, c_tail); + compute_max_op(jj); + } + add(aux_reg_src_w, c * sizeof_src_dt()); + inc(ki); + cmp(ki, reg_kw); + jl(l_kw, T_NEAR); + } + add(aux_reg_src_h, iw * c * sizeof_src_dt()); + inc(kj); + cmp(kj, reg_kh); + jl(l_kh, T_NEAR); + } + + for (int jj = 0; jj < ur_c; jj++) + store_dst(jj, 0, c_tail); +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step(int ur_c, int c_tail) +{ + using namespace data_type; + + Label l_kw, l_kh; + + int iw = jpp.iw; + int c = jpp.c; + + const int num_ll = data_type_size(avg_proc_dt)/data_type_size(jpp.src_dt); + + for (int jj = 0; jj < ur_c; jj++) { + for (int ll = 0; ll < num_ll; ll++) { + bool masked = jj == ur_c - 1 && c_tail; + size_t msk = jpp.tail[ll]; + if (!(masked && !msk)) { + uni_vpxor(vreg_src_s32(jj, ll), vreg_src_s32(jj, ll), vreg_src_s32(jj, ll)); + uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll)); + } + } + } + + mov(aux_reg_src_h, reg_ptr_src_i8); + + xor_(kj, kj); + L(l_kh); + { + mov(aux_reg_src_w, aux_reg_src_h); + xor_(ki, ki); + L(l_kw); + { + for (int jj = 0; jj < ur_c; jj++) { + for (int ll = 0; ll < num_ll; ll++) { + bool masked = jj == ur_c - 1 && c_tail; + size_t msk = jpp.tail[ll]; + if (!(masked && !msk)) { + load_src(jj, ll, c_tail); + vpaddd(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), + vreg_src_s32(jj, ll)); + } + } + } + add(aux_reg_src_w, c * sizeof_src_dt()); + inc(ki); + cmp(ki, reg_kw); + jl(l_kw, T_NEAR); + } + add(aux_reg_src_h, iw * c * sizeof_src_dt()); + inc(kj); + cmp(kj, reg_kh); + jl(l_kh, T_NEAR); + } + + for (int jj = 0; jj < ur_c; jj++) { + for (int ll = 0; ll < num_ll; ll++) { + bool masked = jj == ur_c - 1 && c_tail; + size_t msk = jpp.tail[ll]; + if (!(masked && !msk)) { + vcvtdq2ps(vreg_dst_f32(jj, ll), vreg_dst_s32(jj, ll)); + vfmadd132ps(vreg_dst_f32(jj, ll), vreg_zeros, vreg_tmp); + vcvtps2dq(vreg_dst_s32(jj, ll), vreg_dst_f32(jj, ll)); + store_dst(jj, ll, c_tail); + } + } + } +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::compute_step(int ur_c, int c_tail) { + switch (jpp.alg) { + case pooling_max: + compute_max_step(ur_c, c_tail); break; + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: + compute_avg_step(ur_c, c_tail); break; + default: assert(!"unsupported pooling algorithm"); + } +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::compute_c_block(){ + Label l_main_loop; + + int nb_c = jpp.nb_c; + int c_block = jpp.c_block; + int ur_c = jpp.ur_c; + int ur_c_tail = jpp.ur_c_tail; + int c_steps = nb_c / ur_c; + int c_tail = jpp.c_tail; + + xor_(c_iter, c_iter); + if (c_steps > 0) { + L(l_main_loop); { + compute_step(ur_c, 0); + add(reg_ptr_src_i8, ur_c*c_block*sizeof_src_dt()); + add(reg_ptr_dst_i8, ur_c*c_block*sizeof_dst_dt()); + inc(c_iter); + cmp(c_iter, c_steps); + jl(l_main_loop, T_NEAR); + } + } + + if (ur_c_tail != 0) { + compute_step(ur_c_tail, c_tail); + } +} + +template<> +void jit_uni_i8i8_pooling_fwd_ker_t::init_mask() { + using namespace data_type; + using cpu_isa = cpu_isa_traits; + + // AVX2 mask initialization: mask stored in Ymm-regs + auto init = [&](uint64_t bit_mask, bool init_mask_q) { + const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t); + + uint64_t vmask[QW_PER_VREG]; + for (size_t i = 0; i < QW_PER_VREG; i++){ + + uint64_t qw_vmask=0ULL; + const size_t DBITS = 8*sizeof_src_dt(); + const uint64_t VMSK = 1ULL << (DBITS-1); + const size_t D_PER_QW = (8*sizeof(qw_vmask))/DBITS; + for (size_t j = 0; j < D_PER_QW; j++) { + if (bit_mask & 1) + qw_vmask |= VMSK << DBITS * j; + bit_mask >>= 1; + } + vmask[i] = qw_vmask; + } + + // Put QWORDS with target mask into xmm regs + const int xdst_i[QW_PER_VREG] = { + xreg_mask_lo.getIdx(), + xreg_mask_lo.getIdx(), + xreg_mask_hi.getIdx(), + xreg_mask_hi.getIdx() + }; + const int xsrc_i[QW_PER_VREG] = { + vreg_zeros.getIdx(), // 0-th qword insert in zeros -> {qw0, 0} + xreg_mask_lo.getIdx(), // 1-st and 0-th merge -> {qw0,qw1} + vreg_zeros.getIdx(), + xreg_mask_hi.getIdx() + }; + const uint8 qw_dst_idx[QW_PER_VREG] = {0, 1, 0, 1}; // qword index in 128-bit xreg + + for (size_t i = 0; i < QW_PER_VREG; i++) { + mov(reg_mask, vmask[i]); + vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask, qw_dst_idx[i]); + } + + // Merge Low (xreg_mask_lo alias for vreg_mask.xreg) + // and High (xreg_mask_hi) into full vreg_mask + // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg} + vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1); + + // Keep only low qword of mask in xreg_mask_q + if (init_mask_q) { + mov(reg_mask, vmask[0]); + vpinsrq(xreg_mask_q, Xmm(vreg_zeros.getIdx()), reg_mask, 0); + } + }; + + uint64_t tail_mask = (1ULL << jpp.c_tail) - 1; + switch (jpp.alg) { + case pooling_max: + // For "max" we need mask only in case of non-zero tail + if (tail_mask) + init(tail_mask, false); + break; + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: + // For "avg" we need mask: + // - s32 - in case of the non-zero tail + // - s8/u8 - irrespective of the tail + switch (jpp.src_dt) { + case s32: + if (tail_mask) + init(tail_mask, false); + break; + case s8: + case u8: + init(tail_mask ? tail_mask : ~0ULL, tail_mask == 0); + break; + default: assert(!"unsupported src data type"); + } + break; + default: assert(!"unsupported pooling algorithm"); + } +} + +template<> +void jit_uni_i8i8_pooling_fwd_ker_t::init_mask() { + + for (int ll = 0; ll < max_num_ll; ll++) { + mov(reg_mask, jpp.tail[ll]); + kmovq(mask(ll), reg_mask); + } +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::init_tmp_reg() { + using namespace data_type; + + switch (jpp.alg) { + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: + mov(reg_tmp, ptr[reg_param + offsetof(call_params_t, idivider)]); + movq(xmm_tmp, reg_tmp); + vpbroadcastd(vreg_tmp, xmm_tmp); + break; + case pooling_max: + switch (jpp.src_dt) { + case s32: + mov(reg_tmp, nstl::numeric_limits::lowest()); + break; + case s8: + mov(reg_tmp, nstl::numeric_limits::lowest()); + break; + case u8: + mov(reg_tmp, nstl::numeric_limits::lowest()); + break; + default: assert(!"unsupported src data_type"); + } + + movq(xmm_tmp, reg_tmp); + if (jpp.src_dt == s32) + vpbroadcastd(vreg_tmp, xmm_tmp); + else + vpbroadcastb(vreg_tmp, xmm_tmp); + break; + default: assert(!"unsupported pooling algorithm"); + } + +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::generate() { + preamble(); + +#if !defined(_WIN32) + // Always use rcx as abi_param1 - + // see the note about maskmovdqu near reg_param. + mov(rcx, rdi); +#endif + +# define READ_PARAM(reg, field) \ + mov(reg, ptr[reg_param + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src_i8, src_i8); + READ_PARAM(reg_ptr_dst_i8, dst_i8); + READ_PARAM(reg_kw, kw_range); + READ_PARAM(reg_kh, kh_range); + +# undef READ_PARAM + + uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros); + + init_mask(); + + init_tmp_reg(); + + compute_c_block(); + + postamble(); +} + +template +status_t jit_uni_i8i8_pooling_fwd_ker_t::init_conf(jit_pool_conf_t &jpp, + const pooling_pd_t *ppd) { + if (!mayiuse(isa)) + return status::unimplemented; + + const auto &pd = *ppd->desc(); + const memory_desc_wrapper src_d(ppd->src_md()); + const memory_desc_wrapper dst_d(ppd->dst_md()); + + jpp.mb = src_d.dims()[0]; + jpp.c = src_d.dims()[1]; + jpp.ih = src_d.dims()[2]; + jpp.iw = src_d.dims()[3]; + jpp.oh = dst_d.dims()[2]; + jpp.ow = dst_d.dims()[3]; + + jpp.stride_h = pd.strides[0]; + jpp.stride_w = pd.strides[1]; + jpp.kh = pd.kernel[0]; + jpp.kw = pd.kernel[1]; + + jpp.t_pad = pd.padding[0][0]; + jpp.l_pad = pd.padding[0][1]; + + jpp.alg = pd.alg_kind; + + jpp.src_dt = pd.src_desc.data_type; + jpp.dst_dt = pd.dst_desc.data_type; + + // data_type items per one vreg on the + // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32 + // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32 + int simd_w = cpu_isa_traits::vlen / data_type_size(jpp.src_dt); + + jpp.c_block = simd_w; + jpp.c_tail = jpp.c % jpp.c_block; + jpp.nb_c = jpp.c / jpp.c_block; + jpp.ur_c = 1; + jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c + + (jpp.c_tail != 0); + + size_t tail_mask = (1ULL << jpp.c_tail) - 1; + + switch (jpp.alg) { + case pooling_max: + jpp.tail[0] = tail_mask; + jpp.tail[1] = 0; + jpp.tail[2] = 0; + jpp.tail[3] = 0; + break; + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: { + // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32) + // avx2 : 8, avx512 : 16 + const size_t msk_gran = cpu_isa_traits::vlen / data_type_size(avg_proc_dt); + const size_t msk_msk = (1ULL << msk_gran) - 1; + size_t m = tail_mask; + for (size_t ll = 0; ll < max_num_ll; ll++) { + jpp.tail[ll] = m & msk_msk; + m = m >> msk_gran; + } + break; + } + default: return status::unimplemented; + } + + return status::success; +} + +template +status_t jit_uni_i8i8_pooling_fwd_t::pd_t::jit_conf() { + return jit_uni_i8i8_pooling_fwd_ker_t::init_conf(jpp_, this); +} + +template +jit_uni_i8i8_pooling_fwd_t:: +jit_uni_i8i8_pooling_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd), ker_(nullptr) +{ ker_ = new jit_uni_i8i8_pooling_fwd_ker_t(pd()->jpp_); } + +template +jit_uni_i8i8_pooling_fwd_t:: +~jit_uni_i8i8_pooling_fwd_t() { delete ker_; } + +template +void jit_uni_i8i8_pooling_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src_i8 = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC); + auto dst_i8 = CTX_OUT_MEM(char *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const auto &jpp = pd()->jpp_; + + parallel_nd(jpp.mb, jpp.oh, jpp.ow, + [&](int n, int oh, int ow) { + const int ih = nstl::max(oh*jpp.stride_h - jpp.t_pad, 0); + const int iw = nstl::max(ow*jpp.stride_w - jpp.l_pad, 0); + + const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h); + const int kh_end = nstl::min(jpp.kh, + jpp.ih + jpp.t_pad - oh * jpp.stride_h); + const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w); + const int kw_end = nstl::min(jpp.kw, + jpp.iw + jpp.l_pad - ow * jpp.stride_w); + + auto p = typename jit_uni_i8i8_pooling_fwd_ker_t::call_params_t(); + p.src_i8 = &src_i8[ + src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()]; + p.dst_i8 = &dst_i8[ + dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()]; + p.kw_range = (size_t)(kw_end - kw_start); + p.kh_range = (size_t)(kh_end - kh_start); + p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ? + p.kh_range*p.kw_range : jpp.kw*jpp.kh); + + ker_->ker_(&p); + }); +} + +// Explicit instantiation only for supported values. +// +template struct jit_uni_i8i8_pooling_fwd_ker_t; +template struct jit_uni_i8i8_pooling_fwd_t; + +template struct jit_uni_i8i8_pooling_fwd_ker_t; +template struct jit_uni_i8i8_pooling_fwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp new file mode 100644 index 000000000000..d757679df5e4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_I8I8_POOLING_HPP +#define CPU_JIT_UNI_I8I8_POOLING_HPP + +#include "c_types_map.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +#include "cpu_isa_traits.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_uni_i8i8_pooling_fwd_ker_t; + +template +struct jit_uni_i8i8_pooling_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_i8i8_pooling_fwd_t); + + status_t init() { + bool ok = true + && mayiuse(isa) + && ndims() == 4 + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::forward_inference + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && utils::one_of(src_md()->data_type, data_type::s32, + data_type::s8, data_type::u8) + && src_md()->data_type == dst_md()->data_type + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), format_tag::nhwc) + && memory_desc_matches_tag(*dst_md(), format_tag::nhwc); + if (!ok) return status::unimplemented; + + return jit_conf(); + } + + jit_pool_conf_t jpp_; + + protected: + status_t jit_conf(); + }; + + jit_uni_i8i8_pooling_fwd_t(const pd_t *apd); + ~jit_uni_i8i8_pooling_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_i8i8_pooling_fwd_ker_t *ker_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp new file mode 100644 index 000000000000..2c5a8e8973b6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp @@ -0,0 +1,305 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_uni_lrn_kernel_f32.hpp" +#include "jit_uni_lrn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +template +jit_uni_lrn_fwd_t::jit_uni_lrn_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd), ker_(nullptr) + , ker_first_(nullptr), ker_last_(nullptr) +{ + using namespace alg_kind; + + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + float A = pd()->desc()->lrn_alpha / ls; + float K = pd()->desc()->lrn_k; + + auto pk = pd()->desc()->prop_kind; + auto ak = pd()->desc()->alg_kind; + auto dat_tag = pd()->dat_tag_; + + if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) { + ker_ = new jit_uni_lrn_fwd_kernel_f32( + nchw8c_across(H, W, 0), A, K, pk); + ker_first_ = new jit_uni_lrn_fwd_kernel_f32( + nchw8c_across(H, W, -1), A, K, pk); + ker_last_ = new jit_uni_lrn_fwd_kernel_f32( + nchw8c_across(H, W, +1), A, K, pk); + } else if (dat_tag == nChw8c && ak == lrn_within_channel) { + /* within channel, local_size (x) local_size */ + A /= ls; /* XXX: why? */ + ker_ = new jit_uni_lrn_fwd_kernel_f32( + nchw8c_within(H, W, ls), A, K, pk); + } else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) { + ker_ = new jit_uni_lrn_fwd_kernel_f32( + nchw_across(C, H*W, 0), A, K, pk); + int remind = (H*W) % VECTOR_LENGTH; + if (remind != 0) { + ker_last_ = new jit_uni_lrn_fwd_kernel_f32( + nchw_across(C, H*W, remind), A, K, pk); + } + } else if (true /* XXX: why */) { + ker_ = new jit_uni_lrn_fwd_kernel_f32(nhwc_across(C), A, K, pk); + } +} + +template +jit_uni_lrn_fwd_t::~jit_uni_lrn_fwd_t() +{ delete ker_; delete ker_first_; delete ker_last_; } + +template +void jit_uni_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int HW = pd()->H() * pd()->W(); + const int ls = pd()->desc()->local_size; + + auto ak = pd()->desc()->alg_kind; + auto dat_tag = pd()->dat_tag_; + + if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) { + parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH]; + if (c8 == 0) + (*ker_first_)(&args); + else if (c8 == C / VECTOR_LENGTH - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } + else if (dat_tag == nChw8c && ak == lrn_within_channel) { + parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH]; + (*ker_)(&args); + }); + } + else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) { + parallel_nd(N, (HW + VECTOR_LENGTH - 1) / VECTOR_LENGTH, + [&](int n, int hw8) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + hw8 * VECTOR_LENGTH]; + args.dst = &dst[n*HW*C + hw8 * VECTOR_LENGTH]; + args.scratch = &ws[n*HW*C + hw8 * VECTOR_LENGTH]; + if ((hw8 + 1)*VECTOR_LENGTH > HW) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } + else { // nhwc + parallel_nd(N, HW, [&](int n, int hw) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + hw * C]; + args.dst = &dst[n*HW*C + hw * C]; + args.scratch = &ws[n*HW*C + hw * C]; + (*ker_)(&args); + }); + } +} + +template +status_t jit_uni_lrn_fwd_t::pd_t::init() { + using namespace prop_kind; + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(isa) + && is_fwd() + && everyone_is(data_type::f32, data_d.data_type()) + && !has_zero_dim_memory() + && data_d.ndims() == 4 + && data_d.dims()[1] % VECTOR_LENGTH == 0 + && data_d.dims()[1] >= 2 * VECTOR_LENGTH + && desc()->lrn_beta == 0.75 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + if (desc_.prop_kind == forward_training) ws_md_ = *src_md(); + + dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c, nchw, nhwc); + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && one_of(dat_tag_, nChw8c, nchw, nhwc); + + const int jit_max_local_size = 5; // bigger size triggers too big code size + bool args_ok_within = true + && desc()->alg_kind == lrn_within_channel + && desc()->local_size <= ( jit_max_local_size <= MAX_LOCAL_SIZE + ? jit_max_local_size : MAX_LOCAL_SIZE) + && data_d.dims()[2] >= desc()->local_size + && data_d.dims()[3] >= desc()->local_size + && one_of(dat_tag_, nChw8c); + + return args_ok_across || args_ok_within ? success : unimplemented; +} + +template +jit_uni_lrn_bwd_t::jit_uni_lrn_bwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , ker_(nullptr), ker_first_(nullptr), ker_last_(nullptr) +{ + using namespace alg_kind; + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + float A = pd()->desc()->lrn_alpha / ls; + float B = pd()->desc()->lrn_beta; + + int use_h_parallelizm = 0;// XXX + if (C / VECTOR_LENGTH == 1) { + ker_ = new jit_uni_lrn_bwd_kernel_f32( + nchw8c_across(H, W, 3), A, B, use_h_parallelizm); + } + else { + ker_ = new jit_uni_lrn_bwd_kernel_f32( + nchw8c_across(H, W, 0), A, B, use_h_parallelizm); + ker_first_ = new jit_uni_lrn_bwd_kernel_f32( + nchw8c_across(H, W, -1), A, B, use_h_parallelizm); + ker_last_ = new jit_uni_lrn_bwd_kernel_f32( + nchw8c_across(H, W, +1), A, B, use_h_parallelizm); + } +} + +template +jit_uni_lrn_bwd_t::~jit_uni_lrn_bwd_t() +{ + delete ker_; delete ker_first_; delete ker_last_; +} + +template +void jit_uni_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + + int use_h_parallelizm = 0; // XXX + if (use_h_parallelizm) { + parallel_nd(N, C / VECTOR_LENGTH, H, [&](int n, int c8, int h) { + auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH + + h*W*VECTOR_LENGTH; + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.scratch = &ws[offset]; + args.diff_src = &diff_src[offset]; + if (C / VECTOR_LENGTH == 1) + (*ker_)(&args); + else if (c8 == 0) + (*ker_first_)(&args); + else if (c8 == C / VECTOR_LENGTH - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } + else { + parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { + auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH; + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.scratch = &ws[offset]; + args.diff_src = &diff_src[offset]; + if (C / VECTOR_LENGTH == 1) + (*ker_)(&args); + else if (c8 == 0) + (*ker_first_)(&args); + else if (c8 == C / VECTOR_LENGTH - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } +} + +template +status_t jit_uni_lrn_bwd_t::pd_t::init() { + using namespace prop_kind; + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(isa) + && !is_fwd() + && utils::everyone_is(data_type::f32, data_d.data_type()) + && !has_zero_dim_memory() + && data_d.ndims() == 4 + && data_d.dims()[1] % VECTOR_LENGTH == 0 + && desc()->lrn_beta == 0.75 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + ws_md_ = *src_md(); + if (!compare_ws(hint_fwd_pd_)) return unimplemented; + + dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c); + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && utils::one_of(dat_tag_, nChw8c); + + return args_ok_across ? success : unimplemented; +} + +template struct jit_uni_lrn_fwd_t; +template struct jit_uni_lrn_fwd_t; +template struct jit_uni_lrn_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp new file mode 100644 index 000000000000..333cd3396d36 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_LRN_HPP +#define CPU_JIT_UNI_LRN_HPP + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_isa_traits.hpp" +#include "cpu_lrn_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template struct jit_uni_lrn_fwd_kernel_f32; +template struct jit_uni_lrn_bwd_kernel_f32; + +template +struct jit_uni_lrn_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_fwd_pd_t { + using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_lrn_fwd_t); + + status_t init(); + + format_tag_t dat_tag_; + }; + + jit_uni_lrn_fwd_t(const pd_t *apd); + ~jit_uni_lrn_fwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_lrn_fwd_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +template +struct jit_uni_lrn_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_bwd_pd_t { + using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_lrn_bwd_t); + + status_t init(); + + format_tag_t dat_tag_; + }; + + jit_uni_lrn_bwd_t(const pd_t *apd); + ~jit_uni_lrn_bwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_lrn_bwd_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp new file mode 100644 index 000000000000..89af47272cae --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp @@ -0,0 +1,1487 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_uni_lrn_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +////////////////////////////////////////////////////////////////////////////// +// forward kernel +template +void jit_uni_lrn_fwd_kernel_f32::within_body( + int hoff, int Hoff, int woff, int Woff, int stride, + Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2, + prop_kind_t pk) +{ + vxorps(ysum, ysum, ysum); + for (int i = hoff; i <= Hoff; ++i) + { + for (int j = woff; j <= Woff; ++j) + { + if (i == 0 && j == 0) + { + vmovups(ydst, ptr[src]); + vfmadd231ps(ysum, ydst, ydst); + } + else + { + vmovups(ytmp, ptr[src + (i*stride + j)*VECTOR_LENGTH*4]); + vfmadd231ps(ysum, ytmp, ytmp); + } + } + } + vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk + vmovaps(ytmp, ysum); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ytmp); + vmulps(ysum2, ysum, ysum); + vmulps(ysum, ysum, ysum2); // ysum = (ysum*yalpha+yk)^3; + vsqrtps(ysum, ysum); + vsqrtps(ysum, ysum); // ysum = (ysum*yalpha+yk)^0.75 + vdivps(ydst, ydst, ysum); // ydst <- ydst / ysum + vmovups(ptr[dst], ydst); + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); +} + +template +void jit_uni_lrn_fwd_kernel_f32::within_body_sse42( + int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk) +{ + Xbyak::Xmm xtmp_lo = xmm12; + Xbyak::Xmm xtmp_hi = xmm13; + Xbyak::Xmm xsum_lo = xmm8; + Xbyak::Xmm xsum_hi = xmm9; + Xbyak::Xmm xdst_lo = xmm10; + Xbyak::Xmm xdst_hi = xmm11; + Xbyak::Xmm xsum2_lo = xmm14; + Xbyak::Xmm xsum2_hi = xmm15; + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + for (int i = hoff; i <= Hoff; ++i) + { + for (int j = woff; j <= Woff; ++j) + { + if (i == 0 && j == 0) + { + movups(xdst_lo, ptr[src]); + movups(xdst_hi, ptr[src + 4 * sizeof(float)]); + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + addps(xsum_lo, xdst_lo); + addps(xsum_hi, xdst_hi); + } + else + { + movups(xtmp_lo, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4]); + movups(xtmp_hi, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4 + 4 * sizeof(float)]); + mulps(xtmp_lo, xtmp_lo); + mulps(xtmp_hi, xtmp_hi); + addps(xsum_lo, xtmp_lo); + addps(xsum_hi, xtmp_hi); + } + } + } + mulps(xsum_lo, xalpha); + mulps(xsum_hi, xalpha); + addps(xsum_lo, xk); + addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk + movaps(xtmp_lo, xsum_lo); + movaps(xtmp_hi, xsum_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xtmp_lo); + movups(ptr[scratch + 4 * sizeof(float)], xtmp_hi); + } + movaps(xsum2_lo, xsum_lo); + movaps(xsum2_hi, xsum_hi); + mulps(xsum2_lo, xsum_lo); + mulps(xsum2_hi, xsum_hi); + mulps(xsum_lo, xsum2_lo); + mulps(xsum_hi, xsum2_hi); // xsum = (xsum*xalpha+xk)^3; + + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); // xsum = (xsum*xalpha+xk)^0.75 + + movups(xdst_lo, ptr[src]); + movups(xdst_hi, ptr[src + 4 * sizeof(float)]); + divps(xdst_lo, xsum_lo); + divps(xdst_hi, xsum_hi); // xdst <- xdst / xsum + + movups(ptr[dst], xdst_lo); + movups(ptr[dst + 4 * sizeof(float)], xdst_hi); + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); +} + +template +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_within &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + Xbyak::Reg64 h = r9; + Xbyak::Reg64 w = r10; + Vmm ysum = Vmm(isa == avx2 ? 9 : 9); + Vmm ysum2 = Vmm(isa == avx2 ? 10 : 10); + Vmm ydst = Vmm(isa == avx2 ? 11 : 11); + Vmm ytmp = Vmm(isa == avx2 ? 12 : 12); + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + if (isa == avx2) { + vbroadcastss(yalpha, xalpha); + } else { + shufps(xalpha, xalpha, 0); + } + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + if (isa == avx2) { + vbroadcastss(yk, xk); + } else { + shufps(xk, xk, 0); + } + + int s2 = (J.size - 1) / 2, S2 = J.size - s2 - 1; + + for (int i = 0; i < s2; ++i) + { + Label label_t; + for (int j = 0; j < s2; ++j) { + if (isa == avx2) { + within_body(-i, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } + else { + within_body_sse42(-i, S2, -j, S2, J.W, pk); + } + } + mov(w, J.W - J.size + 1); + L(label_t); + if (isa == avx2) { + within_body(-i, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-i, S2, -s2, S2, J.W, pk); + } + dec(w); + cmp(w, 0); + jne(label_t, T_NEAR); + for (int j = J.W - S2; j < J.W; ++j) { + if (isa == avx2) { + within_body(-i, S2, -s2, J.W - 1 - j, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-i, S2, -s2, J.W - 1 - j, J.W, pk); + } + } + } + + mov(h, J.H - J.size + 1); + Label lrn_loop_h; + L(lrn_loop_h); + for (int j = 0; j < s2; ++j) { + if (isa == avx2) { + within_body(-s2, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, S2, -j, S2, J.W, pk); + } + } + mov(w, J.W - J.size + 1); + Label lrn_loop_w; + L(lrn_loop_w); + if (isa == avx2) { + within_body(-s2, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, S2, -s2, S2, J.W, pk); + } + dec(w); + cmp(w, 0); + jne(lrn_loop_w, T_NEAR); + for (int j = J.W - S2; j < J.W; ++j) { + if (isa == avx2) { + within_body(-s2, S2, -s2, J.W - 1 - j, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, S2, -s2, J.W - 1 - j, J.W, pk); + } + } + dec(h); + cmp(h, 0); + jne(lrn_loop_h, T_NEAR); + + for (int i = J.H - S2; i < J.H; ++i) + { + for (int j = 0; j < s2; ++j) { + if (isa == avx2) { + within_body(-s2, J.H - 1 - i, -j, S2, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, J.H - 1 - i, -j, S2, J.W, pk); + } + } + + mov(w, J.W - J.size + 1); + Label label_b; + L(label_b); + if (isa == avx2) { + within_body(-s2, J.H - 1 - i, -s2, S2, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, J.H - 1 - i, -s2, S2, J.W, pk); + } + dec(w); + cmp(w, 0); + jne(label_b, T_NEAR); + + for (int j = J.W - S2; j < J.W; ++j) { + if (isa == avx2) { + within_body(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, pk); + } + } + } + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + Xbyak::Reg64 t = rsp; + Xbyak::Reg64 hw = r9; + Xbyak::Xmm xsrc_prev = xmm2; + Xbyak::Ymm ysrc = ymm3; + Xbyak::Ymm yc = ymm3; + Xbyak::Xmm xsrc_next = xmm4; + Xbyak::Ymm ya = ymm5; + Xbyak::Ymm yb = ymm6; + Xbyak::Ymm yd = ymm7; + Xbyak::Ymm ye = ymm8; + Xbyak::Ymm ysum = ymm9; + Xbyak::Ymm ysum2 = ymm10; + Xbyak::Ymm ydst = ymm11; + Xbyak::Ymm ybase = ymm12; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + sub(t, 64); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(yalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(yk, xk); + + if (J.version == -1) + { + vxorps(xsrc_prev, xsrc_prev, xsrc_prev); + vmovups(ptr[t + 0], xsrc_prev); + } + if (J.version == +1) + { + vxorps(xsrc_next, xsrc_next, xsrc_next); + vmovups(ptr[t + 48], xsrc_next); + } + + mov(hw, J.H*J.W); + + Label lrn_loop; + L(lrn_loop); + + if (J.version != -1) vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); + vmovups(ysrc, ptr[src]); + if (J.version != +1) vmovups(xsrc_next, ptr[src + J.H*J.W * 32]); + + if (J.version != -1) vmovups(ptr[t + 0], xsrc_prev); + vmovups(ptr[t + 16], ysrc); + if (J.version != +1) vmovups(ptr[t + 48], xsrc_next); + + vmovups(ya, ptr[t + 16 - 8]); + vmovups(yb, ptr[t + 16 - 4]); + vmovups(yd, ptr[t + 16 + 4]); + vmovups(ye, ptr[t + 16 + 8]); + vmulps(ysum, yc, yc); + vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya*ya + vfmadd231ps(ysum, yb, yb); + vfmadd231ps(ysum, yd, yd); + vfmadd231ps(ysum, ye, ye); + vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk + + vmovaps(ybase, ysum); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ybase); + vmulps(ysum2, ysum, ysum); + vmulps(ysum, ysum, ysum2); // ysum = ybase^3; + vsqrtps(ysum, ysum); + vsqrtps(ysum, ysum); // ysum = ybase^0.75 + vdivps(ydst, ysrc, ysum); // ydst = ysrc / ysum + vmovups(ptr[dst], ydst); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + + add(t, 64); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + Xbyak::Reg64 t = rsp; + Xbyak::Reg64 hw = r9; + + Xbyak::Xmm xsrc_lo = xmm2; + Xbyak::Xmm xsrc_hi = xmm3; + Xbyak::Xmm xc_lo = xmm4; + Xbyak::Xmm xc_hi = xmm5; + Xbyak::Xmm xsum_lo = xc_lo; + Xbyak::Xmm xsum_hi = xc_hi; + Xbyak::Xmm xsrc_prev = xmm6; + Xbyak::Xmm xsrc_next = xmm7; + Xbyak::Xmm xa_lo = xmm8; + Xbyak::Xmm xa_hi = xmm9; + Xbyak::Xmm xb_lo = xmm10; + Xbyak::Xmm xb_hi = xmm11; + Xbyak::Xmm xd_lo = xmm12; + Xbyak::Xmm xd_hi = xmm13; + Xbyak::Xmm xe_lo = xmm14; + Xbyak::Xmm xe_hi = xmm15; + Xbyak::Xmm xbase_lo = xmm14; + Xbyak::Xmm xbase_hi = xmm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + sub(t, 64); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + shufps(xalpha, xalpha, 0); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + shufps(xk, xk, 0); + + if (J.version == -1) + { + xorps(xsrc_prev, xsrc_prev); + movups(ptr[t + 0], xsrc_prev); + } + if (J.version == +1) + { + xorps(xsrc_next, xsrc_next); + movups(ptr[t + 48], xsrc_next); + } + + mov(hw, J.H*J.W); + Label lrn_loop; + L(lrn_loop); + + if (J.version != -1) movups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); + movups(xsrc_lo, ptr[src]); + movups(xsrc_hi, ptr[src + 4 * sizeof(float)]); + if (J.version != +1) movups(xsrc_next, ptr[src + J.H*J.W * 32]); + + if (J.version != -1) movups(ptr[t + 0], xsrc_prev); + movups(ptr[t + 16], xsrc_lo); + movups(ptr[t + 16 + 4 * sizeof(float)], xsrc_hi); + if (J.version != +1) movups(ptr[t + 48], xsrc_next); + + movups(xa_lo, ptr[t + 16 - 8]); + movups(xa_hi, ptr[t + 16 - 8 + 4 * sizeof(float)]); + movups(xb_lo, ptr[t + 16 - 4]); + movups(xb_hi, ptr[t + 16 - 4 + 4 * sizeof(float)]); + movups(xd_lo, ptr[t + 16 + 4]); + movups(xd_hi, ptr[t + 16 + 4 + 4 * sizeof(float)]); + movups(xe_lo, ptr[t + 16 + 8]); + movups(xe_hi, ptr[t + 16 + 8 + 4 * sizeof(float)]); + movaps(xc_lo, xsrc_lo); + movaps(xc_hi, xsrc_hi); + mulps(xsum_lo, xc_lo); + mulps(xsum_hi, xc_hi); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + addps(xsum_lo, xa_lo); + addps(xsum_hi, xa_hi); // xsum <- xsum + xa*xa + mulps(xb_lo, xb_lo); + mulps(xb_hi, xb_hi); + addps(xsum_lo, xb_lo); + addps(xsum_hi, xb_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + mulps(xsum_lo, xalpha); + mulps(xsum_hi, xalpha); + addps(xsum_lo, xk); + addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk + + movaps(xbase_lo, xsum_lo); + movaps(xbase_hi, xsum_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + mulps(xsum_lo, xsum_lo); + mulps(xsum_hi, xsum_hi); + mulps(xsum_lo, xbase_lo); + mulps(xsum_hi, xbase_hi); // xsum = xbase^3; + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); // xsum = xbase^0.75 + divps(xsrc_lo, xsum_lo); + divps(xsrc_hi, xsum_hi); // xdst = xsrc / xsum + movups(ptr[dst], xsrc_lo); + movups(ptr[dst + 4 * sizeof(float)], xsrc_hi); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + + add(t, 64); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nhwc_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0, 0, 0x80000000, 0x80000000, 0x80000000, 0x80000000, + 0x80000000, 0x80000000, 0x80000000, 0, 0 + }; + + Xbyak::Reg64 c = r9; + Xbyak::Ymm ya = ymm2; + Xbyak::Ymm yb = ymm3; + Xbyak::Ymm yc = ymm4; + Xbyak::Ymm yd = ymm5; + Xbyak::Ymm ye = ymm6; + Xbyak::Ymm ysum = ymm7; + Xbyak::Ymm ydst = ymm8; + Xbyak::Ymm ybase = ymm9; + Xbyak::Ymm ymask = ymm10; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(yalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(yk, xk); + + vxorps(ysum, ysum, ysum); + + mov(imm_addr64, reinterpret_cast(&mask[0])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(ya, ymask, ptr[src - 8]); + vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 + + mov(imm_addr64, reinterpret_cast(&mask[1])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(yb, ymask, ptr[src - 4]); + vfmadd231ps(ysum, yb, yb); + + mov(c, J.C / 8 - 1); + Label lrn_loop; + L(lrn_loop); + + vmovups(yc, ptr[src]); + vmovups(yd, ptr[src + 4]); + vmovups(ye, ptr[src + 8]); + vfmadd231ps(ysum, yc, yc); + vfmadd231ps(ysum, yd, yd); + vfmadd231ps(ysum, ye, ye); + + vmovups(ydst, ysum); + vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk + + vmovaps(ybase, ydst); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ybase); + vmulps(ydst, ydst, ydst); + vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; + vsqrtps(ydst, ydst); + vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 + + vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 + vmovups(ptr[dst], ydst); + + vxorps(ysum, ysum, ysum); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + + vmovups(ya, ptr[src - 8]); + vfmadd231ps(ysum, ya, ya); + vmovups(yb, ptr[src - 4]); + vfmadd231ps(ysum, yb, yb); + + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + vmovups(yc, ptr[src]); + vfmadd231ps(ysum, yc, yc); + + mov(imm_addr64, reinterpret_cast(&mask[2])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(yd, ymask, ptr[src + 4]); + vfmadd231ps(ysum, yd, yd); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 + + mov(imm_addr64, reinterpret_cast(&mask[3])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(ye, ymask, ptr[src + 8]); + vfmadd231ps(ysum, ye, ye); + + vmovups(ydst, ysum); + vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk + + vmovaps(ybase, ydst); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ybase); + vmulps(ydst, ydst, ydst); + vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; + vsqrtps(ydst, ydst); + vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 + vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 + + vmovups(ptr[dst], ydst); + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nhwc_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0, 0, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0xffffffff, 0xffffffff, 0xffffffff, 0, 0 + }; + + static uint32_t store[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }; + Xbyak::Reg64 c = r9; + + Xbyak::Xmm xdst_lo = xmm0; + Xbyak::Xmm xdst_hi = xmm1; + Xbyak::Xmm xa_lo = xmm2; + Xbyak::Xmm xa_hi = xmm3; + Xbyak::Xmm xb_lo = xmm2; + Xbyak::Xmm xb_hi = xmm3; + Xbyak::Xmm xc_lo = xmm4; + Xbyak::Xmm xc_hi = xmm5; + Xbyak::Xmm xd_lo = xmm6; + Xbyak::Xmm xd_hi = xmm7; + Xbyak::Xmm xe_lo = xmm8; + Xbyak::Xmm xe_hi = xmm9; + Xbyak::Xmm xsum_lo = xmm10; + Xbyak::Xmm xsum_hi = xmm11; + Xbyak::Xmm xmask_lo = xmm12; + Xbyak::Xmm xmask_hi = xmm13; + Xbyak::Xmm xbase_lo = xmm14; + Xbyak::Xmm xbase_hi = xmm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + shufps(xalpha, xalpha, 0); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + shufps(xk, xk, 0); + + mov(store_addr, reinterpret_cast(&store[0])); + and_(store_addr, -15); + movups(ptr[store_addr], xalpha); + movups(ptr[store_addr + 4 * sizeof(float)], xk); + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + + mov(imm_addr64, reinterpret_cast(&mask[0])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xa_lo, ptr[src - 8]); + movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]); + andps(xa_lo, xmask_lo); + andps(xa_hi, xmask_hi); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + addps(xsum_lo, xa_lo); + addps(xsum_hi, xa_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 + + mov(imm_addr64, reinterpret_cast(&mask[1])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xb_lo, ptr[src - 4]); + movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]); + andps(xb_lo, xmask_lo); + andps(xb_hi, xmask_hi); + mulps(xb_lo, xb_lo); + mulps(xb_hi, xb_hi); + addps(xsum_lo, xb_lo); + addps(xsum_hi, xb_hi); + + mov(c, J.C / 8 - 1); + Label lrn_loop; + L(lrn_loop); + + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + movups(xd_lo, ptr[src + 4]); + movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]); + movups(xe_lo, ptr[src + 8]); + movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]); + mulps(xc_lo, xc_lo); + mulps(xc_hi, xc_hi); + addps(xsum_lo, xc_lo); + addps(xsum_hi, xc_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + movaps(xdst_lo, xsum_lo); + movaps(xdst_hi, xsum_hi); + // xdst <- xsum*xalpha+xk + mulps(xdst_lo, ptr[store_addr]); + mulps(xdst_hi, ptr[store_addr]); + addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]); + addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]); + + movaps(xbase_lo, xdst_lo); + movaps(xbase_hi, xdst_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + mulps(xdst_lo, xbase_lo); + mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 + + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + divps(xc_lo, xdst_lo); + divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 + movups(ptr[dst], xc_lo); + movups(ptr[dst + 4 * sizeof(float)], xc_hi); + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + + movups(xa_lo, ptr[src - 8]); + movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + addps(xsum_lo, xa_lo); + addps(xsum_hi, xa_hi); + movups(xb_lo, ptr[src - 4]); + movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]); + mulps(xb_lo, xb_lo); + mulps(xb_hi, xb_hi); + addps(xsum_lo, xb_lo); + addps(xsum_hi, xb_hi); + + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + mulps(xc_lo, xc_lo); + mulps(xc_hi, xc_hi); + addps(xsum_lo, xc_lo); + addps(xsum_hi, xc_hi); + + mov(imm_addr64, reinterpret_cast(&mask[2])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xd_lo, ptr[src + 4]); + movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]); + andps(xd_lo, xmask_lo); + andps(xd_hi, xmask_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 + + mov(imm_addr64, reinterpret_cast(&mask[3])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xe_lo, ptr[src + 8]); + movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]); + andps(xe_lo, xmask_lo); + andps(xe_hi, xmask_hi); + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + movups(xdst_lo, xsum_lo); + movups(xdst_hi, xsum_hi); + // xdst <- xsum*xalpha+xk + mulps(xdst_lo, ptr[store_addr]); + mulps(xdst_hi, ptr[store_addr]); + addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]); + addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]); + + movaps(xbase_lo, xdst_lo); + movaps(xbase_hi, xdst_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + mulps(xdst_lo, xbase_lo); + mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + divps(xc_lo, xdst_lo); + divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 + + movups(ptr[dst], xc_lo); + movups(ptr[dst + 4 * sizeof(float)], xc_hi); + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_body( + int tail, int HW, prop_kind_t pk, + Xbyak::Ymm ymask, + Xbyak::Ymm ya, + Xbyak::Ymm yb, + Xbyak::Ymm yc, + Xbyak::Ymm yd, + Xbyak::Ymm ye, + Xbyak::Ymm ysum) {} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_body( + int tail, int HW, prop_kind_t pk, + Xbyak::Ymm ymask, + Xbyak::Ymm ya, + Xbyak::Ymm yb, + Xbyak::Ymm yc, + Xbyak::Ymm yd, + Xbyak::Ymm ye, + Xbyak::Ymm ysum) +{ + Xbyak::Ymm ydst = ymm14; + Xbyak::Ymm ybase = ymm15; + + vfmadd231ps(ysum, ye, ye); + + vmovups(ydst, ysum); + vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk + + vmovaps(ybase, ydst); + if (pk != prop_kind::forward_inference) + { + if (tail != 0) + vmaskmovps(ptr[scratch], ymask, ybase); + else + vmovups(ptr[scratch], ybase); + } + vmulps(ydst, ydst, ydst); + vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; + vsqrtps(ydst, ydst); + vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 + vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 + + if (tail != 0) + vmaskmovps(ptr[dst], ymask, ydst); + else + vmovups(ptr[dst], ydst); + + + vfnmadd231ps(ysum, ya, ya); + vmovups(ya, yb); + vmovups(yb, yc); + vmovups(yc, yd); + vmovups(yd, ye); +} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_tail_sse42( + int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) +{} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_tail_sse42( + int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) +{ + Xbyak::Xmm xmm_tmp = xmm10; + movaps(xmm_tmp, xtail_lo); + size_t offset = 0; + + if (tail > 4) { + movups(ptr[reg_dst], xtail_lo); + movaps(xmm_tmp, xtail_hi); + offset += 4 * sizeof(float); + tail -= 4; + } + movss(ptr[reg_dst + offset], xmm_tmp); + for (int i = 1; i < tail; i++) + { + psrldq(xmm_tmp, 4); + movss(ptr[reg_dst + offset + i * sizeof(float)], xmm_tmp); + } +} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_body_sse42( + int tail, int HW, prop_kind_t pk, + Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, + Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, + Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) +{ + Xbyak::Xmm xdst_lo = xmm0; + Xbyak::Xmm xdst_hi = xmm1; + Xbyak::Xmm xbase_lo = xmm6; + Xbyak::Xmm xbase_hi = xmm7; + Xbyak::Xmm xtmp_lo = xmm8; + Xbyak::Xmm xtmp_hi = xmm9; + Xbyak::Xmm xa_lo = xmm6; + Xbyak::Xmm xa_hi = xmm7; + Xbyak::Xmm xb_lo = xmm8; + Xbyak::Xmm xb_hi = xmm9; + Xbyak::Xmm xc_lo = xmm10; + Xbyak::Xmm xc_hi = xmm11; + Xbyak::Xmm xd_lo = xmm12; + Xbyak::Xmm xd_hi = xmm13; + + // store xe + movaps(ptr[store_addr + 10 * 4 * sizeof(float)], xe_lo); + movaps(ptr[store_addr + 11 * 4 * sizeof(float)], xe_hi); + + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + // xdst <- xsum*xalpha+xk + movaps(xdst_lo, xsum_lo); + movaps(xdst_hi, xsum_hi); + mulps(xdst_lo, ptr[store_addr + 0 * 4 * sizeof(float)]); + mulps(xdst_hi, ptr[store_addr + 0 * 4 * sizeof(float)]); + addps(xdst_lo, ptr[store_addr + 1 * 4 * sizeof(float)]); + addps(xdst_hi, ptr[store_addr + 1 * 4 * sizeof(float)]); + + movaps(xbase_lo, xdst_lo); + movaps(xbase_hi, xdst_hi); + if (pk != prop_kind::forward_inference) + { + if (tail != 0) { + nchw_tail_sse42(tail, scratch, xbase_lo, xbase_hi); + } + else { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + } + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + mulps(xdst_lo, xbase_lo); + mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 + movaps(xtmp_lo, ptr[store_addr + 6 * 4 * sizeof(float)]); + movaps(xtmp_hi, ptr[store_addr + 7 * 4 * sizeof(float)]); + divps(xtmp_lo, xdst_lo); + divps(xtmp_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 + movaps(xdst_lo, xtmp_lo); + movaps(xdst_hi, xtmp_hi); + + if (tail != 0) { + nchw_tail_sse42(tail, dst, xdst_lo, xdst_hi); + } + else { + movups(ptr[dst], xdst_lo); + movups(ptr[dst + 4 * sizeof(float)], xdst_hi); + } + + movaps(xa_lo, ptr[store_addr + 2 * 4 * sizeof(float)]); + movaps(xa_hi, ptr[store_addr + 3 * 4 * sizeof(float)]); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + subps(xsum_lo, xa_lo); + subps(xsum_hi, xa_hi); + + // xa <- xb + movaps(xb_lo, ptr[store_addr + 4 * 4 * sizeof(float)]); + movaps(xb_hi, ptr[store_addr + 5 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xb_lo); + movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xb_hi); + + // xb <- xc + movaps(xc_lo, ptr[store_addr + 6 * 4 * sizeof(float)]); + movaps(xc_hi, ptr[store_addr + 7 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xc_lo); + movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xc_hi); + + // xc <- xd + movaps(xd_lo, ptr[store_addr + 8 * 4 * sizeof(float)]); + movaps(xd_hi, ptr[store_addr + 9 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xd_lo); + movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xd_hi); + + // xd <- xe + movaps(xe_lo, ptr[store_addr + 10 * 4 * sizeof(float)]); + movaps(xe_hi, ptr[store_addr + 11 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xe_lo); + movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xe_hi); +} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_body_sse42( + int tail, int HW, prop_kind_t pk, + Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, + Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, + Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) {} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + struct nchw_across J, + float A, + float K, + prop_kind_t pk, + void* code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, + 0x80000000, 0x80000000, 0, 0, 0, 0, 0, 0, 0 + }; + Xbyak::Reg64 c = r10; + Xbyak::Ymm ymask = ymm2; + Xbyak::Ymm ye = ymm3; + Xbyak::Ymm ya = ymm4; + Xbyak::Ymm yb = ymm5; + Xbyak::Ymm yc = ymm6; + Xbyak::Ymm yd = ymm7; + Xbyak::Ymm ysum = ymm8; + + this->preamble(); + + if (J.tail != 0) + { + mov(imm_addr64, reinterpret_cast(&mask[7 - J.tail])); + vmovups(ymask, ptr[imm_addr64]); + } + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(yalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(yk, xk); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + + vxorps(ya, ya, ya); + vxorps(yb, yb, yb); + if (J.tail != 0) + vmaskmovps(yc, ymask, ptr[src + J.HW * 0]); + else + vmovups(yc, ptr[src + J.HW * 0]); + if (J.tail != 0) + vmaskmovps(yd, ymask, ptr[src + J.HW * 4]); + else + vmovups(yd, ptr[src + J.HW * 4]); + + vxorps(ysum, ysum, ysum); + vfmadd231ps(ysum, yc, yc); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 + vfmadd231ps(ysum, yd, yd); + + mov(c, J.C - 2); + Label lrn_loop; + L(lrn_loop); + + if (J.tail != 0) + vmaskmovps(ye, ymask, ptr[src + J.HW * 8]); + else + vmovups(ye, ptr[src + J.HW * 8]); + + nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); + + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + vxorps(ye, ye, ye); + + nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + + nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + struct nchw_across J, + float A, + float K, + prop_kind_t pk, + void* code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0 + }; + + Xbyak::Reg64 c = r10; + + Xbyak::Xmm xmask_lo = xmm2; + Xbyak::Xmm xmask_hi = xmm3; + Xbyak::Xmm xsum_lo = xmm4; + Xbyak::Xmm xsum_hi = xmm5; + Xbyak::Xmm xa_lo = xmm6; + Xbyak::Xmm xa_hi = xmm7; + Xbyak::Xmm xb_lo = xmm8; + Xbyak::Xmm xb_hi = xmm9; + Xbyak::Xmm xc_lo = xmm10; + Xbyak::Xmm xc_hi = xmm11; + Xbyak::Xmm xd_lo = xmm12; + Xbyak::Xmm xd_hi = xmm13; + Xbyak::Xmm xe_lo = xmm14; + Xbyak::Xmm xe_hi = xmm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + + sub(rsp, stack_space_needed); + mov(store_addr, rsp); + and_(store_addr, -15); + + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + shufps(xalpha, xalpha, 0); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + shufps(xk, xk, 0); + + // put alpha and k into store (free up regs) + movaps(ptr[store_addr + 0 * 4 * sizeof(float)], xalpha); + movaps(ptr[store_addr + 1 * 4 * sizeof(float)], xk); + + if (J.tail != 0) + { + mov(imm_addr64, reinterpret_cast(&mask[7 - J.tail])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + } + // init xa, xb + xorps(xa_lo, xa_lo); + xorps(xa_hi, xa_hi); + xorps(xb_lo, xb_lo); + xorps(xb_hi, xb_hi); + + // read xc, xd + if (J.tail != 0) { + movups(xc_lo, ptr[src + J.HW * 0]); + movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]); + andps(xc_lo, xmask_lo); + andps(xc_hi, xmask_hi); + } + else { + movups(xc_lo, ptr[src + J.HW * 0]); + movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]); + } + if (J.tail != 0) { + movups(xd_lo, ptr[src + J.HW * 4]); + movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]); + andps(xd_lo, xmask_lo); + andps(xd_hi, xmask_hi); + } + else { + movups(xd_lo, ptr[src + J.HW * 4]); + movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]); + } + + // put xa, xb, xc, xd into store to free-up regs + movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xa_lo); + movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xa_hi); + movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xb_lo); + movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xb_hi); + movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xc_lo); + movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xc_hi); + movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xd_lo); + movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xd_hi); + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + mulps(xc_lo, xc_lo); + mulps(xc_hi, xc_hi); + addps(xsum_lo, xc_lo); + addps(xsum_hi, xc_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 + + mov(c, J.C - 2); + Label lrn_loop; + L(lrn_loop); + + if (J.tail != 0) { + movups(xe_lo, ptr[src + J.HW * 8]); + movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]); + andps(xe_lo, xmask_lo); + andps(xe_hi, xmask_hi); + } + else { + movups(xe_lo, ptr[src + J.HW * 8]); + movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]); + } + + nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, + xe_lo, xe_hi, + xsum_lo, xsum_hi); + + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + xorps(xe_lo, xe_lo); + xorps(xe_hi, xe_hi); + + nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, + xe_lo, xe_hi, + xsum_lo, xsum_hi); + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + + nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, + xe_lo, xe_hi, + xsum_lo, xsum_hi); + + add(rsp, stack_space_needed); + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +////////////////////////////////////////////////////////////////////////////// +// backward kernel +template +jit_uni_lrn_bwd_kernel_f32::jit_uni_lrn_bwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float B, + int use_h_parallel, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , nalphabeta(-2 * A*B) + , use_h_parallelizm(use_h_parallel) +{ + Xbyak::Reg64 t = rsp; + Xbyak::Reg64 hw = r10; + + Xbyak::Xmm xsrc_prev = xmm1; + Xbyak::Xmm xws_prev = xmm2; + Xbyak::Xmm xdiffdst_prev = xmm3; + Xbyak::Ymm ysrc = ymm4; + Xbyak::Ymm yws = ymm5; + Xbyak::Ymm ydiffdst = ymm6; + Xbyak::Xmm xsrc_next = xmm7; + Xbyak::Xmm xws_next = xmm8; + Xbyak::Xmm xdiffdst_next = xmm9; + Xbyak::Ymm ya = ymm10; + Xbyak::Xmm xa = xmm10; + Xbyak::Ymm yb = ymm11; + Xbyak::Ymm yd = ymm12; + Xbyak::Ymm ye = ymm13; + Xbyak::Ymm ysum = ymm14; + Xbyak::Ymm ydiffsrc = ymm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(diffdst, ptr[this->param1 + 8]); + mov(workspace, ptr[this->param1 + 16]); + mov(diffsrc, ptr[this->param1 + 24]); + + sub(t, 64); + mov(imm_addr64, float2int(this->nalphabeta)); + movq(xnalphabeta, imm_addr64); + vbroadcastss(ynalphabeta, xnalphabeta); + + bool is_single = J.version == 3; + bool is_first = J.version == -1 || J.version == -2; + bool is_last = J.version == +1 || J.version == -2; + + if (is_first || is_single) { + vxorps(xsrc_prev, xsrc_prev, xsrc_prev); + vmovups(ptr[t + 0], xsrc_prev); + } + if (is_last || is_single) { + vxorps(xsrc_next, xsrc_next, xsrc_next); + vmovups(ptr[t + 48], xsrc_next); + } + mov(hw, this->use_h_parallelizm ? J.W : J.H*J.W); + Label lrn_loop; + L(lrn_loop); + { + if (!is_first && !is_single) { + vmovups(xws_prev, ptr[workspace - J.H*J.W * 32 + 16]); + vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); + vmovups(xdiffdst_prev, ptr[diffdst - J.H*J.W * 32 + 16]); + vmulps(xa, xws_prev, xws_prev); + vmulps(xa, xa, xws_prev); + vsqrtps(xa, xa); + vsqrtps(xa, xa); + vmulps(xa, xa, xws_prev); + vdivps(xsrc_prev, xsrc_prev, xa); + vmulps(xdiffdst_prev, xdiffdst_prev, xsrc_prev); + } + + vmovups(ysrc, ptr[src]); + vmovups(yws, ptr[workspace]); + vmovups(ydiffdst, ptr[diffdst]); + vmulps(ya, yws, yws); + vmulps(ya, ya, yws); + vsqrtps(ya, ya); + vsqrtps(ya, ya); + vdivps(ydiffsrc, ydiffdst, ya); + vdivps(ysum, ydiffsrc, yws); + vmulps(ysum, ysum, ysrc); + + if (!is_last && !is_single) { + vmovups(xws_next, ptr[workspace + J.H*J.W * 32]); + vmovups(xsrc_next, ptr[src + J.H*J.W * 32]); + vmovups(xdiffdst_next, ptr[diffdst + J.H*J.W * 32]); + vmulps(xa, xws_next, xws_next); + vmulps(xa, xa, xws_next); + vsqrtps(xa, xa); + vsqrtps(xa, xa); + vmulps(xa, xa, xws_next); + vdivps(xsrc_next, xsrc_next, xa); + vdivps(xsrc_next, xsrc_next, xws_next); + vmulps(xdiffdst_next, xdiffdst_next, xsrc_next); + } + + if (!is_first && !is_single) vmovups(ptr[t + 0], xdiffdst_prev); + vmovups(ptr[t + 16], ysum); + if (!is_last && !is_single) vmovups(ptr[t + 48], xdiffdst_next); + + vmovups(ya, ptr[t + 16 - 8]); + vmovups(yb, ptr[t + 16 - 4]); + vaddps(ysum, ysum, ya); + vmulps(ysrc, ysrc, ynalphabeta); + vaddps(ysum, ysum, yb); + + vmovups(yd, ptr[t + 16 + 4]); + vmovups(ye, ptr[t + 16 + 8]); + vaddps(ysum, ysum, yd); + vaddps(ysum, ysum, ye); + + vfmadd231ps(ydiffsrc, ysum, ysrc); + + vmovups(ptr[diffsrc], ydiffsrc); + + add(src, 32); + add(diffsrc, 32); + add(diffdst, 32); + add(workspace, 32); + + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + } + + add(t, 64); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template struct jit_uni_lrn_fwd_kernel_f32; +template struct jit_uni_lrn_fwd_kernel_f32; +template struct jit_uni_lrn_bwd_kernel_f32; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp new file mode 100644 index 000000000000..2b3ed43cd4b7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp @@ -0,0 +1,183 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_LRN_KERNEL_F32_HPP +#define CPU_JIT_UNI_LRN_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +enum params { VECTOR_LENGTH = 8, MAX_LOCAL_SIZE = 32 }; + +typedef struct { + const float *src; + float *dst, *scratch; +} jit_args_fwd_t; + +typedef struct { + const float *src, *diff_dst, *scratch; + float *diff_src; +} jit_args_bwd_t; + +struct nchw8c_across { + /* version: + * -1: channels 0..7, + * 1: channels C-8 .. C-1, + * 0: other channels + * 3: channels only for this kernel(without prev and next) + */ + int H, W, version; + nchw8c_across(int h, int w, int v) : H(h), W(w), version(v) {} +}; + +struct nchw8c_within { + int H, W, size; + nchw8c_within(int h, int w, int s) : H(h), W(w), size(s) {} +}; + +struct nchw_across { + int C, HW, tail; + nchw_across(int c, int hw, int t) : C(c), HW(hw), tail(t) {} +}; + +struct nhwc_across { + int C; + nhwc_across(int c) : C(c) {} +}; + +template +struct jit_uni_lrn_fwd_kernel_f32 : public jit_generator { + Xbyak::Reg64 src = rax; + Xbyak::Reg64 dst = r8; + Xbyak::Reg64 scratch = rdx; + Xbyak::Reg64 imm_addr64 = rbx; + Xbyak::Reg64 store_addr = rbp; + + Xbyak::Xmm xalpha = xmm0; + Xbyak::Ymm yalpha = ymm0; + Xbyak::Xmm xk = xmm1; + Xbyak::Ymm yk = ymm1; + + float alpha; + float k; + + int stack_space_needed = 11 * 4 * sizeof(float) + 16; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_fwd_kernel_f32) + + /* cpu specific part */ + using Vmm = typename utils::conditional::type; + + jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_within &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr = nullptr, + size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE); + jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); + jit_uni_lrn_fwd_kernel_f32( + const struct nhwc_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); + jit_uni_lrn_fwd_kernel_f32( + struct nchw_across J, + float A, + float K, + prop_kind_t pk, + void* code_ptr = nullptr, + size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE); + + void within_body( + int hoff, int Hoff, int woff, int Woff, int stride, + Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2, + prop_kind_t pk); + void within_body_sse42( + int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk); + + + void nchw_body(int tail, int HW, prop_kind_t pk, + Xbyak::Ymm ymask, + Xbyak::Ymm ya, + Xbyak::Ymm yb, + Xbyak::Ymm yc, + Xbyak::Ymm yd, + Xbyak::Ymm ye, + Xbyak::Ymm ysum); + void nchw_body_sse42(int tail, int HW, prop_kind_t pk, + Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, + Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, + Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi); + void nchw_tail_sse42(int tail, Xbyak::Reg64 reg_dst, + Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi); + + void operator()(jit_args_fwd_t *arg) { ker(arg); } + void(*ker)(jit_args_fwd_t *); +}; + +template +struct jit_uni_lrn_bwd_kernel_f32 : public jit_generator { + Xbyak::Reg64 src = rax; + Xbyak::Reg64 diffsrc = r8; + Xbyak::Reg64 diffdst = r9; + Xbyak::Reg64 workspace = rdx; + Xbyak::Reg64 imm_addr64 = rsi; + + Xbyak::Xmm xnalphabeta = xmm0; + Xbyak::Ymm ynalphabeta = ymm0; + + float nalphabeta; + + int use_h_parallelizm; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_bwd_kernel_f32) + + jit_uni_lrn_bwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float B, + int use_h_parallel, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); + + void operator()(jit_args_bwd_t *arg) { ker(arg); } + void(*ker)(jit_args_bwd_t *); +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp new file mode 100644 index 000000000000..bf8e609d23db --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp @@ -0,0 +1,699 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "cpu_pooling_pd.hpp" + +#include "jit_uni_pool_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; +using namespace alg_kind; + +#define GET_OFF(field) offsetof(jit_pool_call_s, field) + +template +status_t jit_uni_pool_kernel_f32::init_conf(jit_pool_conf_t &jpp, + const pooling_pd_t *ppd) { + const auto &pd = *ppd->desc(); + const memory_desc_wrapper src_d( + ppd->is_fwd() ? ppd->src_md() : ppd->diff_src_md()); + const memory_desc_wrapper dst_d( + ppd->is_fwd() ? ppd->dst_md() : ppd->diff_dst_md()); + + bool args_ok = true + && mayiuse(isa) + && utils::one_of(pd.alg_kind, pooling_max, + pooling_avg_include_padding, + pooling_avg_exclude_padding); + if (!args_ok) return status::unimplemented; + + const int simd_w = isa == avx512_common ? 16 : 8; + const int ndims = src_d.ndims(); + + jpp.ndims = ndims; + jpp.mb = src_d.dims()[0]; + + jpp.c = utils::rnd_up(src_d.dims()[1], simd_w); + if (jpp.c > src_d.padded_dims()[1]) + return status::unimplemented; + + jpp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jpp.ih = src_d.dims()[ndims-2]; + jpp.iw = src_d.dims()[ndims-1]; + jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jpp.oh = dst_d.dims()[ndims-2]; + jpp.ow = dst_d.dims()[ndims-1]; + + jpp.stride_d = (ndims == 5 ) ? pd.strides[0] : 1; + jpp.stride_h = pd.strides[ndims-4]; + jpp.stride_w = pd.strides[ndims-3]; + jpp.kd = (ndims == 5) ? pd.kernel[0] : 1; + jpp.kh = pd.kernel[ndims-4]; + jpp.kw = pd.kernel[ndims-3]; + + jpp.f_pad = (ndims == 5 ) ? pd.padding[0][0] : 0; + jpp.t_pad = pd.padding[0][ndims-4]; + jpp.l_pad = pd.padding[0][ndims-3]; + + jpp.alg = pd.alg_kind; + + jpp.is_training = pd.prop_kind == prop_kind::forward_training; + jpp.is_backward = pd.prop_kind == prop_kind::backward_data; + jpp.ind_dt = ppd->workspace_md() + ? ppd->workspace_md()->data_type : data_type::undef; + + jpp.simple_alg = jpp.is_training + || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d); + + jpp.c_block = simd_w; + + jpp.nb_c = jpp.c / jpp.c_block; + if (jpp.alg == pooling_max) { + jpp.ur_w = isa == avx512_common ? 16 : 4; + if (jpp.is_training) + jpp.ur_w = isa == avx512_common ? 9 : 3; + else if (jpp.is_backward) + jpp.ur_w = isa == avx512_common ? 6 : 3; + } else { + if (jpp.is_backward) + jpp.ur_w = isa == avx512_common ? 12 : 6; + else + jpp.ur_w = isa == avx512_common ? 24 : 12; + } + if (jpp.ow < jpp.ur_w) jpp.ur_w = jpp.ow; + if (jpp.l_pad > jpp.ur_w) return status::unimplemented; + + jpp.ur_w_tail = jpp.ow % jpp.ur_w; + + return status::success; +} + +template +inline void jit_uni_pool_kernel_f32::maybe_recalculate_divisor(int jj, + int ur_w, int pad_l, int pad_r) { + if (jpp.alg == pooling_avg_exclude_padding) { + int kw = jpp.kw; + int stride_w = jpp.stride_w; + + int non_zero_kw = kw; + non_zero_kw -= nstl::max(0, pad_l - jj*stride_w); + non_zero_kw -= nstl::max(0, pad_r - (ur_w - 1 - jj)*stride_w); + + if (non_zero_kw != prev_kw) { + mov(tmp_gpr, float2int((float)non_zero_kw)); + movq(xmm_tmp, tmp_gpr); + uni_vbroadcastss(vmm_tmp, xmm_tmp); + uni_vmulps(vmm_tmp, vmm_tmp, vmm_ker_area_h); + prev_kw = non_zero_kw; + } + } +} + +template +inline void jit_uni_pool_kernel_f32::avg_step(int ur_w, int pad_l, + int pad_r) { + + int iw = jpp.iw; + int kw = jpp.kw; + int stride_w = jpp.stride_w; + int c_block = jpp.c_block; + Label kd_label, kh_label; + + for (int jj = 0; jj < ur_w; jj++) { + if (jpp.is_backward) { + uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]); + maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r); + uni_vdivps(vreg(jj), vreg(jj), vmm_tmp); + } else { + uni_vpxor(vreg(jj), vreg(jj), vreg(jj)); + } + } + + if (jpp.simple_alg && jpp.ndims == 5) { + push(reg_input); + push(reg_output); + mov(aux_reg_input_d, reg_input); + mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); + L(kd_label); + mov(aux_reg_input, aux_reg_input_d); + } else { + mov(aux_reg_input, reg_input); + } + + xor_(kj, kj); + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, pad_l - ki); + int jj_end = ur_w + - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; + if (aux_input_offset > iw * c_block) + continue; + int input_offset = sizeof(float)*aux_input_offset; + if (jpp.is_backward) { + uni_vmovups(vreg(ur_w+jj), + ptr[aux_reg_input + input_offset]); + uni_vaddps(vreg(ur_w+jj), vreg(ur_w+jj), vreg(jj)); + uni_vmovups(vmmword[aux_reg_input + input_offset], + vreg(ur_w+jj)); + } else { + uni_vaddps(vreg(jj), vreg(jj), + ptr[aux_reg_input + input_offset]); + } + } + } + add(aux_reg_input, sizeof(float) * iw * c_block); + inc(kj); + cmp(kj, reg_kh); + jl(kh_label, T_NEAR); + } + + if (jpp.simple_alg && jpp.ndims == 5) + { + add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + pop(reg_output); + pop(reg_input); + } + + if (!jpp.is_backward) { + for (int jj = 0; jj < ur_w; jj++) { + maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r); + uni_vdivps(vreg(jj), vreg(jj), vmm_tmp); + uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block], + vreg(jj)); + } + } +} + +template +inline void jit_uni_pool_kernel_f32::max_step_fwd(int ur_w, int pad_l, + int pad_r) { + int iw = jpp.iw; + int kw = jpp.kw; + int stride_w = jpp.stride_w; + int c_block = jpp.c_block; + Label kd_label, kh_label; + + mov(tmp_gpr, float2int(nstl::numeric_limits::lowest())); + movq(xmm_tmp, tmp_gpr); + uni_vbroadcastss(vmm_tmp, xmm_tmp); + + for (int jj = 0; jj < ur_w; jj++) { + uni_vmovups(vreg(jj), vmm_tmp); + if (jpp.is_training) + uni_vpxor(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(2*ur_w+jj)); + } + if (jpp.is_training) + { + movq(xmm_tmp, reg_k_shift); + uni_vpbroadcastd(vmm_k_offset, xmm_tmp); + } + + if (jpp.ndims == 5) { + push(reg_input); + push(reg_output); + mov(aux_reg_input_d, reg_input); + mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); + L(kd_label); + mov(aux_reg_input, aux_reg_input_d); + } else { + mov(aux_reg_input, reg_input); + } + xor_(kj, kj); + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, pad_l - ki); + int jj_end = ur_w + - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; + if (aux_input_offset > iw * c_block) + continue; + int input_offset = sizeof(float)*aux_input_offset; + uni_vmovups(vreg(ur_w+jj), ptr[aux_reg_input + input_offset]); + if (isa == sse42) { + movups(vmm_mask, vreg(jj)); + cmpps(vmm_mask, vreg(ur_w+jj), _cmp_lt_os); + blendvps(vreg(jj), vreg(ur_w+jj)); + if (jpp.is_training) + blendvps(vreg(2*ur_w+jj), vmm_k_offset); + } else if (isa == avx) { + vcmpps(vreg(3*ur_w+jj), vreg(jj), vreg(ur_w+jj), + _cmp_lt_os); + vblendvps(vreg(jj), vreg(jj), vreg(ur_w+jj), + vreg(3*ur_w+jj)); + if (jpp.is_training) + vblendvps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), + vmm_k_offset, vreg(3*ur_w+jj)); + } else { + vcmpps(k_store_mask, vreg(jj), vreg(ur_w+jj), _cmp_lt_os); + vblendmps(vreg(jj) | k_store_mask, vreg(jj), vreg(ur_w+jj)); + if (jpp.is_training) + vblendmps(vreg(2*ur_w+jj) | k_store_mask, + vreg(2*ur_w+jj), vmm_k_offset); + } + } + if (jpp.is_training) { + if (isa == avx && !mayiuse(avx2)) { + avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one); + } + } + } + add(aux_reg_input, sizeof(float) * iw * c_block); + inc(kj); + cmp(kj, reg_kh); + jl(kh_label, T_NEAR); + } + + if (jpp.ndims == 5) + { + add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); + if (jpp.is_training) { + mov(tmp_gpr, ptr[reg_param + GET_OFF(kd_padding_shift)]); + movq(xmm_tmp, tmp_gpr); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + if (isa == avx && !mayiuse(avx2)) { + Xmm t(vmm_mask.getIdx()); + avx_vpadd1(vmm_k_offset, xmm_tmp, t); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp); + } + } + + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + pop(reg_output); + pop(reg_input); + } + + for (int jj = 0; jj < ur_w; jj++) { + uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block], vreg(jj)); + if (jpp.is_training) { + const size_t step_index + = jj * c_block * types::data_type_size(jpp.ind_dt); + + auto x = xreg(2 * ur_w + jj); + if (jpp.ind_dt == data_type::u8) { + if (isa == sse42) { + for (int i = 0; i < 4; ++i) + pextrb(ptr[reg_index + step_index + i], x, 4*i); + } else if (isa == avx) { + auto y = yreg(2 * ur_w + jj); + if (jj == 0) { + movd(xmm_tmp, reg_shuf_mask); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + } + if (mayiuse(avx2)) { + vpshufb(y, y, vmm_tmp); + movd(ptr[reg_index + step_index], x); + vperm2i128(y, y, y, 0x1u); + movd(ptr[reg_index + step_index + 4], x); + } else { + Xmm t(vmm_mask.getIdx()); + vextractf128(t, y, 0); + vpshufb(t, t, xmm_tmp); + movd(ptr[reg_index + step_index], t); + vextractf128(t, y, 1); + vpshufb(t, t, xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0] + movd(ptr[reg_index + step_index + 4], t); + } + } else { + auto v = vreg(2 * ur_w + jj); + vpmovusdb(x, v); + vmovups(ptr[reg_index + step_index], v | k_index_mask); + } + } else { + uni_vmovups(ptr[reg_index + step_index], vreg(2*ur_w+jj)); + } + } + } +} + +template +inline void jit_uni_pool_kernel_f32::max_step_bwd(int ur_w, int pad_l, + int pad_r) { + + int iw = jpp.iw; + int kw = jpp.kw; + int stride_w = jpp.stride_w; + int c_block = jpp.c_block; + Label kd_label, kh_label; + + for (int jj = 0; jj < ur_w; jj++) { + uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]); + + const size_t step_index + = jj * c_block * types::data_type_size(jpp.ind_dt); + if (jpp.ind_dt == data_type::u8) { + if (isa == sse42) { + movd(xreg(ur_w+jj), ptr[reg_index + step_index]); + pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); + } else if (isa == avx) { + movq(xreg(ur_w+jj), ptr[reg_index + step_index]); + if (!mayiuse(avx2)) { + avx_pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj), xmm_tmp); + } else { + vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); + } + } else { + vmovups(vreg(ur_w+jj) | k_index_mask, + ptr[reg_index + step_index]); + vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); + } + } else { + uni_vmovups(vreg(ur_w+jj), ptr[reg_index + step_index]); + } + } + movq(xmm_tmp, reg_k_shift); + uni_vpbroadcastd(vmm_k_offset, xmm_tmp); + + if (jpp.simple_alg && jpp.ndims == 5) { + push(reg_input); + push(reg_output); + if (isa == sse42) { + // Save rdi since it is used in maskmovdqu + assert(dst_ptr == rdi); + push(dst_ptr); + } + mov(aux_reg_input_d, reg_input); + mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); + mov(reg_kd_pad_shift, ptr[reg_param + GET_OFF(kd_padding_shift)]); + L(kd_label); + mov(aux_reg_input, aux_reg_input_d); + } else { + mov(aux_reg_input, reg_input); + } + + xor_(kj, kj); + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, pad_l - ki); + int jj_end = ur_w + - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; + if (aux_input_offset > iw * c_block) + continue; + int input_offset = sizeof(float)*aux_input_offset; + uni_vmovups(vreg(2*ur_w+jj), ptr[aux_reg_input + input_offset]); + if (isa == sse42) { + mov(dst_ptr, aux_reg_input); + add(dst_ptr, input_offset); + + movups(vreg(3*ur_w+jj), vreg(ur_w+jj)); + pcmpeqd(vreg(3*ur_w+jj), vmm_k_offset); + addps(vreg(2*ur_w+jj), vreg(jj)); + maskmovdqu(vreg(2*ur_w+jj), vreg(3*ur_w+jj)); + } else if (isa == avx) { + if (mayiuse(avx2)) { + vpcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset); + } else { + avx_pcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset, xmm_tmp); + } + vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(jj)); + vmaskmovps(vmmword[aux_reg_input + input_offset], + vreg(3*ur_w+jj), vreg(2*ur_w+jj)); + } else { + vpcmpeqd(k_store_mask, vreg(ur_w+jj), vmm_k_offset); + vblendmps(vmm_tmp | k_store_mask | T_z, vreg(jj), vreg(jj)); + vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vmm_tmp); + vmovups(vmmword[aux_reg_input + + sizeof(float)*aux_input_offset], vreg(2*ur_w+jj)); + } + } + if (isa == avx && !mayiuse(avx2)) { + avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one); + } + } + add(aux_reg_input, sizeof(float) * iw * c_block); + inc(kj); + cmp(kj, reg_kh); + jl(kh_label, T_NEAR); + } + if (jpp.simple_alg && jpp.ndims == 5) + { + add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); + + mov(tmp_gpr, reg_kd_pad_shift); + movq(xmm_tmp, tmp_gpr); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + if (isa == avx && !mayiuse(avx2)) { + Xmm t(vmm_mask.getIdx()); + avx_vpadd1(vmm_k_offset, vmm_tmp, t); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp); + } + + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + if (isa == sse42) { + // Save rdi since it is used in maskmovdqu + assert(dst_ptr == rdi); + pop(dst_ptr); + } + pop(reg_output); + pop(reg_input); + } +} + +template +void jit_uni_pool_kernel_f32::maybe_zero_diff_src() { + assert(jpp.c_block * sizeof(float) % cpu_isa_traits::vlen == 0); + Label l_skip, l_zero; + + auto reg_oh = tmp_gpr; + mov(reg_oh, ptr[reg_param + GET_OFF(oh)]); + cmp(reg_oh, 0); + jz(l_skip, T_NEAR); + + if (jpp.ndims == 5) { + mov(zero_size, ptr[reg_param + GET_OFF(oh)]); + mov(tmp_gpr, jpp.ih * jpp.iw * jpp.c_block * sizeof(float)); + imul(zero_size, tmp_gpr); + } + + auto vzero = vmm_tmp; + uni_vpxor(vzero, vzero, vzero); + + auto reg_off = tmp_gpr; + xor_(reg_off, reg_off); + + L(l_zero); + { + const int dim = jpp.iw * jpp.c_block * sizeof(float); + for (int i = 0; i < dim; i += cpu_isa_traits::vlen) + uni_vmovups(ptr[reg_input + reg_off + i], vzero); + add(reg_off, dim); + if (jpp.ndims == 5) cmp(reg_off, zero_size); + else cmp(reg_off, jpp.ih * dim); + jl(l_zero, T_NEAR); + } + + L(l_skip); +} + +template +void jit_uni_pool_kernel_f32::generate() { + + this->preamble(); + + int ow = jpp.ow; + int iw = jpp.iw; + int kw = jpp.kw; + int kh = jpp.kh; + int ur_w = jpp.ur_w; + int c_block = jpp.c_block; + int stride_w = jpp.stride_w; + int l_pad = jpp.l_pad; + int ur_w_tail = jpp.ur_w_tail; + + int n_oi = ow / ur_w; + + prev_kw = 0; + + int vlen = cpu_isa_traits::vlen; + +#if defined(_WIN32) + // Always mimic the Unix ABI (see the note about maskmovdqu in the header + // file). + xor_(rdi, rcx); + xor_(rcx, rdi); + xor_(rdi, rcx); +#endif + + mov(reg_input, ptr[reg_param + GET_OFF(src)]); + mov(reg_output, ptr[reg_param + GET_OFF(dst)]); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + mov(reg_index, ptr[reg_param + GET_OFF(indices)]); + mov(reg_kh, ptr[reg_param + GET_OFF(kh_padding)]); + mov(reg_k_shift, ptr[reg_param + GET_OFF(kh_padding_shift)]); + mov(reg_ker_area_h, ptr[reg_param + GET_OFF(ker_area_h)]); + + if (jpp.is_backward) + maybe_zero_diff_src(); + + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) { + mov(tmp_gpr, 1); + movq(xmm_one, tmp_gpr); + uni_vpbroadcastd(vmm_one, xmm_one); + + if (isa == avx) { + mov(reg_shuf_mask, 0x0c080400); + } else if (isa >= avx512_common) { + mov(tmp_gpr.cvt32(), 0x000f); + kmovw(k_index_mask, tmp_gpr.cvt32()); + } + } + + int r_pad = nstl::max(0, ((ow-1)*stride_w) + kw - 1 - (iw + l_pad - 1)); + int r_pad1 = (ur_w*n_oi - 1)*stride_w + kw - 1 - (iw + l_pad - 1); + if (r_pad1 > 0) n_oi--; + + if (jpp.alg == pooling_avg_exclude_padding) { + movq(xmm_ker_area_h, reg_ker_area_h); + uni_vpbroadcastd(vmm_ker_area_h, xmm_ker_area_h); + } + + if (jpp.alg == pooling_avg_include_padding) { + mov(tmp_gpr, float2int((float)(kw * kh * jpp.kd))); + movq(xmm_tmp, tmp_gpr); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + } + if (l_pad > 0) { + n_oi--; + if (n_oi < 0 && r_pad1 > 0) { + step(ur_w, l_pad, r_pad1); + } else { + step(ur_w, l_pad, 0); + } + + if (isa == sse42) { + if (n_oi < 0 && r_pad1 > 0) { + step_high_half(ur_w, l_pad, r_pad1); + } else { + step_high_half(ur_w, l_pad, 0); + } + } + + if (isa == sse42) { + add(reg_input, sizeof(float)*(ur_w*stride_w-l_pad)*c_block - vlen); + add(reg_output, sizeof(float)*ur_w*c_block - vlen); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, (2 * ur_w - 1) * c_block / 2 + * types::data_type_size(jpp.ind_dt)); + } else { + add(reg_input, sizeof(float)*(ur_w*stride_w - l_pad)*c_block); + add(reg_output, sizeof(float)*ur_w*c_block); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, ur_w * c_block + * types::data_type_size(jpp.ind_dt)); + } + } + + xor_(oi_iter, oi_iter); + if (n_oi > 0) { + Label ow_loop; + L(ow_loop); { + step(ur_w, 0, 0); + + if (isa == sse42) { + step_high_half(ur_w, 0, 0); + } + + if (isa == sse42) { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen); + add(reg_output, sizeof(float)*ur_w*c_block - vlen); + if (jpp.alg == pooling_max && + (jpp.is_training || jpp.is_backward)) + add(reg_index, (2 * ur_w - 1) * c_block / 2 + * types::data_type_size(jpp.ind_dt)); + } else { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block); + add(reg_output, sizeof(float)*ur_w*c_block); + if (jpp.alg == pooling_max && + (jpp.is_training || jpp.is_backward)) + add(reg_index, ur_w * c_block + * types::data_type_size(jpp.ind_dt)); + } + + inc(oi_iter); + cmp(oi_iter, n_oi); + jl(ow_loop, T_NEAR); + } + } + + if (r_pad1 > 0 && n_oi >= 0) { + step(ur_w, 0, r_pad1); + + if (isa == sse42) { + step_high_half(ur_w, 0, r_pad1); + } + + if (isa == sse42) { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen); + add(reg_output, sizeof(float)*ur_w*c_block - vlen); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, (2 * ur_w - 1) * c_block / 2 + * types::data_type_size(jpp.ind_dt)); + } else { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block); + add(reg_output, sizeof(float)*ur_w*c_block); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, ur_w * c_block + * types::data_type_size(jpp.ind_dt)); + } + } + + if (ur_w_tail != 0) { + step(ur_w_tail, 0, r_pad); + + if (isa == sse42) { + step_high_half(ur_w_tail, 0, r_pad); + } + } + + this->postamble(); +} + +template struct jit_uni_pool_kernel_f32; +template struct jit_uni_pool_kernel_f32; // implements both and +template struct jit_uni_pool_kernel_f32; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp new file mode 100644 index 000000000000..992b526587fd --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp @@ -0,0 +1,192 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_POOL_KERNEL_F32_HPP +#define JIT_UNI_POOL_KERNEL_F32_HPP + +#include + +#include "c_types_map.hpp" +#include "pooling_pd.hpp" +#include "type_helpers.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +template +struct jit_uni_pool_kernel_f32: public jit_generator { + jit_uni_pool_kernel_f32(jit_pool_conf_t ajpp): jpp(ajpp) + { + this->generate(); + jit_ker = (decltype(jit_ker))this->getCode(); + } + + jit_pool_conf_t jpp; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel_f32) + + void operator()(jit_pool_call_s *arg) { jit_ker(arg); } + static status_t init_conf(jit_pool_conf_t &jbp, const pooling_pd_t *ppd); + +private: + using Vmm = typename utils::conditional3::type; + Xmm xreg(int idx) { return Xmm((isa == avx512_common ? 31 : 15) - idx); } + Ymm yreg(int idx) { return Ymm(xreg(idx).getIdx()); } + Vmm vreg(int idx) { return Vmm(xreg(idx).getIdx()); } + + const AddressFrame &vmmword = (isa == sse42) ? xword : + (isa == avx) ? yword : zword; + + Xmm vmm_mask = Xmm(0); + Xmm xmm_ker_area_h = Xmm(2); + Xmm xmm_one = Xmm(2); + Xmm xmm_tmp = Xmm(3); + + Vmm vmm_ker_area_h = Vmm(2); + Vmm vmm_one = Vmm(2); + Vmm vmm_tmp = Vmm(3); + + Vmm vmm_k_offset = Vmm(1); + + Opmask k_index_mask = Opmask(6); + Opmask k_store_mask = Opmask(7); + + // Here be some (tame) dragons. This kernel does not follow the regular + // OS-agnostic ABI pattern because when isa is sse42 it uses maskmovdqu + // instruction which has its destination hardcoded in rdi. Therefore: + // - all registers are hardcoded + // - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI + // + // While this is only required by the backward pass, the quirk above + // is applied to the forward pass as well to keep things simpler. + + using reg64_t = const Xbyak::Reg64; + reg64_t reg_param = rdi; // Always mimic the Unix ABI + reg64_t reg_input = r8; + reg64_t aux_reg_input = r9; + reg64_t reg_index = r10; + reg64_t reg_output = r12; + reg64_t reg_kd_pad_shift = r13; + reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu + + reg64_t kj = r14; + reg64_t oi_iter = r15; + reg64_t reg_kh = rax; + reg64_t reg_k_shift = rbx; + reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above + reg64_t reg_ker_area_h = rdx; + + reg64_t zero_size = r15; + reg64_t ki = r12; + reg64_t aux_reg_input_d = r8; + + Xbyak::Reg32 reg_shuf_mask = esi; + + int prev_kw; + void (*jit_ker)(jit_pool_call_s *); + + void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r); + void avg_step(int ur_w, int pad_l, int pad_r); + void max_step_fwd(int ur_w, int pad_l, int pad_r); + void max_step_bwd(int ur_w, int pad_l, int pad_r); + + void maybe_zero_diff_src(); + + void step(int ur_w, int pad_l, int pad_r) { + if (jpp.alg == alg_kind::pooling_max) { + if(jpp.is_backward) + max_step_bwd(ur_w, pad_l, pad_r); + else + max_step_fwd(ur_w, pad_l, pad_r); + } + else + avg_step(ur_w, pad_l, pad_r); + } + + void step_high_half(int ur_w, int pad_l, int pad_r) { + add(reg_input, sizeof(float) * 4); + add(reg_output, sizeof(float) * 4); + if (jpp.alg == alg_kind::pooling_max && + (jpp.is_training || jpp.is_backward)) + add(reg_index, types::data_type_size(jpp.ind_dt) * 4); + + step(ur_w, pad_l, pad_r); + } + + void generate(); + + void avx_vpadd1(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { + assert(y0.getIdx() != x1.getIdx()); + vextractf128(xtmp, y0, 0); + vpaddd(xtmp, xtmp, x1); + vinsertf128(y0, y0, xtmp, 0); + vextractf128(xtmp, y0, 1); + vpaddd(xtmp, xtmp, x1); + vinsertf128(y0, y0, xtmp, 1); + } + + void avx_vpadd1(const Xmm& x0, const Xmm& x1, const Xmm&) { + assert(false /*function should not be used*/); + paddd(x0, x1); + } + + void avx_pmovzxbd(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { + Xmm x0(y0.getIdx()); + pshufd(xmm_tmp, x1, 1); + pmovzxbd(x0, x1); + pmovzxbd(xmm_tmp, xmm_tmp); + vinsertf128(y0, y0, xmm_tmp, 1); + } + + void avx_pmovzxbd(const Xmm& x0, const Xmm& x1, const Xmm&) { + assert(false /*function should not be used*/); + pmovzxbd(x0, x1); + } + + void avx_pcmpeqd(const Ymm& y0, const Ymm& y1, const Ymm& y2, const Xmm& xtmp) { + assert(y0.getIdx() != y1.getIdx()); + assert(y0.getIdx() != y2.getIdx()); + Xmm x0(y0.getIdx()); + Xmm x2(y2.getIdx()); + vextractf128(x0, y1, 1); + vextractf128(xtmp, y2, 1); + pcmpeqd(xtmp, x0); + vextractf128(x0, y1, 0); + pcmpeqd(x0, x2); + vinsertf128(y0, y0, xtmp, 1); + } + + void avx_pcmpeqd(const Xmm& x0, const Xmm& x1, const Xmm&, const Xmm&) { + assert(false /*function should not be used*/); + pcmpeqd(x0, x1); + } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp new file mode 100644 index 000000000000..afbcf996d8ac --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp @@ -0,0 +1,264 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "nstl.hpp" + +#include "jit_uni_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, + data_t *dst, char *indices) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int oh) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &src[src_d.blk_off(n, b_c, ih)]; + arg.dst = &dst[dst_d.blk_off(n, b_c, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = oh == 0; + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)); + (*kernel_)(&arg); + }; + + parallel_nd(jpp.mb, jpp.nb_c, jpp.oh, + [&](int n, int b_c, int oh) { + ker(n, b_c, oh); + }); +} + +template +void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, + data_t *dst, char *indices) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow, + int d_b_overflow) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &src[src_d.blk_off(n, b_c, id, ih)]; + arg.dst = &dst[dst_d.blk_off(n, b_c, od, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, od, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = (oh + od == 0); + arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow; + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh; + arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd - + nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) - + nstl::max(0, jpp.f_pad - od*jpp.stride_d)); + + + (*kernel_)(&arg); + }; + + parallel_nd(jpp.mb, jpp.nb_c, jpp.od, + [&](int n, int b_c, int od) { + const int ik = od * jpp.stride_d; + const int d_t_overflow = nstl::max(0, jpp.f_pad-ik); + const int d_b_overflow = nstl::max(jpp.id, ik+jpp.kd-jpp.f_pad) + -jpp.id; + const int id = nstl::max(ik - jpp.f_pad, 0); + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow); + } + }); +} + +template +void jit_uni_pooling_bwd_t::execute_backward(const data_t *diff_dst, + const char *indices, data_t *diff_src) const { + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int oh) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &diff_src[diff_src_d.blk_off(n, b_c, ih)]; + arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = (oh == 0); + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)); + + (*kernel_)(&arg); + }; + + parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) { + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, oh); + } + }); +} + +template +void jit_uni_pooling_bwd_t::execute_backward_3d(const data_t *diff_dst, + const char *indices, data_t *diff_src) const { + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow, + int d_b_overflow, int zero_size, int kd) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &diff_src[diff_src_d.blk_off(n, b_c, id + kd, ih)]; + arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, od, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, od, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = zero_size; + arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow; + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh + + kd * jpp.kw * jpp.kh; + arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd - + nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) - + nstl::max(0, jpp.f_pad - od*jpp.stride_d)); + + (*kernel_)(&arg); + }; + + if (jpp.simple_alg) { + + parallel_nd(jpp.mb, jpp.nb_c, jpp.od, + [&](int n, int b_c, int od) { + const int ik = od * jpp.stride_d; + const int d_t_overflow = nstl::max(0, jpp.f_pad - ik); + const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd + - jpp.f_pad) - jpp.id; + const int id = nstl::max(ik - jpp.f_pad, 0); + int zero_s = jpp.stride_d - d_t_overflow - (nstl::max( + jpp.id, ik + jpp.stride_d - jpp.f_pad) - jpp.id); + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, + (oh == 0) ? zero_s : 0, 0); + } + }); + } else { + ptrdiff_t nelems = (ptrdiff_t)jpp.mb * (ptrdiff_t)jpp.c + * (ptrdiff_t)jpp.id * (ptrdiff_t)jpp.ih * (ptrdiff_t)jpp.iw; + + parallel_nd(nelems, [&](ptrdiff_t i) { diff_src[i] = 0.f; }); + + for (int kd = 0; kd < jpp.kd; ++kd) { + parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) { + for (int od = 0; od < jpp.od; ++od) { + const int ik = od * jpp.stride_d; + const int d_t_overflow = nstl::max(0, jpp.f_pad-ik); + const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd + - jpp.f_pad) - jpp.id; + if (kd >= jpp.kd - d_t_overflow - d_b_overflow) + continue; + const int id = nstl::max(ik - jpp.f_pad, 0); + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, + 0, kd); + } + } + }); + } + } +} + + +template struct jit_uni_pooling_fwd_t; +template struct jit_uni_pooling_bwd_t; +template struct jit_uni_pooling_fwd_t; +template struct jit_uni_pooling_bwd_t; +template struct jit_uni_pooling_fwd_t; +template struct jit_uni_pooling_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp new file mode 100644 index 000000000000..57bebacdeef3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp @@ -0,0 +1,182 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_POOLING_HPP +#define CPU_JIT_UNI_POOLING_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_uni_pool_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_uni_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_pooling_fwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type::f32, + src_md()->data_type, + dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), desired_fmt_tag()) + && memory_desc_matches_tag(*dst_md(), desired_fmt_tag()); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return jit_uni_pool_kernel_f32::init_conf(jpp_, this); + } + + format_tag_t desired_fmt_tag() { + using namespace format_tag; + return ndims() == 4 + ? isa == avx512_common ? nChw16c : nChw8c + : isa == avx512_common ? nCdhw16c : nCdhw8c; + } + + jit_pool_conf_t jpp_; + }; + + jit_uni_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_pool_kernel_f32(pd()->jpp_); } + + ~jit_uni_pooling_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE); + + if (pd()->ndims() == 5) + execute_forward_3d(src, dst, ws); + else + execute_forward(src, dst, ws); + + return status::success; + } + +private: + void execute_forward(const data_t *src, data_t *dst, char *indices) const; + void execute_forward_3d(const data_t *src, data_t *dst, + char *indices) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_pool_kernel_f32 *kernel_; +}; + +template +struct jit_uni_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_pooling_bwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type::f32, + diff_src_md()->data_type, + diff_dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag()) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag()); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + init_default_ws(); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return jit_uni_pool_kernel_f32::init_conf(jpp_, this); + } + + format_tag_t desired_fmt_tag() { + using namespace format_tag; + return ndims() + ? isa == avx512_common ? nChw16c : nChw8c + : isa == avx512_common ? nCdhw16c : nCdhw8c; + } + + jit_pool_conf_t jpp_; + }; + + jit_uni_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_pool_kernel_f32(pd()->jpp_); } + + ~jit_uni_pooling_bwd_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + if (pd()->ndims() == 5) + execute_backward_3d(diff_dst, ws, diff_src); + else + execute_backward(diff_dst, ws, diff_src); + + return status::success; + } + +private: + void execute_backward(const data_t *diff_dst, const char *indices, + data_t *diff_src) const; + void execute_backward_3d(const data_t *diff_dst, const char *indices, + data_t *diff_src) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_pool_kernel_f32 *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp new file mode 100644 index 000000000000..98796503b793 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp @@ -0,0 +1,1006 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "mkldnn_debug.h" +#include "nstl.hpp" +#include "type_helpers.hpp" + +#include "cpu_primitive.hpp" +#include "cpu_reorder_pd.hpp" +#include "jit_uni_reorder.hpp" + +#include "jit_generator.hpp" + +// #define TR_DEBUG +#if defined(TR_DEBUG) +#define DEBUg(...) do { __VA_ARGS__ } while (0) +#else +#define DEBUg(...) +#endif +#define DEBUG(...) DEBUg(__VA_ARGS__) + +#ifdef _WIN32 +/* seems like s_addr is a reserved macro on Windows */ +#undef s_addr +#endif + +using namespace Xbyak; +using namespace mkldnn::impl::types; + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace tr { + +/** Minimal reasonable/desirable kernel size. + * The constant might be used to determine how a problem should be split + * between kernel and threading driver. */ +const size_t ker_prb_size_min = 64; + +/* kernel */ +struct jit_uni_reorder_kernel_f32: public kernel_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32) + + enum { + len_unroll_max = 256, + ndims_jit_loop_max = 3, + }; + + struct simple_impl_desc_t { + int ndims_full_unroll; + int len_last_dim_unroll; + int len_unroll; + }; + + static bool simple_impl_desc_init(const prb_t &prb, + simple_impl_desc_t *desc) { + const int ndims = prb.ndims; + + int ndims_full_unroll = 0; + int len_last_dim_unroll = 1; + int len_unroll = 1; + + for (int d = 0; d < ndims; ++d) { + auto &node = prb.nodes[d]; + if (len_unroll * node.n <= len_unroll_max) { + ndims_full_unroll++; + len_unroll *= node.n; + } else { + len_last_dim_unroll = len_unroll_max / len_unroll; + while (node.n % len_last_dim_unroll) + --len_last_dim_unroll; + len_unroll *= len_last_dim_unroll; + break; + } + } + + if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max) + return false; + + if (desc) { + desc->ndims_full_unroll = ndims_full_unroll; + desc->len_last_dim_unroll = len_last_dim_unroll; + desc->len_unroll = len_unroll; + } + + return true; + } + + static bool applicable(const prb_t &p) { + using namespace data_type; + + bool ok = true + && p.ndims > 0 + && utils::one_of(p.itype, f32, s32, s8, u8) + && utils::one_of(p.otype, f32, s32, s8, u8) + && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ + && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ + && simple_impl_desc_init(p, nullptr) + && mayiuse(sse42) + && IMPLICATION(!utils::everyone_is(f32, p.itype, p.otype), + mayiuse(avx)); + if (!ok) return false; + + const ptrdiff_t max_stride = (1LL<<31) - 1; + for (int d = 0; d < p.ndims; ++d) { + const ptrdiff_t cms = max_stride / p.nodes[d].n; + bool strides_ok = true + && p.nodes[d].is < cms / (int)data_type_size(p.itype) + && p.nodes[d].os < cms / (int)data_type_size(p.otype); + if (!strides_ok) return false; + } + + return true; + } + + int n(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].n; } + int is(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].is; } + int os(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].os; } + int ss(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].ss; } + + Address i_addr(int i_off) + { return ptr[reg_ptr_in + reg_off_in + i_off * itype_sz]; } + + Address o_addr(int o_off) + { return ptr[reg_ptr_out + reg_off_out + o_off * otype_sz]; } + + Address s_addr(int s_off) + { return ptr[reg_ptr_scale + reg_off_scale + s_off * stype_sz]; } + + void step(int off, int prev_i_off, int prev_o_off, int prev_s_off, + int &i_off, int &o_off, int &s_off, int step_size = 1) { + i_off = prev_i_off; + o_off = prev_o_off; + s_off = prev_s_off; + + if (off == 0) return; + + int start_dim = 0, dims_prod = 1; + for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim) + dims_prod *= n(start_dim); + assert(start_dim < prb_.ndims); + off /= step_size; + + for (int d = start_dim; d < prb_.ndims; ++d) { + i_off += is(d); + o_off += os(d); + s_off += ss(d); + + if (off % n(d)) break; + + i_off += - n(d) * is(d); + o_off += - n(d) * os(d); + s_off += - n(d) * ss(d); + off /= n(d); + + if (off == 0) break; /* FIXME: is it really required? */ + } + } + + void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off, + int step_size = 1) { + int dummy = 0; + step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy, + step_size); + } + + void tr8x8_avx2(int i_off, int o_off) { + for (int i = 0; i < 8; i++) + vmovups(Ymm(i), i_addr(i_off + i * 8)); + + for (int i = 0; i < 8 / 2; i++) { + vunpcklps(Ymm(8 + i), Ymm(2 * i), Ymm(2 * i + 1)); + vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1)); + } + + const unsigned int lfloat = 0x44; + const unsigned int ufloat = 0xee; + for (int i = 0; i < 8 / 2; i++) { + int j = i % 2 == 0 ? 8 + i : i - 1; + vshufps(Ymm(8 / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat); + vshufps(Ymm(8 / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat); + } + + const unsigned int lquad = 0x20; + for (int i = 0; i < 8 / 2; i++) + vperm2f128(Ymm(i), Ymm(8 / 2 + i), Ymm(8 + i), lquad); + + const unsigned int uquad = 0x31; + for (int i = 8 / 2; i < 8; i++) + vperm2f128(Ymm(i), Ymm(i), Ymm(8 / 2 + i), uquad); + + for (int i = 0; i < 8; i++) + vmovups(o_addr(o_off + i * 8), Ymm(i)); + } + + bool process_unroll_tr8x8(int len) { + bool can_do = true + && mayiuse(avx2) + && prb_.ndims >= 2 + && utils::everyone_is(4, itype_sz, otype_sz) + && utils::everyone_is(8, n(0), n(1)) + && utils::everyone_is(1, os(0), is(1)) + && utils::everyone_is(8, os(1), is(0)) + && prb_.scale_type == scale_type_t::NONE + && prb_.beta == 0.f; + if (!can_do) return false; + + const int step_size = n(0) * n(1); + int i_off = 0, o_off = 0; + for (int off = 0; off < len; off += step_size) { + step(off, i_off, o_off, i_off, o_off, step_size); + tr8x8_avx2(i_off, o_off); + } + + return true; + } + + template + bool process_direct_copy(int len) { + using namespace data_type; + + using Vmm = typename cpu_isa_traits::Vmm; + const int simd_w = cpu_isa_traits::vlen / itype_sz; + + bool can_do = true + && mayiuse(isa) + && utils::everyone_is(1, os(0), is(0)) + && (false + || prb_.itype == prb_.otype + || (prb_.itype == s32 && prb_.otype == f32) + || (prb_.itype == f32 && prb_.otype == s32) + ) + && len % simd_w == 0 + && n(0) % len == 0 + && prb_.scale_type == scale_type_t::NONE + && prb_.beta == 0.f; + if (!can_do) return false; + + for (int off = 0; off < len;) { + const int unroll = nstl::min(16, (len - off) / simd_w); + + for (int ur = 0; ur < unroll; ++ur) + uni_vmovups(Vmm(ur), i_addr(off + ur * simd_w)); + + if (prb_.itype != prb_.otype) { + for (int ur = 0; ur < unroll; ++ur) { + if (prb_.itype == s32 && prb_.otype == f32) + uni_vcvtdq2ps(Vmm(ur), Vmm(ur)); + else if (prb_.itype == f32 && prb_.otype == s32) + uni_vcvtps2dq(Vmm(ur), Vmm(ur)); + else assert(!"unreachable"); + } + } + + for (int ur = 0; ur < unroll; ++ur) + uni_vmovups(o_addr(off + ur * simd_w), Vmm(ur)); + + off += unroll * simd_w; + } + + return true; + } + + void process_unroll_generic_step(int reg_unroll, const int *i_off, + const int *o_off, const int *s_off) { + using namespace data_type; + + auto cvt2ps = [=](const Xmm &dst, const Operand &src, data_type_t idt) { + Xmm dst_pure = Xmm(dst.getIdx()); + switch (idt) { + case f32: + if (src.isMEM() || src.getIdx() != dst.getIdx()) + vmovups(dst, src); + break; + case s32: vcvtdq2ps(dst, src); break; + case s8: vpmovsxbd(dst, src); vcvtdq2ps(dst_pure, dst); break; + case u8: vpmovzxbd(dst, src); vcvtdq2ps(dst_pure, dst); break; + default: assert(!"unreachable"); + } + }; + + auto cvt2int = [=](const Xmm &xmm, data_type_t odt, data_type_t idt) { + switch (odt) { + case s32: + if (idt == f32) vcvtps2dq(xmm, xmm); + else if (idt == s8) vpmovsxbd(xmm, xmm); + else if (idt == u8) vpmovzxbd(xmm, xmm); + break; + case s8: + if (idt == f32) vcvtps2dq(xmm, xmm); + if (idt == f32 || idt == s32) { + if (mayiuse(avx512_core)) { + vpmovsdb(xmm, xmm); + } else { + vpackssdw(xmm, xmm, xmm_zero); + vpacksswb(xmm, xmm, xmm_zero); + } + } + if (idt == u8) vpminub(xmm, xmm, xmm_4x127b); + break; + case u8: + if (idt == f32) vcvtps2dq(xmm, xmm); + if (idt == f32 || idt == s32) { + if (mayiuse(avx512_core)) { + vpmaxsd(xmm, xmm, xmm_zero); + vpmovusdb(xmm, xmm); + } else { + vpackssdw(xmm, xmm, xmm_zero); + vpackuswb(xmm, xmm, xmm_zero); + } + } + if (idt == s8) vpmaxsb(xmm, xmm, xmm_zero); + break; + default: assert(!"unreachable"); + } + }; + + auto load = [=](const Xmm &xmm, const Address &addr, int size) { + switch (size) { + case 16: movups(xmm, addr); break; + case 4: movss(xmm, addr); break; + case 1: pinsrb(xmm, addr, 0x0); break; + default: assert(!"unreachable"); + } + }; + + auto store = [=](const Address &addr, const Xmm &xmm, int size) { + switch (size) { + case 16: movups(addr, xmm); break; + case 4: movss(addr, xmm); break; + case 1: pextrb(addr, xmm, 0x0); break; + default: assert(!"unreachable"); + } + }; + + /* check whether loading 4 values at once is possible */ + bool can_load_xmm = mayiuse(avx) && reg_unroll % 4 == 0; + for (int ur = 1; ur < reg_unroll; ++ur) + if (i_off[ur] != i_off[ur - 1] + 1) + can_load_xmm = false; + const int load_step = can_load_xmm ? 4 : 1; + + /* check whether storing 4 values at once is possible */ + bool can_store_xmm = reg_unroll % 4 == 0; + for (int ur = 1; ur < reg_unroll; ++ur) + if (o_off[ur] != o_off[ur - 1] + 1) + can_store_xmm = false; + const int ur_step = can_store_xmm ? 4 : 1; + + const bool interim_f32 = false + || utils::one_of(f32, prb_.itype, prb_.otype) + || prb_.scale_type != scale_type_t::NONE + || prb_.beta != 0.f; + + if (!can_load_xmm && can_store_xmm) { + assert(ur_step == 4); + /* load with stride */ + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + for (int r = 0; r < ur_step; ++r) { + if (itype_sz == 4) + pinsrd(Xmm(ur), i_addr(i_off[ur + r]), r); + else + pinsrb(Xmm(ur), i_addr(i_off[ur + r]), r); + } + } + } else { + for (int ur = 0; ur < reg_unroll; ur += load_step) + load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz); + } + + /* xmm[:] <-- (f32)xmm[:] */ + if (interim_f32) { + const int cvt_step = nstl::max(load_step, ur_step); + for (int ur = 0; ur < reg_unroll; ur += cvt_step) + cvt2ps(Xmm(ur), Xmm(ur), prb_.itype); + } + + if (can_load_xmm && !can_store_xmm) { + const bool fast_return = true // transposition on the fly + && prb_.scale_type != scale_type_t::MANY + && prb_.beta == 0.f; + if (fast_return) { + for (int ur = 0; ur < reg_unroll; ur += load_step) { + if (prb_.scale_type == scale_type_t::COMMON) + mulps(Xmm(ur), xmm_scale); + if (prb_.otype != f32) + cvt2int(Xmm(ur), prb_.otype, + interim_f32 ? f32 : prb_.itype); + for (int r = 0; r < load_step; ++r) { + if (otype_sz == 4) + pextrd(o_addr(o_off[ur + r]), Xmm(ur), r); + else + pextrb(o_addr(o_off[ur + r]), Xmm(ur), r); + } + } + return; + } + + /* scatter elements of xmm into 4 xmms */ + if (itype_sz == 4 || interim_f32) { + for (int ur = 0; ur < reg_unroll; ur += load_step) + for (int r = 1; r < load_step; ++r) + vshufps(Xmm(ur + r), Xmm(ur), Xmm(ur), r); + } else { + for (int ur = 0; ur < reg_unroll; ur += load_step) + for (int r = 1; r < load_step; ++r) + vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur), r); + } + } + + /* scale and beta processing */ + if (can_store_xmm) { + /* xmm <-- scale * xmm[:] */ + if (prb_.scale_type == scale_type_t::COMMON) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) + mulps(Xmm(ur), xmm_scale); + } else if (prb_.scale_type == scale_type_t::MANY) { + enum class scale_load_type_t { bcast, load, gather }; + + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + scale_load_type_t scale_load_type = + scale_load_type_t::bcast; // the best case + + for (int r = ur + 1; r < ur + ur_step; ++r) + if (s_off[r] != s_off[r - 1] + 0) + scale_load_type = scale_load_type_t::load; + + if (scale_load_type == scale_load_type_t::bcast) { + movss(xmm_scale, s_addr(s_off[ur])); + shufps(xmm_scale, xmm_scale, 0x0); + mulps(Xmm(ur), xmm_scale); + continue; + } + + // bcast doesn't work, the next try -- load + for (int r = ur + 1; r < ur + ur_step; ++r) + if (s_off[r] != s_off[r - 1] + 1) + scale_load_type = scale_load_type_t::gather; + + if (scale_load_type == scale_load_type_t::load) { + movups(xmm_scale, s_addr(s_off[ur])); + mulps(Xmm(ur), xmm_scale); + continue; + } + + // load doesn't work as well + // so gather the scale factors one by one + for (int r = ur; r < ur + ur_step; ++r) + pinsrd(xmm_scale, s_addr(s_off[r]), r - ur); + mulps(Xmm(ur), xmm_scale); + } + } + + /* dst <-- beta * dst + xmm[:] */ + assert(prb_.beta == 0.f || prb_.beta == 1.f); + if (prb_.beta == 1.f) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (prb_.otype == f32) { + /* non VEX instructions do not support unaligned + * memory for instructions other than movups. */ + if (mayiuse(avx)) { + vaddps(Xmm(ur), o_addr(o_off[ur])); + } else { + /* register xmm(1) is unused */ + movups(Xmm(1), o_addr(o_off[ur])); + addps(Xmm(ur), Xmm(1)); + } + } else { + cvt2ps(Xmm(1), o_addr(o_off[ur]), prb_.otype); + vaddps(Xmm(ur), Xmm(1)); + } + } + } + } else { + /* xmm[0] <-- scale * xmm[0] */ + if (prb_.scale_type == scale_type_t::COMMON) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) + mulss(Xmm(ur), xmm_scale); + } else if (prb_.scale_type == scale_type_t::MANY) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + mulss(Xmm(ur), s_addr(s_off[ur])); + } + } + + /* dst <-- beta * dst + xmm[0] */ + assert(prb_.beta == 0.f || prb_.beta == 1.f); + if (prb_.beta == 1.f) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (prb_.otype == f32) { + addss(Xmm(ur), o_addr(o_off[ur])); + } else { + if (prb_.otype == s32) { + vmovss(xmm_tmp, o_addr(o_off[ur])); + } else if (utils::one_of(prb_.otype, s8, u8)) { + pinsrb(xmm_tmp, o_addr(o_off[ur]), 0x0); + } else { + assert(!"unsupported o_type"); + } + cvt2ps(xmm_tmp, xmm_tmp, prb_.otype); + addps(Xmm(ur), xmm_tmp); + } + } + } + } + + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (prb_.otype != f32) + cvt2int(Xmm(ur), prb_.otype, interim_f32 ? f32 : prb_.itype); + store(o_addr(o_off[ur]), Xmm(ur), ur_step * otype_sz); + } + } + + void process_unroll_generic(int len) { + const int blk = 8; + + int i_off[2 * blk] = {0}; + int o_off[2 * blk] = {0}; + int s_off[2 * blk] = {0}; + + int curr = 0; // will switch between 0 and 1 + + for (int off = 0; off < len; off += blk) { + const int reg_unroll = nstl::min(off + blk, len) - off; + + /* compute offsets */ + for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { + const int ur_c = curr * blk + ur; + const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur + step(off + ur, + i_off[ur_p], o_off[ur_p], s_off[ur_p], + i_off[ur_c], o_off[ur_c], s_off[ur_c]); + } + + process_unroll_generic_step(reg_unroll, i_off + curr * blk, + o_off + curr * blk, s_off + curr * blk); + + curr = 1 - curr; + } + } + + void loop_begin(Label &l, Reg64 reg_cnt, int len) { + mov(reg_cnt, len); + L(l); + } + + void loop_end(Label &l, Reg64 reg_cnt, int len, + int i_step, int o_step, int s_step) { + add(reg_off_in, i_step * itype_sz); + add(reg_off_out, o_step * otype_sz); + if (prb_.scale_type == scale_type_t::MANY) + add(reg_off_scale, s_step * stype_sz); + dec(reg_cnt); + jnz(l); + + sub(reg_off_in, len * i_step * itype_sz); + sub(reg_off_out, len * o_step * otype_sz); + if (prb_.scale_type == scale_type_t::MANY) + sub(reg_off_scale, len * s_step * stype_sz); + } + + bool simple_impl() { + simple_impl_desc_t d; + if (!simple_impl_desc_init(prb_, &d)) return false; + + const int nfu = d.ndims_full_unroll; + const int ldu = d.len_last_dim_unroll; + const int n_jit_loops = prb_.ndims - d.ndims_full_unroll; + assert(n_jit_loops <= ndims_jit_loop_max); + + xor_(reg_off_in, reg_off_in); + xor_(reg_off_out, reg_off_out); + if (prb_.scale_type == scale_type_t::MANY) + xor_(reg_off_scale, reg_off_scale); + + Label l_loop[3]; + Reg64 reg_cnt[3] = {r15, r14, r13}; + + if (n_jit_loops > 2) + loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2)); + + if (n_jit_loops > 1) + loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1)); + + if (n_jit_loops > 0) + loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu); + + const bool optimized = false + || process_direct_copy(d.len_unroll) + || process_direct_copy(d.len_unroll) + || process_unroll_tr8x8(d.len_unroll); + if (!optimized) + process_unroll_generic(d.len_unroll); + + if (n_jit_loops > 0) + loop_end(l_loop[0], reg_cnt[0], + n(nfu + 0) / ldu, is(nfu + 0) * ldu, os(nfu + 0) * ldu, + ss(nfu + 0) * ldu); + + if (n_jit_loops > 1) + loop_end(l_loop[1], reg_cnt[1], + n(nfu + 1), is(nfu + 1), os(nfu + 1), ss(nfu + 1)); + + if (n_jit_loops > 2) + loop_end(l_loop[2], reg_cnt[2], + n(nfu + 2), is(nfu + 2), os(nfu + 2), ss(nfu + 2)); + + return true; + } + + void impl() { + if (simple_impl()) return; + assert(!"no implementation available"); + } + + jit_uni_reorder_kernel_f32(const desc_t &desc) + : kernel_t(desc), jit_generator() { + itype_sz = data_type_size(prb_.itype); + otype_sz = data_type_size(prb_.otype); + stype_sz = sizeof(float); + + preamble(); +# define PARAM(x) ptr[abi_param1 + offsetof(call_param_t, x)] + if (prb_.scale_type == scale_type_t::COMMON) { + auto reg_ptr_scale_tmp = reg_ptr_in; + mov(reg_ptr_scale_tmp, PARAM(scale)); + movups(xmm_scale, ptr[reg_ptr_scale_tmp]); + } else if (prb_.scale_type == scale_type_t::MANY) { + mov(reg_ptr_scale, PARAM(scale)); + } + mov(reg_ptr_in, PARAM(in)); + mov(reg_ptr_out, PARAM(out)); +# undef PARAM + + if (mayiuse(avx)) { + vxorps(xmm_zero, xmm_zero, xmm_zero); + + if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) { + mov(reg_tmp.cvt32(), 0x7f7f7f7f); + movd(xmm_4x127b, reg_tmp.cvt32()); + } + } + + impl(); + postamble(); + ker_ = (void (*)(const call_param_t *))getCode(); + } + +private: + int itype_sz; + int otype_sz; + int stype_sz; + + Reg64 reg_ptr_in = rsi; + Reg64 reg_ptr_out = rdx; + Reg64 reg_ptr_scale = abi_not_param1; + + Reg64 reg_off_in = r8; + Reg64 reg_off_out = r9; + Reg64 reg_off_scale = r10; + + Reg64 reg_tmp = rax; + + Xmm xmm_scale = xmm15; + Xmm xmm_zero = xmm14; + Xmm xmm_4x127b = xmm13; // TODO: unite with xmm_zero + Xmm xmm_tmp = xmm12; +}; + +status_t kernel_t::desc_init(kernel_t::desc_t &desc, const prb_t &prb, + int ndims_ker_max) { + desc.prb = prb; + desc.prb.ioff = desc.prb.ooff = 0; + + if (ndims_ker_max > prb.ndims) + return status::invalid_arguments; + + auto ndims_ker_max_f = [&]() { + size_t cur_size = 1; + for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n) + if (cur_size >= ker_prb_size_min) return d; + return prb.ndims; + }; + + if (ndims_ker_max <= 0) + ndims_ker_max = ndims_ker_max_f(); + + /* traverse through kernel implementations */ + /* TODO: find a better way to do that... */ + desc.id = 0; + for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) { + desc.prb.ndims = ndims_ker; + if (jit_uni_reorder_kernel_f32::applicable(desc.prb)) + return status::success; + } + + return status::unimplemented; +} + +kernel_t *kernel_t::create(const kernel_t::desc_t &desc) { + switch (desc.id) { + case 0: return new jit_uni_reorder_kernel_f32(desc); + default: assert(!"unknown kernel id"); return nullptr; + } + + return nullptr; +} + +} + +static void prb_block_for_cache(tr::prb_t &prb) { + if (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) { + /** an attempt to use caches more efficient and + * address the 4K-aliasing issue */ + /* TODO: improve the logic around here */ + int j = 1; + for (; j < prb.ndims && prb.nodes[j].is != 1; ++j); + if (j == prb.ndims) return; + + /* it makes sense to re-prioritize sequential read over + * sequential write if the former would not trash the + * cache, i.e. is == 1 and os % 2^smth != 0. Smth is + * set to 2 at the moment */ + const int move_to = prb.nodes[j].os % 4 != 0 ? 0 : 1; + if (j == move_to) return; + + if (prb.nodes[j].n > 16 && prb.nodes[j].n % 16 == 0) + prb_node_split(prb, j, 16); + + prb_node_move(prb, j, move_to); + DEBUG({ printf("cache: "); prb_dump(prb); }); + } +} + +/** finds the maximum number of dimension the kernel should process and + * optionally splits one of the dimension to achieve better balance between + * parallel driver and the kernel. */ +static void prb_thread_kernel_balance(tr::prb_t &prb, int &ndims_ker_max) { + size_t sz_total = 1; + for (int d = 0; d < prb.ndims; ++d) + sz_total *= prb.nodes[d].n; + + /* sz_drv_min is the minimal size for the parallel + * driver required for good parallelization */ + const size_t sz_drv_min = nstl::min( + 16 * mkldnn_get_max_threads(), + utils::div_up(sz_total, 1024)); + + /* kdims -- # of dimensions processed by a kernel + * sz_ker_cur -- product of the dimension processed by a kernel + * sz_drv_cur -- product of the dimension processed by a driver */ + + int kdims = prb.ndims; + size_t sz_drv_cur = 1; + for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims) + sz_drv_cur *= prb.nodes[kdims - 1].n; + + size_t sz_ker_cur = 1; + for (int d = 0; d < kdims; ++d) + sz_ker_cur *= prb.nodes[d].n; + + /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min. + * + * It might happen that for chosen kdims the sz_ker_cur is too small + * (less than tr::ker_prb_size_min). In that case try to split the + * innermost driver dimension into two, to increase sz_ker_cur. */ + bool want_borrow_ker_from_drv = true + && kdims < prb.ndims + && sz_ker_cur < tr::ker_prb_size_min + && sz_drv_cur > sz_drv_min; + if (want_borrow_ker_from_drv) { + /* sz_want_borrow is the minimal sz, so that: + * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min + * o) current innermost driver dimension is divisible by + * sz_want_borrow (so that we can evenly split that + * dimension into two) + * + * In the worst case the minimal sz_want_borrow is equal + * to the innermost driver dimension itself. In that case + * we will sacrifice it in favor of kernel (is it fine?). */ + size_t sz_want_borrow + = utils::div_up(tr::ker_prb_size_min, sz_ker_cur); + for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow); + if (sz_want_borrow != prb.nodes[kdims].n) + prb_node_split(prb, kdims, sz_want_borrow); + kdims += 1; + } + + /* On the other hand it might happen that for chosen kdims + * the sz_drv_cur is too small (less than sz_drv_min). In that case + * try to split the outermost kernel dimension into two, to increase + * sz_drv_cur. */ + bool want_borrow_drv_from_ker = true + && sz_ker_cur > tr::ker_prb_size_min + && sz_drv_cur < sz_drv_min; + if (want_borrow_drv_from_ker) { + size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur); + for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow); + if (sz_want_borrow != prb.nodes[kdims - 1].n) + prb_node_split(prb, kdims - 1, + prb.nodes[kdims - 1].n / sz_want_borrow); + } + + ndims_ker_max = kdims; + + if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) { + DEBUG({ printf("split: "); prb_dump(prb); + printf("ndims_ker_max = %d\n", ndims_ker_max); }); + } +} + +struct jit_uni_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + auto prb = tr::prb_t(); + + status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); + if (prb_init_status != status::success) return prb_init_status; + + DEBUG({ printf("init : "); prb_dump(prb); }); + prb_normalize(prb); + DEBUG({ printf("norm : "); prb_dump(prb); }); + prb_simplify(prb); + DEBUG({ printf("smpl : "); prb_dump(prb); }); + + prb_block_for_cache(prb); + + int ndims_ker_max; + prb_thread_kernel_balance(prb, ndims_ker_max); + + tr::kernel_t::desc_t ker_desc; + status_t ker_init_status + = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max); + if (ker_init_status != status::success) return ker_init_status; + + const int ndims_driver = prb.ndims - ker_desc.prb.ndims; + if (ndims_driver > jit_uni_reorder_t::ndims_driver_max) + return status::unimplemented; + + DEBUG({ printf("ker : "); prb_dump(ker_desc.prb); }); + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return status::out_of_memory; + if (_pd->init() != status::success) { + delete _pd; + return status::unimplemented; + } + _pd->prb_ = prb; + _pd->ker_desc_ = ker_desc; + return safe_ptr_assign(*reorder_pd, _pd); + } + + tr::prb_t prb_; + tr::kernel_t::desc_t ker_desc_; + }; + + jit_uni_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { + kernel_ = tr::kernel_t::create(pd()->ker_desc_); + assert(kernel_); + } + ~jit_uni_reorder_t() { delete kernel_; } + + void omp_driver_0d(int off, const char *in, char *out, + const float *scale) const { + tr::call_param_t c{in, out, scale}; + (*kernel_)(&c); + } + + void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype); + c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss; + (*kernel_)(&c); + }); + } + + void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + (d0 * ns[0].is + d1 * ns[1].is) + * data_type_size(pd()->prb_.itype); + c.out = out + (d0 * ns[0].os + d1 * ns[1].os) + * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss; + (*kernel_)(&c); + }); + } + + void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, + (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is) + * data_type_size(pd()->prb_.itype); + c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os) + * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; + (*kernel_)(&c); + }); + } + + void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, + (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is + + d3 * ns[3].is) * data_type_size(pd()->prb_.itype); + c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os + + d3 * ns[3].os) * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss + + d3 * ns[3].ss; + (*kernel_)(&c); + }); + } + + void omp_driver(const char *in, char *out, const float *scale) const { + in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype); + out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype); + + DEBUG({ printf("prb : "); tr::prb_dump(pd()->prb_); }); + DEBUG({ printf("ker : "); tr::prb_dump(pd()->ker_desc_.prb); }); + + int ndims = pd()->prb_.ndims; + int ndims_ker = pd()->ker_desc_.prb.ndims; + assert(ndims - ndims_ker <= ndims_driver_max); + + if (ndims - ndims_ker == 0) { + omp_driver_0d(ndims_ker, in, out, scale); + } else { + parallel(0, [&](const int ithr, const int nthr) { + switch (ndims - ndims_ker) { + case 1: omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale); break; + case 2: omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale); break; + case 3: omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale); break; + case 4: omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale); break; + default: assert(!"unimplemented"); + } + }); + } + } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto in = CTX_IN_MEM(const char *, MKLDNN_ARG_FROM); + auto out = CTX_OUT_MEM(char *, MKLDNN_ARG_TO); + + omp_driver(in, out, pd()->attr()->output_scales_.scales_); + + return status::success; + } + + enum { ndims_driver_max = 4 }; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + tr::kernel_t *kernel_; +}; + +status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + return jit_uni_reorder_t::pd_t::create(reorder_pd, engine, attr, + src_engine, src_md, dst_engine, dst_md); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp new file mode 100644 index 000000000000..0746ea61d396 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp @@ -0,0 +1,127 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef _JIT_UNI_REORDER_HPP +#define _JIT_UNI_REORDER_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "cpu_primitive.hpp" +#include "cpu_reorder_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace tr { + +constexpr int max_ndims = MKLDNN_MAX_NDIMS; + +struct node_t { + size_t n; + ptrdiff_t is; // input stride + ptrdiff_t os; // output stride + ptrdiff_t ss; // scale stride +}; + +enum class scale_type_t { NONE, COMMON, MANY }; + +struct prb_t { + data_type_t itype; + data_type_t otype; + int ndims; + node_t nodes[max_ndims]; + ptrdiff_t ioff; + ptrdiff_t ooff; + scale_type_t scale_type; + float beta; +}; + +status_t prb_init(prb_t &prb, const memory_desc_t &imd, + const memory_desc_t &omd, const primitive_attr_t *attr); + +/** sorts the problem nodes so that output strides come in ascending order */ +void prb_normalize(prb_t &p); + +/** folds nodes together if possible */ +void prb_simplify(prb_t &p); + +/** splits the node dim into two of sizes n1 and n / n1 + * @warning n must be multiple of n1 */ +void prb_node_split(prb_t &p, int dim, size_t n1); + +/** swaps d0 and d1 nodes */ +void prb_node_swap(prb_t &p, int d0, int d1); + +/** moves node d0 to the d1 position. + * nodes (d0, d1] are shifted to the left if d0 < d1 or + * to the right if d0 > d1 */ +void prb_node_move(prb_t &p, int d0, int d1); + +/** dumps the problem to stdout */ +void prb_dump(const prb_t &p); + +struct call_param_t { + const void *in; + void *out; + const float *scale; +}; + +struct kernel_t { + struct desc_t { + int id; + prb_t prb; + }; + + kernel_t(const desc_t &desc): desc_(desc), ker_(nullptr) {} + void operator()(const call_param_t *c) const { assert(ker_); ker_(c); } + virtual ~kernel_t() {} + + /** inits kernel descriptor: + * desc -- kernel descriptor (output) + * prb -- transposition problem (input) + * ndims_ker_max -- limit the maximum number of dimensions kernel + * will process (optional, 0 -- no limitation) */ + static status_t desc_init(desc_t &desc, const prb_t &prb, + int ndims_ker_max = 0); + + /** creates kernel for the problem described in desc */ + static kernel_t *create(const desc_t &desc); + +protected: + const desc_t desc_; + const prb_t &prb_ = desc_.prb; + void (*ker_)(const call_param_t *); +}; + +/* TODO: add trans_t class */ + +} + +/* for cpu reorder list */ +status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp new file mode 100644 index 000000000000..69b7a33604ce --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp @@ -0,0 +1,313 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "mkldnn_debug.h" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_uni_reorder.hpp" + +using namespace mkldnn::impl::types; +using namespace mkldnn::impl::status; + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace tr { + +/** ad-hoc structure to describe blocked memory layout */ +struct layout_desc_t { + data_type_t dt; + int ndims; + dims_t id; + dims_t dims; + strides_t strides; +}; + +status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, + layout_desc_t &ld) { + const auto md = memory_desc_wrapper(md_); + + bool ok = true + && md.is_blocking_desc() + && md.extra().flags == 0; + if (!ok) return invalid_arguments; + + const auto &bd = md.blocking_desc(); + + ld.ndims = 0; + ld.dt = md.data_type(); + + auto P = [&ld](int id, int dim, ptrdiff_t stride) { + assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0])); + ld.id[ld.ndims] = id; + ld.dims[ld.ndims] = dim; + ld.strides[ld.ndims] = stride; + ++ld.ndims; + }; + + dims_t blocks; + md.compute_blocks(blocks); + + for (int d = 0; d < md.ndims(); ++d) { + const int ld_ndims_start = ld.ndims; + if (blocks[d] != 1) { + stride_t stride = 1; + for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) { + if (bd.inner_idxs[iblk] == d) + P(d, bd.inner_blks[iblk], stride); + stride *= bd.inner_blks[iblk]; + } + } + P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]); + + // TODO: NOW: revisit, do we need a reverse? + // TODO: NOW: consider using strides instead of block sizes in md + // reverse the order of dims + for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) { + const int idx0 = ld_ndims_start + ld_d; + const int idx1 = ld.ndims - 1 - ld_d; + nstl::swap(ld.dims[idx0], ld.dims[idx1]); + nstl::swap(ld.strides[idx0], ld.strides[idx1]); + } + } + + return success; +} + +status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + const primitive_attr_t *attr) { + auto im_d = memory_desc_wrapper(imd); + auto om_d = memory_desc_wrapper(omd); + + bool ok = true + && im_d.is_blocking_desc() + && om_d.is_blocking_desc() + && !im_d.has_zero_dim() + && !om_d.has_zero_dim(); + if (!ok) + return unimplemented; + + dims_t iblocks, oblocks; + im_d.compute_blocks(iblocks); + om_d.compute_blocks(oblocks); + + /* padding_dim consistency check */ + for (int d = 0; d < im_d.ndims(); ++d) { + const auto pdim = im_d.padded_dims()[d]; + bool ok = true + && pdim == om_d.padded_dims()[d] + && pdim % iblocks[d] == 0 + && pdim % oblocks[d] == 0; + if (!ok) return unimplemented; + } + + layout_desc_t ild, old; + status_t status = cvt_mem_desc_to_layout_desc(imd, ild); + if (status != success) return status; + status = cvt_mem_desc_to_layout_desc(omd, old); + if (status != success) return status; + + p.itype = ild.dt; + p.otype = old.dt; + + p.scale_type = attr->output_scales_.has_default_values() + ? scale_type_t::NONE + : (attr->output_scales_.mask_ == 0 + ? scale_type_t::COMMON + : scale_type_t::MANY); + + ptrdiff_t ss[max_ndims] = {0}; + if (p.scale_type == scale_type_t::MANY) { + ptrdiff_t last_ss = 1; + for (int d = old.ndims - 1; d >=0; --d) { + assert((d == 0 || old.id[d - 1] <= old.id[d]) + && "logical dimensions should be in ascending order"); + if (attr->output_scales_.mask_ & (1 << old.id[d])) { + ss[d] = last_ss; + last_ss *= old.dims[d]; + } + } + } + + int ndims = 0; + + int i_pos = 0; /* state for input -- current dimension */ + int o_pos = 0; /* state for output -- current dimension */ + + while (i_pos < ild.ndims && o_pos < old.ndims) { + assert(ild.id[i_pos] == old.id[o_pos]); + if (ild.id[i_pos] != old.id[o_pos]) + return runtime_error; + + assert(ndims < max_ndims); + if (ndims == max_ndims) + return runtime_error; + + if (ild.dims[i_pos] == old.dims[o_pos]) { + p.nodes[ndims].n = ild.dims[i_pos]; + p.nodes[ndims].is = ild.strides[i_pos]; + p.nodes[ndims].os = old.strides[o_pos]; + p.nodes[ndims].ss = ss[o_pos]; + ++ndims; + ++i_pos; + ++o_pos; + } else if (ild.dims[i_pos] < old.dims[o_pos]) { + assert(old.dims[o_pos] % ild.dims[i_pos] == 0); + int factor = old.dims[o_pos] / ild.dims[i_pos]; + p.nodes[ndims].n = ild.dims[i_pos]; + p.nodes[ndims].is = ild.strides[i_pos]; + p.nodes[ndims].os = old.strides[o_pos] * factor; + p.nodes[ndims].ss = ss[o_pos] * factor; + ++ndims; + ++i_pos; + old.dims[o_pos] = factor; + } else if (ild.dims[i_pos] > old.dims[o_pos]) { + assert(ild.dims[i_pos] % old.dims[o_pos] == 0); + int factor = ild.dims[i_pos] / old.dims[o_pos]; + p.nodes[ndims].n = old.dims[o_pos]; + p.nodes[ndims].is = ild.strides[i_pos] * factor; + p.nodes[ndims].os = old.strides[o_pos]; + p.nodes[ndims].ss = ss[o_pos]; + ++ndims; + ++o_pos; + ild.dims[i_pos] = factor; + } + } + p.ndims = ndims; + + dims_t zero_pos = {0}; + p.ioff = memory_desc_wrapper(imd).off_v(zero_pos); + p.ooff = memory_desc_wrapper(omd).off_v(zero_pos); + + const int sum_idx = attr->post_ops_.find(primitive_kind::sum); + p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale; + + return success; +} + +void prb_normalize(prb_t &p) { + for (int d = 0; d < p.ndims; ++d) { + int min_pos = d; + for (int j = d + 1; j < p.ndims; ++j) { + bool new_min = false + || p.nodes[j].os < p.nodes[min_pos].os + || (true + && p.nodes[j].os == p.nodes[min_pos].os + && p.nodes[j].n < p.nodes[min_pos].n); + if (new_min) min_pos = j; + } + if (min_pos != d) + nstl::swap(p.nodes[d], p.nodes[min_pos]); + } +} + +void prb_simplify(prb_t &p) { +#if defined(__GNUC__) && __GNUC__ >= 4 +/* GCC produces bogus array subscript is above array bounds warning for + * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Warray-bounds" +#endif + for (int d = 0; d < p.ndims - 1; ++d) { + auto &this_node = p.nodes[d + 0]; + auto &next_node = p.nodes[d + 1]; + const bool fold = false + || next_node.n == (size_t)1 // trivial case, just drop next node + || (true // or real folding if possible + && next_node.is == (ptrdiff_t)this_node.n * this_node.is + && next_node.os == (ptrdiff_t)this_node.n * this_node.os + && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss); + if (fold) { + this_node.n *= next_node.n; + for (int j = d + 2; j < p.ndims; ++j) + p.nodes[j - 1] = p.nodes[j]; + --p.ndims; + --d; // make another try + } + } +#if defined(__GNUC__) && __GNUC__ >= 4 +#pragma GCC diagnostic pop +#endif +} + +void prb_node_split(prb_t &p, int dim, size_t n1) { + assert(dim < p.ndims); + assert(p.ndims < max_ndims); + assert(p.nodes[dim].n % n1 == 0); + + p.ndims += 1; + + for (int d = p.ndims; d > dim + 1; --d) + p.nodes[d] = p.nodes[d - 1]; + + p.nodes[dim + 1].n = p.nodes[dim].n / n1; + p.nodes[dim + 1].is = p.nodes[dim].is * n1; + p.nodes[dim + 1].os = p.nodes[dim].os * n1; + p.nodes[dim + 1].ss = p.nodes[dim].ss * n1; + + p.nodes[dim].n = n1; +} + +void prb_node_swap(prb_t &p, int d0, int d1) { + assert(d0 < p.ndims); + assert(d1 < p.ndims); + assert(p.ndims < max_ndims); + + if (d0 == d1) return; + + nstl::swap(p.nodes[d0], p.nodes[d1]); +} + +void prb_node_move(prb_t &p, int d0, int d1) { + assert(d0 < p.ndims); + assert(d1 < p.ndims); + assert(p.ndims < max_ndims); + + if (d0 == d1) return; + + node_t node = p.nodes[d0]; + + if (d0 < d1) + for (int d = d0; d < d1; ++d) + p.nodes[d] = p.nodes[d + 1]; + else + for (int d = d0; d > d1; --d) + p.nodes[d] = p.nodes[d - 1]; + + p.nodes[d1] = node; +} + +void prb_dump(const prb_t &p) { + printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype), + mkldnn_dt2str(p.otype), p.ndims); + for (int d = 0; d < p.ndims; ++d) + printf("[%zu:%td:%td:%td]", + p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss); + printf(" off:%zu:%zu\n", p.ioff, p.ooff); +} + +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp new file mode 100644 index 000000000000..08747aa89c9c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "utils.hpp" + +#ifndef MKLDNN_ENABLE_JIT_PROFILING +#define MKLDNN_ENABLE_JIT_PROFILING 1 +#endif + +#ifndef MKLDNN_ENABLE_JIT_DUMP +#define MKLDNN_ENABLE_JIT_DUMP 1 +#endif + +#if MKLDNN_ENABLE_JIT_PROFILING +#include "jitprofiling/jitprofiling.h" +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace jit_utils { + +// WARNING: These functions are not thread safe and must be protected by a +// mutex + +void dump_jit_code(const void *code, size_t code_size, const char *code_name) +{ +#if MKLDNN_ENABLE_JIT_DUMP + if (code && jit_dump_enabled()) { + static int counter = 0; +#define MAX_FNAME_LEN 256 + char fname[MAX_FNAME_LEN + 1]; + // TODO (Roma): support prefix for code / linux perf dumps + snprintf(fname, MAX_FNAME_LEN, "mkldnn_dump_%s.%d.bin", code_name, + counter); + counter++; + + FILE *fp = fopen(fname, "w+"); + // Failure to dump code is not fatal + if (fp) { + size_t unused = fwrite(code, code_size, 1, fp); + UNUSED(unused); + fclose(fp); + } + } +#undef MAX_FNAME_LEN +#else + UNUSED(code); + UNUSED(code_size); + UNUSED(code_name); +#endif +} + +void register_jit_code_vtune(const void *code, size_t code_size, + const char *code_name, const char *source_file_name) +{ +#if MKLDNN_ENABLE_JIT_PROFILING + if (iJIT_IsProfilingActive() == iJIT_SAMPLING_ON) { + auto jmethod = iJIT_Method_Load(); + jmethod.method_id = iJIT_GetNewMethodID(); // XXX: not thread-safe + jmethod.method_name = (char *)code_name; // XXX: dropping const + jmethod.class_file_name = NULL; + jmethod.source_file_name = (char *)source_file_name; // XXX: dropping const + jmethod.method_load_address = (void *)code; + jmethod.method_size = (unsigned int)code_size; + + iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, + (void*)&jmethod); + } +#else + UNUSED(code); + UNUSED(code_size); + UNUSED(code_name); + UNUSED(source_file_name); +#endif +} + +void register_jit_code(const void *code, size_t code_size, + const char *code_name, const char *source_file_name) +{ + // The #ifdef guards are required to avoid generating a function that only + // consists of lock and unlock code +#if MKLDNN_ENABLE_JIT_PROFILING || MKLDNN_ENABLE_JIT_DUMP + static std::mutex m; + std::lock_guard guard(m); + + dump_jit_code(code, code_size, code_name); + register_jit_code_vtune(code, code_size, code_name, source_file_name); +#else + UNUSED(code); + UNUSED(code_size); + UNUSED(code_name); + UNUSED(source_file_name); +#endif +} + +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp new file mode 100644 index 000000000000..2f52dba4acd9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_SUPPORT_HPP +#define JIT_SUPPORT_HPP + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace jit_utils { + +void register_jit_code(const void *code, size_t code_size, + const char *code_name, const char *source_file_name); + +} +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD new file mode 100644 index 000000000000..4fd21cea5774 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD @@ -0,0 +1,27 @@ +Copyright (c) 2011, Intel Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md new file mode 100644 index 000000000000..fc67c4f134fe --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md @@ -0,0 +1 @@ +This code is from [Intel SEAPI library](https://github.com/intel/IntelSEAPI) diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h new file mode 100644 index 000000000000..edbf4a15f0fc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h @@ -0,0 +1,595 @@ +/* + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +#ifndef _ITTNOTIFY_CONFIG_H_ +#define _ITTNOTIFY_CONFIG_H_ + +/** @cond exclude_from_documentation */ +#ifndef ITT_OS_WIN +# define ITT_OS_WIN 1 +#endif /* ITT_OS_WIN */ + +#ifndef ITT_OS_LINUX +# define ITT_OS_LINUX 2 +#endif /* ITT_OS_LINUX */ + +#ifndef ITT_OS_MAC +# define ITT_OS_MAC 3 +#endif /* ITT_OS_MAC */ + +#ifndef ITT_OS_FREEBSD +# define ITT_OS_FREEBSD 4 +#endif /* ITT_OS_FREEBSD */ + +#ifndef ITT_OS +# if defined WIN32 || defined _WIN32 +# define ITT_OS ITT_OS_WIN +# elif defined( __APPLE__ ) && defined( __MACH__ ) +# define ITT_OS ITT_OS_MAC +# elif defined( __FreeBSD__ ) +# define ITT_OS ITT_OS_FREEBSD +# else +# define ITT_OS ITT_OS_LINUX +# endif +#endif /* ITT_OS */ + +#ifndef ITT_PLATFORM_WIN +# define ITT_PLATFORM_WIN 1 +#endif /* ITT_PLATFORM_WIN */ + +#ifndef ITT_PLATFORM_POSIX +# define ITT_PLATFORM_POSIX 2 +#endif /* ITT_PLATFORM_POSIX */ + +#ifndef ITT_PLATFORM_MAC +# define ITT_PLATFORM_MAC 3 +#endif /* ITT_PLATFORM_MAC */ + +#ifndef ITT_PLATFORM_FREEBSD +# define ITT_PLATFORM_FREEBSD 4 +#endif /* ITT_PLATFORM_FREEBSD */ + +#ifndef ITT_PLATFORM +# if ITT_OS==ITT_OS_WIN +# define ITT_PLATFORM ITT_PLATFORM_WIN +# elif ITT_OS==ITT_OS_MAC +# define ITT_PLATFORM ITT_PLATFORM_MAC +# elif ITT_OS==ITT_OS_FREEBSD +# define ITT_PLATFORM ITT_PLATFORM_FREEBSD +# else +# define ITT_PLATFORM ITT_PLATFORM_POSIX +# endif +#endif /* ITT_PLATFORM */ + +#if defined(_UNICODE) && !defined(UNICODE) +#define UNICODE +#endif + +#include +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#include +#if defined(UNICODE) || defined(_UNICODE) +#include +#endif /* UNICODE || _UNICODE */ +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +#ifndef ITTAPI_CDECL +# if ITT_PLATFORM==ITT_PLATFORM_WIN +# define ITTAPI_CDECL __cdecl +# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# if defined _M_IX86 || defined __i386__ +# define ITTAPI_CDECL __attribute__ ((cdecl)) +# else /* _M_IX86 || __i386__ */ +# define ITTAPI_CDECL /* actual only on x86 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* ITTAPI_CDECL */ + +#ifndef STDCALL +# if ITT_PLATFORM==ITT_PLATFORM_WIN +# define STDCALL __stdcall +# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# if defined _M_IX86 || defined __i386__ +# define STDCALL __attribute__ ((stdcall)) +# else /* _M_IX86 || __i386__ */ +# define STDCALL /* supported only on x86 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* STDCALL */ + +#define ITTAPI ITTAPI_CDECL +#define LIBITTAPI ITTAPI_CDECL + +/* TODO: Temporary for compatibility! */ +#define ITTAPI_CALL ITTAPI_CDECL +#define LIBITTAPI_CALL ITTAPI_CDECL + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +/* use __forceinline (VC++ specific) */ +#define ITT_INLINE __forceinline +#define ITT_INLINE_ATTRIBUTE /* nothing */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +/* + * Generally, functions are not inlined unless optimization is specified. + * For functions declared inline, this attribute inlines the function even + * if no optimization level was specified. + */ +#ifdef __STRICT_ANSI__ +#define ITT_INLINE static +#define ITT_INLINE_ATTRIBUTE __attribute__((unused)) +#else /* __STRICT_ANSI__ */ +#define ITT_INLINE static inline +#define ITT_INLINE_ATTRIBUTE __attribute__((always_inline, unused)) +#endif /* __STRICT_ANSI__ */ +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +/** @endcond */ + +#ifndef ITT_ARCH_IA32 +# define ITT_ARCH_IA32 1 +#endif /* ITT_ARCH_IA32 */ + +#ifndef ITT_ARCH_IA32E +# define ITT_ARCH_IA32E 2 +#endif /* ITT_ARCH_IA32E */ + +#ifndef ITT_ARCH_ARM +# define ITT_ARCH_ARM 4 +#endif /* ITT_ARCH_ARM */ + +#ifndef ITT_ARCH_PPC64 +# define ITT_ARCH_PPC64 5 +#endif /* ITT_ARCH_PPC64 */ + +#ifndef ITT_ARCH +# if defined _M_IX86 || defined __i386__ +# define ITT_ARCH ITT_ARCH_IA32 +# elif defined _M_X64 || defined _M_AMD64 || defined __x86_64__ +# define ITT_ARCH ITT_ARCH_IA32E +# elif defined _M_IA64 || defined __ia64__ +# define ITT_ARCH ITT_ARCH_IA64 +# elif defined _M_ARM || defined __arm__ +# define ITT_ARCH ITT_ARCH_ARM +# elif defined __powerpc64__ +# define ITT_ARCH ITT_ARCH_PPC64 +# endif +#endif + +#ifdef __cplusplus +# define ITT_EXTERN_C extern "C" +# define ITT_EXTERN_C_BEGIN extern "C" { +# define ITT_EXTERN_C_END } +#else +# define ITT_EXTERN_C /* nothing */ +# define ITT_EXTERN_C_BEGIN /* nothing */ +# define ITT_EXTERN_C_END /* nothing */ +#endif /* __cplusplus */ + +#define ITT_TO_STR_AUX(x) #x +#define ITT_TO_STR(x) ITT_TO_STR_AUX(x) + +#define __ITT_BUILD_ASSERT(expr, suffix) do { \ + static char __itt_build_check_##suffix[(expr) ? 1 : -1]; \ + __itt_build_check_##suffix[0] = 0; \ +} while(0) +#define _ITT_BUILD_ASSERT(expr, suffix) __ITT_BUILD_ASSERT((expr), suffix) +#define ITT_BUILD_ASSERT(expr) _ITT_BUILD_ASSERT((expr), __LINE__) + +#define ITT_MAGIC { 0xED, 0xAB, 0xAB, 0xEC, 0x0D, 0xEE, 0xDA, 0x30 } + +/* Replace with snapshot date YYYYMMDD for promotion build. */ +#define API_VERSION_BUILD 20151119 + +#ifndef API_VERSION_NUM +#define API_VERSION_NUM 0.0.0 +#endif /* API_VERSION_NUM */ + +#define API_VERSION "ITT-API-Version " ITT_TO_STR(API_VERSION_NUM) \ + " (" ITT_TO_STR(API_VERSION_BUILD) ")" + +/* OS communication functions */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include +typedef HMODULE lib_t; +typedef DWORD TIDT; +typedef CRITICAL_SECTION mutex_t; +#define MUTEX_INITIALIZER { 0 } +#define strong_alias(name, aliasname) /* empty for Windows */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#include +#if defined(UNICODE) || defined(_UNICODE) +#include +#endif /* UNICODE */ +#ifndef _GNU_SOURCE +#define _GNU_SOURCE 1 /* need for PTHREAD_MUTEX_RECURSIVE */ +#endif /* _GNU_SOURCE */ +#ifndef __USE_UNIX98 +#define __USE_UNIX98 1 /* need for PTHREAD_MUTEX_RECURSIVE, on SLES11.1 with gcc 4.3.4 wherein pthread.h missing dependency on __USE_XOPEN2K8 */ +#endif /*__USE_UNIX98*/ +#include +typedef void* lib_t; +typedef pthread_t TIDT; +typedef pthread_mutex_t mutex_t; +#define MUTEX_INITIALIZER PTHREAD_MUTEX_INITIALIZER +#define _strong_alias(name, aliasname) \ + extern __typeof (name) aliasname __attribute__ ((alias (#name))); +#define strong_alias(name, aliasname) _strong_alias(name, aliasname) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_get_proc(lib, name) GetProcAddress(lib, name) +#define __itt_mutex_init(mutex) InitializeCriticalSection(mutex) +#define __itt_mutex_lock(mutex) EnterCriticalSection(mutex) +#define __itt_mutex_unlock(mutex) LeaveCriticalSection(mutex) +#define __itt_load_lib(name) LoadLibraryA(name) +#define __itt_unload_lib(handle) FreeLibrary(handle) +#define __itt_system_error() (int)GetLastError() +#define __itt_fstrcmp(s1, s2) lstrcmpA(s1, s2) +#define __itt_fstrnlen(s, l) strnlen_s(s, l) +#define __itt_fstrcpyn(s1, b, s2, l) strncpy_s(s1, b, s2, l) +#define __itt_fstrdup(s) _strdup(s) +#define __itt_thread_id() GetCurrentThreadId() +#define __itt_thread_yield() SwitchToThread() +#ifndef ITT_SIMPLE_INIT +ITT_INLINE long +__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE; +ITT_INLINE long __itt_interlocked_increment(volatile long* ptr) +{ + return InterlockedIncrement(ptr); +} +#endif /* ITT_SIMPLE_INIT */ + +#define DL_SYMBOLS (1) +#define PTHREAD_SYMBOLS (1) + +#else /* ITT_PLATFORM!=ITT_PLATFORM_WIN */ +#define __itt_get_proc(lib, name) dlsym(lib, name) +#define __itt_mutex_init(mutex) {\ + pthread_mutexattr_t mutex_attr; \ + int error_code = pthread_mutexattr_init(&mutex_attr); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutexattr_init", \ + error_code); \ + error_code = pthread_mutexattr_settype(&mutex_attr, \ + PTHREAD_MUTEX_RECURSIVE); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutexattr_settype", \ + error_code); \ + error_code = pthread_mutex_init(mutex, &mutex_attr); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutex_init", \ + error_code); \ + error_code = pthread_mutexattr_destroy(&mutex_attr); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutexattr_destroy", \ + error_code); \ +} +#define __itt_mutex_lock(mutex) pthread_mutex_lock(mutex) +#define __itt_mutex_unlock(mutex) pthread_mutex_unlock(mutex) +#define __itt_load_lib(name) dlopen(name, RTLD_LAZY) +#define __itt_unload_lib(handle) dlclose(handle) +#define __itt_system_error() errno +#define __itt_fstrcmp(s1, s2) strcmp(s1, s2) + +/* makes customer code define safe APIs for SDL_STRNLEN_S and SDL_STRNCPY_S */ +#ifdef SDL_STRNLEN_S +#define __itt_fstrnlen(s, l) SDL_STRNLEN_S(s, l) +#else +#define __itt_fstrnlen(s, l) strlen(s) +#endif /* SDL_STRNLEN_S */ +#ifdef SDL_STRNCPY_S +#define __itt_fstrcpyn(s1, b, s2, l) SDL_STRNCPY_S(s1, b, s2, l) +#else +#define __itt_fstrcpyn(s1, b, s2, l) strncpy(s1, s2, l) +#endif /* SDL_STRNCPY_S */ + +#define __itt_fstrdup(s) strdup(s) +#define __itt_thread_id() pthread_self() +#define __itt_thread_yield() sched_yield() +#if ITT_ARCH==ITT_ARCH_IA64 +#ifdef __INTEL_COMPILER +#define __TBB_machine_fetchadd4(addr, val) __fetchadd4_acq((void *)addr, val) +#else /* __INTEL_COMPILER */ +/* TODO: Add Support for not Intel compilers for IA-64 architecture */ +#endif /* __INTEL_COMPILER */ +#elif ITT_ARCH==ITT_ARCH_IA32 || ITT_ARCH==ITT_ARCH_IA32E /* ITT_ARCH!=ITT_ARCH_IA64 */ +ITT_INLINE long +__TBB_machine_fetchadd4(volatile void* ptr, long addend) ITT_INLINE_ATTRIBUTE; +ITT_INLINE long __TBB_machine_fetchadd4(volatile void* ptr, long addend) +{ + long result; + __asm__ __volatile__("lock\nxadd %0,%1" + : "=r"(result),"=m"(*(int*)ptr) + : "0"(addend), "m"(*(int*)ptr) + : "memory"); + return result; +} +#elif ITT_ARCH==ITT_ARCH_ARM || ITT_ARCH==ITT_ARCH_PPC64 +#define __TBB_machine_fetchadd4(addr, val) __sync_fetch_and_add(addr, val) +#endif /* ITT_ARCH==ITT_ARCH_IA64 */ +#ifndef ITT_SIMPLE_INIT +ITT_INLINE long +__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE; +ITT_INLINE long __itt_interlocked_increment(volatile long* ptr) +{ + return __TBB_machine_fetchadd4(ptr, 1) + 1L; +} +#endif /* ITT_SIMPLE_INIT */ + +void* dlopen(const char*, int) __attribute__((weak)); +void* dlsym(void*, const char*) __attribute__((weak)); +int dlclose(void*) __attribute__((weak)); +#define DL_SYMBOLS (dlopen && dlsym && dlclose) + +int pthread_mutex_init(pthread_mutex_t*, const pthread_mutexattr_t*) __attribute__((weak)); +int pthread_mutex_lock(pthread_mutex_t*) __attribute__((weak)); +int pthread_mutex_unlock(pthread_mutex_t*) __attribute__((weak)); +int pthread_mutex_destroy(pthread_mutex_t*) __attribute__((weak)); +int pthread_mutexattr_init(pthread_mutexattr_t*) __attribute__((weak)); +int pthread_mutexattr_settype(pthread_mutexattr_t*, int) __attribute__((weak)); +int pthread_mutexattr_destroy(pthread_mutexattr_t*) __attribute__((weak)); +pthread_t pthread_self(void) __attribute__((weak)); +#define PTHREAD_SYMBOLS (pthread_mutex_init && pthread_mutex_lock && pthread_mutex_unlock && pthread_mutex_destroy && pthread_mutexattr_init && pthread_mutexattr_settype && pthread_mutexattr_destroy && pthread_self) + +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +typedef enum { + __itt_collection_normal = 0, + __itt_collection_paused = 1 +} __itt_collection_state; + +typedef enum { + __itt_thread_normal = 0, + __itt_thread_ignored = 1 +} __itt_thread_state; + +#pragma pack(push, 8) + +typedef struct ___itt_thread_info +{ + const char* nameA; /*!< Copy of original name in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* nameW; /*!< Copy of original name in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* nameW; +#endif /* UNICODE || _UNICODE */ + TIDT tid; + __itt_thread_state state; /*!< Thread state (paused or normal) */ + int extra1; /*!< Reserved to the runtime */ + void* extra2; /*!< Reserved to the runtime */ + struct ___itt_thread_info* next; +} __itt_thread_info; + +#include "ittnotify_types.h" /* For __itt_group_id definition */ + +typedef struct ___itt_api_info_20101001 +{ + const char* name; + void** func_ptr; + void* init_func; + __itt_group_id group; +} __itt_api_info_20101001; + +typedef struct ___itt_api_info +{ + const char* name; + void** func_ptr; + void* init_func; + void* null_func; + __itt_group_id group; +} __itt_api_info; + +typedef struct __itt_counter_info +{ + const char* nameA; /*!< Copy of original name in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* nameW; /*!< Copy of original name in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* nameW; +#endif /* UNICODE || _UNICODE */ + const char* domainA; /*!< Copy of original name in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* domainW; /*!< Copy of original name in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* domainW; +#endif /* UNICODE || _UNICODE */ + int type; + long index; + int extra1; /*!< Reserved to the runtime */ + void* extra2; /*!< Reserved to the runtime */ + struct __itt_counter_info* next; +} __itt_counter_info_t; + +struct ___itt_domain; +struct ___itt_string_handle; + +typedef struct ___itt_global +{ + unsigned char magic[8]; + unsigned long version_major; + unsigned long version_minor; + unsigned long version_build; + volatile long api_initialized; + volatile long mutex_initialized; + volatile long atomic_counter; + mutex_t mutex; + lib_t lib; + void* error_handler; + const char** dll_path_ptr; + __itt_api_info* api_list_ptr; + struct ___itt_global* next; + /* Joinable structures below */ + __itt_thread_info* thread_list; + struct ___itt_domain* domain_list; + struct ___itt_string_handle* string_list; + __itt_collection_state state; + __itt_counter_info_t* counter_list; +} __itt_global; + +#pragma pack(pop) + +#define NEW_THREAD_INFO_W(gptr,h,h_tail,t,s,n) { \ + h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \ + if (h != NULL) { \ + h->tid = t; \ + h->nameA = NULL; \ + h->nameW = n ? _wcsdup(n) : NULL; \ + h->state = s; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->thread_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_THREAD_INFO_A(gptr,h,h_tail,t,s,n) { \ + h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \ + if (h != NULL) { \ + h->tid = t; \ + h->nameA = n ? __itt_fstrdup(n) : NULL; \ + h->nameW = NULL; \ + h->state = s; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->thread_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_DOMAIN_W(gptr,h,h_tail,name) { \ + h = (__itt_domain*)malloc(sizeof(__itt_domain)); \ + if (h != NULL) { \ + h->flags = 1; /* domain is enabled by default */ \ + h->nameA = NULL; \ + h->nameW = name ? _wcsdup(name) : NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->domain_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_DOMAIN_A(gptr,h,h_tail,name) { \ + h = (__itt_domain*)malloc(sizeof(__itt_domain)); \ + if (h != NULL) { \ + h->flags = 1; /* domain is enabled by default */ \ + h->nameA = name ? __itt_fstrdup(name) : NULL; \ + h->nameW = NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->domain_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_STRING_HANDLE_W(gptr,h,h_tail,name) { \ + h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \ + if (h != NULL) { \ + h->strA = NULL; \ + h->strW = name ? _wcsdup(name) : NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->string_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_STRING_HANDLE_A(gptr,h,h_tail,name) { \ + h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \ + if (h != NULL) { \ + h->strA = name ? __itt_fstrdup(name) : NULL; \ + h->strW = NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->string_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_COUNTER_W(gptr,h,h_tail,name,domain,type) { \ + h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \ + if (h != NULL) { \ + h->nameA = NULL; \ + h->nameW = name ? _wcsdup(name) : NULL; \ + h->domainA = NULL; \ + h->domainW = name ? _wcsdup(domain) : NULL; \ + h->type = type; \ + h->index = 0; \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->counter_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_COUNTER_A(gptr,h,h_tail,name,domain,type) { \ + h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \ + if (h != NULL) { \ + h->nameA = name ? __itt_fstrdup(name) : NULL; \ + h->nameW = NULL; \ + h->domainA = domain ? __itt_fstrdup(domain) : NULL; \ + h->domainW = NULL; \ + h->type = type; \ + h->index = 0; \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->counter_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#endif /* _ITTNOTIFY_CONFIG_H_ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h new file mode 100644 index 000000000000..99fbc240547f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h @@ -0,0 +1,94 @@ +/* + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _ITTNOTIFY_TYPES_H_ +#define _ITTNOTIFY_TYPES_H_ + +typedef enum ___itt_group_id +{ + __itt_group_none = 0, + __itt_group_legacy = 1<<0, + __itt_group_control = 1<<1, + __itt_group_thread = 1<<2, + __itt_group_mark = 1<<3, + __itt_group_sync = 1<<4, + __itt_group_fsync = 1<<5, + __itt_group_jit = 1<<6, + __itt_group_model = 1<<7, + __itt_group_splitter_min = 1<<7, + __itt_group_counter = 1<<8, + __itt_group_frame = 1<<9, + __itt_group_stitch = 1<<10, + __itt_group_heap = 1<<11, + __itt_group_splitter_max = 1<<12, + __itt_group_structure = 1<<12, + __itt_group_suppress = 1<<13, + __itt_group_arrays = 1<<14, + __itt_group_all = -1 +} __itt_group_id; + +#pragma pack(push, 8) + +typedef struct ___itt_group_list +{ + __itt_group_id id; + const char* name; +} __itt_group_list; + +#pragma pack(pop) + +#define ITT_GROUP_LIST(varname) \ + static __itt_group_list varname[] = { \ + { __itt_group_all, "all" }, \ + { __itt_group_control, "control" }, \ + { __itt_group_thread, "thread" }, \ + { __itt_group_mark, "mark" }, \ + { __itt_group_sync, "sync" }, \ + { __itt_group_fsync, "fsync" }, \ + { __itt_group_jit, "jit" }, \ + { __itt_group_model, "model" }, \ + { __itt_group_counter, "counter" }, \ + { __itt_group_frame, "frame" }, \ + { __itt_group_stitch, "stitch" }, \ + { __itt_group_heap, "heap" }, \ + { __itt_group_structure, "structure" }, \ + { __itt_group_suppress, "suppress" }, \ + { __itt_group_arrays, "arrays" }, \ + { __itt_group_none, NULL } \ + } + +#endif /* _ITTNOTIFY_TYPES_H_ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c new file mode 100644 index 000000000000..15f4b9929baa --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c @@ -0,0 +1,293 @@ +/* + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#include "ittnotify_config.h" + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM != ITT_PLATFORM_MAC && ITT_PLATFORM != ITT_PLATFORM_FREEBSD +#include +#endif +#include + +#include "jitprofiling.h" + +static const char rcsid[] = "\n@(#) $Revision: 471937 $\n"; + +#define DLL_ENVIRONMENT_VAR "VS_PROFILER" + +#ifndef NEW_DLL_ENVIRONMENT_VAR +#if ITT_ARCH==ITT_ARCH_IA32 +#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER32" +#else +#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER64" +#endif +#endif /* NEW_DLL_ENVIRONMENT_VAR */ + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define DEFAULT_DLLNAME "JitPI.dll" +HINSTANCE m_libHandle = NULL; +#elif ITT_PLATFORM==ITT_PLATFORM_MAC +#define DEFAULT_DLLNAME "libJitPI.dylib" +void* m_libHandle = NULL; +#else +#define DEFAULT_DLLNAME "libJitPI.so" +void* m_libHandle = NULL; +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/* default location of JIT profiling agent on Android */ +#define ANDROID_JIT_AGENT_PATH "/data/intel/libittnotify.so" + +/* the function pointers */ +typedef unsigned int(JITAPI *TPInitialize)(void); +static TPInitialize FUNC_Initialize=NULL; + +typedef unsigned int(JITAPI *TPNotify)(unsigned int, void*); +static TPNotify FUNC_NotifyEvent=NULL; + +static iJIT_IsProfilingActiveFlags executionMode = iJIT_NOTHING_RUNNING; + +/* end collector dll part. */ + +/* loadiJIT_Funcs() : this function is called just in the beginning + * and is responsible to load the functions from BistroJavaCollector.dll + * result: + * on success: the functions loads, iJIT_DLL_is_missing=0, return value = 1 + * on failure: the functions are NULL, iJIT_DLL_is_missing=1, return value = 0 + */ +static int loadiJIT_Funcs(void); + +/* global representing whether the collector can't be loaded */ +static int iJIT_DLL_is_missing = 0; + +ITT_EXTERN_C int JITAPI +iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData) +{ + int ReturnValue = 0; + + /* initialization part - the collector has not been loaded yet. */ + if (!FUNC_NotifyEvent) + { + if (iJIT_DLL_is_missing) + return 0; + + if (!loadiJIT_Funcs()) + return 0; + } + + if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED || + event_type == iJVM_EVENT_TYPE_METHOD_UPDATE) + { + if (((piJIT_Method_Load)EventSpecificData)->method_id == 0) + return 0; + } + else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2) + { + if (((piJIT_Method_Load_V2)EventSpecificData)->method_id == 0) + return 0; + } + else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3) + { + if (((piJIT_Method_Load_V3)EventSpecificData)->method_id == 0) + return 0; + } + else if (event_type == iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED) + { + if (((piJIT_Method_Inline_Load)EventSpecificData)->method_id == 0 || + ((piJIT_Method_Inline_Load)EventSpecificData)->parent_method_id == 0) + return 0; + } + + ReturnValue = (int)FUNC_NotifyEvent(event_type, EventSpecificData); + + return ReturnValue; +} + +ITT_EXTERN_C iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive() +{ + if (!iJIT_DLL_is_missing) + { + loadiJIT_Funcs(); + } + + return executionMode; +} + +/* This function loads the collector dll and the relevant functions. + * on success: all functions load, iJIT_DLL_is_missing = 0, return value = 1 + * on failure: all functions are NULL, iJIT_DLL_is_missing = 1, return value = 0 + */ +static int loadiJIT_Funcs() +{ + static int bDllWasLoaded = 0; + char *dllName = (char*)rcsid; /* !! Just to avoid unused code elimination */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN + DWORD dNameLength = 0; +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + + if(bDllWasLoaded) + { + /* dll was already loaded, no need to do it for the second time */ + return 1; + } + + /* Assumes that the DLL will not be found */ + iJIT_DLL_is_missing = 1; + FUNC_NotifyEvent = NULL; + + if (m_libHandle) + { +#if ITT_PLATFORM==ITT_PLATFORM_WIN + FreeLibrary(m_libHandle); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + dlclose(m_libHandle); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + m_libHandle = NULL; + } + + /* Try to get the dll name from the environment */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN + dNameLength = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR, NULL, 0); + if (dNameLength) + { + DWORD envret = 0; + dllName = (char*)malloc(sizeof(char) * (dNameLength + 1)); + if(dllName != NULL) + { + envret = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR, + dllName, dNameLength); + if (envret) + { + /* Try to load the dll from the PATH... */ + m_libHandle = LoadLibraryExA(dllName, + NULL, LOAD_WITH_ALTERED_SEARCH_PATH); + } + free(dllName); + } + } else { + /* Try to use old VS_PROFILER variable */ + dNameLength = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR, NULL, 0); + if (dNameLength) + { + DWORD envret = 0; + dllName = (char*)malloc(sizeof(char) * (dNameLength + 1)); + if(dllName != NULL) + { + envret = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR, + dllName, dNameLength); + if (envret) + { + /* Try to load the dll from the PATH... */ + m_libHandle = LoadLibraryA(dllName); + } + free(dllName); + } + } + } +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + dllName = getenv(NEW_DLL_ENVIRONMENT_VAR); + if (!dllName) + dllName = getenv(DLL_ENVIRONMENT_VAR); +#if defined(__ANDROID__) || defined(ANDROID) + if (!dllName) + dllName = ANDROID_JIT_AGENT_PATH; +#endif + if (dllName) + { + /* Try to load the dll from the PATH... */ + m_libHandle = dlopen(dllName, RTLD_LAZY); + } +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + + if (!m_libHandle) + { +#if ITT_PLATFORM==ITT_PLATFORM_WIN + m_libHandle = LoadLibraryA(DEFAULT_DLLNAME); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + m_libHandle = dlopen(DEFAULT_DLLNAME, RTLD_LAZY); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + } + + /* if the dll wasn't loaded - exit. */ + if (!m_libHandle) + { + iJIT_DLL_is_missing = 1; /* don't try to initialize + * JIT agent the second time + */ + return 0; + } + +#if ITT_PLATFORM==ITT_PLATFORM_WIN + FUNC_NotifyEvent = (TPNotify)GetProcAddress(m_libHandle, "NotifyEvent"); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + FUNC_NotifyEvent = (TPNotify)dlsym(m_libHandle, "NotifyEvent"); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + if (!FUNC_NotifyEvent) + { + FUNC_Initialize = NULL; + return 0; + } + +#if ITT_PLATFORM==ITT_PLATFORM_WIN + FUNC_Initialize = (TPInitialize)GetProcAddress(m_libHandle, "Initialize"); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + FUNC_Initialize = (TPInitialize)dlsym(m_libHandle, "Initialize"); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + if (!FUNC_Initialize) + { + FUNC_NotifyEvent = NULL; + return 0; + } + + executionMode = (iJIT_IsProfilingActiveFlags)FUNC_Initialize(); + + bDllWasLoaded = 1; + iJIT_DLL_is_missing = 0; /* DLL is ok. */ + + return 1; +} + +ITT_EXTERN_C unsigned int JITAPI iJIT_GetNewMethodID() +{ + static unsigned int methodID = 1; + + if (methodID == 0) + return 0; /* ERROR : this is not a valid value */ + + return methodID++; +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h new file mode 100644 index 000000000000..bf0489b1a188 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h @@ -0,0 +1,673 @@ +/* + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef __JITPROFILING_H__ +#define __JITPROFILING_H__ + +/** + * @brief JIT Profiling APIs + * + * The JIT Profiling API is used to report information about just-in-time + * generated code that can be used by performance tools. The user inserts + * calls in the code generator to report information before JIT-compiled + * code goes to execution. This information is collected at runtime and used + * by tools like Intel(R) VTune(TM) Amplifier to display performance metrics + * associated with JIT-compiled code. + * + * These APIs can be used to\n + * - **Profile trace-based and method-based JIT-compiled + * code**. Some examples of environments that you can profile with these APIs: + * dynamic JIT compilation of JavaScript code traces, JIT execution in OpenCL(TM) + * software technology, Java/.NET managed execution environments, and custom + * ISV JIT engines. + * @code + * #include + * + * if (iJIT_IsProfilingActive != iJIT_SAMPLING_ON) { + * return; + * } + * + * iJIT_Method_Load jmethod = {0}; + * jmethod.method_id = iJIT_GetNewMethodID(); + * jmethod.method_name = "method_name"; + * jmethod.class_file_name = "class_name"; + * jmethod.source_file_name = "source_file_name"; + * jmethod.method_load_address = code_addr; + * jmethod.method_size = code_size; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&jmethod); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_SHUTDOWN, NULL); + * @endcode + * + * * Expected behavior: + * * If any iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an + * already reported method, then such a method becomes invalid and its + * memory region is treated as unloaded. VTune Amplifier displays the metrics + * collected by the method until it is overwritten. + * * If supplied line number information contains multiple source lines for + * the same assembly instruction (code location), then VTune Amplifier picks up + * the first line number. + * * Dynamically generated code can be associated with a module name. + * Use the iJIT_Method_Load_V2 structure.\n + * Clarification of some cases: + * * If you register a function with the same method ID multiple times, + * specifying different module names, then the VTune Amplifier picks up + * the module name registered first. If you want to distinguish the same + * function between different JIT engines, supply different method IDs for + * each function. Other symbolic information (for example, source file) + * can be identical. + * + * - **Analyze split functions** (multiple joint or disjoint code regions + * belonging to the same function) **including re-JIT** + * with potential overlapping of code regions in time, which is common in + * resource-limited environments. + * @code + * #include + * + * unsigned int method_id = iJIT_GetNewMethodID(); + * + * iJIT_Method_Load a = {0}; + * a.method_id = method_id; + * a.method_load_address = 0x100; + * a.method_size = 0x20; + * + * iJIT_Method_Load b = {0}; + * b.method_id = method_id; + * b.method_load_address = 0x200; + * b.method_size = 0x30; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&b); + * @endcode + * + * * Expected behaviour: + * * If a iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an + * already reported method, then such a method becomes invalid and + * its memory region is treated as unloaded. + * * All code regions reported with the same method ID are considered as + * belonging to the same method. Symbolic information (method name, + * source file name) will be taken from the first notification, and all + * subsequent notifications with the same method ID will be processed + * only for line number table information. So, the VTune Amplifier will map + * samples to a source line using the line number table from the current + * notification while taking the source file name from the very first one.\n + * Clarification of some cases:\n + * * If you register a second code region with a different source file + * name and the same method ID, then this information will be saved and + * will not be considered as an extension of the first code region, but + * VTune Amplifier will use the source file of the first code region and map + * performance metrics incorrectly. + * * If you register a second code region with the same source file as + * for the first region and the same method ID, then the source file will be + * discarded but VTune Amplifier will map metrics to the source file correctly. + * * If you register a second code region with a null source file and + * the same method ID, then provided line number info will be associated + * with the source file of the first code region. + * + * - **Explore inline functions** including multi-level hierarchy of + * nested inline methods which shows how performance metrics are distributed through them. + * @code + * #include + * + * // method_id parent_id + * // [-- c --] 3000 2000 + * // [---- d -----] 2001 1000 + * // [---- b ----] 2000 1000 + * // [------------ a ----------------] 1000 n/a + * + * iJIT_Method_Load a = {0}; + * a.method_id = 1000; + * + * iJIT_Method_Inline_Load b = {0}; + * b.method_id = 2000; + * b.parent_method_id = 1000; + * + * iJIT_Method_Inline_Load c = {0}; + * c.method_id = 3000; + * c.parent_method_id = 2000; + * + * iJIT_Method_Inline_Load d = {0}; + * d.method_id = 2001; + * d.parent_method_id = 1000; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&b); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&c); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&d); + * @endcode + * + * * Requirements: + * * Each inline (iJIT_Method_Inline_Load) method should be associated + * with two method IDs: one for itself; one for its immediate parent. + * * Address regions of inline methods of the same parent method cannot + * overlap each other. + * * Execution of the parent method must not be started until it and all + * its inline methods are reported. + * * Expected behaviour: + * * In case of nested inline methods an order of + * iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED events is not important. + * * If any event overwrites either inline method or top parent method, + * then the parent, including inline methods, becomes invalid and its memory + * region is treated as unloaded. + * + * **Life time of allocated data**\n + * The client sends an event notification to the agent with event-specific + * data, which is a structure. The pointers in the structure refer to memory + * allocated by the client, which responsible for releasing it. The pointers are + * used by the iJIT_NotifyEvent method to copy client's data in a trace file, + * and they are not used after the iJIT_NotifyEvent method returns. + */ + +/** + * @defgroup jitapi JIT Profiling + * @ingroup internal + * @{ + */ + +/** + * @brief Enumerator for the types of notifications + */ +typedef enum iJIT_jvm_event +{ + iJVM_EVENT_TYPE_SHUTDOWN = 2, /**<\brief Send this to shutdown the agent. + * Use NULL for event data. */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED = 13, /**<\brief Send when dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load as event + * data. */ +/** @cond exclude_from_documentation */ + iJVM_EVENT_TYPE_METHOD_UNLOAD_START, /**<\brief Send when compiled dynamic + * code is being unloaded from memory. + * Use iJIT_Method_Load as event data.*/ +/** @endcond */ + + iJVM_EVENT_TYPE_METHOD_UPDATE, /**<\brief Send to provide new content for + * a previously reported dynamic code. + * The previous content will be invalidated + * starting from the time of the notification. + * Use iJIT_Method_Load as event data but + * required fields are following: + * - method_id identify the code to update. + * - method_load_address specify start address + * within identified code range + * where update should be started. + * - method_size specify length of updated code + * range. */ + + + iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, /**<\brief Send when an inline dynamic + * code is JIT compiled and loaded + * into memory by the JIT engine, + * but before the parent code region + * starts executing. + * Use iJIT_Method_Inline_Load as event data.*/ + +/** @cond exclude_from_documentation */ + iJVM_EVENT_TYPE_METHOD_UPDATE_V2, +/** @endcond */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 = 21, /**<\brief Send when a dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load_V2 as event data. */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 /**<\brief Send when a dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load_V3 as event data. */ +} iJIT_JVM_EVENT; + +/** + * @brief Enumerator for the agent's mode + */ +typedef enum _iJIT_IsProfilingActiveFlags +{ + iJIT_NOTHING_RUNNING = 0x0000, /**<\brief The agent is not running; + * iJIT_NotifyEvent calls will + * not be processed. */ + iJIT_SAMPLING_ON = 0x0001, /**<\brief The agent is running and + * ready to process notifications. */ +} iJIT_IsProfilingActiveFlags; + +/** + * @brief Description of a single entry in the line number information of a code region. + * @details A table of line number entries gives information about how the reported code region + * is mapped to source file. + * Intel(R) VTune(TM) Amplifier uses line number information to attribute + * the samples (virtual address) to a line number. \n + * It is acceptable to report different code addresses for the same source line: + * @code + * Offset LineNumber + * 1 2 + * 12 4 + * 15 2 + * 18 1 + * 21 30 + * + * VTune Amplifier constructs the following table using the client data + * + * Code subrange Line number + * 0-1 2 + * 1-12 4 + * 12-15 2 + * 15-18 1 + * 18-21 30 + * @endcode + */ +typedef struct _LineNumberInfo +{ + unsigned int Offset; /**<\brief Offset from the begining of the code region. */ + unsigned int LineNumber; /**<\brief Matching source line number offset (from beginning of source file). */ + +} *pLineNumberInfo, LineNumberInfo; + +/** + * @brief Enumerator for the code architecture. + */ +typedef enum _iJIT_CodeArchitecture +{ + iJIT_CA_NATIVE = 0, /**<\brief Native to the process architecture that is calling it. */ + + iJIT_CA_32, /**<\brief 32-bit machine code. */ + + iJIT_CA_64 /**<\brief 64-bit machine code. */ + +} iJIT_CodeArchitecture; + +#pragma pack(push, 8) + +/** + * @brief Description of a JIT-compiled method + * @details When you use the iJIT_Method_Load structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise different + * method IDs specify different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, data provided with + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table.0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array */ + + unsigned int class_id; /**<\brief This field is obsolete. */ + + char* class_file_name; /**<\brief Class name. Can be NULL.*/ + + char* source_file_name; /**<\brief Source file name. Can be NULL.*/ + +} *piJIT_Method_Load, iJIT_Method_Load; + +/** + * @brief Description of a JIT-compiled method + * @details When you use the iJIT_Method_Load_V2 structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load_V2 +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise different + * method IDs specify different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, then data provided with the + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array. */ + + char* class_file_name; /**<\brief Class name. Can be NULL. */ + + char* source_file_name; /**<\brief Source file name. Can be NULL. */ + + char* module_name; /**<\brief Module name. Can be NULL. + The module name can be useful for distinguishing among + different JIT engines. VTune Amplifier will display + reported methods grouped by specific module. */ + +} *piJIT_Method_Load_V2, iJIT_Method_Load_V2; + +/** + * @brief Description of a JIT-compiled method + * @details The iJIT_Method_Load_V3 structure is the same as iJIT_Method_Load_V2 + * with a newly introduced 'arch' field that specifies architecture of the code region. + * When you use the iJIT_Method_Load_V3 structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load_V3 +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise they are + * treated as regions of different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Cannot be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, then data provided with the + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array. */ + + char* class_file_name; /**<\brief Class name. Can be NULL. */ + + char* source_file_name; /**<\brief Source file name. Can be NULL. */ + + char* module_name; /**<\brief Module name. Can be NULL. + * The module name can be useful for distinguishing among + * different JIT engines. VTune Amplifier will display + * reported methods grouped by specific module. */ + + iJIT_CodeArchitecture module_arch; /**<\brief Architecture of the method's code region. + * By default, it is the same as the process + * architecture that is calling it. + * For example, you can use it if your 32-bit JIT + * engine generates 64-bit code. + * + * If JIT engine reports both 32-bit and 64-bit types + * of methods then VTune Amplifier splits the methods + * with the same module name but with different + * architectures in two different modules. VTune Amplifier + * modifies the original name provided with a 64-bit method + * version by ending it with '(64)' */ + +} *piJIT_Method_Load_V3, iJIT_Method_Load_V3; + +/** + * @brief Description of an inline JIT-compiled method + * @details When you use the_iJIT_Method_Inline_Load structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED + * as an event type to report it. + */ +typedef struct _iJIT_Method_Inline_Load +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself. */ + + unsigned int parent_method_id; /**<\brief Unique immediate parent's method ID. + * Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /** <\brief The virtual address on which the method + * is inlined. If NULL, then data provided with + * the event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array */ + + char* class_file_name; /**<\brief Class name. Can be NULL.*/ + + char* source_file_name; /**<\brief Source file name. Can be NULL.*/ + +} *piJIT_Method_Inline_Load, iJIT_Method_Inline_Load; + +/** @cond exclude_from_documentation */ +/** + * @brief Description of a segment type + * @details Use the segment type to specify a type of data supplied + * with the iJVM_EVENT_TYPE_METHOD_UPDATE_V2 event to be applied to + * a certain code trace. + */ +typedef enum _iJIT_SegmentType +{ + iJIT_CT_UNKNOWN = 0, + + iJIT_CT_CODE, /**<\brief Executable code. */ + + iJIT_CT_DATA, /**<\brief Data (not executable code). + * VTune Amplifier uses the format string + * (see iJIT_Method_Update) to represent + * this data in the VTune Amplifier GUI */ + + iJIT_CT_KEEP, /**<\brief Use the previous markup for the trace. + * Can be used for the following + * iJVM_EVENT_TYPE_METHOD_UPDATE_V2 events, + * if the type of the previously reported segment + * type is the same. */ + iJIT_CT_EOF +} iJIT_SegmentType; + +/** + * @brief Description of a dynamic update of the content within JIT-compiled method + * @details The JIT engine may generate the methods that are updated at runtime + * partially by mixed (data + executable code) content. When you use the iJIT_Method_Update + * structure to describe the update of the content within a JIT-compiled method, + * use iJVM_EVENT_TYPE_METHOD_UPDATE_V2 as an event type to report it. + * + * On the first Update event, VTune Amplifier copies the original code range reported by + * the iJVM_EVENT_TYPE_METHOD_LOAD event, then modifies it with the supplied bytes and + * adds the modified range to the original method. For next update events, VTune Amplifier + * does the same but it uses the latest modified version of a code region for update. + * Eventually, VTune Amplifier GUI displays multiple code ranges for the method reported by + * the iJVM_EVENT_TYPE_METHOD_LOAD event. + * Notes: + * - Multiple update events with different types for the same trace are allowed + * but they must be reported for the same code ranges. + * Example, + * @code + * [-- data---] Allowed + * [-- code --] Allowed + * [code] Ignored + * [-- data---] Allowed + * [-- code --] Allowed + * [------------ trace ---------] + * @endcode + * - The types of previously reported events can be changed but they must be reported + * for the same code ranges. + * Example, + * @code + * [-- data---] Allowed + * [-- code --] Allowed + * [-- data---] Allowed + * [-- code --] Allowed + * [------------ trace ---------] + * @endcode + */ + +typedef struct _iJIT_Method_Update +{ + void* load_address; /**<\brief Start address of the update within a method */ + + unsigned int size; /**<\brief The update size */ + + iJIT_SegmentType type; /**<\brief Type of the update */ + + const char* data_format; /**<\brief C string that contains a format string + * that follows the same specifications as format in printf. + * The format string is used for iJIT_CT_CODE only + * and cannot be NULL. + * Format can be changed on the fly. */ +} *piJIT_Method_Update, iJIT_Method_Update; + +/** @endcond */ + +#pragma pack(pop) + +/** @cond exclude_from_documentation */ +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +#ifndef JITAPI_CDECL +# if defined WIN32 || defined _WIN32 +# define JITAPI_CDECL __cdecl +# else /* defined WIN32 || defined _WIN32 */ +# if defined _M_IX86 || defined __i386__ +# define JITAPI_CDECL __attribute__ ((cdecl)) +# else /* _M_IX86 || __i386__ */ +# define JITAPI_CDECL /* actual only on x86_64 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* defined WIN32 || defined _WIN32 */ +#endif /* JITAPI_CDECL */ + +#define JITAPI JITAPI_CDECL +/** @endcond */ + +/** + * @brief Generates a new unique method ID. + * + * You must use this API to obtain unique and valid method IDs for methods or + * traces reported to the agent if you don't have your own mechanism to generate + * unique method IDs. + * + * @return a new unique method ID. When out of unique method IDs, this API + * returns 0, which is not an accepted value. + */ +unsigned int JITAPI iJIT_GetNewMethodID(void); + +/** + * @brief Returns the current mode of the agent. + * + * @return iJIT_SAMPLING_ON, indicating that agent is running, or + * iJIT_NOTHING_RUNNING if no agent is running. + */ +iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive(void); + +/** + * @brief Reports infomation about JIT-compiled code to the agent. + * + * The reported information is used to attribute samples obtained from any + * Intel(R) VTune(TM) Amplifier collector. This API needs to be called + * after JIT compilation and before the first entry into the JIT-compiled + * code. + * + * @param[in] event_type - type of the data sent to the agent + * @param[in] EventSpecificData - pointer to event-specific data + * + * @returns 1 on success, otherwise 0. + */ +int JITAPI iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData); + +#ifdef __cplusplus +} +#endif /* __cplusplus */ +/** @endcond */ + +/** @} jitapi group */ + +#endif /* __JITPROFILING_H__ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp new file mode 100644 index 000000000000..ef4c42bacf00 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp @@ -0,0 +1,317 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" + +#include "nchw_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void nchw_pooling_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper ws_d(pd()->workspace_md()); + const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + auto alg = pd()->desc()->alg_kind; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) { + if (ws) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + size_t ws_offset + = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + + (size_t)OW * OH * od + + (size_t)OW * oh + + (size_t)ow; + if (ws_dt == data_type::u8) { + assert(0 <= value && value <= 255); + ws[ws_offset] = value; + } else + reinterpret_cast(ws)[ws_offset] = value; + } + }; + + auto ker_max = [=](data_t *d, int mb, int c, int od, int oh, int ow) { + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (id < 0 || id >= ID) continue; + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + auto src_offset + = (size_t)IW * IH * ID * C * mb + + (size_t)IW * IH * ID * c + + (size_t)IW * IH * id + + (size_t)IW * ih + + (size_t)iw; + auto s = src[src_offset]; + if (s > d[0]) { + d[0] = s; + set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw); + } + } + } + } + }; + + auto ker_avg = [=](data_t *d, int mb, int c, int od, int oh, int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KD*KW*KH + : (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start); + + for (int id = id_start; id < id_end; ++id) { + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + auto src_offset + = (size_t)IW * IH * ID * C * mb + + (size_t)IW * IH * ID * c + + (size_t)IW * IH * id + + (size_t)IW * ih + + (size_t)iw; + d[0] += src[src_offset]; + } + } + } + + d[0] = math::out_round((float)d[0] / num_summands); + }; + + + if (pd()->desc()->alg_kind == pooling_max) { + parallel_nd(MB, C, OD, OH, OW, + [&](int mb, int c, int od, int oh, int ow) { + size_t dst_offset + = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + + (size_t)OW * OH * od + + (size_t)OW * oh + + (size_t)ow; + data_t *d = &dst[dst_offset]; + d[0] = nstl::numeric_limits::lowest(); + set_ws(mb, c, od, oh, ow, 0); + ker_max(d, mb, c, od, oh, ow); + }); + } else { + parallel_nd(MB, C, OD, OH, OW, + [&](int mb, int c, int od, int oh, int ow) { + size_t dst_offset + = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + + (size_t)OW * OH * od + + (size_t)OW * oh + + (size_t)ow; + data_t *d = &dst[dst_offset]; + d[0] = 0; + ker_avg(d, mb, c, od, oh, ow); + }); + } +} + +template +void nchw_pooling_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper ws_d(pd()->workspace_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; + + auto alg = pd()->desc()->alg_kind; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto ker_zero = [=](int mb, int c) { + size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW; + for (int id = 0; id < ID; ++id) { + for (int ih = 0; ih < IH; ++ih) { + for (int iw = 0; iw < IW; ++iw) { + diff_src[diff_src_offset++] = 0; + } + } + } + }; + + auto ker_max = [=](const data_t *d, int mb, int c, int od, int oh, int ow) { + auto b_c = ws_d.blocking_desc().inner_nblks == 0 + ? 1 : ws_d.blocking_desc().inner_blks[0]; + auto ws_offset = is_3d + ? ws_d.blk_off(mb, c / b_c, od, oh, ow) + c % b_c + : ws_d.blk_off(mb, c / b_c, oh, ow) + c % b_c; + + const int index = ws_d.data_type() == data_type::u8 + ? (int)ws[ws_offset] : ((const int *)ws)[ws_offset]; + const int kw = index % KW; + const int kh = (index / KW) % KH; + const int kd = (index / KW) / KH; + + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + // If padding area could fit the kernel, + // then input displacement would be out of bounds. + // No need to back propagate there as padding is + // virtual in pooling_max case. + if (id < 0 || id >= ID) + return; + if (ih < 0 || ih >= IH) + return; + if (iw < 0 || iw >= IW) + return; + + size_t diff_src_offset = + (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW + (size_t)id*IH*IW + + (size_t)ih*IW + (size_t)iw; + diff_src[diff_src_offset] += d[0]; + }; + + auto ker_avg = [=](const data_t *d, int mb, int c, int od, int oh, int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + size_t num_summands = (alg == pooling_avg_include_padding) + ? (size_t)KW*KH*KD + : (size_t)(id_end - id_start)*(ih_end - ih_start) + *(iw_end - iw_start); + + for (int id = id_start; id < id_end; ++id) { + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + + (size_t)c*ID*IH*IW + (size_t)id*IH*IW + + (size_t)ih*IW + (size_t)iw; + diff_src[diff_src_offset] += d[0] / num_summands; + } + } + } + }; + + if (pd()->desc()->alg_kind == pooling_max) { + parallel_nd(MB, C, [&](int mb, int c) { + size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + + (size_t)c*OD*OH*OW; + ker_zero(mb, c); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = &diff_dst[diff_dst_offset++]; + ker_max(d, mb, c, od, oh, ow); + } + } + } + }); + } else { + parallel_nd(MB, C, [&](int mb, int c) { + size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + + (size_t)c*OD*OH*OW; + ker_zero(mb, c); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = &diff_dst[diff_dst_offset++]; + ker_avg(d, mb, c, od, oh, ow); + } + } + } + }); + } +} + +template struct nchw_pooling_fwd_t; +template struct nchw_pooling_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp new file mode 100644 index 000000000000..bbdd04f6b924 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp @@ -0,0 +1,147 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NCHW_POOLING_HPP +#define CPU_NCHW_POOLING_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct nchw_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T("nchw_pooling:any", nchw_pooling_fwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && !has_zero_dim_memory() + && utils::everyone_is(data_type, src_md()->data_type, + dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && memory_desc_matches_tag(*dst_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return status::success; + } + }; + + nchw_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct nchw_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T("nchw:any", nchw_pooling_bwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; + + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && !has_zero_dim_memory() + && utils::everyone_is(data_type, + diff_dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + bool ws_ok = true + && hint_fwd_pd_ + && hint_fwd_pd_->workspace_md(); + if (!ws_ok) + return status::unimplemented; + + const auto &ws_blk = + hint_fwd_pd_->workspace_md()->format_desc.blocking; + ws_ok = ws_ok + && ws_blk.inner_nblks < 1 + && IMPLICATION(ws_blk.inner_nblks == 1, + ws_blk.inner_idxs[0] == 1); + if (!ws_ok) + return status::unimplemented; + + ws_md_ = *hint_fwd_pd_->workspace_md(); + } + + return status::success; + } + }; + + nchw_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp new file mode 100644 index 000000000000..c0e93fefe4d6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp @@ -0,0 +1,382 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "cpu_batch_normalization_utils.hpp" +#include "jit_generator.hpp" + +#include "ncsp_batch_normalization.hpp" + +// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases +#if (defined __clang_major__) && (__clang_major__ >= 6) +#define SAFE_TO_USE_OMP_SIMD 0 +#else +#define SAFE_TO_USE_OMP_SIMD 1 +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +void ncsp_batch_normalization_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + const bool calculate_stats = !pd()->stats_is_src(); + const bool save_stats = pd()->is_training(); + const bool is_training = pd()->is_training(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + auto *ws_reduce = scratchpad.get(key_bnorm_reduction); + + data_t *mean, *variance; + if (!calculate_stats) { + mean = const_cast( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)); + variance = const_cast( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)); + } else { + if (save_stats) { + mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); + variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); + } else { + mean = scratchpad.get(key_bnorm_tmp_mean); + variance = scratchpad.get(key_bnorm_tmp_var); + } + } + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + const bool with_relu = pd()->with_relu_post_op(); + auto maybe_post_op + = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; }; + const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5); + dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1; + dim_t N = pd()->MB(); + dim_t C = pd()->C(); + + int nthr = mkldnn_get_max_threads(); + size_t l3_size_ = get_cache_size(3, true) * nthr / 2; + size_t data_size = N * C * SP * sizeof(data_t); + bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0); + + parallel(0, [&](const int ithr, const int nthr) { + int C_ithr = 0, C_nthr = 0; + int N_ithr = 0, N_nthr = 0; + int S_ithr = 0, S_nthr = 0; + + dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0; + dim_t N_s = 0, N_e = 0; + dim_t S_s = 0, S_e = 0; + + dim_t C_blks_per_iter = 1; + int64_t iters = 1; + + if (do_blocking) { + size_t working_set_size = N * SP * sizeof(data_t); + bnorm_utils::cache_balance( + working_set_size, C, C_blks_per_iter, iters); + } else + C_blks_per_iter = C; + int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter; + bool spatial_thr_allowed + = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N, + C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e, + N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e); + balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + int SP_N_ithr = N_ithr * S_nthr + S_ithr; + int SP_N_nthr = N_nthr * S_nthr; + for (int64_t it = 0; it < iters; ++it) { + if (it == iters - 1 && iters > 1) { + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). So sync the + // threads if they are not synced by the algorithm. + if (SP_N_nthr == 1 && mkldnn_thr_syncable()) + mkldnn_thr_barrier(); + + S_s = S_e = C_blk_s = C_blk_e = N_s = N_e = 0; + spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, + spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, + C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, + N_e, S_ithr, S_nthr, S_s, S_e); + balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + SP_N_ithr = N_ithr * S_nthr + S_ithr; + SP_N_nthr = N_nthr * S_nthr; + } + size_t C_off = it * C_blks_per_iter; + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). Since sync is not always + // possible (in case of TBB) use different parts of ws for each + // iteration if threads are not synced by the algorithm. + size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * C_off; + + if (calculate_stats) { + data_t *mean_blk = mean + C_off; + data_t *variance_blk = variance + C_off; + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = (c + C_off) * SP; + data_t sum = 0; + for (dim_t n = N_s; n < N_e; ++n) + PRAGMA_OMP_SIMD(reduction(+ : sum)) + for (dim_t sp = S_s; sp < S_e; ++sp) { + sum += src[off + n * C * SP + sp]; + } + ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] + = sum; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { + mean_blk[c] = 0.; + for (dim_t n = 0; n < SP_N_nthr; n++) + mean_blk[c] += ws_reduce[ws_iter_off + + n * C_blks_per_iter + c]; + mean_blk[c] /= (N * SP); + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t sum = 0.; + for (dim_t n = N_s; n < N_e; ++n) + PRAGMA_OMP_SIMD(reduction(+ : sum)) + for (dim_t sp = S_s; sp < S_e; ++sp) { + data_t m = src[off * SP + n * C * SP + sp] + - mean[off]; + sum += m * m; + } + ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] + = sum; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { + variance_blk[c] = 0.; + for (dim_t n = 0; n < SP_N_nthr; n++) + variance_blk[c] += ws_reduce[ws_iter_off + + n * C_blks_per_iter + c]; + variance_blk[c] /= (N * SP); + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + } + + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t sqrt_variance + = static_cast(sqrtf(variance[off] + eps)); + data_t sm = (use_scaleshift ? scaleshift[off] : 1.0f) / sqrt_variance; + data_t sv = use_scaleshift ? scaleshift[C + off] : 0; + for (dim_t n = N_s; n < N_e; ++n) +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t sp = S_s; sp < S_e; ++sp) { + size_t d_off = off * SP + n * C * SP + sp; + data_t bn_res + = sm * (src[d_off] - mean[off]) + sv; + if (fuse_bn_relu) { + if (bn_res <= 0) { + bn_res = 0; + if (is_training) + ws[d_off] = 0; + } else { + if (is_training) + ws[d_off] = 1; + } + } + dst[d_off] = maybe_post_op(bn_res); + } + } + } + }); +} + +void ncsp_batch_normalization_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + auto *ws_reduce = scratchpad.get(key_bnorm_reduction); + + if (diff_scaleshift == nullptr) + diff_scaleshift = scratchpad.get(key_bnorm_tmp_diff_ss); + + const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5); + dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1; + dim_t C = pd()->C(), N = pd()->MB(); + const bool use_scaleshift = pd()->use_scaleshift(); + const float eps = pd()->desc()->batch_norm_epsilon; + const bool calculate_diff_stats = !pd()->use_global_stats(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + int nthr = mkldnn_get_max_threads(); + size_t l3_size_ = get_cache_size(3, true) * nthr / 2; + size_t data_size = N * C * SP * sizeof(data_t); + bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0); + + parallel(0, [&](const int ithr, const int nthr) { + int C_ithr = 0, C_nthr = 0; + int N_ithr = 0, N_nthr = 0; + int S_ithr = 0, S_nthr = 0; + + dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0; + dim_t N_s = 0, N_e = 0; + dim_t S_s = 0, S_e = 0; + + dim_t C_blks_per_iter = 1; + int64_t iters = 1; + + if (do_blocking) { + size_t working_set_size = 2 * N * SP * sizeof(data_t); + bnorm_utils::cache_balance( + working_set_size, C, C_blks_per_iter, iters); + } else + C_blks_per_iter = C; + int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter; + bool spatial_thr_allowed + = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N, + C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e, + N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e); + balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + int SP_N_ithr = N_ithr * S_nthr + S_ithr; + int SP_N_nthr = N_nthr * S_nthr; + + for (int64_t it = 0; it < iters; ++it) { + if (it == iters - 1 && iters > 1) { + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). So sync the + // threads if they are not synced by the algorithm. + if (SP_N_nthr == 1 && mkldnn_thr_syncable()) + mkldnn_thr_barrier(); + + C_blk_s = C_blk_e = N_s = N_e = 0; + spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, + spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, + C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, + N_e, S_ithr, S_nthr, S_s, S_e); + balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + SP_N_ithr = N_ithr * S_nthr + S_ithr; + SP_N_nthr = N_nthr * S_nthr; + } + size_t C_off = it * C_blks_per_iter; + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). Since sync is not always + // possible (in case of TBB) use different parts of ws for each + // iteration if threads are not synced by the algorithm. + size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * 2 * C_off; + + data_t *diff_gamma_blk = diff_scaleshift + C_off; + data_t *diff_beta_blk = diff_scaleshift + C + C_off; + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t diff_gamma = 0.0, diff_beta = 0.0; + data_t v_mean = mean[off]; + for (dim_t n = N_s; n < N_e; ++n) + PRAGMA_OMP_SIMD(reduction(+ : diff_gamma, diff_beta)) + for (dim_t sp = S_s; sp < S_e; ++sp) { + const size_t d_off = off * SP + n * C * SP + sp; + data_t dd; + if (fuse_bn_relu) + dd = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + dd = diff_dst[d_off]; + diff_gamma += (src[d_off] - v_mean) * dd; + diff_beta += dd; + } + ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] + = diff_gamma; + ws_reduce[ws_iter_off + SP_N_nthr * C_blks_per_iter + + SP_N_ithr * C_blks_per_iter + c] = diff_beta; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { + data_t sqrt_variance = static_cast( + 1.0f / sqrtf(variance[c + C_off] + eps)); + diff_gamma_blk[c] = 0.; + diff_beta_blk[c] = 0.; + for (dim_t n = 0; n < SP_N_nthr; n++) { + diff_gamma_blk[c] += ws_reduce[ws_iter_off + + n * C_blks_per_iter + c]; + diff_beta_blk[c] += ws_reduce[ws_iter_off + + SP_N_nthr * C_blks_per_iter + n * C_blks_per_iter + + c]; + } + diff_gamma_blk[c] *= sqrt_variance; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t gamma = use_scaleshift ? scaleshift[off] : 1; + data_t sqrt_variance + = static_cast(1.0f / sqrtf(variance[off] + eps)); + data_t v_mean = mean[off]; + for (dim_t n = N_s; n < N_e; ++n) +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t sp = S_s; sp < S_e; ++sp) { + const size_t d_off = off * SP + n * C * SP + sp; + + data_t v_diff_src; + if (fuse_bn_relu) + v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + v_diff_src = diff_dst[d_off]; + if (calculate_diff_stats) { + v_diff_src -= diff_beta_blk[c] / (SP * N) + + (src[d_off] - v_mean) * diff_gamma_blk[c] + * sqrt_variance / (SP * N); + } + v_diff_src *= gamma * sqrt_variance; + diff_src[d_off] = v_diff_src; + } + } + } + }); +} +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp new file mode 100644 index 000000000000..97ca3b003f69 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp @@ -0,0 +1,160 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NCSP_BATCH_NORMALIZATION_HPP +#define CPU_NCSP_BATCH_NORMALIZATION_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ncsp_batch_normalization_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_fwd_pd_t { + using cpu_batch_normalization_fwd_pd_t::cpu_batch_normalization_fwd_pd_t; + + DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_fwd_t); + + status_t init() { + using namespace data_type; + using namespace prop_kind; + using namespace format_tag; + + bool ok = true + && is_fwd() + && !has_zero_dim_memory() + && src_md()->data_type == f32 + && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) + && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc) + && (attr()->has_default_values() || this->with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (is_training() && fuse_bn_relu()) init_default_ws(8); + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + if (!stats_is_src()) { + scratchpad.book(key_bnorm_reduction, + sizeof(data_t) * C() * mkldnn_get_max_threads()); + + if (!is_training()) { + scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * C()); + scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * C()); + } + } + } + }; + + typedef typename prec_traits::type data_t; + + ncsp_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~ncsp_batch_normalization_fwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct ncsp_batch_normalization_bwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_bwd_pd_t { + using cpu_batch_normalization_bwd_pd_t::cpu_batch_normalization_bwd_pd_t; + + DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_bwd_t); + + status_t init() { + using namespace data_type; + using namespace format_tag; + + bool ok = true + && is_bwd() + && !has_zero_dim_memory() + && utils::everyone_is(f32, src_md()->data_type, + diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), + utils::everyone_is(f32, + weights_md()->data_type, + diff_weights_md()->data_type)) + && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc) + && memory_desc_matches_one_of_tag(*diff_src_md(), ncdhw, nchw, nc) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (fuse_bn_relu()) { + init_default_ws(8); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_bnorm_reduction, + sizeof(data_t) * 2 * C() * mkldnn_get_max_threads()); + if (!(use_scaleshift() && desc()->prop_kind == prop_kind::backward)) + scratchpad.book(key_bnorm_tmp_diff_ss, + sizeof(data_t) * 2 * C()); + } + }; + + typedef typename prec_traits::type data_t; + + ncsp_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~ncsp_batch_normalization_bwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp new file mode 100644 index 000000000000..38cfb28dce30 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp @@ -0,0 +1,392 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" + +#include "nhwc_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define MEM_D(name) name##_d + +#define DECLARE_READ_STRIDES(name) \ + const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0]; \ + const size_t name##_d_stride = (!is_3d) \ + ? 0 \ + : MEM_D(name).blocking_desc().strides[2]; \ + const size_t name##_h_stride = (!is_3d) \ + ? MEM_D(name).blocking_desc().strides[2] \ + : MEM_D(name).blocking_desc().strides[3]; \ + const size_t name##_w_stride = (!is_3d) \ + ? MEM_D(name).blocking_desc().strides[3] \ + : MEM_D(name).blocking_desc().strides[4]; + +namespace nhwc_pooling { + size_t strided_offset(const int _n, const size_t _sn, + const int _d, const size_t _sd, + const int _h, const size_t _sh, + const int _w, const size_t _sw) + { + return _n * _sn + + _d * _sd + + _h * _sh + + _w * _sw; + } +} + +template +void nhwc_pooling_fwd_t::array_div_by_const(const int n, + const data_t *src, const size_t num, data_t *dst) const +{ + for (int i = 0; i < n; ++i) + { + float ftmp = (float)src[i]; + ftmp = ftmp / num; + dst[i] = math::out_round(ftmp); + } +} + +template +void nhwc_pooling_fwd_t::array_add(const int n, const data_t *src, + data_t *dst) const +{ + for (int i = 0; i < n; ++i) + { + dst[i] += src[i]; + } +} + +template +void nhwc_pooling_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace prop_kind; + using namespace nhwc_pooling; + + auto alg = pd()->desc()->alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper MEM_D(src)(pd()->src_md()); + const memory_desc_wrapper MEM_D(dst)(pd()->dst_md()); + const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + const int MB = pd()->MB(); + const int OC = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + const bool is_3d = pd()->desc()->src_desc.ndims == 5; + const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; + + DECLARE_READ_STRIDES(src); + DECLARE_READ_STRIDES(dst); + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + parallel_nd(MB, OD, OH, OW, + [&](int mb, int od, int oh, int ow) { + size_t dst_offset_init = strided_offset(mb, dst_n_stride, + od, dst_d_stride, + oh, dst_h_stride, + ow, dst_w_stride); + if (alg == pooling_max) { + size_t ws_offset_init = 0; + if (ws) + { + DECLARE_READ_STRIDES(ws); + ws_offset_init = strided_offset(mb, ws_n_stride, + od, ws_d_stride, + oh, ws_h_stride, + ow, ws_w_stride); + } + // Note: GCC 4.8.5 won't vectorize below + // simple loops unless they are singled out + // into separate helper routines: + // array_nhwc_initialize, array_nhwc_max + if (!ws) + array_nhwc_initialize(OC, dst + dst_offset_init, + ws, ws_offset_init, ws_dt); + else + array_nhwc_initialize(OC, dst + dst_offset_init, + ws, ws_offset_init, ws_dt); + + + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (id < 0 || id >= ID) + continue; + if (ih < 0 || ih >= IH) + continue; + if (iw < 0 || iw >= IW) + continue; + + size_t src_offset_init = strided_offset(mb, src_n_stride, + id, src_d_stride, + ih, src_h_stride, + iw, src_w_stride); + + if (!ws) + array_nhwc_max(OC, + dst + dst_offset_init, + src + src_offset_init, + ws, ws_offset_init, + ws_dt, + kd * KH * KW + kh * KW + kw + ); + else + array_nhwc_max(OC, + dst + dst_offset_init, + src + src_offset_init, + ws, ws_offset_init, + ws_dt, + kd * KH * KW + kh * KW + kw + ); + } + } else { + // pooling_avg + auto d = dst + dst_offset_init; + + utils::array_set(d, 0, OC); + + auto id_start = apply_offset(od * SD, padF); + auto ih_start = apply_offset(oh * SH, padT); + auto iw_start = apply_offset(ow * SW, padL); + auto id_end = nstl::min(od * SD - padF + KD, ID); + auto ih_end = nstl::min(oh * SH - padT + KH, IH); + auto iw_end = nstl::min(ow * SW - padL + KW, IW); + + // it is cheaper to actually count this in a loop + // as the typical kernel is small + size_t num_summands = 0; + + for (int id = id_start; id < id_end; ++id) + for (int ih = ih_start; ih < ih_end; ++ih) + for (int iw = iw_start; iw < iw_end; ++iw) { + size_t src_offset_init = strided_offset(mb, src_n_stride, + id, src_d_stride, + ih, src_h_stride, + iw, src_w_stride); + auto s = src + src_offset_init; + + // need to move the loop to separate function + // for GCC 4.8.5 to vectorize + array_add(OC, s, d); + + num_summands++; + } + + num_summands = (alg == pooling_avg_include_padding) ? + KW * KH * KD : num_summands; + + // need to move the loop to separate function + // for GCC 4.8.5 to vectorize + array_div_by_const(OC, d, num_summands, d); + } + }); +} + +template +void nhwc_pooling_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace nhwc_pooling; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md()); + const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md()); + const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int OC = pd()->C(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; + auto alg = pd()->desc()->alg_kind; + + DECLARE_READ_STRIDES(diff_src); + DECLARE_READ_STRIDES(diff_dst); + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + const int MB = pd()->MB(); + + parallel_nd(MB, ID, IH, IW, + [&](int mb, int id, int ih, int iw) { + size_t src_offset_init = strided_offset(mb, diff_src_n_stride, + id, diff_src_d_stride, + ih, diff_src_h_stride, + iw, diff_src_w_stride); + + // check if kernel windows are disjoint, in this case there's no + // update needed and we just write there once, no initialization + // required. + if (!(KD == SD && KH == SH && KW == SW)) + for (int oc = 0; oc < OC; ++oc) + diff_src[src_offset_init + oc] = data_type_t(0); + + // Find out which output cells may correspond to current + // input position. Current input postition divided by + // stride, with integer divide rounding down, is the + // right-most output. + // Left-most output may be computed if we decrement input + // by (kernel_size - 1) and then do the same division by + // stride. + int od_left = nstl::max((id + padF - KD + 1) / SD, 0); + int oh_left = nstl::max((ih + padT - KH + 1) / SH, 0); + int ow_left = nstl::max((iw + padL - KW + 1) / SW, 0); + // Notice +1 here to preserve the C loop "less than" + // condition for continuing the for loop. + int od_right = nstl::min((id + padF) / SD + 1 , OD); + int oh_right = nstl::min((ih + padT) / SH + 1 , OH); + int ow_right = nstl::min((iw + padL) / SW + 1 , OW); + + for (int od = od_left; od < od_right; ++od) + for (int oh = oh_left; oh < oh_right; ++oh) + for (int ow = ow_left; ow < ow_right; ++ow) { + const int kd = id - od*SD + padF; + const int kh = ih - oh*SH + padT; + const int kw = iw - ow*SW + padL; + + if (kd < 0 || kd >= KD) + continue; + if (kh < 0 || kh >= KH) + continue; + if (kw < 0 || kw >= KW) + continue; + + size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride, + od, diff_dst_d_stride, + oh, diff_dst_h_stride, + ow, diff_dst_w_stride); + + if (alg == pooling_max) { + DECLARE_READ_STRIDES(ws); + size_t ws_offset_init = strided_offset(mb, ws_n_stride, + od, ws_d_stride, + oh, ws_h_stride, + ow, ws_w_stride); + const int index = kd * KH * KW + kh * KW + kw; + + PRAGMA_OMP_SIMD() + for (int oc = 0; oc < OC; ++oc) { + const int index_from_ws = + (MEM_D(ws).data_type() == data_type::u8) + ? (int)ws[ws_offset_init + oc] + : ((int *)ws)[ws_offset_init + oc]; + + const data_t d = diff_dst[dst_offset_init + oc]; + + // Check if kernel windows are disjoint, in this case + // there's no update needed and we just write there once + // otherwise we add value to the contents. + if (!(KD == SD && KH == SH && KW == SW)) + diff_src[src_offset_init + oc] += + (index_from_ws == index) + ? d + : data_type_t(0); + else + diff_src[src_offset_init + oc] = + (index_from_ws == index) + ? d + : data_type_t(0); + } + } else { + // pooling_avg + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) + ? KW*KH*KD + : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); + + PRAGMA_OMP_SIMD() + for (int oc = 0; oc < OC; ++oc) { + const data_t d = diff_dst[dst_offset_init + oc]; + // Check if kernel windows are disjoint, in this case + // there's no update needed and we just write there once + // otherwise we add value to the contents. + if (!(KD == SD && KH == SH && KW == SW)) + diff_src[src_offset_init + oc] += d / num_summands; + else + diff_src[src_offset_init + oc] = d / num_summands; + } + } + } + }); +} + +template struct nhwc_pooling_fwd_t; +template struct nhwc_pooling_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp new file mode 100644 index 000000000000..7e33b6869fba --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp @@ -0,0 +1,210 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NHWC_POOLING_HPP +#define CPU_NHWC_POOLING_HPP + +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace nhwc_pooling { +size_t strided_offset(const int _n, const size_t _sn, const int _d, + const size_t _sd, const int _h, const size_t _sh, const int _w, + const size_t _sw); +} + +template +struct nhwc_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T("nhwc_pooling:any", nhwc_pooling_fwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nhwc : format_tag::ndhwc; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && utils::everyone_is(data_type, + src_md()->data_type, + dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && memory_desc_matches_tag(*dst_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return status::success; + } + }; + + nhwc_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + void array_div_by_const(const int n, const data_t *src, const size_t num, + data_t *dst) const; + void array_add(const int n, const data_t *src, data_t *dst) const; + + template + void array_nhwc_max(const int n, data_t *dst, const data_t *src, + unsigned char *ws, const size_t ws_offset, const data_type_t ws_dt, + const int index) const { + assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists + PRAGMA_OMP_SIMD() + for (int oc = 0; oc < n; ++oc) { + auto s = src[oc]; + data_t mv = dst[oc]; + + // update index of maximum +#if defined __INTEL_COMPILER + if ((use_workspace) && (s > mv)) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + if (ws_dt == data_type::u8) { + assert(0 <= index && index <= 255); + ws[ws_offset + oc] = index; + } else + reinterpret_cast(ws)[ws_offset + oc] = index; + } +#else + // Need to add explicit predicates for GCC to vectorize this. + // And although the resulting code is ugly, it is still 4 times + // faster than scalar + if (use_workspace) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + + if (ws_dt == data_type::u8) { + assert(0 <= index && index <= 255); + unsigned char predicate = (s > mv) ? 0xff : 0; + unsigned char current_value = ws[ws_offset + oc]; + current_value = (predicate & (unsigned char)index) + | ((~predicate) & current_value); + ws[ws_offset + oc] = current_value; + } else { + auto wint = reinterpret_cast(ws); + unsigned int predicate = (s > mv) ? 0xffffffff : 0; + unsigned int current_value = wint[ws_offset + oc]; + current_value = (predicate & (unsigned int)index) + | ((~predicate) & current_value); + wint[ws_offset + oc] = current_value; + } + } +#endif + // update maximum + dst[oc] = nstl::max(s, mv); + } + } + + template + void array_nhwc_initialize(const int n, data_t *dst, unsigned char *ws, + const size_t ws_offset, const data_type_t ws_dt) const { + assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists + for (int oc = 0; oc < n; ++oc) { + if (use_workspace) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + if (ws_dt == data_type::u8) { + ws[ws_offset + oc] = 0; + } else + reinterpret_cast(ws)[ws_offset + oc] = 0; + } + dst[oc] = nstl::numeric_limits::lowest(); + } + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct nhwc_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T("nhwc:any", nhwc_pooling_bwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; + + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && utils::everyone_is(data_type, + diff_dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + init_default_ws(); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return status::success; + } + }; + + nhwc_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +}// namespace cpu +}// namespace impl +}// namespace mkldnn + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp new file mode 100644 index 000000000000..e20333e66fe0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp @@ -0,0 +1,288 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "cpu_batch_normalization_utils.hpp" +#include "jit_generator.hpp" + +#include "nspc_batch_normalization.hpp" + +// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases +#if (defined __clang_major__) && (__clang_major__ >= 6) +#define SAFE_TO_USE_OMP_SIMD 0 +#else +#define SAFE_TO_USE_OMP_SIMD 1 +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +void nspc_batch_normalization_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + const bool save_stats = pd()->is_training(); + const bool is_training = pd()->is_training(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + const bool calculate_stats = !pd()->stats_is_src(); + const bool with_relu = pd()->with_relu_post_op(); + + auto scratchpad = this->scratchpad(ctx); + auto tmp_mean = scratchpad.get(key_bnorm_tmp_mean); + auto tmp_var = scratchpad.get(key_bnorm_tmp_var); + auto *ws_reduce = scratchpad.get(key_bnorm_reduction); + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + + data_t *mean, *variance; + if (!calculate_stats) { + mean = const_cast( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)); + variance = const_cast( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)); + } else { + if (save_stats) { + mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); + variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); + } else { + mean = tmp_mean; + variance = tmp_var; + } + } + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + const dim_t SP = pd()->H() * pd()->W() * pd()->D(); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + auto maybe_post_op + = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; }; + + assert(mkldnn_thr_syncable()); + parallel(0, [&](const int ithr, const int nthr) { + dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0; + balance211(N, nthr, ithr, N_s, N_e); + balance211(C, nthr, ithr, C_s, C_e); + data_t *mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr; + data_t *variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr; + + if (calculate_stats) { + for (dim_t c = 0; c < C; c++) + ws_reduce[C * ithr + c] = 0.; + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) + PRAGMA_OMP_SIMD() + for (dim_t c = 0; c < C; c++) + ws_reduce[C * ithr + c] += src[(size_t)n * SP * C + + sp * C + c]; + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + mean[c] = 0; + for (dim_t n = 0; n < nthr; n++) + mean[c] += ws_reduce[C * n + c]; + mean[c] /= SP * N; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) { + mean_loc[c] = mean[c]; + ws_reduce[C * ithr + c] = 0.; + } + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) + PRAGMA_OMP_SIMD() + for (dim_t c = 0; c < C; c++) { + data_t m = src[(size_t)n * SP * C + sp * C + c] + - mean_loc[c]; + ws_reduce[C * ithr + c] += m * m; + } + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + variance[c] = 0; + for (dim_t n = 0; n < nthr; n++) + variance[c] += ws_reduce[C * n + c]; + variance[c] /= SP * N; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) + variance_loc[c] = variance[c]; + } else { + variance_loc = variance; + mean_loc = mean; + } + + for (dim_t n = N_s; n < N_e; n++) { + for (dim_t sp = 0; sp < SP; sp++) { +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + data_t sqrt_variance = static_cast( + sqrtf(variance_loc[c] + eps)); + data_t sm = (use_scaleshift ? scaleshift[c] : 1.0f) / sqrt_variance; + data_t sv = use_scaleshift ? scaleshift[C + c] : 0; + size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t bn_res = sm * (src[d_off] - mean_loc[c]) + sv; + if (fuse_bn_relu) { + if (bn_res <= 0) { + bn_res = 0; + if (is_training) + ws[d_off] = 0; + } else { + if (is_training) + ws[d_off] = 1; + } + } + dst[d_off] = maybe_post_op(bn_res); + } + } + } + }); +} + +void nspc_batch_normalization_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + auto tmp_diff_ss = scratchpad.get(key_bnorm_tmp_diff_ss); + + if (diff_scaleshift == nullptr) + diff_scaleshift = tmp_diff_ss; + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + const dim_t SP = pd()->D() * pd()->H() * pd()->W(); + data_t *diff_gamma = diff_scaleshift, *diff_beta = diff_scaleshift + C; + auto *ws_reduce = scratchpad.get(key_bnorm_reduction); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + const bool calculate_diff_stats = !pd()->use_global_stats(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + assert(mkldnn_thr_syncable()); + parallel(0, [&](const int ithr, const int nthr) { + dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0; + balance211(N, nthr, ithr, N_s, N_e); + balance211(C, nthr, ithr, C_s, C_e); + + data_t *diff_gamma_loc = tmp_diff_ss + 2 * C + C * ithr; + data_t *diff_beta_loc = tmp_diff_ss + 2 * C + C * (nthr + ithr); + + for (dim_t c = 0; c < C; c++) { + ws_reduce[C * ithr + c] = 0.; + ws_reduce[C * nthr + C * ithr + c] = 0.; + } + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + const size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t dd; + if (fuse_bn_relu) + dd = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + dd = diff_dst[d_off]; + ws_reduce[C * ithr + c] += (src[d_off] - mean[c]) * dd; + ws_reduce[C * nthr + C * ithr + c] += dd; + } + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + data_t sqrt_variance + = static_cast(1.0f / sqrtf(variance[c] + eps)); + diff_gamma[c] = 0; + diff_beta[c] = 0; + for (dim_t n = 0; n < nthr; n++) { + diff_gamma[c] += ws_reduce[C * n + c]; + diff_beta[c] += ws_reduce[C * nthr + C * n + c]; + } + diff_gamma[c] *= sqrt_variance; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) { + diff_gamma_loc[c] = diff_gamma[c]; + diff_beta_loc[c] = diff_beta[c]; + } + + for (dim_t n = N_s; n < N_e; n++) { + for (dim_t sp = 0; sp < SP; sp++) { +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + const size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t gamma = use_scaleshift ? scaleshift[c] : 1; + data_t sqrt_variance + = static_cast(1.0f / sqrtf(variance[c] + eps)); + data_t v_diff_src; + if (fuse_bn_relu) + v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + v_diff_src = diff_dst[d_off]; + if (calculate_diff_stats) { + v_diff_src -= diff_beta_loc[c] / (SP * N) + + (src[d_off] - mean[c]) * diff_gamma_loc[c] + * sqrt_variance / (SP * N); + } + v_diff_src *= gamma * sqrt_variance; + diff_src[d_off] = v_diff_src; + } + } + } + }); +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp new file mode 100644 index 000000000000..aad86b05a772 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp @@ -0,0 +1,169 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NSPC_BATCH_NORMALIZATION_HPP +#define CPU_NSPC_BATCH_NORMALIZATION_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct nspc_batch_normalization_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_fwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_fwd_t); + + status_t init() { + using namespace data_type; + using namespace prop_kind; + + bool ok = true + /* the algorithm requires barriers while switching + * between parallelization over N and C dimensions */ + && mkldnn_thr_syncable() + && is_fwd() + && !has_zero_dim_memory() + && src_md()->data_type == f32 + && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) + && memory_desc_matches_tag(*src_md(), format_tag::nhwc) + && (attr()->has_default_values() || this->with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (is_training() && fuse_bn_relu()) init_default_ws(8); + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + if (!stats_is_src()) { + dim_t sz = nstl::max(C(), 16) * mkldnn_get_max_threads(); + scratchpad.book(key_bnorm_reduction, sizeof(data_t) * sz); + scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * sz); + scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * sz); + } + } + }; + + typedef typename prec_traits::type data_t; + + nspc_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~nspc_batch_normalization_fwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct nspc_batch_normalization_bwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_bwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_bwd_t); + + status_t init() { + using namespace data_type; + using namespace prop_kind; + + bool ok = true + /* the algorithm requires barriers while switching + * between parallelization over N and C dimensions */ + && mkldnn_thr_syncable() + && is_bwd() + && !has_zero_dim_memory() + && utils::everyone_is(f32, src_md()->data_type, + diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), + utils::everyone_is(f32, + weights_md()->data_type, + diff_weights_md()->data_type)) + && memory_desc_matches_tag(*src_md(), format_tag::nhwc) + && memory_desc_matches_tag(*diff_src_md(), format_tag::nhwc) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (fuse_bn_relu()) { + init_default_ws(8); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_bnorm_reduction, + sizeof(data_t) * 2 * C() * mkldnn_get_max_threads()); + scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * 2 * C() + * (mkldnn_get_max_threads() + 1)); + } + }; + + typedef typename prec_traits::type data_t; + + nspc_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~nspc_batch_normalization_bwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp new file mode 100644 index 000000000000..d79b1a034ba1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp @@ -0,0 +1,265 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "simple_q10n.hpp" + +#include "ref_batch_normalization.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void ref_batch_normalization_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + /* fast return */ + if (this->pd()->has_zero_dim_memory()) return; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scaleshift = CTX_IN_MEM(const float *, MKLDNN_ARG_SCALE_SHIFT); + + auto mean = pd()->stats_is_src() + ? const_cast(CTX_IN_MEM(const float *, MKLDNN_ARG_MEAN)) + : CTX_OUT_MEM(float *, MKLDNN_ARG_MEAN); + auto variance = pd()->stats_is_src() + ? const_cast(CTX_IN_MEM(const float *, MKLDNN_ARG_VARIANCE)) + : CTX_OUT_MEM(float *, MKLDNN_ARG_VARIANCE); + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper scaleshift_d(pd()->weights_md()); + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + dim_t H = 1, W = 1, D = 1; + const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5); + if (has_spatial) { + D = pd()->D(); + H = pd()->H(); + W = pd()->W(); + } + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift();; + const bool save_stats = pd()->is_training(); + const bool is_training = pd()->is_training(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + const bool calculate_stats = !pd()->stats_is_src(); + + const bool with_relu = pd()->with_relu_post_op(); + auto maybe_post_op = [&](float res) { + return (with_relu && res < 0.0f) ? 0.0f : res; + }; + const bool is_3d = data_d.ndims() == 5; + + auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c, + dim_t d, dim_t h, dim_t w) { + if (has_spatial) { + if (is_3d) + return data_d.off(n, c, d, h, w); + else + return data_d.off(n, c, h, w); + } else + return data_d.off(n, c); + }; + + parallel_nd(C, [&](dim_t c) { + float v_mean = calculate_stats ? 0 : mean[c]; + float v_variance = calculate_stats ? 0 : variance[c]; + + if (calculate_stats) { + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) + v_mean += src[data_offset(data_d, n, c, d, h, w)]; + v_mean /= W*N*H*D; + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + float m = src[data_offset(data_d, n, c, d, h, w)] - v_mean; + v_variance += m*m; + } + v_variance /= W*H*N*D; + } + + float sqrt_variance = sqrtf(v_variance + eps); + float sm = (use_scaleshift + ? scaleshift[scaleshift_d.off(0, c)] + : 1.0f) / sqrt_variance; + float sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0; + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + auto d_off = data_offset(data_d,n,c,d,h,w); + float bn_res = sm * ((float)src[d_off] - v_mean) + sv; + if (fuse_bn_relu) { + if (bn_res <= 0) { + bn_res = 0; + if (is_training) + ws[d_off] = 0; + } else { + if (is_training) + ws[d_off] = 1; + } + } + if (data_type == data_type::s8) { + dst[d_off] = qz_a1b0()(maybe_post_op(bn_res)); + } else { + dst[d_off] = static_cast(maybe_post_op(bn_res)); + } + } + + if (calculate_stats) { + if (save_stats) { + mean[c] = v_mean; + variance[c] = v_variance; + } + } + }); +} + +template struct ref_batch_normalization_fwd_t; +template struct ref_batch_normalization_fwd_t; + +template +void ref_batch_normalization_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + const memory_desc_wrapper scaleshift_d(pd()->weights_md()); + const memory_desc_wrapper diff_scaleshift_d(pd()->diff_weights_md()); + + const dim_t C = pd()->C(); + + /* fast return */ + if (this->pd()->has_zero_dim_memory()) { + if (diff_scaleshift) { + for (dim_t c = 0; c < C; ++c) { + diff_scaleshift[diff_scaleshift_d.off(0, c)] = 0; + diff_scaleshift[diff_scaleshift_d.off(1, c)] = 0; + } + } + return; + } + + const dim_t N = pd()->MB(); + dim_t H = 1, W = 1, D = 1; + const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5); + if (has_spatial) { + D = pd()->D(); + H = pd()->H(); + W = pd()->W(); + } + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + const bool calculate_diff_stats = !pd()->use_global_stats(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + const bool is_3d = data_d.ndims() == 5; + + auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c, + dim_t d, dim_t h, dim_t w) { + if (has_spatial) { + if (is_3d) + return data_d.off(n, c, d, h, w); + else + return data_d.off(n, c, h, w); + } else + return data_d.off(n, c); + }; + + parallel_nd(C, [&](dim_t c) { + data_t v_mean = mean[c]; + data_t v_variance = variance[c]; + data_t sqrt_variance = static_cast(1.0f / sqrtf(v_variance + eps)); + data_t gamma = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1; + data_t diff_gamma = data_t(0); + data_t diff_beta = data_t(0); + diff_gamma = 0.0; + diff_beta = 0.0; + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + const size_t s_off = data_offset(data_d, n, c, d, h, w); + data_t dd = diff_dst[data_offset(diff_data_d, n, c, d, h, w)]; + if (fuse_bn_relu && !ws[s_off]) + dd = 0; + + diff_gamma += (src[s_off] - v_mean) * dd; + diff_beta += dd; + } + diff_gamma *= sqrt_variance; + + if (diff_scaleshift) { + diff_scaleshift[diff_scaleshift_d.off(0, c)] = diff_gamma; + diff_scaleshift[diff_scaleshift_d.off(1, c)] = diff_beta; + } + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + const size_t s_off = data_offset(data_d, n, c, d, h, w); + const size_t dd_off = data_offset(diff_data_d, n, c, d, h, w); + data_t dd = diff_dst[dd_off]; + if (fuse_bn_relu && !ws[s_off]) + dd = 0; + + data_t v_diff_src = dd; + if (calculate_diff_stats) { + v_diff_src -= diff_beta/(D*W*H*N) + + (src[s_off] - v_mean) * + diff_gamma*sqrt_variance/(D*W*H*N); + } + v_diff_src *= gamma*sqrt_variance; + diff_src[dd_off] = v_diff_src; + } + }); +} + +template struct ref_batch_normalization_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp new file mode 100644 index 000000000000..aa9f74125a69 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp @@ -0,0 +1,127 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_BATCH_NORMALIZATION_HPP +#define CPU_REF_BATCH_NORMALIZATION_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_batch_normalization_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_fwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && src_md()->data_type == data_type + && IMPLICATION(use_scaleshift(), + weights_md()->data_type == data_type::f32) + && (attr()->has_default_values() || with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (src_md()->data_type == data_type::s8 && !stats_is_src()) + return status::unimplemented; + + if (is_training() && fuse_bn_relu()) init_default_ws(8); + + return status::success; + } + }; + + ref_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_batch_normalization_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_bwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_bwd_t); + + status_t init() { + bool ok = true + && is_bwd() + && utils::everyone_is(data_type, src_md()->data_type, + diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), utils::everyone_is(data_type, + weights_md()->data_type, + diff_weights_md()->data_type)) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (fuse_bn_relu()) { + init_default_ws(8); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return status::success; + } + }; + + ref_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp new file mode 100644 index 000000000000..4c534b550820 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp @@ -0,0 +1,97 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_CONCAT_HPP +#define REF_CONCAT_HPP + +#include "reorder_pd.hpp" + +#include "cpu_concat_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ref_concat_t: public cpu_primitive_t { + struct pd_t: public cpu_concat_pd_t { + using cpu_concat_pd_t::cpu_concat_pd_t; + + pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) { + for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) + reorder_pds_.push_back( + (const reorder_pd_t *)rhs.reorder_pds_[i]->clone()); + } + ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; } + + DECLARE_CONCAT_PD_T("ref:any", ref_concat_t); + + status_t init() { + bool ok = cpu_concat_pd_t::init() == status::success; + if (!ok) return status::unimplemented; + + for (int i = 0; i < n_; ++i) { + auto r_impls = engine_->get_reorder_implementation_list(); + for (auto r = r_impls; *r; ++r) { + const primitive_attr_t attr; /* alpha == 1. */ + reorder_pd_t *r_pd = nullptr; + if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i), + engine_, src_image_md(i)) == status::success) { + r_pd->init_info(); + reorder_pds_.push_back(r_pd); + break; + } + } + } + + ok = reorder_pds_.size() == (size_t)n_; + return ok ? status::success : status::unimplemented; + } + + nstl::vector reorder_pds_; + }; + + ref_concat_t(const pd_t *apd): cpu_primitive_t(apd) { + const int n = pd()->n_inputs(); + reorders_.resize(n); + for (int i = 0; i < n; ++i) + pd()->reorder_pds_[i]->create_primitive(&reorders_[i]); + } + + ~ref_concat_t() { for (auto &r: reorders_) delete r; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto n = pd()->n_inputs(); + for (int i = 0; i < n; ++i) { + exec_args_t r_args; + r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i); + r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST); + exec_ctx_t r_ctx(ctx.stream(), std::move(r_args)); + reorders_[i]->execute(r_ctx); + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + nstl::vector reorders_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp new file mode 100644 index 000000000000..c0a979c4cf44 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp @@ -0,0 +1,395 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_traits.hpp" +#include "type_helpers.hpp" + +#include "ref_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using math::saturate; +using math::get_bias; + +template +void ref_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const bool with_groups = pd()->with_groups(); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + + const int OC = pd()->OC() / G; + const int IC = pd()->IC() / G; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + + const int KSD = pd()->KSD(); + const int KSH = pd()->KSH(); + const int KSW = pd()->KSW(); + + const int KDD = pd()->KDD(); + const int KDH = pd()->KDH(); + const int KDW = pd()->KDW(); + + const int padFront = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool with_relu = 0; // TODO: change if support post_ops + const float nslope = 0.f; + + const int ndims = pd()->desc()->src_desc.ndims; + + auto ker = [=](int g, int mb, int oc, int od, int oh, + int ow) { + acc_data_t d = 0; + for (int ic = 0; ic < IC; ++ic) + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + const int id = od * KSD - padFront + kd * (1 + KDD); + const int ih = oh * KSH - padT + kh * (1 + KDH); + const int iw = ow * KSW - padL + kw * (1 + KDW); + + if (id < 0 || id >= ID) continue; + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + if (ndims == 5) + d += (acc_data_t)src[src_d.off(mb, g*IC + ic, id, ih, iw)] + * (with_groups + ? weights[weights_d.off(g, oc, ic, kd, kh, kw)] + : weights[weights_d.off(oc, ic, kd, kh, kw)]); + else if (ndims == 4) + d += (acc_data_t)src[src_d.off(mb, g*IC + ic, ih, iw)] + * (with_groups + ? weights[weights_d.off(g, oc, ic, kh, kw)] + : weights[weights_d.off(oc, ic, kh, kw)]); + else if (ndims == 3) + d += (acc_data_t)src[src_d.off(mb, g*IC + ic, iw)] + * (with_groups + ? weights[weights_d.off(g, oc, ic, kw)] + : weights[weights_d.off(oc, ic, kw)]); + else + assert(false); + + } + return d; + }; + + parallel_nd(G, MB, OC, OD, OH, OW, + [&](int g, int mb, int oc, int od, int oh, int ow) { + float a = bias + ? get_bias(bias, bias_d.off(g * OC + oc), + pd()->desc()->bias_desc.data_type) + : 0; + a += ker(g, mb, oc, od, oh, ow); + if (with_relu && a < 0) + a = a * nslope; + if (ndims == 5) + dst[dst_d.off(mb, g*OC + oc, od, oh, ow)] = saturate(a); + else if (ndims == 4) + dst[dst_d.off(mb, g*OC + oc, oh, ow)] = saturate(a); + else if (ndims == 3) + dst[dst_d.off(mb, g*OC + oc, ow)] = saturate(a); + else + assert(false); + }); +} + +template +void ref_convolution_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const bool with_groups = pd()->with_groups(); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + + const int OC = pd()->OC() / G; + const int IC = pd()->IC() / G; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + + const int KSD = pd()->KSD(); + const int KSH = pd()->KSH(); + const int KSW = pd()->KSW(); + + const int KDD = pd()->KDD(); + const int KDH = pd()->KDH(); + const int KDW = pd()->KDW(); + + const int padFront = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const int ndims = pd()->desc()->diff_src_desc.ndims; + + auto ker = [=](int g, int mb, int ic, int id, int ih, + int iw) { + acc_data_t d = 0; + for (int oc = 0; oc < OC; ++oc) + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + if (iw + padL < kw * (1 + KDW) + || ih + padT < kh * (1 + KDH) + || id + padFront < kd * (1 + KDD)) + continue; + int ow = iw - kw * (1 + KDW) + padL; + int oh = ih - kh * (1 + KDH) + padT; + int od = id - kd * (1 + KDD) + padFront; + if (ow % KSW != 0 || oh % KSH != 0 || od % KSD != 0) + continue; + + ow /= KSW; + oh /= KSH; + od /= KSD; + + if (od < OD && oh < OH && ow < OW) { + if (ndims == 5) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + + oc, od, oh, ow)] * (with_groups + ? weights[weights_d.off(g, oc, ic, kd, kh, kw)] + : weights[weights_d.off(oc, ic, kd, kh, kw)]); + else if (ndims == 4) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + + oc, oh, ow)] * (with_groups + ? weights[weights_d.off(g, oc, ic, kh, kw)] + : weights[weights_d.off(oc, ic, kh, kw)]); + else if (ndims == 3) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + + oc, ow)] * (with_groups + ? weights[weights_d.off(g, oc, ic, kw)] + : weights[weights_d.off(oc, ic, kw)]); + else + assert(false); + } + } + return d; + }; + + parallel_nd(G, MB, IC, ID, IH, IW, + [&](int g, int mb, int ic, int id, int ih, int iw) { + auto ds_idx = (ndims == 5) + ? diff_src_d.off(mb, g*IC + ic, id, ih, iw) + : (ndims == 4) + ? diff_src_d.off(mb, g*IC + ic, ih, iw) + : diff_src_d.off(mb, g*IC + ic, iw); + float a = bias + ? get_bias(bias, bias_d.off(g * IC + ic), + pd()->desc()->bias_desc.data_type) + : 0; + a += ker(g, mb, ic, id, ih, iw); + diff_src[ds_idx] = saturate(a); + }); +} + +template +void ref_convolution_bwd_weights_t::execute_backward_weights(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + const bool with_groups = pd()->with_groups(); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + + const int OC = pd()->OC() / G; + const int IC = pd()->IC() / G; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + + const int KSD = pd()->KSD(); + const int KSH = pd()->KSH(); + const int KSW = pd()->KSW(); + + const int KDD = pd()->KDD(); + const int KDH = pd()->KDH(); + const int KDW = pd()->KDW(); + + const int padFront = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const int ndims = pd()->desc()->src_desc.ndims; + +auto ker = [=](acc_data_t &d, int g, int oc, int ic, int kd, int kh, int kw) { + for (int mb = 0; mb < MB; ++mb) + for (int od = 0; od < OD; ++od) + for (int oh = 0; oh < OH; ++oh) + for (int ow = 0; ow < OW; ++ow) { + if (ow*KSW + kw * (1 + KDW) < padL + || oh*KSH + kh * (1 + KDH) < padT + || od*KSD + kd * (1 + KDD) < padFront + || ow*KSW + kw * (1 + KDW) >= IW + padL + || oh*KSH + kh * (1 + KDH) >= IH + padT + || od*KSD + kd * (1 + KDD) >= ID + padFront) + continue; + + int id = od*KSD - padFront + kd * (1 + KDD); + int ih = oh*KSH - padT + kh * (1 + KDH); + int iw = ow*KSW - padL + kw * (1 + KDW); + if (ndims == 5) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, + oh, ow)] * src[src_d.off(mb, g*IC + ic, id, ih, iw)]; + else if (ndims == 4) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, ow)] + * src[src_d.off(mb, g*IC + ic, ih, iw)]; + else if (ndims == 3) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)] + * src[src_d.off(mb, g*IC + ic, iw)]; + else + assert(false); + } + }; + + auto ker_bias = [=](acc_data_t &d, int g, int oc) { + for (int mb = 0; mb < MB; ++mb) + for (int od = 0; od < OD; ++od) + for (int oh = 0; oh < OH; ++oh) + for (int ow = 0; ow < OW; ++ow) { + if (ndims == 5) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, oh, + ow)]; + else if (ndims == 4) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, + ow)]; + else if (ndims == 3) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)]; + else + assert(false); + } + }; + + parallel_nd(G, OC, [&](int g, int oc) { + if (diff_bias) { + // XXX: loss of precision when bias is a float... + acc_data_t db = 0; + ker_bias(db, g, oc); + diff_bias[diff_bias_d.off(g*OC+oc)] + = saturate(db); + } + + for (int ic = 0; ic < IC; ++ic) + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + acc_data_t dw = 0; + ker(dw, g, oc, ic, kd, kh, kw); + + if (ndims == 5) { + auto idx = with_groups + ? diff_weights_d.off(g, oc, ic, kd, kh, kw) + : diff_weights_d.off(oc, ic, kd, kh, kw); + diff_weights[idx] = saturate(dw); + } else if (ndims == 4) { + auto idx = with_groups + ? diff_weights_d.off(g, oc, ic, kh, kw) + : diff_weights_d.off(oc, ic, kh, kw); + diff_weights[idx] = saturate(dw); + } else if (ndims == 3) { + auto idx = with_groups + ? diff_weights_d.off(g, oc, ic, kw) + : diff_weights_d.off(oc, ic, kw); + diff_weights[idx] = saturate(dw); + } else { + assert(false); + } + } + }); +} + +using namespace data_type; + +template struct ref_convolution_fwd_t; + +template struct ref_convolution_fwd_t; +template struct ref_convolution_fwd_t; +template struct ref_convolution_fwd_t; +template struct ref_convolution_fwd_t; + +template struct ref_convolution_bwd_data_t; + +template struct ref_convolution_bwd_data_t; +template struct ref_convolution_bwd_data_t; +template struct ref_convolution_bwd_data_t; +template struct ref_convolution_bwd_data_t; + +template struct ref_convolution_bwd_weights_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp new file mode 100644 index 000000000000..7c83d0c6d4fc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp @@ -0,0 +1,194 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_CONVOLUTION_HPP +#define CPU_REF_CONVOLUTION_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_convolution_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, data_type::undef, + dst_type, acc_type) + && IMPLICATION(with_bias(), true + && IMPLICATION(src_type == u8, + utils::one_of(bias_md_.data_type, f32, s32, s8, u8)) + && IMPLICATION(src_type == f32, + bias_md_.data_type == f32)) + && set_default_formats() + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + + protected: + bool set_default_formats() { + using namespace format_tag; + auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + ref_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, wei_type, data_type::undef, + diff_dst_type, acc_type) + && set_default_formats() + && attr()->has_default_values(); + + return ok ? status::success : status::unimplemented; + } + + virtual bool support_bias() const override { return true; } + + protected: + bool set_default_formats() { + using namespace format_tag; + auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + ref_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type diff_src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + using cpu_convolution_bwd_weights_pd_t::cpu_convolution_bwd_weights_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, diff_wei_type, diff_wei_type, + diff_dst_type, acc_type) + && set_default_formats() + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + + protected: + bool set_default_formats() { + using namespace format_tag; + auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + ref_convolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type diff_wei_data_t; + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp new file mode 100644 index 000000000000..541a303aabee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp @@ -0,0 +1,199 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_traits.hpp" +#include "math_utils.hpp" + +#include "ref_deconvolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +void ref_deconvolution_fwd_t::compute_fwd_bias(const data_t *bias, + data_t *dst) const { + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int OD = pd()->OD(); + const int OC = pd()->OC() / G; + const int ndims = pd()->desc()->src_desc.ndims; + + parallel_nd(MB, G, OC, OD, OH, OW, + [&](int mb, int g, int oc, int od, int oh, int ow) { + auto b = bias[g * OC + oc]; + switch (ndims) { + case 5: dst[dst_d.off(mb, g * OC + oc, od, oh, ow)] += b; break; + case 4: dst[dst_d.off(mb, g * OC + oc, oh, ow)] += b; break; + case 3: dst[dst_d.off(mb, g * OC + oc, ow)] += b; break; + default: assert(!"invalid dimension size"); + } + }); +} + +void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const data_t *bias, + data_t *dst) const { + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int SP = pd()->OW()*pd()->OH()*pd()->OD(); + + parallel_nd(MB, OC, [&](int mb, int oc) { + PRAGMA_OMP_SIMD() + for (int sp = 0; sp < SP; ++sp) { + auto offset = (size_t)(mb * OC + oc) * SP + sp; + dst[offset] += bias[oc]; + } + }); +} + +template +void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const data_t *bias, + data_t *dst) const { + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int SP = pd()->OW() * pd()->OH() * pd()->OD(); + + const ptrdiff_t stride_mb = dst_d.blocking_desc().strides[0]; + + parallel_nd(MB, utils::div_up(OC, blksize), SP, + [&](int mb, int oc_blk, int sp) { + int oc = oc_blk * blksize; + auto offset = mb * stride_mb + oc * SP + sp * blksize; + const int blk = nstl::min(blksize, OC - oc); + + PRAGMA_OMP_SIMD() + for (int i = 0; i < blk; ++i) + dst[offset + i] += bias[oc + i]; + }); +} + +void ref_deconvolution_bwd_weights_t::compute_bwd_bias(const data_t *diff_dst, + data_t *diff_bias) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int OC = pd()->OC() / G; + const int OD = pd()->OD(); + const int ndims = pd()->desc()->src_desc.ndims; + + parallel_nd(G, OC, [&](int g, int oc) { + data_t db = 0; + for (int mb = 0; mb < MB; ++mb) { + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + switch (ndims) { + case 5: + db += diff_dst[diff_dst_d.off( + mb, g * OC + oc, od, oh, ow)]; + break; + case 4: + db += diff_dst[diff_dst_d.off( + mb, g * OC + oc, oh, ow)]; + break; + case 3: + db += diff_dst[diff_dst_d.off(mb, g * OC + oc, ow)]; + break; + default: assert(!"invalid dimension size"); + } + } + } + } + } + diff_bias[g * OC + oc] = db; + }); +} + +void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw( + const data_t *diff_dst, data_t *diff_bias) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + const int OC = pd()->OC(); + const int MB = pd()->MB(); + const int SP = pd()->OH()*pd()->OW()*pd()->OD(); + + parallel_nd(OC, [&](int oc) { + data_t db = 0; + for (int mb = 0; mb < MB; ++mb) { + PRAGMA_OMP_SIMD() + for (int sp = 0; sp < SP; ++sp) { + auto offset = (size_t)(mb * OC + oc) * SP + sp; + db += diff_dst[offset]; + } + } + diff_bias[oc] = db; + }); +} + +template +void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc( + const data_t *diff_dst, data_t *diff_bias) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + const int OC = pd()->OC(); + const int MB = pd()->MB(); + const int SP = pd()->OH() * pd()->OW() * pd()->OD(); + + const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0]; + + parallel_nd(utils::div_up(OC, blksize), [&](int ocb) { + data_t db[blksize] = {0}; + + for (int mb = 0; mb < MB; ++mb) { + for (int sp = 0; sp < SP; ++sp) { + auto offset = mb * stride_mb + (ocb * SP + sp) * blksize; + + PRAGMA_OMP_SIMD() + for (int i = 0; i < blksize; ++i) + db[i] += diff_dst[offset+i]; + } + } + + const int blk = nstl::min(blksize, OC - ocb * blksize); + + PRAGMA_OMP_SIMD() + for (int i = 0; i < blk; ++i) + diff_bias[ocb * blksize + i] = db[i]; + }); +} + +template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<8>( + const data_t *diff_dst, data_t *diff_bias) const; +template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<16>( + const data_t *diff_dst, data_t *diff_bias) const; +template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<8>( + const data_t *diff_dst, data_t *diff_bias) const; +template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<16>( + const data_t *diff_dst, data_t *diff_bias) const; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp new file mode 100644 index 000000000000..d61903c32d6f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp @@ -0,0 +1,502 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_DECONVOLUTION_HPP +#define CPU_REF_DECONVOLUTION_HPP + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "primitive_iterator.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_deconvolution_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +static status_t compute_blocked_format(bool with_groups, + const memory_desc_t *oi_md, memory_desc_t *io_md) +{ + /* Computes blocking for *i*o* format from *o*i* format */ + + bool sanity_check_ok = true + && oi_md->ndims == io_md->ndims + && oi_md->format_kind == format_kind::blocked; + if (!sanity_check_ok) return status::invalid_arguments; + + const blocking_desc_t &oi_blk = oi_md->format_desc.blocking; + blocking_desc_t io_blk = io_md->format_desc.blocking; + + io_md->format_kind = format_kind::blocked; + io_blk = oi_blk; + + const int ID_OC = 0 + with_groups; + const int ID_IC = 1 + with_groups; + + nstl::swap(io_blk.strides[ID_OC], io_blk.strides[ID_IC]); + for (int i_blk = 0; i_blk < io_blk.inner_nblks; ++i_blk) { + if (utils::one_of(io_blk.inner_idxs[i_blk], ID_OC, ID_IC)) { + io_blk.inner_idxs[i_blk] = + (io_blk.inner_idxs[i_blk] == ID_OC ? ID_IC : ID_OC); + } + } + + return memory_desc_init_by_blocking_desc(*io_md, io_blk); +} + +static status_t conv_descr_create(const deconvolution_desc_t *dd, + convolution_desc_t *cd) +{ + using namespace prop_kind; + alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct + ? alg_kind::convolution_direct : alg_kind::convolution_winograd; + + const memory_desc_t *src_md, *dst_md, *d_weights_d; + prop_kind_t prop_kind; + memory_desc_t c_weights_d; + if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) { + prop_kind = backward_data; + src_md = &dd->dst_desc; + dst_md = &dd->src_desc; + d_weights_d = &dd->weights_desc; + } else if (dd->prop_kind == backward_data) { + prop_kind = forward_training; + src_md = &dd->diff_dst_desc; + dst_md = &dd->diff_src_desc; + d_weights_d = &dd->weights_desc; + } else { + prop_kind = dd->prop_kind; + src_md = &dd->diff_dst_desc; + dst_md = &dd->src_desc; + d_weights_d = &dd->diff_weights_desc; + } + + const bool with_groups = d_weights_d->ndims == src_md->ndims + 1; + + /* create weights desc for convolution */ + c_weights_d = *d_weights_d; + + const int ID_OC = 0 + with_groups; + const int ID_IC = 1 + with_groups; + + nstl::swap(c_weights_d.dims[ID_OC], c_weights_d.dims[ID_IC]); + nstl::swap(c_weights_d.padded_dims[ID_OC], c_weights_d.padded_dims[ID_IC]); + nstl::swap(c_weights_d.padded_offsets[ID_OC], c_weights_d.padded_offsets[ID_IC]); + + if (c_weights_d.format_kind != format_kind::any) + CHECK(compute_blocked_format(with_groups, d_weights_d, &c_weights_d)); + + return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d, + prop_kind != backward_weights ? &dd->bias_desc : nullptr, + dst_md, dd->strides, dd->dilates, + dd->padding[0], dd->padding[1], dd->padding_kind); +} + +struct ref_deconvolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_deconvolution_fwd_pd_t { + pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) + {} + + pd_t(const pd_t &other) + : cpu_deconvolution_fwd_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) + , conv_supports_bias_(other.conv_supports_bias_) + , dst_tag_(other.dst_tag_) + {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t); + + status_t init_convolution() { + using namespace types; + + convolution_desc_t cd; + CHECK(conv_descr_create(desc(), &cd)); + + mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, + &attr_, nullptr); + while (++it != it.end()) { + conv_pd_ = *it; + conv_supports_bias_ = + static_cast(conv_pd_) + ->support_bias(); + bool output_f32 = utils::everyone_is(data_type::f32, + desc()->accum_data_type, desc()->dst_desc.data_type); + + bool ok = true + && conv_pd_->weights_md()->extra.flags == 0 + /* deconv reference code can process only f32 bias */ + && IMPLICATION(with_bias(), + conv_supports_bias_ || output_f32); + if (ok) return status::success; + + delete conv_pd_; + } + conv_pd_ = nullptr; + return status::unimplemented; + } + + status_t init() { + using namespace format_tag; + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::deconvolution_direct, + alg_kind::deconvolution_winograd) + && attr()->post_ops_.has_default_values(); + + if (ok) { + CHECK(init_convolution()); + if (weights_md_.format_kind == format_kind::any) { + CHECK(compute_blocked_format(with_groups(), + conv_pd_->weights_md(), &desc_.weights_desc)); + weights_md_ = desc_.weights_desc; + } + if (src_md_.format_kind == format_kind::any) + src_md_ = *conv_pd_->diff_dst_md(); + if (dst_md_.format_kind == format_kind::any) + dst_md_ = *conv_pd_->diff_src_md(); + if (bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, x)); + + dst_tag_ = memory_desc_matches_one_of_tag(dst_md_, + utils::pick(ndims() - 3, ncw, nchw, ncdhw), + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), + utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); + + return status::success; + } + + return status::unimplemented; + } + + virtual void init_scratchpad_md() override { + scratchpad_md_ = *conv_pd_->scratchpad_md(); + } + + primitive_desc_t *conv_pd_; + bool conv_supports_bias_; + format_tag_t dst_tag_; + }; + + typedef typename prec_traits::type data_t; + + ref_deconvolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + ~ref_deconvolution_fwd_t() { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC); + conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS); + if (pd()->with_bias() && pd()->conv_supports_bias_) + conv_args[MKLDNN_ARG_BIAS] = args.at(MKLDNN_ARG_BIAS); + conv_args[MKLDNN_ARG_DIFF_SRC] = args.at(MKLDNN_ARG_DST); + if (!types::is_zero_md(pd()->scratchpad_md())) + conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); + const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); + + conv_p_->execute(conv_ctx); + + if (pd()->with_bias() && !pd()->conv_supports_bias_) { + using namespace format_tag; + + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + switch (pd()->dst_tag_) { + case ncdhw: case nchw: case ncw: + compute_fwd_bias_ncdhw(bias, dst); + break; + case nCdhw8c: case nChw8c: case nCw8c: + compute_fwd_bias_nCdhwXc<8>(bias, dst); + break; + case nCdhw16c: case nChw16c: case nCw16c: + compute_fwd_bias_nCdhwXc<16>(bias, dst); + break; + default: + compute_fwd_bias(bias, dst); + break; + } + } + return status::success; + } + +private: + void compute_fwd_bias(const data_t *bias, data_t *dst) const; + void compute_fwd_bias_ncdhw(const data_t *bias, data_t *dst) const; + template void compute_fwd_bias_nCdhwXc(const data_t *bias, + data_t *dst) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + primitive_t *conv_p_; +}; + +struct ref_deconvolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_deconvolution_bwd_data_pd_t { + pd_t(engine_t *engine, const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) + {} + + pd_t(const pd_t &other) + : cpu_deconvolution_bwd_data_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t); + + status_t init_convolution() { + using namespace types; + + convolution_desc_t cd; + status_t status = conv_descr_create(desc(), &cd); + if (status != status::success) return status; + + mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, + &attr_, nullptr); + while (++it != it.end()) { + conv_pd_ = *it; + if (conv_pd_->weights_md()->extra.flags == 0) + return status::success; + delete conv_pd_; + } + + return status::unimplemented; + } + + status_t init() { + using namespace data_type; + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && utils::everyone_is(data_type::f32, + desc()->diff_src_desc.data_type, + desc()->weights_desc.data_type, + desc()->diff_dst_desc.data_type) + && utils::one_of(desc()->alg_kind, + alg_kind::deconvolution_direct, + alg_kind::deconvolution_winograd); + + if (ok) { + CHECK(init_convolution()); + if (weights_md_.format_kind == format_kind::any) { + CHECK(compute_blocked_format(with_groups(), + conv_pd_->weights_md(), &desc_.weights_desc)); + weights_md_ = desc_.weights_desc; + } + if (diff_src_md_.format_kind == format_kind::any) + diff_src_md_ = *conv_pd_->dst_md(); + if (diff_dst_md_.format_kind == format_kind::any) + diff_dst_md_ = *conv_pd_->src_md(); + + return status::success; + } + + return status::unimplemented; + } + + virtual void init_scratchpad_md() override { + scratchpad_md_ = *conv_pd_->scratchpad_md(); + } + + primitive_desc_t *conv_pd_; + }; + + typedef typename prec_traits::type data_t; + + ref_deconvolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + ~ref_deconvolution_bwd_data_t() { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST); + conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS); + conv_args[MKLDNN_ARG_DST] = args.at(MKLDNN_ARG_DIFF_SRC); + if (!types::is_zero_md(pd()->scratchpad_md())) + conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); + const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); + + conv_p_->execute(conv_ctx); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + primitive_t *conv_p_; +}; + +struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_deconvolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) + {} + + pd_t(const pd_t &other) + : cpu_deconvolution_bwd_weights_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) + , dst_tag_(other.dst_tag_) + {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t); + + status_t init_convolution() { + using namespace types; + + convolution_desc_t cd; + status_t status = conv_descr_create(desc(), &cd); + if (status != status::success) return status; + + mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, + &attr_, nullptr); + while (++it != it.end()) { + conv_pd_ = *it; + if (conv_pd_->diff_weights_md()->extra.flags == 0) + return status::success; + delete conv_pd_; + } + return status::unimplemented; + } + + status_t init() { + using namespace format_tag; + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && utils::everyone_is(data_type::f32, + desc()->src_desc.data_type, + desc()->diff_weights_desc.data_type, + desc()->diff_dst_desc.data_type) + && utils::one_of(desc()->alg_kind, + alg_kind::deconvolution_direct, + alg_kind::deconvolution_winograd) + && attr()->has_default_values(); + if (ok) { + CHECK(init_convolution()); + if (diff_weights_md_.format_kind == format_kind::any) { + CHECK(compute_blocked_format(with_groups(), + conv_pd_->diff_weights_md(), + &desc_.diff_weights_desc)); + diff_weights_md_ = desc_.diff_weights_desc; + } + if (src_md_.format_kind == format_kind::any) + src_md_ = *conv_pd_->diff_dst_md(); + if (diff_dst_md_.format_kind == format_kind::any) + diff_dst_md_ = *conv_pd_->src_md(); + if (diff_bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_bias_md_, x)); + + dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_, + utils::pick(ndims() - 3, ncw, nchw, ncdhw), + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), + utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); + + return status::success; + } + + return status::unimplemented; + } + + virtual void init_scratchpad_md() override { + scratchpad_md_ = *conv_pd_->scratchpad_md(); + } + + primitive_desc_t *conv_pd_; + format_tag_t dst_tag_; + }; + + typedef typename prec_traits::type data_t; + + ref_deconvolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + ~ref_deconvolution_bwd_weights_t() { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC); + conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST); + conv_args[MKLDNN_ARG_DIFF_WEIGHTS] = args.at(MKLDNN_ARG_DIFF_WEIGHTS); + if (!types::is_zero_md(pd()->scratchpad_md())) + conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); + const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); + + status_t status = conv_p_->execute(conv_ctx); + if (status != status::success) return status; + + if (pd()->with_bias()) { + using namespace format_tag; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + switch (pd()->dst_tag_) { + case ncdhw: case nchw: case ncw: + compute_bwd_bias_ncdhw(diff_dst, diff_bias); + break; + case nCdhw8c: case nChw8c: case nCw8c: + compute_bwd_bias_nCdhwXc<8>(diff_dst, diff_bias); + break; + case nCdhw16c: case nChw16c: case nCw16c: + compute_bwd_bias_nCdhwXc<16>(diff_dst, diff_bias); + break; + default: + compute_bwd_bias(diff_dst, diff_bias); + break; + } + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + void compute_bwd_bias(const data_t *diff_dst, data_t *diff_bias) const; + void compute_bwd_bias_ncdhw(const data_t *diff_dst, + data_t *diff_bias) const; + template void compute_bwd_bias_nCdhwXc( + const data_t *diff_dst, data_t *diff_bias) const; + + primitive_t *conv_p_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp new file mode 100644 index 000000000000..7beee8d32370 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp @@ -0,0 +1,297 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace alg_kind; +using namespace math; + +ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, + float beta): alg_(alg), alpha_(alpha), beta_(beta) { + assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); +} + +ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t( + const post_ops_t::entry_t::eltwise_t &eltwise) + : ref_eltwise_scalar_fwd_t(eltwise.alg, eltwise.alpha, eltwise.beta) {} + +float ref_eltwise_scalar_fwd_t::compute_scalar(float s) { + switch (alg_) { + case eltwise_relu: return relu_fwd(s, alpha_); + case eltwise_tanh: return tanh_fwd(s); + case eltwise_elu: return elu_fwd(s, alpha_); + case eltwise_square: return square_fwd(s); + case eltwise_abs: return abs_fwd(s); + case eltwise_sqrt: return sqrt_fwd(s); + case eltwise_linear: return linear_fwd(s, alpha_, beta_); + case eltwise_bounded_relu: return bounded_relu_fwd(s, alpha_); + case eltwise_soft_relu: return soft_relu_fwd(s); + case eltwise_logistic: return logistic_fwd(s); + default: assert(!"unknown eltwise alg_kind"); + } + + return 0.f; +} + +template +void ref_eltwise_fwd_t::execute_forward_nCspBc_padded( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + const blocking_desc_t &blk = data_d.blocking_desc(); + const int block = blk.inner_blks[0]; + + const int MB = pd()->MB(); + const int C = pd()->C() / block; + const int C_PADDED = data_d.padded_dims()[1] / block; + const int tail = pd()->C() % block; + const int SP = pd()->D() * pd()->H() * pd()->W(); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + + auto ker = [=] (data_t &d, data_t s) { + switch (alg_kind) { + case eltwise_linear: d = linear_fwd(s, alpha, beta); break; + case eltwise_bounded_relu: + d = bounded_relu_fwd(s, alpha); break; + case eltwise_soft_relu: d = soft_relu_fwd(s); break; + case eltwise_logistic: d = logistic_fwd(s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }; + + // FIXME: integer overflow? + + parallel_nd(MB, C_PADDED, SP, + [&](int n, int c, int sp) { + auto d_off = (n*C_PADDED*SP + c*SP + sp) * block; + if (c < C) { + for (int v = 0; v < block; v++) + ker(dst[d_off + v], src[d_off + v]); + } else { + for (int v = 0; v < tail; v++) + ker(dst[d_off + v], src[d_off + v]); + } + }); +} + +template +void ref_eltwise_fwd_t::execute_forward_generic( + const exec_ctx_t &ctx) const { + /* fast return */ + if (pd()->has_zero_dim_memory()) return; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int D = pd()->D(); + const int H = pd()->H(); + const int W = pd()->W(); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + const bool is_3d = pd()->desc()->data_desc.ndims == 5; + + parallel_nd(MB, C, D, H, W, + [&](int n, int c, int id, int h, int w) { + auto d_off = is_3d + ? data_d.off(n, c, id, h, w) : data_d.off(n, c, h, w); + data_t s = src[d_off]; + data_t &d = dst[d_off]; + switch (alg_kind) { + case eltwise_relu: d = relu_fwd(s, alpha); break; + case eltwise_tanh: d = tanh_fwd(s); break; + case eltwise_elu: d = elu_fwd(s, alpha); break; + case eltwise_square: d = square_fwd(s); break; + case eltwise_abs: d = abs_fwd(s); break; + case eltwise_sqrt: d = sqrt_fwd(s); break; + case eltwise_linear: d = linear_fwd(s, alpha, beta); break; + case eltwise_bounded_relu: + d = bounded_relu_fwd(s, alpha); break; + case eltwise_soft_relu: d = soft_relu_fwd(s); break; + case eltwise_logistic: d = logistic_fwd(s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template +void ref_eltwise_fwd_t::execute_forward_dense( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const ptrdiff_t nelems = static_cast(data_d.nelems(true)); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + + src += data_d.offset0(); + dst += data_d.offset0(); + + if (alg_kind == eltwise_relu) { + // a fast path for relu as the most popular activation + parallel_nd(nelems, [&](ptrdiff_t e) { + dst[e] = relu_fwd(src[e], alpha); + }); + return; + } + + parallel_nd(nelems, [&](ptrdiff_t e) { + const data_t s = src[e]; + data_t &d = dst[e]; + + switch (alg_kind) { + case eltwise_tanh: d = tanh_fwd(s); break; + case eltwise_elu: d = elu_fwd(s, alpha); break; + case eltwise_square: d = square_fwd(s); break; + case eltwise_abs: d = abs_fwd(s); break; + case eltwise_sqrt: d = sqrt_fwd(s); break; + case eltwise_linear: d = linear_fwd(s, alpha, beta); break; + case eltwise_bounded_relu: d = bounded_relu_fwd(s, alpha); break; + case eltwise_soft_relu: d = soft_relu_fwd(s); break; + case eltwise_logistic: d = logistic_fwd(s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template +void ref_eltwise_bwd_t::execute_backward_generic( + const exec_ctx_t &ctx) const { + /* fast return */ + if (pd()->has_zero_dim_memory()) return; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int D = pd()->D(); + const int H = pd()->H(); + const int W = pd()->W(); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + const bool is_3d = pd()->desc()->data_desc.ndims == 5; + + parallel_nd(MB, C, D, H, W, + [&](int n, int c, int d, int h, int w) { + auto data_off = is_3d + ? data_d.off(n, c, d, h, w) : data_d.off(n, c, h, w); + auto diff_data_off = is_3d + ? diff_data_d.off(n, c, d, h, w) + : diff_data_d.off(n, c, h, w); + data_t s = src[data_off]; + data_t dd = diff_dst[diff_data_off]; + data_t &ds = diff_src[diff_data_off]; + switch (alg_kind) { + case eltwise_relu: ds = relu_bwd(dd, s, alpha); break; + case eltwise_tanh: ds = tanh_bwd(dd, s); break; + case eltwise_elu: ds = elu_bwd(dd, s, alpha); break; + case eltwise_square: ds = square_bwd(dd, s); break; + case eltwise_abs: ds = abs_bwd(dd, s); break; + case eltwise_sqrt: ds = sqrt_bwd(dd, s); break; + case eltwise_linear: + ds = linear_bwd(dd, s, alpha, beta); break; + case eltwise_bounded_relu: + ds = bounded_relu_bwd(dd, s, alpha); break; + case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break; + case eltwise_logistic: ds = logistic_bwd(dd, s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template +void ref_eltwise_bwd_t::execute_backward_dense( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + + const ptrdiff_t nelems = static_cast(data_d.nelems(true)); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + + src += data_d.offset0(); + diff_dst += diff_data_d.offset0(); + diff_src += diff_data_d.offset0(); + + parallel_nd(nelems, [&](ptrdiff_t e) { + const data_t dd = diff_dst[e]; + const data_t s = src[e]; + data_t &ds = diff_src[e]; + + switch (alg_kind) { + case eltwise_relu: ds = relu_bwd(dd, s, alpha); break; + case eltwise_tanh: ds = tanh_bwd(dd, s); break; + case eltwise_elu: ds = elu_bwd(dd, s, alpha); break; + case eltwise_square: ds = square_bwd(dd, s); break; + case eltwise_abs: ds = abs_bwd(dd, s); break; + case eltwise_sqrt: ds = sqrt_bwd(dd, s); break; + case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break; + case eltwise_bounded_relu: ds = bounded_relu_bwd(dd, s, alpha); break; + case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break; + case eltwise_logistic: ds = logistic_bwd(dd, s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template struct ref_eltwise_fwd_t; +template struct ref_eltwise_fwd_t; +template struct ref_eltwise_fwd_t; +template struct ref_eltwise_fwd_t; + +template struct ref_eltwise_bwd_t; +template struct ref_eltwise_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp new file mode 100644 index 000000000000..8f4ab354135c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp @@ -0,0 +1,168 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_ELTWISE_HPP +#define CPU_REF_ELTWISE_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_eltwise_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ref_eltwise_scalar_fwd_t { +public: + ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, float beta); + + // note that eltwise.scale is ignored + ref_eltwise_scalar_fwd_t(const post_ops_t::entry_t::eltwise_t &eltwise); + + float compute_scalar(float s); + + const alg_kind_t alg_; + const float alpha_; + const float beta_; +}; + +template +struct ref_eltwise_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_eltwise_fwd_pd_t { + using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_eltwise_fwd_t); + + status_t init() { + using namespace utils; + + auto src_d = memory_desc_wrapper(src_md()); + + use_dense_ = false + || src_d.is_dense() + || (src_d.is_dense(true) && is_zero_preserved()); + + use_nCspBc_padded_ = !use_dense_ + && src_d.blocking_desc().inner_nblks == 1 + && one_of(src_d.blocking_desc().inner_blks[0], 8, 16) + && src_d.blocking_desc().inner_idxs[0] == 1 + && src_d.only_padded_dim(1) + && src_d.is_dense(true); + + if (has_zero_dim_memory()) + use_dense_ = use_nCspBc_padded_ = false; + + const bool use_generic = !use_dense_ && !use_nCspBc_padded_; + + bool ok = true + && is_fwd() + && everyone_is(data_type, desc()->data_desc.data_type) + && IMPLICATION(use_generic, one_of(src_d.ndims(), 4, 5)) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + return status::success; + } + + bool use_dense_, use_nCspBc_padded_; + }; + + ref_eltwise_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->use_dense_) + execute_forward_dense(ctx); + else if (pd()->use_nCspBc_padded_) + execute_forward_nCspBc_padded(ctx); + else + execute_forward_generic(ctx); + return status::success; + } + +private: + void execute_forward_nCspBc_padded(const exec_ctx_t &ctx) const; + void execute_forward_dense(const exec_ctx_t &ctx) const; + void execute_forward_generic(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_eltwise_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_eltwise_bwd_pd_t { + using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_eltwise_bwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && !is_fwd() + && everyone_is(data_type, + desc()->data_desc.data_type, + desc()->diff_data_desc.data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + auto diff_dst_d = memory_desc_wrapper(diff_dst_md()); + const bool same_fmt_ = diff_dst_d == memory_desc_wrapper(src_md()); + + use_dense_ = true + && same_fmt_ + && diff_dst_d.is_dense(true) + && is_zero_preserved() + && !has_zero_dim_memory(); + const bool use_generic = !use_dense_; + + if (use_generic && !one_of(diff_dst_d.ndims(), 4, 5)) + return status::unimplemented; + + return status::success; + } + + bool use_dense_; + }; + + ref_eltwise_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->use_dense_) + execute_backward_dense(ctx); + else + execute_backward_generic(ctx); + return status::success; + } + +private: + void execute_backward_dense(const exec_ctx_t &ctx) const; + void execute_backward_generic(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp new file mode 100644 index 000000000000..c807a9ffd066 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp @@ -0,0 +1,285 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_traits.hpp" +#include "math_utils.hpp" + +#include "ref_inner_product.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using math::saturate; +using math::get_bias; + +template +void ref_inner_product_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC(); + + const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4, 5); + const int ndims = src_d.ndims() - 2; + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_relu = post_ops.len_ == 1; + const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f; + + auto ker_has_spatial = [=](int mb, int oc) { + acc_data_t d = 0; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + for (int ic = 0; ic < IC; ++ic) { + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + switch (ndims) { + case 3: + d += (acc_data_t)src[src_d.off(mb, ic, kd, kh, kw)] + * weights[weights_d.off( + oc, ic, kd, kh, kw)]; + break; + case 2: + d += (acc_data_t)src[src_d.off(mb, ic, kh, kw)] + * weights[weights_d.off(oc, ic, kh, kw)]; + break; + case 1: + d += (acc_data_t)src[src_d.off(mb, ic, kw)] + * weights[weights_d.off(oc, ic, kw)]; + break; + default: assert(!"unsupported ndims size"); + } + } + } + } + } + return d; + }; + + auto ker_no_spatial = [=](int mb, int oc) { + acc_data_t d = 0; + for (int ic = 0; ic < IC; ++ic) { + d += (acc_data_t)src[src_d.off(mb, ic)] + * weights[weights_d.off(oc, ic)]; + } + return d; + }; + + parallel_nd(MB, OC, [&](int mb, int oc) { + float a = bias + ? get_bias(bias, bias_d.off(oc), pd()->desc()->bias_desc.data_type) + : 0; + if (src_has_spatial) + a += ker_has_spatial(mb, oc); + else + a += ker_no_spatial(mb, oc); + if (do_relu && a < (acc_data_t)0) + a *= nslope; + dst[dst_d.off(mb, oc)] = saturate(a); + }); +} + +using namespace data_type; +template struct ref_inner_product_fwd_t; +template struct ref_inner_product_fwd_t; +template struct ref_inner_product_fwd_t; +template struct ref_inner_product_fwd_t; +template struct ref_inner_product_fwd_t; + +template +void ref_inner_product_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC(); + + const bool diff_src_has_spatial + = utils::one_of(diff_src_d.ndims(), 3, 4, 5); + const int ndims = diff_src_d.ndims() - 2; + + parallel_nd(MB, IC, [&](int mb, int ic) { + if (diff_src_has_spatial) { + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + acc_data_t ds = acc_data_t(0); + for (int oc = 0; oc < OC; ++oc) { + switch (ndims) { + case 3: + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] + * weights[weights_d.off(oc, ic, kd, kh, kw)]); + break; + case 2: + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] + * weights[weights_d.off(oc, ic, kh, kw)]); + break; + case 1: + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] + * weights[weights_d.off(oc, ic, kw)]); + break; + default: assert(!"unsupported ndims size"); + } + } + switch (ndims) { + case 3: + diff_src[diff_src_d.off(mb, ic, kd, kh, kw)] + = (diff_src_data_t)ds; + break; + case 2: + diff_src[diff_src_d.off(mb, ic, kh, kw)] + = (diff_src_data_t)ds; + break; + case 1: + diff_src[diff_src_d.off(mb, ic, kw)] = (diff_src_data_t)ds; + break; + default: assert(!"unsupported ndims size"); + } + } + } else { + acc_data_t ds = acc_data_t(0); + for (int oc = 0; oc < OC; ++oc) { + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] * + weights[weights_d.off(oc, ic)]); + } + diff_src[diff_src_d.off(mb, ic)] = (diff_src_data_t)ds; + } + }); +} + +template struct ref_inner_product_bwd_data_t; + +template +void ref_inner_product_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC(); + + const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4 ,5); + const int ndims = src_d.ndims() - 2; + + parallel_nd(OC, IC, [&](int oc, int ic) { + if (src_has_spatial) { + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + data_t *dw(nullptr); + switch (ndims) { + case 3: + dw = &diff_weights[diff_weights_d.off( + oc, ic, kd, kh, kw)]; + break; + case 2: + dw = &diff_weights[diff_weights_d.off( + oc, ic, kh, kw)]; + break; + case 1: + dw = &diff_weights[diff_weights_d.off(oc, ic, kw)]; + break; + default: assert(!"unsupported ndims size"); + } + *dw = data_t(0); + for (int mb = 0; mb < MB; ++mb) { + switch (ndims) { + case 3: + *dw += diff_dst[diff_dst_d.off(mb, oc)] + * src[src_d.off(mb, ic, kd, kh, kw)]; + break; + case 2: + *dw += diff_dst[diff_dst_d.off(mb, oc)] + * src[src_d.off(mb, ic, kh, kw)]; + break; + case 1: + *dw += diff_dst[diff_dst_d.off(mb, oc)] + * src[src_d.off(mb, ic, kw)]; + break; + default: assert(!"unsupported ndims size"); + } + } + } + } + } + } else { + data_t *dw = &diff_weights[diff_weights_d.off(oc, ic)]; + *dw = data_t(0); + for (int mb = 0; mb < MB; ++mb) { + *dw += diff_dst[diff_dst_d.off(mb, oc)] * + src[src_d.off(mb, ic)]; + } + } + }); + + if (diff_bias) { + diff_bias += diff_bias_d.offset0(); + + parallel_nd(OC, [&](int oc) { + data_t *db = &diff_bias[oc]; + *db = data_t(0); + for (int mb = 0; mb < MB; ++mb) + *db += diff_dst[diff_dst_d.off(mb, oc)]; + }); + } +} + +template struct ref_inner_product_bwd_weights_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp new file mode 100644 index 000000000000..bf87dbd51459 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp @@ -0,0 +1,159 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_INNER_PRODUCT_HPP +#define CPU_REF_INNER_PRODUCT_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_inner_product_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_inner_product_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_fwd_pd_t { + using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_inner_product_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && src_md()->data_type == src_type + && weights_md()->data_type == wei_type + && desc()->accum_data_type == acc_type + && dst_md()->data_type == dst_type + && IMPLICATION(with_bias(), utils::one_of( + weights_md(1)->data_type, f32, s32, s8, u8)) + && attr()->output_scales_.has_default_values() + && attr()->post_ops_.len_ <= 1 + && IMPLICATION(attr()->post_ops_.len_ == 1, + attr()->post_ops_.entry_[0].is_relu(true, false)); + return ok ? status::success : status::unimplemented; + } + }; + + ref_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_inner_product_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_data_pd_t { + using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_data_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_data + && diff_src_md()->data_type == diff_src_type + && weights_md()->data_type == wei_type + && desc()->accum_data_type == acc_type + && diff_dst_md()->data_type == diff_dst_type + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + }; + + ref_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type diff_src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_inner_product_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_weights_pd_t { + using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_weights_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_weights + && utils::everyone_is(data_type, + src_md()->data_type, + diff_dst_md()->data_type, + diff_weights_md()->data_type) + && IMPLICATION(with_bias(), + data_type == diff_weights_md(1)->data_type) + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + }; + + ref_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp new file mode 100644 index 000000000000..325e97963bf6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp @@ -0,0 +1,252 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +#include "ref_lrn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +static inline float fast_negative_powf(float omega, float beta) { + float Y; +/* + * Y = omega^(-3/4) = + * = 1.0f / sqrtf(omega) * sqrtf(1.0f / sqrtf(omega)) + * = sqrtf(1.0f / sqrtf(omega)) * 1.0f / sqrtf(omega) + * = sqrtf(1.0f / sqrtf(omega)) / sqrtf(omega) + * = sqrtf(1.0f / sqrtf(omega) / omega) + * = sqrtf(1.0f / (sqrtf(omega) * omega)) + */ + if (beta == 0.75f) { + Y = sqrtf(1.0f / (sqrtf(omega) * omega)); + } else { + Y = 1.0f / powf(omega, beta); + } + return Y; +}; + +template +template +void ref_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace format_tag; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const size_t stride_mb = data_d.blocking_desc().strides[0]; + const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels; + constexpr int blksize = tag == nChw16c ? 16 : 8; + + auto data_off = [&](int mb, int c, int h, int w) -> size_t { + switch (tag) { + case nChw16c: + case nChw8c: return mb * stride_mb + c / blksize * H * W * blksize + + h * W * blksize + w * blksize + c % blksize; + case nchw: return mb * stride_mb + c * H * W + h * W + w; + case nhwc: return mb * stride_mb + h * W * C + w * C + c; + default: return data_d.off(mb, c, h, w); + } + }; + + auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) { + const float alpha = static_cast(pd()->desc()->lrn_alpha); + const float beta = static_cast(pd()->desc()->lrn_beta); + const float k = static_cast(pd()->desc()->lrn_k); + + const int size = pd()->desc()->local_size; + const int half_size = (size - 1) / 2; + + float sum = 0; + if (across_channels) { + const int c_st = nstl::max(oc - half_size + 0, 0); + const int c_en = nstl::min(oc + half_size + 1, C); + + for (int c = c_st; c < c_en; ++c) { + const float s = src[data_off(mb, c, oh, ow)]; + sum += s * s; + } + } else { + int h_st = nstl::max(oh - half_size + 0, 0); + int h_en = nstl::min(oh + half_size + 1, H); + int w_st = nstl::max(ow - half_size + 0, 0); + int w_en = nstl::min(ow + half_size + 1, W); + for (int h = h_st; h < h_en; ++h) { + for (int w = w_st; w < w_en; ++w) { + const float s = src[data_off(mb, oc, h, w)]; + sum += s * s; + } + } + } + const int summands = across_channels ? size : size * size; + sum = k + alpha * sum / summands; + size_t off = data_off(mb, oc, oh, ow); + d[0] = static_cast(src[off] * fast_negative_powf(sum, beta)); + }; + + const int MB = pd()->MB(); + if (tag == nChw16c || tag == nChw8c) { + parallel_nd(MB, utils::div_up(C, blksize), H, W, + [&](int mb, int c_blk, int h, int w) { + int c = c_blk * blksize; + const size_t off = mb * stride_mb + c * H * W + + (h * W + w) * blksize; + PRAGMA_OMP_SIMD() + for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc) + ker(&dst[off + cc], mb, c + cc, h, w); + }); + } else if (tag == nhwc) { + parallel_nd(MB, H, W, C, + [&](int mb, int h, int w, int c) { + const size_t off = mb * stride_mb + h * W * C + w * C + c; + ker(&dst[off], mb, c, h, w); + }); + } else { + parallel_nd(MB, C, H, W, + [&](int mb, int c, int h, int w) { + const size_t off = data_off(mb, c, h, w); + ker(&dst[off], mb, c, h, w); + }); + } +} + +template +template +void ref_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace format_tag; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const size_t stride_mb = data_d.blocking_desc().strides[0]; + constexpr int blksize = tag == nChw16c ? 16 : 8; + + const float alpha = static_cast(pd()->desc()->lrn_alpha); + const float beta = static_cast(pd()->desc()->lrn_beta); + const float k = static_cast(pd()->desc()->lrn_k); + const int kernel_size = pd()->desc()->local_size; + const int half_ksize = (kernel_size - 1) / 2; + + auto data_off = [&](int mb, int c, int h, int w) -> size_t { + switch (tag) { + case nChw16c: + case nChw8c: return mb * stride_mb + c/blksize * H * W * blksize + + h * W * blksize + w * blksize + c%blksize; + case nchw: return mb * stride_mb + c * H * W + h * W + w; + case nhwc: return mb * stride_mb + h * W * C + w * C + c; + default: return data_d.off(mb, c, h, w); + } + }; + + auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) { + const int c_st = nstl::max(oc - half_ksize + 0, 0); + const int c_en = nstl::min(oc + half_ksize + 1, C); + + float A = 0, B = 0, omega_mid = 0; + for (int c = c_st; c < c_en; c++) { + float sum = 0.0; + const int i_st = nstl::max(c - half_ksize, 0); + const int i_en = nstl::min(c + kernel_size - half_ksize, C); + + for (int i = i_st; i < i_en; ++i) { + const float value = src[data_off(mb, i, oh, ow)]; + sum += value * value; + } + const float omega = static_cast(k + sum * alpha / kernel_size); + if (c == oc) omega_mid = omega; + float t = src[data_off(mb, c, oh, ow)] + * fast_negative_powf(omega, beta); + B += 1.0f / omega * t * diff_dst[data_off(mb, c, oh, ow)]; + } + + const size_t off = data_off(mb, oc, oh, ow); + A = fast_negative_powf(omega_mid, beta) * diff_dst[off]; + B *= src[off]; + B *= (2.0f * alpha * beta) / kernel_size; + *d = static_cast(A - B); // final cast down to data_t + }; + + if (tag == nChw16c || tag == nChw8c) { + parallel_nd(MB, utils::div_up(C, blksize), H, W, + [&](int mb, int c_blk, int h, int w) { + int c = c_blk * blksize; + const size_t off = mb * stride_mb + c * H * W + + (h * W + w) * blksize; + PRAGMA_OMP_SIMD() + for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc) + ker(&diff_src[off + cc], mb, c + cc, h, w); + }); + } else if (tag == nhwc) { + parallel_nd(MB, H, W, C, + [&](int mb, int h, int w, int c) { + const size_t off = mb * stride_mb + h * W * C + w * C + c; + ker(&diff_src[off], mb, c, h, w); + }); + } else { + parallel_nd(MB, C, H, W, + [&](int mb, int c, int h, int w) { + const size_t off = data_off(mb, c, h, w); + ker(&diff_src[off], mb, c, h, w); + }); + } +} + +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp new file mode 100644 index 000000000000..f25cfb7faefa --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp @@ -0,0 +1,136 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_LRN_HPP +#define CPU_REF_LRN_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_lrn_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_lrn_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_fwd_pd_t { + using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_lrn_fwd_t); + + status_t init() { + using namespace format_tag; + + bool ok = true + && is_fwd() + && src_md()->data_type == data_type + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + dat_tag_ = memory_desc_matches_one_of_tag( + *src_md(), nChw16c, nChw8c, nchw, nhwc); + + return status::success; + } + + format_tag_t dat_tag_; + }; + + ref_lrn_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + using namespace format_tag; + switch (pd()->dat_tag_) { + case nChw16c: execute_forward(ctx); break; + case nChw8c: execute_forward(ctx); break; + case nchw: execute_forward(ctx); break; + case nhwc: execute_forward(ctx); break; + default: execute_forward(ctx); + } + return status::success; + } + +private: + template + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_lrn_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_bwd_pd_t { + using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_lrn_bwd_t); + + status_t init() { + using namespace format_tag; + using namespace alg_kind; + + bool ok = true + && !is_fwd() + && utils::one_of(desc()->alg_kind, lrn_across_channels + /*, lrn_within_channel */) // not supported yet + && utils::everyone_is(data_type, + src_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + dat_tag_ = memory_desc_matches_one_of_tag( + *src_md(), nChw16c, nChw8c, nchw, nhwc); + + return status::success; + } + + format_tag_t dat_tag_; + }; + + ref_lrn_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + using namespace format_tag; + switch (pd()->dat_tag_) { + case nChw16c: execute_backward(ctx); break; + case nChw8c: execute_backward(ctx); break; + case nchw: execute_backward(ctx); break; + case nhwc: execute_backward(ctx); break; + default: execute_backward(ctx); + } + return status::success; + } + +private: + template + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp new file mode 100644 index 000000000000..65b934e12335 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp @@ -0,0 +1,381 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" + +#include "ref_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void ref_pooling_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace prop_kind; + + auto alg = pd()->desc()->alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper ws_d(pd()->workspace_md()); + const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool is_3d = pd()->desc()->src_desc.ndims == 5; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto set_ws = [=](int mb, int oc, int od, int oh, int ow, int value) { + if (ws) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + size_t offset = is_3d + ? ws_d.off(mb, oc, od, oh, ow) : ws_d.off(mb, oc, oh, ow);; + if (ws_dt == data_type::u8) { + assert(0 <= value && value <= 255); + ws[offset] = value; + } else + reinterpret_cast(ws)[offset] = value; + } + }; + + auto ker_max = [=](data_t *d, int mb, int oc, int oh, int ow) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + auto s = src[src_d.off(mb, oc, ih, iw)]; + if (s > d[0]) { + d[0] = s; + set_ws(mb, oc, 1, oh, ow, kh*KW + kw); + } + } + } + }; + + auto ker_avg = [=](data_t *d, int mb, int oc, int oh, int ow) { + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH + : (ih_end - ih_start)*(iw_end - iw_start); + + acc_data_t dst = 0; + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + dst += src[src_d.off(mb, oc, ih, iw)]; + } + } + + d[0] = math::out_round((float)dst / num_summands); + }; + + auto ker_max_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) { + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (id < 0 || id >= ID) continue; + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + auto s = src[src_d.off(mb, oc, id, ih, iw)]; + if (s > d[0]) { + d[0] = s; + set_ws(mb, oc, od, oh, ow, kd * KH * KW + kh*KW + kw); + } + } + } + } + }; + + auto ker_avg_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD + : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); + + acc_data_t dst = 0; + for (int id = id_start; id < id_end; ++id) { + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + dst += src[src_d.off(mb, oc, id, ih, iw)]; + } + } + } + + d[0] = math::out_round((float)dst / num_summands); + }; + + const int MB = pd()->MB(); + const int OC = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + if (alg == pooling_max) { + parallel_nd(MB, OC, OD, OH, OW, + [&](int mb, int oc, int od, int oh, int ow) { + data_t *d = is_3d + ? &dst[dst_d.off(mb, oc, od, oh, ow)] + : &dst[dst_d.off(mb, oc, oh, ow)]; + d[0] = nstl::numeric_limits::lowest(); + set_ws(mb, oc, od, oh, ow, 0); + if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow); + else ker_max(d, mb, oc, oh, ow); + }); + } else { + parallel_nd(MB, OC, OD, OH, OW, + [&](int mb, int oc, int od, int oh, int ow) { + data_t *d = is_3d + ? &dst[dst_d.off(mb, oc, od, oh, ow)] + : &dst[dst_d.off(mb, oc, oh, ow)]; + d[0] = 0; + if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow); + else ker_avg(d, mb, oc, oh, ow); + }); + } +} + +template +void ref_pooling_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper ws_d(pd()->workspace_md()); + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; + + auto alg = pd()->desc()->alg_kind; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto ker_zero = [=](int _mb, int _oc) { + for (int ih = 0; ih < IH; ++ih) { + for (int iw = 0; iw < IW; ++iw) { + diff_src[diff_src_d.off(_mb, _oc, ih, iw)] = data_type_t(0); + } + } + }; + + auto ker_max = [=](const data_t *d, int mb, int oc, int oh, int ow) { + const size_t ws_off = ws_d.off(mb, oc, oh, ow); + const int index = ws_d.data_type() == data_type::u8 + ? (int)ws[ws_off] : ((int *)ws)[ws_off]; + const int kw = index % KW; + const int kh = index / KW; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + // If padding area could fit the kernel, + // then input displacement would be out of bounds. + // No need to back propagate there as padding is + // virtual in pooling_max case. + if (ih < 0 || ih >= IH) + return; + if (iw < 0 || iw >= IW) + return; + + diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0]; + }; + + auto ker_avg = [=](const data_t *d, int mb, int oc, int oh, int ow) { + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH + : (ih_end - ih_start)*(iw_end - iw_start); + + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0] / num_summands; + } + } + }; + + auto ker_zero_3d = [=](int _mb, int _oc) { + for (int id = 0; id < ID; ++id) { + for (int ih = 0; ih < IH; ++ih) { + for (int iw = 0; iw < IW; ++iw) { + diff_src[diff_src_d.off(_mb, _oc, id, ih, iw)] = + data_type_t(0); + } + } + } + }; + + auto ker_max_3d = [=](const data_t *d, int mb, int oc, int od, int oh, + int ow) { + const size_t ws_off = ws_d.off(mb, oc, od, oh, ow); + const int index = ws_d.data_type() == data_type::u8 + ? (int)ws[ws_off] : ((int *)ws)[ws_off]; + const int kw = index % KW; + const int kh = (index / KW) % KH; + const int kd = (index / KW) / KH; + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + // If padding area could fit the kernel, + // then input displacement would be out of bounds. + // No need to back propagate there as padding is + // virtual in pooling_max case. + if (id < 0 || id >= ID) + return; + if (ih < 0 || ih >= IH) + return; + if (iw < 0 || iw >= IW) + return; + + diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0]; + }; + + auto ker_avg_3d = [=](const data_t *d, int mb, int oc, int od, int oh, + int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD + : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); + + for (int id = id_start; id < id_end; ++id) + for (int ih = ih_start; ih < ih_end; ++ih) + for (int iw = iw_start; iw < iw_end; ++iw) { + diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0] / num_summands; + } + }; + + const int MB = pd()->MB(); + const int OC = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + if (pd()->desc()->alg_kind == alg_kind::pooling_max) { + parallel_nd(MB, OC, [&](int mb, int oc) { + if (is_3d) ker_zero_3d(mb, oc); + else ker_zero(mb, oc); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = is_3d + ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)] + : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)]; + if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow); + else ker_max(d, mb, oc, oh, ow); + } + } + } + }); + } else { + parallel_nd(MB, OC, [&](int mb, int oc) { + if (is_3d) ker_zero_3d(mb, oc); + else ker_zero(mb, oc); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = is_3d + ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)] + : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)]; + if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow); + else ker_avg(d, mb, oc, oh, ow); + } + } + } + }); + } +} + +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; + +template struct ref_pooling_bwd_t; +template struct ref_pooling_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp new file mode 100644 index 000000000000..e43ceaa82b24 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp @@ -0,0 +1,119 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_POOLING_HPP +#define CPU_REF_POOLING_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_pooling_fwd_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && is_fwd() + && utils::everyone_is(data_type, src_md()->data_type, + dst_md()->data_type) + && desc()->accum_data_type == acc_type + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return status::success; + } + }; + + ref_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_pooling_bwd_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && utils::everyone_is(data_type, diff_dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + init_default_ws(); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return status::success; + } + }; + + ref_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp new file mode 100644 index 000000000000..af2774311064 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp @@ -0,0 +1,153 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +#include "ref_shuffle.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace format_tag; + +template +template +void ref_shuffle_t::execute_(const exec_ctx_t &ctx) const { + using namespace prop_kind; + using namespace utils; + + const memory_desc_wrapper data_d(pd()->data_md()); + + auto i_arg = pd()->is_fwd() ? MKLDNN_ARG_SRC : MKLDNN_ARG_DIFF_DST; + auto o_arg = pd()->is_fwd() ? MKLDNN_ARG_DST : MKLDNN_ARG_DIFF_SRC; + auto input = CTX_IN_MEM(const data_t *, i_arg); + auto output = CTX_OUT_MEM(data_t *, o_arg); + + const int axis = pd()->axis(); + const int axis_size = pd()->axis_size(); + + const int MB = pd()->MB(); + const int C = pd()->C(); + int H = 1, W = 1, D = 1, HW = 1, SP = 1; + const bool has_spatial = utils::one_of(data_d.ndims(), 3, 4 ,5); + if (has_spatial) + { + D = pd()->D(); + H = pd()->H(); + W = pd()->W(); + HW = H * W; + SP = D * HW; + } + const size_t stride_mb = data_d.blocking_desc().strides[0]; + constexpr int blksize = one_of(tag, nChw16c, nCdhw16c) ? 16 : 8; + + if (axis == 1 && one_of(tag, nChw16c, nChw8c, nCdhw16c, nCdhw16c)) { +#if MKLDNN_THR == MKLDNN_THR_OMP +# pragma omp parallel for collapse(3) schedule(static) + for (int mb = 0; mb < MB; ++mb) + for (int cb = 0; cb < C; cb += blksize) + for (int sp = 0; sp < SP; ++sp) { + const size_t off = mb * stride_mb + sp * blksize; + const size_t output_off = off + cb * SP; + PRAGMA_OMP_SIMD() + for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc) + { + int input_c = rev_transposed_[cb + cc]; + const size_t input_off = off + input_c / blksize * SP * blksize + + input_c % blksize; + output[output_off + cc] = input[input_off]; + } + } +#else + parallel_nd(MB, utils::div_up(C, blksize), SP, [&](int mb, int c, + int sp) { + const size_t off = mb * stride_mb + sp * blksize; + const int cb = c * blksize; + const size_t output_off = off + cb * SP; + for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc) + { + int input_c = rev_transposed_[cb + cc]; + const size_t input_off = off + input_c / blksize * SP * blksize + + input_c % blksize; + output[output_off + cc] = input[input_off]; + } + }); +#endif + } else if (axis == 1 && one_of(tag, nhwc, ndhwc)) { + parallel_nd(MB, SP, [&](int mb, int sp) { + const size_t off = mb * stride_mb + sp * C; + PRAGMA_OMP_SIMD() + for (int c = 0; c < C; ++c) + output[off + c] = input[off + rev_transposed_[c]]; + }); + } else if (axis == 1 && one_of(tag, nchw, ncdhw)) { + parallel_nd(MB, C, [&](int mb, int c) { + const size_t output_off = mb * stride_mb + c * SP; + const size_t input_off = mb * stride_mb + rev_transposed_[c] * SP; + PRAGMA_OMP_SIMD() + for (int sp = 0; sp < SP; ++sp) { + output[output_off + sp] = input[input_off + sp]; + } + }); + } else { + auto dims = pd()->desc()->data_desc.dims; + auto ndims = pd()->desc()->data_desc.ndims; + const size_t outer_size = utils::array_product(dims, axis); + const size_t inner_size = utils::array_product(dims + axis + 1, + ndims - axis - 1); + const size_t dim = axis_size * inner_size; + + parallel_nd(outer_size, axis_size, inner_size, [&](size_t ou, int a, + size_t in) + { + const size_t off = ou * dim + in; + auto &o = output[data_d.off_l(off + a * inner_size)]; + o = input[data_d.off_l(off + rev_transposed_[a] * inner_size)]; + }); + } +} + +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; + +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp new file mode 100644 index 000000000000..5e09a1a69bc8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp @@ -0,0 +1,111 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_SHUFFLE_HPP +#define CPU_REF_SHUFFLE_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_shuffle_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_shuffle_t : public cpu_primitive_t { + using shuffle_class = ref_shuffle_t; + + struct pd_t: public cpu_shuffle_pd_t { + using cpu_shuffle_pd_t::cpu_shuffle_pd_t; + + DECLARE_COMMON_PD_T("ref:any", shuffle_class); + + status_t init() { + using namespace format_tag; + + bool ok = true + && data_type_size + == types::data_type_size(data_md()->data_type); + if (!ok) return status::unimplemented; + + if (ndims() == 5) { + dat_tag_ = memory_desc_matches_one_of_tag( + *data_md(), nCdhw16c, nCdhw8c, ncdhw, ndhwc); + } else if (ndims() == 4) { + dat_tag_ = memory_desc_matches_one_of_tag( + *data_md(), nChw16c, nChw8c, nchw, nhwc); + } else + dat_tag_ = any; + + return status::success; + } + + format_tag_t dat_tag_; + }; + + ref_shuffle_t(const pd_t *apd): cpu_primitive_t(apd) { + const int axis_size = pd()->axis_size(); + const int group_size = pd()->group_size(); + const int transpose_row = pd()->is_fwd() ? group_size + : axis_size / group_size; + const int transpose_col = pd()->is_fwd() ? axis_size / group_size + : group_size; + rev_transposed_ = (int *)malloc(axis_size * sizeof(int), 64); + parallel_nd(transpose_col, transpose_row, [&](int i, int j) { + rev_transposed_[j * transpose_col + i] = i * transpose_row + j; + }); + } + + ~ref_shuffle_t() { free(rev_transposed_); } + + typedef typename typesize_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + using namespace format_tag; + switch (pd()->dat_tag_) { + case nCdhw16c: execute_(ctx); break; + case nChw16c: execute_(ctx); break; + case nCdhw8c: execute_(ctx); break; + case nChw8c: execute_(ctx); break; + case ncdhw: execute_(ctx); break; + case nchw: execute_(ctx); break; + case ndhwc: execute_(ctx); break; + case nhwc: execute_(ctx); break; + default: execute_(ctx); break; + } + return status::success; + } + +private: + template + void execute_(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + int *rev_transposed_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp new file mode 100644 index 000000000000..36d5237f564e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp @@ -0,0 +1,264 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +#include "ref_softmax.hpp" +#include "gemm/os_blas.hpp" + +#ifdef USE_MKL +#include "mkl_vml_functions.h" +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void ref_softmax_fwd_t::execute_forward_dense( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + parallel_nd(outer_size_, [&](int ou) { + const data_t *src_data = src + ou * channels_; + data_t *dst_data = dst + ou * channels_; + data_t scalar = 0; + + _max(channels_, src_data, &scalar); + _sub(channels_, scalar, src_data, dst_data); + _exp(channels_, dst_data, dst_data); + _sum(channels_, dst_data, &scalar); + _scal(channels_, data_t(1)/scalar, dst_data); + }); +} + +template +void ref_softmax_fwd_t::execute_forward_generic( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + data_t space_max_val = 0, space_denom_val = 0; + data_t *space_max = &space_max_val, *space_denom = &space_denom_val; + if (inner_size_ > 1) { + using namespace memory_tracking::names; + space_max = scratchpad(ctx).template get(key_softmax_reduction); + space_denom = space_max + inner_size_; + } + + const memory_desc_wrapper data_d(pd()->src_md()); + const size_t dim = channels_ * inner_size_; + + for (int ou = 0; ou < outer_size_; ou++) { + utils::array_set(space_max, -FLT_MAX, inner_size_); + utils::array_set(space_denom, 0, inner_size_); + + for (int c = 0; c < channels_; c++) { + for(int in = 0; in < inner_size_; in++) { + size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); + space_max[in] = nstl::max(space_max[in], src[off]); + } + } + + for (int c = 0; c < channels_; c++) { + for(int in = 0; in < inner_size_; in++) { + size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); + space_denom[in] += dst[off] = exp(src[off] - space_max[in]); + } + } + + for (int c = 0; c < channels_; c++) { + for (int in = 0; in < inner_size_; in++) { + size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); + dst[off] /= space_denom[in]; + } + } + } +} + +template +void ref_softmax_fwd_t::_max(int n, const data_t *x, + data_t *max_data) const { +// Intel(R) C++ Compiler generates the maxps + shuffle pattern +// for the max search which works faster +#if !defined(__INTEL_COMPILER) + // The code below makes a compiler to generate maxps instruction + // rather than maxss, which is generated for the 'else' code path + auto max_wrapper = [](data_t a, data_t b) { return nstl::max(a, b); }; + auto min_wrapper = [](int a, int b) { return nstl::min(a, b); }; + + constexpr int unroll_factor = 32; + data_t max_values[unroll_factor]; + + if (n < unroll_factor) { + data_t max_val = x[0]; + for (int i = 1; i < n; i++) { + max_val = max_wrapper(max_val, x[i]); + } + max_data[0] = max_val; + return; + } + for (int i = 0; i < unroll_factor; i++) { + max_values[i] = x[i]; + } + for (int i = unroll_factor; i < n; i += unroll_factor) { + int offset = min_wrapper(i, n - unroll_factor); + for (int j = 0; j < unroll_factor; j++) { + max_values[j] = max_wrapper(max_values[j], x[offset + j]); + } + } + data_t max_val = max_values[0]; + for (int i = 1; i < unroll_factor; i++) { + max_val = max_wrapper(max_val, max_values[i]); + } + max_data[0] = max_val; +#else + max_data[0] = x[0]; + for (int c = 1; c < n; ++c) + max_data[0] = nstl::max(max_data[0], x[c]); +#endif +} + +template +void ref_softmax_fwd_t::_sub(int n, data_t alpha, const data_t *x, + data_t *y) const { + constexpr int unroll_factor = 32; + int tail = n % unroll_factor; + for (int i = 0; i < n - tail; i += unroll_factor) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < unroll_factor; j++) { + y[i + j] = x[i + j] - alpha; + } + } + PRAGMA_OMP_SIMD() + for (int i = n - tail; i < n; i++) { + y[i] = x[i] - alpha; + } +} + +template +void ref_softmax_fwd_t::_exp(int n, const data_t *a, + data_t *r) const { +#ifdef USE_MKL + if (data_type == data_type::f32) { + vsExp(n, a, r); + return; + } +#endif + parallel_nd(n, [&](int c) { r[c] = expf(a[c]); }); +} + +template +void ref_softmax_fwd_t::_sum(int n, const data_t *x, + data_t *sum_data) const { +#ifdef USE_CBLAS + // Here we are summing x's eg. e^z , which are positives + // so we can use BLAS ASUM + if (data_type == data_type::f32) { + sum_data[0] = cblas_sasum(n, x, 1); + return; + } +#endif + data_t tsum = static_cast(0); + PRAGMA_OMP_SIMD(reduction(+ : tsum)) + for (int c = 0; c < n; ++c) + tsum += x[c]; + sum_data[0] = tsum; +} + +template +void ref_softmax_fwd_t::_scal(int n, data_t alpha, data_t *x) const { +#ifdef USE_CBLAS + if (data_type == data_type::f32) { + cblas_sscal(n, alpha, x, 1); + return; + } +#endif + parallel_nd(n, [&](int c) { x[c] *= alpha; }); +} + +template struct ref_softmax_fwd_t; + + +// NC/NCHW softmax for along final axe (1 for NC, 3 for NCHW) +template +void ref_softmax_bwd_t::execute_backward_dense( + const exec_ctx_t &ctx) const { + auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + parallel_nd(outer_size_, [&](int ou) { + data_t sbr = 0; + size_t off = channels_*ou; + for (int c = 0; c < channels_; c++) { + size_t loff = off + c; + data_t ldata = dst[loff]; + sbr += diff_dst[loff]*ldata; + diff_src[loff] = ldata; + } + + for(int c=0; c < channels_ ; ++c) { + size_t loff = off + c; + diff_src[loff] *= (diff_dst[loff] - sbr); + } + }); +} + +template +void ref_softmax_bwd_t::execute_backward_generic( + const exec_ctx_t &ctx) const { + auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_d(pd()->diff_src_md()); + const memory_desc_wrapper data_d(pd()->dst_md()); + + const size_t dim = channels_ * inner_size_; + + parallel_nd(outer_size_, [&](int ou) { + for (int in = 0; in < inner_size_; in++) { + data_t sbr = 0; + for (int c = 0; c < channels_; c++) { + size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in); + size_t off_data = diff_d.off_l(ou * dim + c * inner_size_ + in); + sbr += diff_dst[off_diff] * dst[off_data]; + } + + for(int c=0; c < channels_ ; ++c) { + size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in); + size_t off_data = data_d.off_l(ou * dim + c * inner_size_ + in); + diff_src[off_diff] = dst[off_data] * (diff_dst[off_diff] - sbr); + } + } + }); +} + +template struct ref_softmax_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp new file mode 100644 index 000000000000..5cb74d8007a9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp @@ -0,0 +1,186 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_SOFTMAX_HPP +#define CPU_REF_SOFTMAX_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_softmax_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_softmax_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_softmax_fwd_pd_t { + using cpu_softmax_fwd_pd_t::cpu_softmax_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_softmax_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && src_md()->data_type == data_type + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + const int inner_size = utils::array_product( + desc()->data_desc.dims + desc()->softmax_axis + 1, + desc()->data_desc.ndims - desc()->softmax_axis - 1); + + if (inner_size > 1) { + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(memory_tracking::names::key_softmax_reduction, + sizeof(data_t) * 2 * inner_size); + } + } + }; + + ref_softmax_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { + auto ndims = pd()->desc()->data_desc.ndims; + auto dims = pd()->desc()->data_desc.dims; + auto axis = pd()->desc()->softmax_axis; + + outer_size_ = utils::array_product(dims, axis); + channels_ = dims[axis]; + inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1); + + const memory_desc_wrapper data_d(pd()->src_md()); + + bool no_axis_blocking = true; + for (int iblk = 0; iblk < data_d.blocking_desc().inner_nblks; ++iblk) + if (data_d.blocking_desc().inner_idxs[iblk] == axis) + no_axis_blocking = false; + + use_dense_ = inner_size_ == 1 && data_d.is_dense() + && no_axis_blocking + && data_d.blocking_desc().strides[axis] == 1; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (use_dense_) + execute_forward_dense(ctx); + else + execute_forward_generic(ctx); + return status::success; + } + +private: + void execute_forward_dense(const exec_ctx_t &ctx) const; + void execute_forward_generic(const exec_ctx_t &ctx) const; + + void _max(int n, const data_t *x, data_t *max_data) const; + void _sub(int n, data_t alpha, const data_t *x, data_t *y) const; + void _exp(int n, const data_t *a, data_t *r) const; + void _sum(int n, const data_t *x, data_t *sum_data) const; + void _scal(int n, data_t alpha, data_t *x) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + bool use_dense_; + int outer_size_, channels_, inner_size_; +}; + +template +struct ref_softmax_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_softmax_bwd_pd_t { + using cpu_softmax_bwd_pd_t::cpu_softmax_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_softmax_bwd_t); + + status_t init() { + bool ok = true + && !is_fwd() + && utils::everyone_is(data_type, + dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + return status::success; + } + }; + + ref_softmax_bwd_t(const pd_t *apd): cpu_primitive_t(apd) { + auto dims = pd()->desc()->diff_desc.dims; + auto axis = pd()->desc()->softmax_axis; + auto ndims = pd()->desc()->diff_desc.ndims; + + outer_size_ = utils::array_product(dims, axis); + channels_ = dims[axis]; + inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1); + + const memory_desc_wrapper data_d(pd()->dst_md()); + const memory_desc_wrapper diff_d(pd()->diff_dst_md()); + + bool no_axis_blocking = true; + for (int iblk = 0; iblk < diff_d.blocking_desc().inner_nblks; ++iblk) + if (diff_d.blocking_desc().inner_idxs[iblk] == axis) + no_axis_blocking = false; + + use_dense_ = true + && inner_size_ == 1 + && diff_d == data_d + && diff_d.is_dense() + && no_axis_blocking + && diff_d.blocking_desc().strides[axis] == 1; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (use_dense_) + execute_backward_dense(ctx); + else + execute_backward_generic(ctx); + return status::success; + } + +private: + void execute_backward_dense(const exec_ctx_t &ctx) const; + void execute_backward_generic(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + bool use_dense_; + int outer_size_, channels_, inner_size_; +}; + + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp new file mode 100644 index 000000000000..3b2a75d99b87 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp @@ -0,0 +1,101 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_SUM_HPP +#define REF_SUM_HPP + +#include "reorder_pd.hpp" + +#include "cpu_sum_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ref_sum_t: public cpu_primitive_t { + struct pd_t: public cpu_sum_pd_t { + using cpu_sum_pd_t::cpu_sum_pd_t; + + pd_t(const pd_t &rhs): cpu_sum_pd_t(rhs) { + for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) + reorder_pds_.push_back( + (const reorder_pd_t *)rhs.reorder_pds_[i]->clone()); + } + + ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; } + + DECLARE_SUM_PD_T("ref:any", ref_sum_t); + + status_t init() { + bool ok = cpu_sum_pd_t::init() == status::success; + if (!ok) return status::unimplemented; + + for (int i = 0; i < n_; ++i) { + auto r_impls = engine_->get_reorder_implementation_list(); + for (auto r = r_impls; *r; ++r) { + primitive_attr_t attr; + attr.output_scales_.set(scales_[i]); + if (i != 0) attr.post_ops_.append_sum(1.0); + + reorder_pd_t *r_pd; + if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i), + engine_, dst_md()) == status::success) { + r_pd->init_info(); + reorder_pds_.push_back(r_pd); + break; + } + } + } + + ok = reorder_pds_.size() == (size_t)n_; + return ok ? status::success : status::unimplemented; + } + + nstl::vector reorder_pds_; + }; + + ref_sum_t(const pd_t *apd): cpu_primitive_t(apd) { + const int n = pd()->n_inputs(); + reorders_.resize(n); + for (int i = 0; i < n; ++i) + pd()->reorder_pds_[i]->create_primitive(&reorders_[i]); + } + + ~ref_sum_t() { for (auto &r: reorders_) delete r; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto n = pd()->n_inputs(); + for (int i = 0; i < n; ++i) { + exec_args_t r_args; + r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i); + r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST); + exec_ctx_t r_ctx(ctx.stream(), std::move(r_args)); + reorders_[i]->execute(r_ctx); + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + nstl::vector reorders_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp new file mode 100644 index 000000000000..537084db9120 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp @@ -0,0 +1,90 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Common for RNN and LSTM cell execution + */ +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +using namespace rnn_utils; + +template +rnn_cell_execution_sig( + (_ref_rnn_common_t::cell_execution)) { + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic, + 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld); + + if (rnn_postgemm_ != nullptr) + rnn_postgemm_->execute(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + else + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); +} +template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); +template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution) { + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + + /// bwd by data on the cell + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic, + 1.0, w_iter_[0], rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, + 0.0, diff_states_t_l_, rnn.states_ws_ld); + + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld); + + /// bwd by weights on the cell + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + } + + if (!rnn.merge_gemm_iter) + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, + diff_w_iter_, rnn.diff_weights_iter_ld); + + /// bwd by bias we just accumulate diffs from the gates + gates_reduction(rnn, ws_gates_, diff_bias_); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp new file mode 100644 index 000000000000..e1a61d4c6233 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp @@ -0,0 +1,180 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution GRU + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +#define AOC array_offset_calculator +template <> +rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_[0]); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + + // 1. gemm Wx[0-2],x + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + + // 2. gemm Wh[0-1],h + (this->*gemm_iter_func)('N', 'N', (rnn.n_gates - 1) * rnn.dic, rnn.mb, + rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld); + + // 3. activation zt and rt + elemwise multiplication rt,ht-1 + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 1, j); + } + }); + + // 4. gemm Wh[2],h~t + (this->*gemm_iter_func)('N', 'N', rnn.dic, rnn.mb, rnn.sic, 1.0, w_iter_[1], + rnn.weights_iter_ld, states_t_l_, rnn.states_ws_ld, 1.0, + &(ws_gates(0, 2, 0)), rnn.gates_ws_ld); + + // 5. activation h~t + calculate ht + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) + + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); + } + }); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru) { + assert(!"GRU int8 is not supported"); +} + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_w_iter_aoc_t diff_w_iter(rnn, diff_w_iter_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + // use state memory for intermediate computations + // TODO: use cell ws for that + float *dhG1_ = &(diff_states_t_l(rnn.n_states, 0, 0)); + float *hG1_ = dhG1_; + AOC dhG1(dhG1_, rnn.states_nld, rnn.states_ws_ld); + AOC hG1(hG1_, rnn.states_nld, rnn.states_ws_ld); + + // 1. calculate dG2, dG1, and part of dht-1 + // dG2^ = dh * (1 - G0) * (1 - G2^2) + // dG0^ = dh * (ht-1 - G2) * u * (1 - G0) + // dht-1 (part) = dh * G0 + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt + * one_m_square(ws_gates(i, 2, j)); + float dG0 = (h - ws_gates(i, 2, j)) * dHt + * x_m_square(ws_gates(i, 0, j)); + + diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j); + ws_gates(i, 0, j) = dG0; + ws_gates(i, 2, j) = dG2; + } + }); + + // 2. calculate intermediate d(hG1) + // d(hG1) = dG2 * W2h^t + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.dic, 1.0, w_iter_[1], + rnn.weights_iter_ld, &(ws_gates(0, 2, 0)), rnn.gates_ws_ld, 0.0, + dhG1_, rnn.states_ws_ld); + + // 3. calculate dG1^ and part of dht-1 + // dG1^ = d(hG1) * h * G1 * (1 - G1) + // dht-1 (part) += d(hG1) * G1 + // h * G1 (required for dWh) + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float G1 = ws_gates(i, 1, j); + diff_states_t_l(0, i, j) += dhG1(i, j) * G1; + ws_gates(i, 1, j) = dhG1(i, j) * h * x_m_square(G1); + hG1(i, j) = G1 * h; + } + }); + + // 4. calculate diff weights + // dWh1 += dG1 * h, dWh2 += dG2 * h, dWh3 += dG3 * (G1(*)h) + gemm('N', 'T', (rnn.n_gates - 1) * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_, + rnn.diff_weights_iter_ld); + gemm('N', 'T', rnn.dic, rnn.sic, rnn.mb, 1.0, &(ws_gates(0, 2, 0)), + rnn.gates_ws_ld, hG1_, rnn.states_ws_ld, 1.0, + &(diff_w_iter(0, 2, 0)), rnn.diff_weights_iter_ld); + + // 5. calculate diff states + // dht-1 += dG1 * W1h + dG0 * W0h + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, + (rnn.n_gates - 1) * rnn.dic, 1.0, w_iter_[0], + rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, 1.0, + diff_states_t_l_, rnn.states_ws_ld); + + if (!rnn.merge_gemm_layer) { + // dWx += [dG0 dG1 dG2] * [x] + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + // dx = dG2 * W2x + dG1 * W1x + dG0 * W0x + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &(diff_states_t_l(rnn.n_states, 0, 0)), rnn.states_ws_ld); + } + + // 6. calculate diff bias + gates_reduction(rnn, ws_gates_, diff_bias_); +} +#undef AOC + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp new file mode 100644 index 000000000000..8dea8c90a402 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp @@ -0,0 +1,170 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution GRU with linear before reset + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; +#define AOC array_offset_calculator + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_gates_aoc_t ws_gemm_state(rnn, ws_cell_); + AOC ws_Wh_b(ws_grid_, rnn.mb, rnn.dic); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float Wh_b = ws_gemm_state(i, 2, j) + bias(3, j); + ws_gates(i, 0, j) = logistic_fwd( + ws_gates(i, 0, j) + ws_gemm_state(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd( + ws_gates(i, 1, j) + ws_gemm_state(i, 1, j) + bias(1, j)); + ws_gates(i, 2, j) = tanh_fwd( + ws_gates(i, 2, j) + ws_gates(i, 1, j) * Wh_b + bias(2, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) + + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); + if (rnn.is_training) + ws_Wh_b(i, j) = Wh_b; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise) { + assert(!"GRU LBR int8 is not supported"); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr) { + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic, + 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 0.0, ws_cell_, rnn.gates_ws_ld); + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr) { + assert(!"GRU LBR int8 is not supported"); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + ws_gates_aoc_t ws_gates_r(rnn, ws_cell_); + AOC ws_Wh_b(ws_grid_, rnn.mb, rnn.dic); + + // 1. calculate dG1 dG2 dG3 + // dG0 = (dht - G2) * dht * (1 - G0) * G0 + // dG1 = (W*h + b) * dG2 * (1 - G1) * G1 + // dG2 = (1 - G0) * dht * (1 - G2*G2) + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dG0 = (h - ws_gates(i, 2, j)) * dHt + * x_m_square(ws_gates(i, 0, j)); + float dG2 = (1.0f - ws_gates(i, 0, j)) + * one_m_square(ws_gates(i, 2, j)) * dHt; + float dG1 = ws_Wh_b(i, j) * dG2 * x_m_square(ws_gates(i, 1, j)); + + diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j); + ws_gates(i, 2, j) = dG2; + ws_gates_r(i, 2, j) = dG2 * ws_gates(i, 1, j); + ws_gates(i, 0, j) = ws_gates_r(i, 0, j) = dG0; + ws_gates(i, 1, j) = ws_gates_r(i, 1, j) = dG1; + } + }); +} + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr) { + ws_gates_aoc_t ws_gates_r(rnn, ws_cell_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + + if (!rnn.merge_gemm_layer) { + // dx = dG * Wx^t + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld); + // dWx += dG^t * x + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + } + // dh += dGr * Wh^t + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic, + 1.0, w_iter_[0], rnn.weights_iter_ld, ws_cell_, rnn.gates_ws_ld, + 1.0, diff_states_t_l_, rnn.states_ws_ld); + + // dWh += dGr^t * h + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_cell_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_, + rnn.diff_weights_layer_ld); + + // db1-3 += e * dG + // db4 += e * (r * dG2) + gates_reduction(rnn, ws_gates_, diff_bias_); + + parallel_nd(rnn.dic, [&](int j) { + for (int i = 0; i < rnn.mb; i++) { + diff_bias_[3 * rnn.dic + j] += ws_gates_r(i, 2, j); + } + }); +} + +#undef AOC + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp new file mode 100644 index 000000000000..a15ba00d4c4f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp @@ -0,0 +1,143 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution LSTM + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "../simple_q10n.hpp" +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); + ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j)); + ws_gates(i, 3, j) = logistic_fwd(ws_gates(i, 3, j) + bias(3, j)); + + float tmp = ws_gates(i, 1, j) * c_states_tm1_l(i, j) + + ws_gates(i, 0, j) * ws_gates(i, 2, j); + states_t_l(i, j) = ws_gates(i, 3, j) * tanh_fwd(tmp); + c_states_t_l(i, j) = tmp; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise) { + ws_gates_aoc_s32_t ws_gates_s32(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_u8_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + auto q_d = [&](float f) { + float qf = f * data_scale + data_shift; + return qz_a1b0()(qf); + }; + + auto deq_w = [&](acc_data_t s, int gate, int j) { + return pd()->attr()->rnn_weights_qparams_.mask_ == 0 ? + saturate(s) * (1.f / (weights_scales[0] * data_scale)) : + saturate(s) * (1.f / (weights_scales[gate * rnn.dic + j] + * data_scale)); + }; + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float G0 = logistic_fwd( + deq_w(ws_gates_s32(i, 0, j), 0, j) + bias(0, j)); + float G1 = logistic_fwd( + deq_w(ws_gates_s32(i, 1, j), 1, j) + bias(1, j)); + float G2 = tanh_fwd( + deq_w(ws_gates_s32(i, 2, j), 2, j) + bias(2, j)); + float G3 = logistic_fwd( + deq_w(ws_gates_s32(i, 3, j), 3, j) + bias(3, j)); + float tmp = G1 * c_states_tm1_l(i, j) + G0 * G2; + states_t_l(i, j) = q_d(G3 * tanh_fwd(tmp)); + c_states_t_l(i, j) = tmp; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float Ct = c_states_t_l(i, j); + /// @todo save it in the workspace in fwd pass or recompute it to + /// save bw + float tanhCt = tanh_fwd(Ct); + // we have 2 incoming diffs on Ht + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dCt = diff_states_tp1_l(1, i, j) + + one_m_square(tanhCt) * ws_gates(i, 3, j) * dHt; + + float dG1 = c_states_tm1_l(i, j) * dCt + * x_m_square(ws_gates(i, 1, j)); + float dG0 = ws_gates(i, 2, j) * dCt * x_m_square(ws_gates(i, 0, j)); + float dG3 = tanhCt * dHt * x_m_square(ws_gates(i, 3, j)); + float dG2 + = ws_gates(i, 0, j) * dCt * one_m_square(ws_gates(i, 2, j)); + + diff_states_t_l(1, i, j) = dCt * ws_gates(i, 1, j); + + ws_gates(i, 0, j) = dG0; + ws_gates(i, 1, j) = dG1; + ws_gates(i, 2, j) = dG2; + ws_gates(i, 3, j) = dG3; + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp new file mode 100644 index 000000000000..4536e8dfad7c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp @@ -0,0 +1,113 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution of Vanilla RNN + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return relu_fwd(s, alpha); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return relu_bwd(dd, s, alpha); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return tanh_fwd(s); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return dd * one_m_square(s); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return logistic_fwd(s); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return dd * x_m_square(s); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + + parallel_nd(rnn.mb, [&](int i) { + for (int j = 0; j < rnn.dic; j++) { + const float h + = activation_func(0, ws_gates(i, 0, j) + bias(0, j), 0, 0); + ws_gates(i, 0, j) = states_t_l(i, j) = h; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise) { + assert(!"VANILLA RNN int8 is not supported"); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + parallel_nd(rnn.mb, [&](int i) { + for (int j = 0; j < rnn.dic; ++j) { + const float dH = diff_states_t_lp1(rnn.n_states, i, j) + + diff_states_tp1_l(0, i, j); + auto g = ws_gates(i, 0, j); + ws_gates(i, 0, j) = activation_func(dH, g, 0, 0); + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp new file mode 100644 index 000000000000..b39427caf94d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp @@ -0,0 +1,191 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_RNN_PD_HPP +#define CPU_RNN_PD_HPP + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "rnn_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "rnn_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t { + using rnn_fwd_pd_t::rnn_fwd_pd_t; + +protected: + status_t set_default_params() { + using namespace format_tag; + if (src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); + if (dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); + + // Optional parameters + if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc)); + if (with_bias() && bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); + if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc)); + + return status::success; + } + + status_t check_layout_consistency() { + using namespace format_tag; + using namespace data_type; + using namespace types; + + auto is_blocked = [&](memory_desc_t md, int ndims) { + return md.format_kind == format_kind::blocked && md.ndims == ndims; + }; + + bool ok = true; + ok = ok && is_blocked(src_layer_md_, 3) + && is_blocked(dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_), + is_blocked(src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&dst_iter_md_), + is_blocked(dst_iter_md_, 5)); + + if (weights_layer_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldigo_p); + else + ok = ok && rnn_utils::is_ldigo(&weights_layer_md_); + + if (weights_iter_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldigo_p); + else + ok = ok && rnn_utils::is_ldigo(&weights_iter_md_); + + ok = ok && IMPLICATION(!is_zero_md(&bias_md_), + memory_desc_matches_tag(bias_md_, ldgo)); + + /* Int8 is supported only for packed weights */ + data_type_t weights_iter_dt = weights_iter_md_.data_type; + data_type_t weights_layer_dt = weights_layer_md_.data_type; + ok = ok && IMPLICATION( + weights_iter_dt == s8, weights_iter_md_.format_kind + == format_kind::rnn_packed); + ok = ok && IMPLICATION( + weights_layer_dt == s8, weights_layer_md_.format_kind + == format_kind::rnn_packed); + + return ok ? status::success : status::unimplemented; + } +}; + +struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t { + using rnn_bwd_pd_t::rnn_bwd_pd_t; + +protected: + status_t set_default_params() { + using namespace format_tag; + if (src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); + if (dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); + + if (diff_src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_src_layer_md_, tnc)); + if (diff_weights_layer_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_layer_md_, ldigo)); + CHECK(rnn_utils::set_good_strides(diff_weights_layer_md_, ldigo)); + } + if (diff_weights_iter_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_iter_md_, ldigo)); + CHECK(rnn_utils::set_good_strides(diff_weights_iter_md_, ldigo)); + } + if (diff_dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_dst_layer_md_, tnc)); + + // Optional parameters + if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc)); + if (with_bias() && bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); + if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc)); + + if (with_src_iter() && diff_src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_src_iter_md_, ldsnc)); + if (with_bias() && diff_bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_bias_md_, ldgo)); + if (with_dst_iter() && diff_dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_dst_iter_md_, ldsnc)); + + return status::success; + } + + status_t check_layout_consistency() { + using namespace format_tag; + using namespace types; + + auto is_blocked = [&](memory_desc_t md, int ndims) { + return md.format_kind == format_kind::blocked && md.ndims == ndims; + }; + + bool ok = true; + ok = ok && is_blocked(src_layer_md_, 3) + && is_blocked(dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_), + is_blocked(src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&dst_iter_md_), + is_blocked(dst_iter_md_, 5)); + + if (weights_layer_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldgoi_p); + else + ok = ok && rnn_utils::is_ldgoi(&weights_layer_md_); + + if (weights_iter_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldgoi_p); + else + ok = ok && rnn_utils::is_ldgoi(&weights_iter_md_); + + ok = ok && IMPLICATION(!is_zero_md(&bias_md_), + memory_desc_matches_tag(bias_md_, ldgo)); + + ok = ok && is_blocked(diff_src_layer_md_, 3) + && is_blocked(diff_dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&diff_src_iter_md_), + is_blocked(diff_src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&diff_dst_iter_md_), + is_blocked(diff_dst_iter_md_, 5)); + + ok = ok && rnn_utils::is_ldigo(&diff_weights_layer_md_) + && rnn_utils::is_ldigo(&diff_weights_iter_md_); + ok = ok && IMPLICATION(!is_zero_md(&diff_bias_md_), + memory_desc_matches_tag(diff_bias_md_, ldgo)); + + return ok ? status::success : status::unimplemented; + } +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp new file mode 100644 index 000000000000..09445648aa97 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp @@ -0,0 +1,401 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution LSTM + */ + +#include "rnn_utils.hpp" +#include "../jit_generator.hpp" +#include "../jit_uni_eltwise.hpp" +#include "c_types_map.hpp" +#include "utils.hpp" + +#include "mkldnn_thread.hpp" + + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_uni_rnn_postgemm_kernel : public jit_generator { + + typedef void (*kernel_t)(void *gates_, const void *bias, void *states_t_l_, + void *c_states_t_l_, void *c_states_tm1_l_); + + jit_uni_rnn_postgemm_kernel(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr): rnn_(rnn), attr_(attr){} + + virtual void init() = 0; + +template + rnn_elemwise_sig(execute) { + rnn_utils::ws_gates_aoc ws_gates(rnn, ws_gates_); + rnn_utils::bias_aoc_t bias(rnn, bias_); + rnn_utils::ws_states_aoc states_t_l(rnn, states_t_l_); + rnn_utils::ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + rnn_utils::ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + // Todo: add parallelization on dic for the batch 1 case + // Assumption: the kernel runs a loop on dic elements + parallel_nd(rnn.mb, [&](int i) { + auto b_ = &bias(0, 0); + auto g_ = &ws_gates(i, 0, 0); + auto s_tl_ = &states_t_l(i, 0); + auto c_tl_ = &c_states_t_l(i, 0); + auto c_tm1l_ = &c_states_tm1_l(i, 0); + kernel_(g_, b_, s_tl_, c_tm1l_, c_tl_); + }); + } + +protected: + kernel_t kernel_; + const rnn_utils::rnn_conf_t &rnn_; + const primitive_attr_t *attr_; +}; + +template +struct jit_uni_lstm_postgemm_kernel_fwd: public jit_uni_rnn_postgemm_kernel +{ + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_postgemm_kernel_fwd) + + typedef typename utils::conditional::type acc_data_t; + typedef typename utils::conditional, + jit_uni_eltwise_injector_f32>::type injector_t; + + jit_uni_lstm_postgemm_kernel_fwd(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr) + : jit_uni_rnn_postgemm_kernel(rnn, attr){} + + void init() override { + // we use rax for both constant tables as they use the same table + sigmoid_injector_ = new injector_t(this, + alg_kind::eltwise_logistic, 0.0f, 0.0f, true, rax); + tanh_injector_ = new injector_t(this, + alg_kind::eltwise_tanh, 0.0f, 0.0f, true, rax); + generate(); + kernel_ = (kernel_t) this->getCode(); + } + +protected: + injector_t *sigmoid_injector_; + injector_t *tanh_injector_; + + // register size in bytes + using Vmm = typename jit_uni_eltwise_injector_f32::Vmm; + size_t vlen = cpu_isa_traits::vlen; + size_t vlen_dst = (src_data_t == data_type::u8) ? vlen/4 : vlen; + size_t cstate_dt_size = sizeof(float); + size_t hstate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint8_t) : sizeof(float); + size_t gate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint32_t) : sizeof(float); + size_t qscale_dt_size = sizeof(float); + size_t bias_dt_size = sizeof(float); + + void generate() { + using namespace Xbyak; + + int mask = attr_->rnn_weights_qparams_.mask_; + float *weights_scales = attr_->rnn_weights_qparams_.scales_; + float data_scale = attr_->rnn_data_qparams_.scale_; + float data_shift = attr_->rnn_data_qparams_.shift_; + + // Labels declaration + Label vector_loop_start_label, vector_loop_end_label; + Label rem_loop_start_label, rem_loop_end_label; + Label table_label; + + // Register map + Reg64 loop_cnt(r11); // loop counter + Reg64 table_reg(rbx); // table is used for data scale and shifts + Reg64 weights_scales_reg(r13); + // We skip vmm0 as it can be used by the injector for masks on sse4.2 + Vmm G0(1), G1(2), G2(3), G3(4), tmp1_vmm(5), tmp2_vmm(6), zero_vmm(7); + + // constant table map + Address dscale_off_addr = ptr[table_reg]; + Address dshift_off_addr = ptr[table_reg + vlen]; + Address ymm_perm_mask_addr = ptr[table_reg + 2*vlen]; + Address zmm_perm_mask_addr = ptr[table_reg + 2*vlen + cpu_isa_traits::vlen]; + + // quantize from float to u8 + auto q_d = [&](Vmm f, Vmm tmp_vmm) { + uni_vpxor(tmp_vmm, tmp_vmm, tmp_vmm); + uni_vmulps(f, f, dscale_off_addr); // apply scale + uni_vaddps(f, f, dshift_off_addr); // apply shift + uni_vcvtps2dq(f, f); // convert to int32 + uni_vpackssdw(f, f, tmp_vmm); // convert from s32 to s16 + uni_vpackuswb(f, f, tmp_vmm); // convert from s16 to u8 with saturation + // Note that the results are interleaved by 128 bit chunks, so we need to merge them together + switch (vlen) { + case 64: { //avx512 + Zmm fz(f.getIdx()), tmpz(tmp_vmm.getIdx()); + uni_vmovups(tmpz, zmm_perm_mask_addr); + vpermd(fz, tmpz, fz); + break; } + case 32: { //avx + Ymm fy(f.getIdx()), tmpy(tmp_vmm.getIdx()); + uni_vmovups(tmpy, ymm_perm_mask_addr); + vpermd(fy, tmpy, fy); + break; } + case 16: // sse: nothing to do + break; + default: assert(!"Unsupported case"); + }; + }; + + auto fast_recip =[&](Vmm s, Vmm tmp, bool packed) { + if (packed) + uni_vrcpps(tmp, s); + else + uni_vrcpss(tmp, s); // prevent divide by zero + // we add one Newton iteration + uni_vmulps(s, s, tmp); + uni_vmulps(s, s, tmp); // s <- s * tmp^2 + uni_vaddps(tmp, tmp, tmp); + uni_vsubps(tmp, tmp, s); + uni_vmovups(s, tmp); // s <- 2 * tmp - s * tmp^2 + }; + + // dequantize from s32 to float + auto deq_w = [&](Vmm s, Vmm tmp1, Vmm tmp2, int gate, bool packed) { + // TODO: if mask is 0 precompute mul and inverse + if (mask == 0) + uni_vbroadcastss(tmp1, ptr[weights_scales_reg]); + else + uni_vmovups(tmp1, ptr[weights_scales_reg + gate * rnn_.dic * qscale_dt_size]); + uni_vcvtdq2ps(s, s); + uni_vmulps(tmp1, tmp1, dscale_off_addr); + fast_recip(tmp1, tmp2, packed); + uni_vmulps(s, s, tmp1); + }; + + // We start code generations here + preamble(); + + // extract addresses passed as parameter +#ifdef _WIN32 + auto addr_ws_gates_reg = abi_param1; + auto addr_bias_reg = abi_param2; + auto addr_states_t_l_reg = abi_param3; + auto addr_c_states_tm1_l_reg = abi_param4; + auto addr_c_states_t_l_reg = r10; + // Here we cannot use rbp to have initial stack pointer so we + // use rsp and offset it with the size of pushed registers in + // preamble + mov(addr_c_states_t_l_reg, ptr[rsp + get_size_of_abi_save_regs() + 40]); +#else + auto addr_ws_gates_reg = abi_param1; + auto addr_bias_reg = abi_param2; + auto addr_states_t_l_reg = abi_param3; + auto addr_c_states_tm1_l_reg = abi_param4; + auto addr_c_states_t_l_reg = abi_param5; +#endif + + // initialize registers with addresses and constants + mov(table_reg, table_label); + mov(weights_scales_reg, size_t(weights_scales)); + // both sigmoid and tanh use the same table so load address just once in rax + sigmoid_injector_->load_table_addr(); + + mov(loop_cnt, rnn_.dic * gate_dt_size); + cmp(loop_cnt, vlen); + jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); + + L(vector_loop_start_label); + { + // load G0 G1 G2 G3 + uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); + uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); + uni_vmovups(G2, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); + uni_vmovups(G3, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]); + + // dequantize the gates from s32 to f32 if needed + if (src_data_t == data_type::u8){ + deq_w(G0, tmp1_vmm, tmp2_vmm, 0, true); + deq_w(G1, tmp1_vmm, tmp2_vmm, 1, true); + deq_w(G2, tmp1_vmm, tmp2_vmm, 2, true); + deq_w(G3, tmp1_vmm, tmp2_vmm, 3, true); + } + + // add biases + uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vaddps(G3, G3, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); + + // inject eltwise code + sigmoid_injector_->compute_vector(G0.getIdx()); + sigmoid_injector_->compute_vector(G1.getIdx()); + tanh_injector_->compute_vector(G2.getIdx()); + sigmoid_injector_->compute_vector(G3.getIdx()); + + // compute c_states_t_l = G1 * c_tm1_l + G0 * G2 + uni_vmovups(tmp1_vmm, ptr[addr_c_states_tm1_l_reg]); + uni_vmulps(tmp1_vmm, tmp1_vmm, G1); + uni_vfmadd231ps(tmp1_vmm, G0, G2); + uni_vmovups(ptr[addr_c_states_t_l_reg], tmp1_vmm); + + // states_t_l = G3 * tanh(c_states_t_l) + tanh_injector_->compute_vector(tmp1_vmm.getIdx()); + uni_vmulps(tmp1_vmm, tmp1_vmm, G3); + + // if int8, we quantize the resulting state + if (src_data_t == data_type::u8) + q_d(tmp1_vmm, tmp2_vmm); + + // write back the result + if(vlen_dst == vlen) + uni_vmovups(ptr[addr_states_t_l_reg], tmp1_vmm); + else + // we write only 1/4 of the register + switch(vlen_dst){ + case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + default: + assert(!"Unsuported vector length for quantization"); + } + + // increment address pointers + add(addr_ws_gates_reg, vlen); + add(addr_bias_reg, vlen); + add(addr_states_t_l_reg, vlen_dst); + add(addr_c_states_tm1_l_reg, vlen); + add(addr_c_states_t_l_reg, vlen); + if (mask != 0) + add(weights_scales_reg, vlen); + + // increment loop counter + sub(loop_cnt, vlen); + cmp(loop_cnt, vlen); + jge(vector_loop_start_label); + } + L(vector_loop_end_label); + + cmp(loop_cnt, 0); + je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); + // Same code as above, we just use movuss for accessing inputs + // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar + L(rem_loop_start_label); + { + // remaping registers to Xmms + Xmm G0s(G0.getIdx()), G1s(G1.getIdx()), G2s(G2.getIdx()), G3s(G3.getIdx()); + Xmm tmp1s_vmm(tmp1_vmm.getIdx()); + + // load G0 G1 G2 G3 + uni_vmovss(G0s, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); + uni_vmovss(G1s, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); + uni_vmovss(G2s, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); + uni_vmovss(G3s, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]); + + // dequantize the gates from s32 to f32 if needed + if (src_data_t == data_type::u8){ + deq_w(G0, tmp1_vmm, tmp2_vmm, 0, false); + deq_w(G1, tmp1_vmm, tmp2_vmm, 1, false); + deq_w(G2, tmp1_vmm, tmp2_vmm, 2, false); + deq_w(G3, tmp1_vmm, tmp2_vmm, 3, false); + } + + // add biases + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vaddps(G0s, G0s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vaddps(G1s, G1s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2s, G2s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); + uni_vaddps(G3s, G3s, tmp1s_vmm); + + // inject eltwise code + sigmoid_injector_->compute_vector(G0s.getIdx()); + sigmoid_injector_->compute_vector(G1s.getIdx()); + tanh_injector_->compute_vector(G2s.getIdx()); + sigmoid_injector_->compute_vector(G3s.getIdx()); + + // compute c_states_t_l = G1 * c_tm1_l + G0s * G2 + uni_vmovups(tmp1s_vmm, ptr[addr_c_states_tm1_l_reg]); + uni_vmulps(tmp1s_vmm, tmp1s_vmm, G1s); + uni_vfmadd231ps(tmp1s_vmm, G0s, G2s); + uni_vmovss(ptr[addr_c_states_t_l_reg], tmp1s_vmm); + + // states_t_l = G3 * tanh(c_states_t_l) + tanh_injector_->compute_vector(tmp1s_vmm.getIdx()); + uni_vmulps(tmp1s_vmm, tmp1s_vmm, G3s); + + // if int8, we quantize the resulting state + if (src_data_t == data_type::u8) + q_d(tmp1_vmm, tmp2_vmm); + + // write back the result + if(vlen_dst == vlen) + uni_vmovups(ptr[addr_states_t_l_reg], tmp1s_vmm); + else + // we write only 1/4 of the register + switch(vlen_dst){ + case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + default: + assert(!"Unsuported vector length for quantization"); + } + + // increment address pointers + add(addr_ws_gates_reg, gate_dt_size); + add(addr_bias_reg, bias_dt_size); + add(addr_states_t_l_reg, hstate_dt_size); + add(addr_c_states_tm1_l_reg, cstate_dt_size); + add(addr_c_states_t_l_reg, cstate_dt_size); + if (mask != 0) + add(weights_scales_reg, qscale_dt_size); + + // increment loop counter + sub(loop_cnt, gate_dt_size); + cmp(loop_cnt, 0); + jg(rem_loop_start_label); + + } + L(rem_loop_end_label); + + postamble(); + + // Again, only one table is needed and shared between sigmoid and tanh + sigmoid_injector_->prepare_table(false); + tanh_injector_->prepare_table(true); + + L(table_label); + { + for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_scale)); + for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_shift)); + // perm mask for ymm + dd(0); dd(4); dd(2); dd(3); dd(1); dd(5); dd(6); dd(7); + // perm mask for zmm + dd(0); dd(4); dd(8); dd(12); dd(1); dd(5); dd(6); dd(7); + dd(2); dd(9); dd(10); dd(11); dd(3); dd(12); dd(13); dd(14); + } + } + +}; + +template struct jit_uni_lstm_postgemm_kernel_fwd; +template struct jit_uni_lstm_postgemm_kernel_fwd; +template struct jit_uni_lstm_postgemm_kernel_fwd; + +template struct jit_uni_lstm_postgemm_kernel_fwd; +template struct jit_uni_lstm_postgemm_kernel_fwd; +template struct jit_uni_lstm_postgemm_kernel_fwd; +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp new file mode 100644 index 000000000000..ead536816c1c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp @@ -0,0 +1,788 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + General architecture + + for diff states, we have n_states + 1 as we have n_states diff + to propagate to the previous iteration and 1 states to propagate + to the previous layer + index 0 is dh for cell(t-1, l) to consume + index 1 is dc for cell(t-1, l) to consume + index 2 is dh for cell(t, l-1) to consume + this indexing enables to have the same indexing for states in elemwise + function + only the cell execution function should be impacted + + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" +#include "../gemm/gemm.hpp" +#include "../simple_q10n.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::memory_tracking::names; +using namespace rnn_utils; +#define AOC array_offset_calculator + +template +void _ref_rnn_common_t::gates_reduction( + const rnn_conf_t &rnn, const acc_data_t *ws_gates_, + float *diff_bias_) const { + auto body = [&](int i, int k) { + for (int j = 0; j < rnn.mb; j++) + diff_bias_[i * rnn.dic + k] + += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k]; + }; + + // @todo block k on simd-width +#if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \ + /* icc 17.0 has a problem with simd collapse */ \ + && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700)) +#pragma omp parallel for simd collapse(2) + for (int i = 0; i < rnn.n_gates; i++) + for (int k = 0; k < rnn.dic; k++) + body(i, k); +#else + parallel_nd(rnn.n_gates, rnn.dic, body); +#endif +} + +template +rnn_gemm_sig((_ref_rnn_common_t::gemm)) { + assert(ldA * ldB * ldC != 0); + extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB, + &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm); +} + +template <> +rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) { + assert(!"non packed gemm is disabled for int8"); +} + +template +rnn_gemm_sig((_ref_rnn_common_t::packed_gemm)) { +#if (USE_MKL_PACKED_GEMM) + assert(transA == 'N'); + cblas_sgemm_compute(CblasColMajor, CblasPacked, + (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_, + ldB, beta, c_, ldC); +#else + UNUSED(transA); + UNUSED(transB); + UNUSED(m); + UNUSED(n); + UNUSED(k); + UNUSED(alpha); + UNUSED(ldA); + UNUSED(b_); + UNUSED(ldB); + UNUSED(beta); + UNUSED(c_); + UNUSED(ldC); + assert(!"packed gemm is disabled"); +#endif +} + +template <> +rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) { +#if (USE_MKL_PACKED_GEMM) + int8_t offseta = 0, offsetb = 0; + int32_t offsetc = 0; + cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked, + CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_, + ldB, offsetb, beta, c_, ldC, &offsetc); +#else + UNUSED(transA); + UNUSED(transB); + UNUSED(m); + UNUSED(n); + UNUSED(k); + UNUSED(alpha); + UNUSED(ldA); + UNUSED(b_); + UNUSED(ldB); + UNUSED(beta); + UNUSED(c_); + UNUSED(ldC); + assert(!"packed gemm is disabled"); +#endif +} + +//*************** Grid computations strategy: linear ***************// +template +rnn_grid_execution_sig( + (_ref_rnn_common_t::linear_execution)) { + AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); + AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + (rnn.n_states + 1), rnn.n_iter + 1, + rnn.states_nld * rnn.states_ws_ld); + AOC ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter, + rnn.gates_nld * rnn.gates_ws_ld); + AOC weights_input( + weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer); + AOC weights_states( + weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter); + AOC bias( + bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); + AOC diff_weights_layer(diff_weights_layer_, rnn.n_layer, + rnn.n_dir, + rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld); + AOC diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir, + rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld); + AOC diff_bias( + diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + AOC ws_grid( + ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell); + + // We run the grid of computation + for (int dir = 0; dir < rnn.n_dir; dir++) { + for (int j = 0; j < rnn.n_layer; j++) { + int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1; + + if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, + rnn.mb * rnn.n_iter, rnn.slc, 1.0, + weights_input(lay, dir, 0), rnn.weights_iter_ld, + &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0, + &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld); + } + + for (int i = 0; i < rnn.n_iter; i++) { + int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1; + (this->*cell_func)(rnn, + &(ws_states(lay + 1, dir, iter + 1, 0)), + &(ws_c_states(lay + 1, dir, iter + 1, 0)), + &(ws_diff_states(lay, dir, 0, iter, 0)), + &(weights_input(lay, dir, 0)), + &(weights_states(lay, dir, 0)), + &(bias(lay, dir, 0)), + &(ws_states(lay, dir, iter + 1, 0)), + &(ws_states(lay + 1, dir, iter, 0)), + &(ws_c_states(lay + 1, dir, iter, 0)), + &(ws_diff_states(lay + 1, dir, 0, iter, 0)), + &(ws_diff_states(lay, dir, 0, iter + 1, 0)), + &(diff_weights_layer(lay, dir, 0)), + &(diff_weights_iter(lay, dir, 0)), + &(diff_bias(lay, dir, 0)), + &(ws_gates(lay, dir, iter, 0)), + &(ws_grid(lay, dir, iter, 0)), + ws_cell_); + } + + if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter, + rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0), + rnn.weights_layer_ld, + (src_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, 0.0, + (acc_data_t *)(&(ws_diff_states( + lay, dir, rnn.n_states, 0, 0))), + rnn.states_ws_ld); + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, + rnn.mb * rnn.n_iter, 1.0, + (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, + (src_data_t *)(&(ws_states(lay, dir, 1, 0))), + rnn.states_ws_ld, 1.0, + (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))), + rnn.diff_weights_layer_ld); + } + if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) { + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, + rnn.mb * rnn.n_iter, 1.0, + (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, + (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))), + rnn.states_ws_ld, 1.0, + (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))), + rnn.diff_weights_iter_ld); + } + } + } +} + +//********* GRID computations strategy: utility functions **********// + +template +void _ref_rnn_common_t::copy_init_layer( + const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, + float *__restrict ws_diff_states_, const src_data_t *__restrict xt_, + const float *__restrict diff_dst_layer_) const { + + AOC ws_states( + ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto xt_d = memory_desc_wrapper(pd()->src_md(0)); + + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto xxt = xt_ + xt_d.blk_off(it, b); + src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0)); + src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0)); + if (rnn.exec_dir != r2l) + for (int c = 0; c < rnn.slc; c++) + ws_l2r_ptr[c] = xxt[c]; + if (rnn.exec_dir != l2r) + for (int c = 0; c < rnn.slc; c++) + ws_r2l_ptr[c] = xxt[c]; + }); +} + +template <> +void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_, + const float *diff_dst_layer_) const { + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_md(0)); + + switch (rnn.exec_dir) { + case bi_concat: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + ws_diff_states( + rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) + = diff_dst_layer_x[rnn.dic + s]; + } + }); + break; + case bi_sum: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + ws_diff_states( + rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + case l2r: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + case r2l: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x = diff_dst_layer_ + + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + default: assert(!"Unsupported direction"); break; + } +} + +/* For int8 configuration, input iteration states may be of types f32 or u8 + * Internally h_state is always stored in u8 and c_state is always stored in f32 + * If input states are of type u8 then h state is copied and c state is dequantized + * If input states are of type f32 then h state is quantized and c_state is copied + * */ +template +template +void _ref_rnn_common_t::copy_init_iter( + const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, + float *__restrict ws_c_states_, float *__restrict ws_diff_states_, + const input_data_t *__restrict firstit_states_, + const float *__restrict diff_dst_iter_) const { + AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + const bool quantize = pd()->with_src_iter() + && pd()->src_md(1)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_q = [&](input_data_t f) { + if (quantize) { + float qf = f * data_scale + data_shift; + return qz_a1b0()(qf); + } else + return (src_data_t)f; + }; + + const bool dequantize = pd()->with_src_iter() + && pd()->src_md(1)->data_type == data_type::u8; + auto maybe_deq = [&](input_data_t s) { + if (dequantize) + return (((float)s - data_shift) / data_scale); + else + return (float)s; + }; + auto firstit_states_d = memory_desc_wrapper(pd()->src_md(1)); + if (firstit_states_) { + parallel_nd( + rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { + for (int s = 0; s < rnn.sic; s++) + ws_states(lay + 1, dir, 0, b, s) = maybe_q( + firstit_states_[firstit_states_d.blk_off( + lay, dir, 0, b, s)]); + if (pd()->cell_kind() == alg_kind::vanilla_lstm) + for (int s = 0; s < rnn.sic; s++) + ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq( + firstit_states_[firstit_states_d.blk_off( + lay, dir, 1, b, s)]); + }); + } else { + parallel_nd( + rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { + for (int j = 0; j < rnn.sic; j++) { + ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0; + ws_c_states(lay + 1, dir, 0, b, j) = 0.0f; + } + }); + } +} + +template <> +template +void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_, + const input_data_t *firstit_states_, + const float *diff_dst_iter_) const { + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1)); + if (diff_dst_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int b) { + array_copy(&(ws_diff_states( + lay, dir, state, rnn.n_iter, b, 0)), + diff_dst_iter_ + + diff_dst_iter_d.blk_off( + lay, dir, state, b), + rnn.dic); + }); + } else { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int i) { + for (int j = 0; j < rnn.dic; j++) + ws_diff_states(lay, dir, state, rnn.n_iter, i, j) + = 0.0f; + }); + } +} + +template +template +void _ref_rnn_common_t::copy_res_layer( + const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer, + const src_data_t *ws_states_, const float *ws_diff_states_) const { + + auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0)); + AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float shift = (pd()->attr()->rnn_data_qparams_.shift_); + float scale = (pd()->attr()->rnn_data_qparams_.scale_); + + const bool dequantize = pd()->dst_md(0)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_deq = [&](src_data_t s) { + if (dequantize) + return (dst_data_t)(((float)s - shift) / scale); + else + return (dst_data_t)s; + }; + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + int dir = 0; + if (rnn.exec_dir != r2l) { + for (int s = 0; s < rnn.dic; s++) { + dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] + = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s)); + } + dir = 1; + } + if (rnn.exec_dir != l2r) { + for (int s = 0; s < rnn.dic; s++) + switch (rnn.exec_dir) { + case bi_sum: + dst_layer_[dst_layer_d.blk_off(it, b, s)] + += maybe_deq(ws_states( + rnn.n_layer, dir, rnn.n_iter - it, b, s)); + break; + default: + dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] + = maybe_deq(ws_states( + rnn.n_layer, dir, rnn.n_iter - it, b, s)); + } + } + }); +} + +template <> +template +void ref_rnn_bwd_f32_t::copy_res_layer( + const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_, + const src_data_t *ws_states_, const float *ws_diff_states_) const { + auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0)); + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, + rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, + rnn.states_ws_ld); + + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + int dir = 0; + for (int s = 0; s < rnn.slc; s++) { + float *dst_addr = diff_src_layer_ + + diff_src_layer_d.blk_off( + (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it, + b, dir * rnn.slc + s); + float res = ws_diff_states(0, 0, rnn.n_states, it, b, s); + if (rnn.n_dir - 1) + res += ws_diff_states( + 0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s); + dst_addr[0] = res; + } + }); +} + +template +template +void _ref_rnn_common_t::copy_res_iter( + const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, + const src_data_t *ws_states_, float *ws_c_states_, + const float *ws_diff_states_) const { + auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1)); + AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + const bool quantize = pd()->with_dst_iter() + && pd()->dst_md(1)->data_type == data_type::u8 + && rnn.dt_conf != all_f32; + auto maybe_q = [&](float f) { + if (quantize) { + float qf = f * data_scale + data_shift; + return qz_a1b0()(qf); + } else + return (output_data_t)f; + }; + + const bool dequantize = pd()->with_dst_iter() + && pd()->dst_md(1)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_deq = [&](src_data_t s) { + if (dequantize) + return (output_data_t)(((float)s - data_shift) / data_scale); + else + return (output_data_t)s; + }; + if (dst_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb, + [&](int lay, int dir, int b) { + for (int s = 0; s < rnn.dic; s++) { + dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)] + = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s)); + } + if (pd()->cell_kind() == alg_kind::vanilla_lstm) + for (int s = 0; s < rnn.dic; s++) { + dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)] + = maybe_q(ws_c_states( + lay + 1, dir, rnn.n_iter, b, s)); + } + }); + } +} + +template <> +template +void ref_rnn_bwd_f32_t::copy_res_iter( + const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, + const src_data_t *ws_states_, float *ws_c_states_, + const float *ws_diff_states_) const { + auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1)); + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, + rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, + rnn.states_ws_ld); + if (diff_src_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int b) { + for (int s = 0; s < rnn.sic; s++) { + diff_src_iter_[diff_src_iter_d.blk_off( + lay, dir, state, b, s)] + = ws_diff_states(lay, dir, state, 0, b, s); + } + }); + } +} + +template +rnn_bias_prepare_sig((_ref_rnn_common_t::bias_prepare)) { + /* Original set of bias provided by the user */ + AOC b( + b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + /* Array of pointers initialized in packing */ + AOC bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); + AOC scratch_bias( + scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + + if (rnn.copy_bias) { + parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic, + [&](size_t i) { scratch_bias_[i] = b_[i]; }); + } + + for (int i = 0; i < rnn.n_layer; i++) { + for (int d = 0; d < rnn.n_dir; d++) { + int offset_bias = 0; + for (int p = 0; p < rnn.n_parts_bias; p++) { + bias(i, d, p) = rnn.copy_bias + ? (float *) &scratch_bias(i, d, offset_bias) + : (float *) &b(i, d, offset_bias); + offset_bias += rnn.parts_bias[p] * rnn.dic; + } + } + } + +} + +template +rnn_bias_finalize_sig( + (_ref_rnn_common_t::bias_finalize)) { + if (rnn.dt_conf != all_f32) { + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; + bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0; + for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++) + for (int j = 0; j < rnn.n_bias * rnn.dic; j++) { + size_t off = i * rnn.n_bias * rnn.dic + j; + float weights_scale + = scale_per_oc ? weights_scales[j] : weights_scales[0]; + scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off]) + * data_shift / (weights_scale * data_scale); + } + } +} + +template +rnn_weights_assign_sig((_ref_rnn_common_t::assign_packed_weights)) { + assert(md->format_kind == format_kind::rnn_packed); + const auto packed_desc = md->format_desc.rnn_packed_desc; + AOC weights(weights_, + rnn.n_layer, rnn.n_dir, packed_desc.n_parts); + + size_t offset_packed = 0; + for (int l = 0; l < rnn.n_layer; l++) + for (int d = 0; d < rnn.n_dir; d++) { + for (int p = 0; p < packed_desc.n_parts; p++) { + weights(l, d, p) = (weights_data_t *)&w_[offset_packed]; + offset_packed + += packed_desc.part_pack_size[p] / sizeof(weights_data_t); + } + } +} + +template +rnn_weights_assign_sig( + (_ref_rnn_common_t::assign_weights)) { + assert(md->format_kind == format_kind::blocked); + const auto &blk = md->format_desc.blocking; + /* Original set of weights provided by the user */ + AOC w(w_, + rnn.n_layer, rnn.n_dir, (int)blk.strides[1]); + /* Array of pointers for each part of weights */ + AOC weights(weights_, rnn.n_layer, rnn.n_dir, n_parts); + + for (int i = 0; i < rnn.n_layer; i++) + for (int d = 0; d < rnn.n_dir; d++) { + size_t offset_weights = 0; + for (int p = 0; p < n_parts; p++) { + weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights); + offset_weights += gates_per_part[p] * blk.strides[3]; + } + } +} + +//********************* Execution function *********************// +template +void _ref_rnn_common_t::execute_( + const exec_ctx_t &ctx) const { + const rnn_conf_t &rnn = this->pd()->rnn_; + auto input = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC_LAYER); + auto states = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC_ITER); + auto layer_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_LAYER); + auto iter_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_ITER); + auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + + auto dst_last_layer = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_LAYER) + : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_LAYER)); + auto dst_last_iter = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_ITER) + : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_ITER)); + + auto diff_dst_layer = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_LAYER); + auto diff_dst_iter = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_ITER); + + auto w_layer = reinterpret_cast(layer_weights_n_comp); + auto w_iter = reinterpret_cast(iter_weights_n_comp); + auto w_iter_comp = reinterpret_cast( + iter_weights_n_comp + rnn.weights_iter_comp_offset); + auto w_layer_comp = reinterpret_cast( + layer_weights_n_comp + rnn.weights_layer_comp_offset); + + auto scratchpad = this->scratchpad(ctx); + + auto ptr_wei_layer + = scratchpad.template get(key_rnn_ptrs_wei_layer); + auto ptr_wei_iter + = scratchpad.template get(key_rnn_ptrs_wei_iter); + auto ptr_bias = + scratchpad.template get(key_rnn_ptrs_bia); + + // fetchihg buffers from the workspace + // if no workspace was provided we use the scratchpad + char *scratch_ptr = scratchpad.template get(key_rnn_space); + char *ws_ptr = nullptr; + if (rnn.use_workspace) + ws_ptr = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE) + : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE)); + + char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr; + acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_); + src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_); + float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_); + float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_); + float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_); + float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_); + + auto diff_src_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_LAYER); + auto diff_src_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_ITER); + + auto diff_weights_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_LAYER); + auto diff_weights_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_ITER); + auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); + + // Fetching extra buffers from scratchpad + float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_); + + // initialize diff_states to 0 + if (aprop == prop_kind::backward) + array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float)); + + /* Pack(if using packed gemm API) or copy(if input arrays have bad leading + * dimension */ + (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias); + + (this->*weights_iter_assign_func)(rnn, pd()->weights_md(1), + rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic, + rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter, + rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter, + ptr_bias, bias, ws_bias); + (this->*weights_layer_assign_func)(rnn, pd()->weights_md(0), + rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc, + rnn.n_parts_weights_layer, rnn.parts_weights_layer, + rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias, + bias, ws_bias); + + (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp); + + // we first need to copy the initial states and input into ws + copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer); + if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, + (const float *)states, diff_dst_iter); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) + copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, + (const uint8_t *)states, diff_dst_iter); + else + assert(!"unimplemented"); + + // run the execution on the grid + (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias, + ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid, + diff_weights_layer, diff_weights_iter, diff_bias); + + // Finally we copy the results to the result buffers + if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states, + ws_diff_states); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8) + copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer, + ws_states, ws_diff_states); + else + assert(!"unimplemented"); + + if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states, + ws_c_states, ws_diff_states); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) + copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states, + ws_c_states, ws_diff_states); + else + assert(!"unimplemented"); +}; + +/* Fix for MSVS warning C4661 */ +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise); + +template struct _ref_rnn_common_t; +template struct _ref_rnn_common_t; +template struct _ref_rnn_common_t; + +#undef AOC +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp new file mode 100644 index 000000000000..6f449a9016f8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp @@ -0,0 +1,328 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_RNN_HPP +#define CPU_REF_RNN_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "../cpu_isa_traits.hpp" +#include "../gemm/os_blas.hpp" + +#include "cpu_rnn_pd.hpp" +#include "../cpu_primitive.hpp" +#include "rnn_utils.hpp" +#include "jit_uni_rnn_postgemm.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +float activation(float s, float alpha, float cliping, float dd); + +template +struct _ref_rnn_common_t : public cpu_primitive_t { + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type weights_data_t; + typedef typename utils::conditional::type acc_data_t; + + using class_name = _ref_rnn_common_t; + + typedef rnn_elemwise_sig((class_name::*elemwise_f)); + typedef rnn_cell_execution_sig((class_name::*cell_execution_f)); + typedef rnn_grid_execution_sig((class_name::*grid_execution_f)); + + typedef rnn_gemm_sig((class_name::*gemm_t)); + typedef rnn_bias_prepare_sig((class_name::*bias_prepare_t)); + typedef rnn_bias_finalize_sig((class_name::*bias_finalize_t)); + typedef rnn_weights_assign_sig((class_name::*weights_assign_t)); + + using base_pd_t = + typename utils::conditional::type; + + struct pd_t : public base_pd_t { + using base_pd_t::base_pd_t; + + DECLARE_COMMON_PD_T("ref:any", class_name); + + status_t init() { + using namespace prop_kind; + using namespace utils; + using namespace format_tag; + using namespace rnn_utils; + const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind; + + data_type_t src_layer_dt = this->desc()->src_layer_desc.data_type; + data_type_t weights_iter_dt + = this->desc()->weights_iter_desc.data_type; + data_type_t weights_layer_dt + = this->desc()->weights_layer_desc.data_type; + + bool ok = true + && one_of(cell_kind, alg_kind::vanilla_rnn, + alg_kind::vanilla_lstm, alg_kind::vanilla_gru, + alg_kind::gru_linear_before_reset) + && IMPLICATION(aprop == prop_kind::forward, + one_of(this->desc()->prop_kind, forward_training, + forward_inference)) + && IMPLICATION(aprop == backward, + one_of(this->desc()->prop_kind, backward)) + && src_layer_dt == src_type + && everyone_is( + weights_type, weights_iter_dt, weights_layer_dt) + && this->set_default_params() == status::success + && this->with_bias(); + if (!ok) + return status::unimplemented; + + init_conf(rnn_, *this->desc(), this->src_md(0), this->src_md(1), + this->weights_md(0), this->weights_md(1), this->dst_md(0)); + + if (rnn_.dt_conf == all_f32) + ok = ok && this->attr()->has_default_values(); + + // Set weights descriptors to desired format + memory_desc_t new_weights_layer_md = *this->weights_md(0); + CHECK(set_expected_desc(rnn_, new_weights_layer_md, false)); + if (this->weights_layer_md_.format_kind == format_kind::any) { + this->weights_layer_md_ = new_weights_layer_md; + } else if (this->weights_layer_md_.format_kind + == format_kind::rnn_packed) { + if (this->weights_layer_md_ != new_weights_layer_md) + return status::unimplemented; + } + + memory_desc_t new_weights_iter_md = *this->weights_md(1); + CHECK(set_expected_desc(rnn_, new_weights_iter_md, true)); + if (this->weights_iter_md_.format_kind == format_kind::any) { + this->weights_iter_md_ = new_weights_iter_md; + } else if (this->weights_iter_md_.format_kind + == format_kind::rnn_packed) { + if (this->weights_iter_md_ != new_weights_iter_md) + return status::unimplemented; + } + + CHECK(this->check_layout_consistency()); + + set_conf(rnn_, *this->desc(), this->weights_md(0), + this->weights_md(1), this->diff_weights_md(0), + this->diff_weights_md(1)); + + size_t scratchpad_sz{0}, ws_sz{0}; + get_scratchpad_and_workspace_sizes(rnn_, scratchpad_sz, ws_sz); + + // initialize the workspace if needed + if (rnn_.is_training) { + dims_t ws_dims = { (int)ws_sz }; + mkldnn_memory_desc_init_by_tag(&this->ws_md_, 1, ws_dims, + data_type::u8, format_tag::x); + } + + init_scratchpad(scratchpad_sz); + + return status::success; + } + + rnn_utils::rnn_conf_t rnn_; + + private: + void init_scratchpad(size_t scratchpad_sz) { + using namespace memory_tracking::names; + auto scratchpad = this->scratchpad_registry().registrar(); + scratchpad.book(key_rnn_space, sizeof(float) * scratchpad_sz, 4096); + + int max_nparts = this->cell_kind() == alg_kind::vanilla_gru ? 2 : 1; + int ptr_wei_sz = rnn_.n_layer * rnn_.n_dir * max_nparts; + scratchpad.book(key_rnn_ptrs_wei_layer, + sizeof(float *) * ptr_wei_sz); + scratchpad.book(key_rnn_ptrs_wei_iter, + sizeof(float *) * ptr_wei_sz); + scratchpad.book(key_rnn_ptrs_bia, + sizeof(float *) * ptr_wei_sz); + } + }; + + _ref_rnn_common_t(const pd_t *apd) + : cpu_primitive_t(apd, true), rnn_postgemm_(nullptr) { + /// @todo set max_feature_size assuming that we limit the number of + /// iterations and layer to one if slc != dic and sic != dic + /// respectively + + bias_preparation_func = &class_name::bias_prepare; + bias_finalization_func = &class_name::bias_finalize; + + auto set_gemm_funcs + = [](bool packed_gemm, gemm_t &g, weights_assign_t &a) { + if (packed_gemm) { + g = &class_name::packed_gemm; + a = &class_name::assign_packed_weights; + } else { + g = &class_name::gemm; + a = &class_name::assign_weights; + } + }; + set_gemm_funcs(pd()->rnn_.use_iter_packed_gemm, gemm_iter_func, + weights_iter_assign_func); + + set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func, + weights_layer_assign_func); + + switch (pd()->cell_kind()) { + case alg_kind::vanilla_lstm: + cell_func = &class_name::cell_execution; + if (aprop == prop_kind::forward) { + if (mayiuse(avx512_core)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd( + pd()->rnn_, pd()->attr()); + else if (mayiuse(avx2)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd( + pd()->rnn_, pd()->attr()); + else if (mayiuse(sse42)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd( + pd()->rnn_, pd()->attr()); + assert(rnn_postgemm_ != nullptr); + rnn_postgemm_->init(); + } + elemwise_func = &class_name::lstm_elemwise; + break; + case alg_kind::vanilla_rnn: // @todo switch on cell kind + cell_func = &class_name::cell_execution; + elemwise_func = &class_name::rnn_elemwise; + switch (pd()->activation_kind()) { + case alg_kind::eltwise_relu: + activation_func = &activation; + break; + case alg_kind::eltwise_tanh: + activation_func = &activation; + break; + case alg_kind::eltwise_logistic: + activation_func = &activation; + break; + default: break; + } + break; + case alg_kind::vanilla_gru: + cell_func = &class_name::cell_execution_gru; + break; + case alg_kind::gru_linear_before_reset: + cell_func = &class_name::cell_execution_gru_lbr; + elemwise_func = &class_name::gru_lbr_elemwise; + break; + default: break; + } + + grid_computation = &class_name::linear_execution; + + size_t scratchpad_size, workspace_size; + rnn_utils::set_offsets(pd()->rnn_, ws_gates_offset_, ws_states_offset_, + ws_c_states_offset_, ws_diff_states_offset_, + ws_grid_comp_offset_, ws_cell_comp_offset_, + ws_bias_offset_, scratchpad_size, workspace_size); + } + + ~_ref_rnn_common_t() {} + + // typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_(ctx); + return status::success; + } + +private: + void execute_(const exec_ctx_t &ctx) const; + rnn_grid_execution_sig(linear_execution); + rnn_cell_execution_sig(cell_execution); + rnn_cell_execution_sig(cell_execution_gru); + rnn_cell_execution_sig(cell_execution_gru_lbr); + rnn_elemwise_sig(rnn_elemwise); + rnn_elemwise_sig(lstm_elemwise); + rnn_elemwise_sig(gru_lbr_elemwise); + rnn_gemm_sig(gemm); + rnn_gemm_sig(packed_gemm); + rnn_bias_prepare_sig(bias_prepare); + rnn_bias_finalize_sig(bias_finalize); + rnn_weights_assign_sig(assign_weights); + rnn_weights_assign_sig(assign_packed_weights); + + float (*activation_func)(float dd, float s, float alpha, float cliping); + + void copy_init_layer(const rnn_utils::rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_diff_states_, + const src_data_t *xt_, const float *diff_dst_layer) const; + + template + void copy_init_iter(const rnn_utils::rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_c_states, float *ws_diff_states_, + const input_data_t *firstit_states_, + const float *diff_dst_iter) const; + + template + void copy_res_layer(const rnn_utils::rnn_conf_t &rnn, + dst_data_t *dst_layer_, float *diff_src_layer, + const src_data_t *ws_states_, const float *ws_diff_states_) const; + + template + void copy_res_iter(const rnn_utils::rnn_conf_t &rnn, + output_data_t *dst_iter_, float *diff_src_iter, + const src_data_t *ws_states_, float *ws_c_states, + const float *ws_diff_states_) const; + + void gates_reduction(const rnn_utils::rnn_conf_t &rnn, + const acc_data_t *ws_gates_, float *diff_bias_) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + size_t ws_gates_offset_; + size_t ws_states_offset_; + size_t ws_c_states_offset_; + size_t ws_bias_offset_; + size_t ws_diff_states_offset_; + size_t ws_grid_comp_offset_; + size_t ws_cell_comp_offset_; + jit_uni_rnn_postgemm_kernel *rnn_postgemm_; + + grid_execution_f grid_computation; + cell_execution_f cell_func; + + bias_prepare_t bias_preparation_func; + bias_finalize_t bias_finalization_func; + weights_assign_t weights_layer_assign_func; + weights_assign_t weights_iter_assign_func; + + gemm_t gemm_layer_func; + gemm_t gemm_iter_func; + elemwise_f elemwise_func; +}; + +using ref_rnn_fwd_f32_t = _ref_rnn_common_t; +using ref_rnn_bwd_f32_t = _ref_rnn_common_t; +using ref_rnn_fwd_u8s8_t = _ref_rnn_common_t; +} +} +} +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp new file mode 100644 index 000000000000..78cdedbae456 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp @@ -0,0 +1,380 @@ +/******************************************************************************* + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#ifndef CPU_RNN_REORDERS_HPP +#define CPU_RNN_REORDERS_HPP + +#include + +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "simple_q10n.hpp" +#include "cpu_reorder_pd.hpp" +#include "../gemm/os_blas.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct rnn_data_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_data_reorder", rnn_data_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && id.matches_one_of_tag(format_tag::tnc, format_tag::ldsnc) + && od == id; + if (!args_ok) return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + return safe_ptr_assign(*reorder_pd, _pd); + } + }; + +private: + typedef typename prec_traits::type in_data_t; + typedef typename prec_traits::type out_data_t; + + rnn_data_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const size_t nelems = input_d.nelems(); + const float scale = pd()->attr()->rnn_data_qparams_.scale_; + const float shift = pd()->attr()->rnn_data_qparams_.shift_; + + parallel_nd(nelems, [&](size_t i) { + float in = (float)input[input_d.off_l(i)] * scale + shift; + output[output_d.off_l(i)] = qz_a1b0()(in); + }); + + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct rnn_weights_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { +#if !USE_MKL_PACKED_GEMM + return status::unimplemented; +#endif + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && od.format_kind() == format_kind::rnn_packed + && od.rnn_packed_desc().format == mkldnn_ldigo_p + && od.rnn_packed_desc().n_parts == 1 + && attr != nullptr; + if (!args_ok) return status::invalid_arguments; + + format_tag_t itag = id.matches_one_of_tag( + format_tag::ldigo, format_tag::ldgoi); + if (itag == format_tag::undef) return status::invalid_arguments; + + const int mask = attr->rnn_weights_qparams_.mask_; + if (!utils::one_of(mask, 0, 3)) return status::unimplemented; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + _pd->itag_ = itag; + if (_pd->init() != success) { delete _pd; return unimplemented; } + return safe_ptr_assign(*reorder_pd, _pd); + } + + status_t init() { + status_t status = cpu_reorder_pd_t::init(); + if (status != status::success) return status; + + init_scratchpad(); + + return status::success; + } + + format_tag_t itag_ = mkldnn_format_tag_undef; + + private: + void init_scratchpad() { + const memory_desc_wrapper id(src_md()); + const size_t nelems = id.nelems(); + const auto &dims = id.dims(); + + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + size_t quantization_size = sizeof(int8_t) * nelems; + size_t reduction_size = itag_ == ldigo + ? sizeof(int32_t) * mkldnn_get_max_threads() * dims[0] + * dims[1] * dims[3] * dims[4] + : 0; + scratchpad.book( + key_reorder_rnn_weights_quantization, quantization_size); + scratchpad.book(key_reorder_rnn_weights_reduction, reduction_size); + } + }; + +private: + typedef typename prec_traits::type in_data_t; + typedef typename prec_traits::type out_data_t; + + rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { +#if USE_MKL_PACKED_GEMM + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(char *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const auto &dims = input_d.dims(); + + const int L = dims[0]; + const int D = dims[1]; + const int I = dims[2]; + const int G = dims[3]; + const int O = dims[4]; + + const bool is_igo = pd()->itag_ == format_tag::ldigo; + + /* Quantize input & compute compensation */ + auto quantized = (int8_t * __restrict)scratchpad(ctx).template get( + memory_tracking::names::key_reorder_rnn_weights_quantization); + auto reduction = (int32_t * __restrict)scratchpad(ctx).template get( + memory_tracking::names::key_reorder_rnn_weights_reduction); + float *comp = reinterpret_cast( + output + output_d.rnn_packed_desc().offset_compensation); + const float *scales = pd()->attr()->rnn_weights_qparams_.scales_; + const int mask = pd()->attr()->rnn_weights_qparams_.mask_; + + if (is_igo) { + int nthr = mkldnn_get_max_threads(); + int LD_nthr = nstl::min(L * D, nthr); + int I_nthr = nstl::min(I, nthr / LD_nthr); + parallel(nthr, [&](const int ithr, const int nthr) { + int LD_ithr = -1, LD_s = -1, LD_e = -1; + int I_ithr = -1, I_s = -1, I_e = -1; + if (ithr < LD_nthr * I_nthr) { + LD_ithr = ithr % LD_nthr; + I_ithr = ithr / LD_nthr; + balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e); + balance211(I, I_nthr, I_ithr, I_s, I_e); + } + int32_t *comp_ithr = reduction + I_ithr * L * D * G * O; + for (int ld = LD_s; ld < LD_e; ld++) { + for (int go = 0; go < G * O; go++) + comp_ithr[ld * G * O + go] = 0; + for (int i = I_s; i < I_e; i++) { + PRAGMA_OMP_SIMD() + for (int go = 0; go < G * O; go++) { + const float s = scales[(mask == 0) ? 0 : go]; + int8_t q = qz_b0()( + input[ld * I * G * O + i * G * O + go], s); + quantized[ld * I * G * O + i * G * O + go] + = (int32_t)q; + comp_ithr[ld * G * O + go] += (int32_t)q; + } + } + } + }); + parallel_nd(L * D * G * O, + [&](int s) { comp[s] = saturate(reduction[s]); }); + for (int i = 1; i < I_nthr; i++) { + parallel_nd(L * D * G * O, [&](int s) { + comp[s] += saturate( + reduction[i * L * D * G * O + s]); + }); + } + } else { + parallel_nd(L * D, G * O, [&](int ld, int go) { + int32_t compensation = 0; + const float s = scales[(mask == 0) ? 0 : go]; + PRAGMA_OMP_SIMD() + for (int i = 0; i < I; i++) { + int8_t q = qz_b0()( + input[ld * G * O * I + go * I + i], s); + compensation += (int32_t)q; + quantized[ld * G * O * I + go * I + i] = q; + } + comp[ld * G * O + go] = saturate(compensation); + }); + } + + /* Pack */ + auto off_igo = [&](int l, int d, int i, int g, int o) { + return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; + }; + auto off_goi = [&](int l, int d, int i, int g, int o) { + return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; + }; + int n_parts = output_d.rnn_packed_desc().n_parts; + const size_t *size_packed_cell + = output_d.rnn_packed_desc().part_pack_size; + const int *parts = output_d.rnn_packed_desc().parts; + const int n = output_d.rnn_packed_desc().n; + char *to_pack = output; + for (int l = 0; l < L; l++) { + for (int d = 0; d < D; d++) { + for (int p = 0; p < n_parts; p++) { + int g = (p > 0) ? parts[p - 1] : 0; + int m_p = parts[p] * O; + int k_p = I; + cblas_gemm_s8u8s32_pack(CblasColMajor, CblasAMatrix, + is_igo ? CblasNoTrans : CblasTrans, m_p, n, k_p, + &quantized[is_igo ? off_igo(l, d, 0, g, 0) : + off_goi(l, d, g, 0, 0)], + is_igo ? G * O : I, to_pack); + to_pack += size_packed_cell[p]; + } + } + } +#endif + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <> +struct rnn_weights_reorder_t + : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { +#if !USE_MKL_PACKED_GEMM + return status::unimplemented; +#endif + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == data_type::f32 + && od.data_type() == data_type::f32 + && od.format_kind() == format_kind::rnn_packed + && utils::one_of(od.rnn_packed_desc().format, + mkldnn_ldigo_p, mkldnn_ldgoi_p) + && attr->has_default_values(); + if (!args_ok) return status::invalid_arguments; + + format_tag_t itag = id.matches_one_of_tag( + format_tag::ldigo, format_tag::ldgoi); + if (itag == format_tag::undef) return status::invalid_arguments; + + const int mask = attr->rnn_weights_qparams_.mask_; + if (!utils::one_of(mask, 0, 3)) return status::unimplemented; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + _pd->itag_ = itag; + return safe_ptr_assign(*reorder_pd, _pd); + } + + format_tag_t itag_; + }; + +private: + rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { +#if USE_MKL_PACKED_GEMM + auto input = CTX_IN_MEM(const float *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(float *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const auto &dims = input_d.dims(); + const rnn_packed_desc_t &rnn_pdata = output_d.rnn_packed_desc(); + const int L = dims[0]; + const int D = dims[1]; + const int I = dims[2]; + const int G = dims[3]; + const int O = dims[4]; + + /* Pack */ + bool cross_case = false + || (pd()->itag_ == format_tag::ldigo + && rnn_pdata.format == mkldnn_ldgoi_p) + || (pd()->itag_ == format_tag::ldgoi + && rnn_pdata.format == mkldnn_ldigo_p); + auto trans = cross_case ? CblasTrans : CblasNoTrans; + int n_parts = rnn_pdata.n_parts; + const size_t *size_packed_cell = rnn_pdata.part_pack_size; + const int *parts = rnn_pdata.parts; + const int n = rnn_pdata.n; + + const bool is_igo = pd()->itag_ == format_tag::ldigo; + auto off_igo = [&](int l, int d, int i, int g, int o) { + return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; + }; + auto off_goi = [&](int l, int d, int i, int g, int o) { + return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; + }; + for (int l = 0; l < L; l++) { + for (int d = 0; d < D; d++) { + for (int p = 0; p < n_parts; p++) { + int g = (p > 0) ? parts[p - 1] : 0; + int m_p = is_igo ? parts[p] * O : I; + int k_p = is_igo ? I : parts[p] * O; + int ld = is_igo ? G * O : I; + cblas_sgemm_pack(CblasColMajor, CblasAMatrix, trans, m_p, n, + k_p, 1.0f, &input[is_igo ? off_igo(l, d, 0, g, 0) : + off_goi(l, d, 0, g, 0)], + ld, output); + output += size_packed_cell[p] / sizeof(float); + } + } + } +#endif + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp new file mode 100644 index 000000000000..1d60415cbcdb --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp @@ -0,0 +1,426 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" +#include "rnn_utils.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace rnn_utils; +using namespace format_tag; +using namespace rnn_packed_format; +using namespace data_type; + +bool rnn_utils::is_ldigo(const memory_desc_wrapper &md) { + if (md.format_kind() != format_kind::blocked) + return false; + + auto blk = md.blocking_desc(); + auto str = blk.strides; + auto dims = md.dims(); + return md.ndims() == 5 && blk.inner_nblks == 0 && str[4] == 1 + && str[3] == dims[4] && str[1] == str[2] * dims[2] + && str[0] == str[1] * dims[1]; +}; + +bool rnn_utils::is_ldgoi(const memory_desc_wrapper &md) { + if (md.format_kind() != format_kind::blocked) + return false; + + auto blk = md.blocking_desc(); + auto str = blk.strides; + auto dims = md.dims(); + return md.ndims() == 5 && blk.inner_nblks == 0 && str[2] == 1 + && str[3] == dims[4] * str[4] && str[1] == str[3] * dims[3] + && str[0] == str[1] * dims[1]; +}; + +void rnn_utils::init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &src_layer_d, + const memory_desc_wrapper &src_iter_d, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &dst_layer_d) { + rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + rnn.is_training = utils::one_of( + rd.prop_kind, prop_kind::forward_training, prop_kind::backward); + rnn.is_lbr = rd.cell_desc.cell_kind == mkldnn_gru_linear_before_reset; + + switch (rd.direction) { + case mkldnn_unidirectional_left2right: rnn.exec_dir = l2r; break; + case mkldnn_unidirectional_right2left: rnn.exec_dir = r2l; break; + case mkldnn_bidirectional_concat: rnn.exec_dir = bi_concat; break; + case mkldnn_bidirectional_sum: rnn.exec_dir = bi_sum; break; + default: break; + } + + if (everyone_is(f32, src_layer_d.data_type(), dst_layer_d.data_type(), + weights_layer_d.data_type())) + rnn.dt_conf = all_f32; + else if (dst_layer_d.data_type() == u8) { + if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8)) + rnn.dt_conf = u8u8u8u8; + else + rnn.dt_conf = f32u8f32u8; + } else { + if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8)) + rnn.dt_conf = u8u8u8f32; + else + rnn.dt_conf = f32u8f32f32; + } + + rnn.n_layer = weights_layer_d.dims()[0]; + rnn.n_iter = src_layer_d.dims()[0]; + rnn.n_dir = weights_layer_d.dims()[1]; + rnn.n_gates = weights_layer_d.dims()[3]; + rnn.n_states = mkldnn_rnn_cell_get_states_count(&rd.cell_desc); + rnn.n_bias = rnn.n_gates + rnn.is_lbr; + rnn.mb = src_layer_d.dims()[1]; + rnn.sic = weights_iter_d.dims()[2]; + rnn.slc = weights_layer_d.dims()[2]; + rnn.dic = weights_layer_d.dims()[4]; + rnn.dlc = dst_layer_d.dims()[2]; + + rnn.gates_ld = rnn.dic * rnn.n_gates; + rnn.gates_nld = rnn.mb; + rnn.states_nld = rnn.mb; + + /* Set the correct number of weights parts */ + bool is_orig_gru = rd.cell_desc.cell_kind == alg_kind::vanilla_gru; + rnn.n_parts_weights_layer = 1; + rnn.parts_weights_layer[0] = rnn.n_gates; + rnn.parts_weights_layer[1] = 0; + + rnn.n_parts_weights_iter = is_orig_gru ? 2 : 1; + rnn.parts_weights_iter[0] = is_orig_gru ? 2 : rnn.n_gates; + rnn.parts_weights_iter[1] = is_orig_gru ? 1 : 0; + + rnn.n_parts_bias = 1; + rnn.parts_bias[0] = rnn.n_bias; + rnn.parts_bias[1] = 0; + + /* Decide wich gemm implementation to use: packed/nonpacked jit/cblas + * and if to mergre gemm across iterations */ + bool is_int8 = rnn.dt_conf != all_f32; + rnn.merge_gemm_layer = ((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd) + || is_int8; + bool is_gru = utils::one_of(rd.cell_desc.cell_kind, alg_kind::vanilla_gru, + alg_kind::gru_linear_before_reset); + rnn.merge_gemm_iter = !(rnn.is_fwd || is_gru) || is_int8; + bool is_inference = !rnn.is_training; + + rnn.use_jit_gemm = !mayiuse(avx512_mic) + && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100)) + || (rnn.is_training && rnn.dic < 500)); + + /* Decide to copy bias */ + rnn.copy_bias = rnn.dt_conf != all_f32; + +#if USE_MKL_PACKED_GEMM + rnn.use_layer_packed_gemm + = (weights_layer_d.format_kind() == format_kind::any + && rnn.slc > 760 && rnn.dic > 760 && is_inference) + || is_int8; // packed gemm is the only supported option for int8 + rnn.use_iter_packed_gemm + = (weights_iter_d.format_kind() == format_kind::any && rnn.sic > 760 + && rnn.dic > 760 && is_inference) + || is_int8; +#else + rnn.use_layer_packed_gemm = false; + rnn.use_iter_packed_gemm = false; +#endif + + /* Set packed gemm sizes */ + if (rnn.use_layer_packed_gemm) { + rnn.weights_layer_pack_size = 0; + for (int p = 0; p < rnn.n_parts_weights_layer; p++) { + int m_p = rnn.is_fwd + ? (rnn.parts_weights_layer[p] * rnn.dic) + : rnn.slc; + int k_p = rnn.is_fwd + ? rnn.slc + : (rnn.parts_weights_layer[p] * rnn.dic); + int n_p = rnn.merge_gemm_layer ? rnn.mb * rnn.n_iter : rnn.mb; + +#if USE_MKL_PACKED_GEMM + if (rnn.dt_conf == all_f32) + rnn.part_weights_layer_pack_size[p] = cblas_sgemm_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); + else + rnn.part_weights_layer_pack_size[p] + = cblas_gemm_s8u8s32_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); +#else + UNUSED(m_p); + UNUSED(k_p); + UNUSED(n_p); + rnn.part_weights_layer_pack_size[p] = 0; +#endif + rnn.weights_layer_pack_size += rnn.n_layer * rnn.n_dir + * rnn.part_weights_layer_pack_size[p]; + } + rnn.weights_layer_comp_offset = rnn.weights_layer_pack_size; + rnn.weights_layer_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer + * rnn.n_dir * rnn.n_gates * rnn.dlc * sizeof(float); + } + + if (rnn.use_iter_packed_gemm) { + rnn.weights_iter_pack_size = 0; + for (int p = 0; p < rnn.n_parts_weights_iter; p++) { + int m_p = rnn.is_fwd ? (rnn.parts_weights_iter[p] * rnn.dic) : + rnn.sic; + int k_p = rnn.is_fwd ? rnn.sic : + (rnn.parts_weights_iter[p] * rnn.dic); + int n_p = rnn.merge_gemm_iter ? rnn.mb * rnn.n_iter : rnn.mb; + +#if USE_MKL_PACKED_GEMM + if (rnn.dt_conf == all_f32) + rnn.part_weights_iter_pack_size[p] = cblas_sgemm_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); + else + rnn.part_weights_iter_pack_size[p] + = cblas_gemm_s8u8s32_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); +#else + UNUSED(m_p); + UNUSED(k_p); + UNUSED(n_p); + rnn.part_weights_iter_pack_size[p] = 0; +#endif + rnn.weights_iter_pack_size += rnn.n_layer * rnn.n_dir + * rnn.part_weights_iter_pack_size[p]; + } + rnn.weights_iter_comp_offset = rnn.weights_iter_pack_size; + rnn.weights_iter_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer + * rnn.n_dir * rnn.n_gates * rnn.dic * sizeof(float); + } + +} + +void rnn_utils::set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &diff_weights_layer_d, + const memory_desc_wrapper &diff_weights_iter_d) { + + /* Set leading dimensions for input weights arrays depending on input format + */ + rnn.weights_layer_is_packed + = weights_layer_d.format_kind() == format_kind::rnn_packed; + rnn.weights_iter_is_packed + = weights_iter_d.format_kind() == format_kind::rnn_packed; + + auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) { + ld = 0; nld = 0; + if (md.is_blocking_desc()) { + if (is_ldigo(md)) { + ld = (int)md.blocking_desc().strides[2]; + nld = md.dims()[2]; + } else if (is_ldgoi(md)) { + ld = (int)md.blocking_desc().strides[4]; + nld = md.dims()[3] * md.dims()[4]; + } else + assert(!"unsupported weights format"); + } + }; + set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld); + set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld); + if (!rnn.is_fwd) { + set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld, + rnn.diff_weights_layer_nld); + set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld, + rnn.diff_weights_iter_nld); + } + + int sizeof_states_dt + = rnn.dt_conf == all_f32 ? sizeof(float) : sizeof(uint8_t); + rnn.states_ws_ld + = get_good_ld(nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dic)), + sizeof_states_dt); + rnn.gates_ws_ld = get_good_ld(rnn.gates_ld, sizeof(float)); + + /* Set workspace sizes to store: + * states to copmute a pass + * diff states to copmute bwd pass (training only) + * intermediate results from the gates + */ + rnn.use_workspace = rnn.is_training; + rnn.ws_states_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir + * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * sizeof_states_dt; + bool is_lstm = rd.cell_desc.cell_kind == mkldnn_vanilla_lstm; + rnn.ws_c_states_size = is_lstm + ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb + * rnn.states_ws_ld * sizeof(float) + : 0; + rnn.ws_diff_states_size = rnn.is_training + ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) + * (rnn.n_states + 1) * rnn.mb * rnn.states_ws_ld + * sizeof(float) + : (size_t)0; + rnn.ws_gates_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.mb + * rnn.gates_ws_ld * sizeof(float); + + /* set other sizes */ + rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dic * sizeof(float); + rnn.ws_cell_comp_size + = rnn.is_lbr || rnn.dt_conf != all_f32 + ? (size_t) rnn.gates_nld * rnn.gates_ws_ld * sizeof(float) + : 0; + rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer + * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float); + rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic + * sizeof(float); +} + +int rnn_utils::get_good_ld(int dim, int sizeof_dt) { + // we want matrices leading dimentions to be 64-byte aligned, + // and not divisible by 256 to avoid 4K aliasing effects + int ld = rnd_up(dim, 64 / sizeof_dt); + return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld; +} + +void rnn_utils::set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, + size_t &ws_states_offset, size_t &ws_c_states_offset, + size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset, + size_t &ws_cell_comp_offset, size_t &ws_bias_offset, + size_t &scratchpad_size, size_t &workspace_size) { + + const size_t page_size = 4096; // 2097152; + size_t current_offset; + /* Mandatory workspaces: go to workspace if use_workspace, scratchpad + * otherwise */ + current_offset = 0; // assumes the workspace base pointer is page aligned + ws_gates_offset = current_offset; + current_offset += rnn.ws_gates_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_states_offset = current_offset; + current_offset += rnn.ws_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_c_states_offset = current_offset; + current_offset += rnn.ws_c_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_diff_states_offset = current_offset; + current_offset += rnn.ws_diff_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_grid_comp_offset = current_offset; + current_offset += rnn.ws_grid_comp_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_cell_comp_offset = current_offset; + current_offset += rnn.ws_cell_comp_size; + + workspace_size = rnn.use_workspace ? current_offset : 0; + + /* Optional scratchpads */ + // Assumes the scratchpad base pointer is page aligned. + // If use_workspace, the following goes to scratchpad alone, + // otherwise, all goes to scratchpad and continue incrementing offset + current_offset = rnn.use_workspace ? 0 : current_offset; + + if (rnn.copy_bias) { + current_offset = utils::rnd_up(current_offset, page_size); + ws_bias_offset = current_offset; + current_offset += rnn.ws_bias_size; + } + + scratchpad_size = current_offset; +} + +void rnn_utils::get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn, + size_t &scratchpad_size, size_t &workspace_size) { + size_t ws_gates_offset, ws_states_offset, ws_c_states_offset, + ws_diff_states_offset, ws_grid_comp_offset, ws_cell_comp_offset, + ws_bias_offset; + set_offsets(rnn, ws_gates_offset, ws_states_offset, ws_diff_states_offset, + ws_c_states_offset, ws_grid_comp_offset, ws_cell_comp_offset, + ws_bias_offset, scratchpad_size, workspace_size); +} + +status_t rnn_utils::set_good_strides( + memory_desc_t &weights_md, format_tag_t tag) { + auto &strides = weights_md.format_desc.blocking.strides; + auto dims = weights_md.dims; + + if (tag == ldigo) { + strides[2] = rnn_utils::get_good_ld((int)strides[2], + (int)types::data_type_size(weights_md.data_type)); + strides[1] = dims[2] * strides[2]; + strides[0] = dims[1] * strides[1]; + } else if (tag == ldgoi) { + strides[4] = rnn_utils::get_good_ld((int)strides[4], + (int)types::data_type_size(weights_md.data_type)); + strides[3] = dims[4] * strides[4]; + strides[1] = dims[3] * strides[3]; + strides[0] = dims[1] * strides[1]; + } else + return status::unimplemented; + + return status::success; +} + +status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn, + memory_desc_t &weights_md, bool is_iter) { + using namespace format_tag; + bool use_packed_gemm = is_iter + ? rnn.use_iter_packed_gemm + : rnn.use_layer_packed_gemm; + if (use_packed_gemm) { + weights_md.format_kind = format_kind::rnn_packed; + rnn_packed_desc_t &rnn_pdata = weights_md.format_desc.rnn_packed_desc; + rnn_pdata.format = rnn.is_fwd ? mkldnn_ldigo_p : mkldnn_ldgoi_p; + if (is_iter) { + rnn_pdata.n = rnn.mb; + rnn_pdata.n_parts = rnn.n_parts_weights_iter; + array_copy(rnn_pdata.parts, rnn.parts_weights_iter, + MKLDNN_RNN_MAX_N_PARTS); + array_copy(rnn_pdata.part_pack_size, + rnn.part_weights_iter_pack_size, MKLDNN_RNN_MAX_N_PARTS); + rnn_pdata.offset_compensation = rnn.weights_iter_comp_offset; + rnn_pdata.size = rnn.weights_iter_pack_size; + } else { + rnn_pdata.n = rnn.merge_gemm_layer ? rnn.n_iter * rnn.mb : rnn.mb; + rnn_pdata.n_parts = rnn.n_parts_weights_layer; + array_copy(rnn_pdata.parts, rnn.parts_weights_layer, + MKLDNN_RNN_MAX_N_PARTS); + array_copy(rnn_pdata.part_pack_size, + rnn.part_weights_layer_pack_size, MKLDNN_RNN_MAX_N_PARTS); + rnn_pdata.offset_compensation = rnn.weights_layer_comp_offset; + rnn_pdata.size = rnn.weights_layer_pack_size; + } + } else { + CHECK(memory_desc_init_by_tag(weights_md, rnn.is_fwd ? ldigo : ldgoi)); + // Adjust strides for good leading dimension in GEMM + CHECK(set_good_strides(weights_md, rnn.is_fwd ? ldigo : ldgoi)); + } + return status::success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp new file mode 100644 index 000000000000..99eb787a64ed --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp @@ -0,0 +1,225 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef RNN_UTILS_HPP +#define RNN_UTILS_HPP + +#include "mkldnn.h" + +#include "cpu_rnn_pd.hpp" + + +#define rnn_elemwise_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, acc_data_t *ws_gates_, \ + src_data_t *states_t_l_, float *c_states_t_l_, \ + src_data_t *states_tm1_l_, float *c_states_tm1_l_, \ + float *diff_states_t_l_, float *diff_states_t_lp1_, \ + float *diff_states_tp1_l_, float *bias_, float *ws_grid_, \ + float *ws_cell_) const + +#define rnn_cell_execution_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, src_data_t *states_t_l_, \ + float *c_states_t_l_, float *diff_states_t_l_, \ + weights_data_t **w_layer_, weights_data_t **w_iter_, \ + float **bias_, src_data_t *states_t_lm1_, \ + src_data_t *states_tm1_l_, float *c_states_tm1_l_, \ + float *diff_states_t_lp1_, float *diff_states_tp1_l_, \ + float *diff_w_layer_, float *diff_w_iter_, float *diff_bias_, \ + acc_data_t *ws_gates_, float *ws_grid_, float *ws_cell_) const + +#define rnn_grid_execution_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, weights_data_t **weights_layer_, \ + weights_data_t **weights_states_, float **bias_, \ + src_data_t *ws_states_, float *ws_c_states_, \ + float *ws_diff_states_, acc_data_t *ws_gates_, float *ws_cell_, \ + float *ws_grid_, float *diff_weights_layer_, \ + float *diff_weights_iter_, float *diff_bias_) const + +#define rnn_gemm_sig(f) \ + void f(const char transA, const char transB, int m, int n, int k, \ + const float alpha, const weights_data_t *a_, const int ldA, \ + const src_data_t *b_, const int ldB, const float beta, \ + acc_data_t *c_, const int ldC) const + +#define rnn_bias_prepare_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, float **bias_, const float *b_, \ + float *scratch_bias_) const + +#define rnn_bias_finalize_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, float *scratch_bias_, \ + const float *w_iter_comp, const float *w_layer_comp) const + +#define rnn_weights_assign_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, const memory_desc_t *md, int nld, \ + int ld, int OC_size, int IC_size, const int n_parts, \ + const int *gates_per_part, const size_t *part_weights_pack_size, \ + weights_data_t **weights_, const weights_data_t *w_, \ + float **bias_, const float *b_, float *scratch_bias_) const + + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace rnn_utils { + +using namespace mkldnn::impl::utils; + +enum execution_direction_t { + l2r, + r2l, + bi_concat, + bi_sum, +}; + +enum data_type_conf_t { + all_f32, + u8u8u8f32, + f32u8f32f32, + u8u8u8u8, + f32u8f32u8 +}; + +struct rnn_conf_t { + execution_direction_t exec_dir; + data_type_conf_t dt_conf; + int n_layer, n_iter, n_dir, n_gates, n_states; + int mb; + int slc, sic, dic, dlc; + int gates_ld, gates_nld, gates_ws_ld; + int n_parts_weights_layer, parts_weights_layer[MKLDNN_RNN_MAX_N_PARTS]; + int n_parts_weights_iter, parts_weights_iter[MKLDNN_RNN_MAX_N_PARTS]; + int n_bias, n_parts_bias, parts_bias[MKLDNN_RNN_MAX_N_PARTS]; + size_t part_weights_iter_pack_size[MKLDNN_RNN_MAX_N_PARTS], + part_weights_layer_pack_size[MKLDNN_RNN_MAX_N_PARTS]; + bool weights_layer_is_packed, weights_iter_is_packed; + /* Size of packed data in bytes */ + size_t weights_layer_comp_offset, weights_layer_pack_size, + weights_iter_comp_offset, weights_iter_pack_size; + + bool copy_bias; + int weights_layer_ld, weights_layer_nld; + int diff_weights_layer_ld, diff_weights_layer_nld; + int weights_iter_ld, weights_iter_nld; + int diff_weights_iter_ld, diff_weights_iter_nld; + int states_nld, states_ws_ld; + int weights_iter_compensation_size, weights_layer_compensation_size; + bool is_fwd, is_training, is_lbr; + bool use_workspace; + + /* Size of workspace for each tensor in bytes */ + size_t ws_gates_size, ws_states_size, ws_c_states_size, ws_diff_states_size, + ws_cell_comp_size, ws_grid_comp_size, ws_per_cell, ws_bias_size; + bool merge_gemm_iter, merge_gemm_layer, use_jit_gemm, use_layer_packed_gemm, + use_iter_packed_gemm; +}; + +bool is_ldigo(const memory_desc_wrapper &md); +bool is_ldgoi(const memory_desc_wrapper &md); + +int get_good_ld(int dim, int sizeof_dt); + +void init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &src_layer_d, + const memory_desc_wrapper &src_iter_d, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &dst_layer_d); + +void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &diff_weights_layer_d, + const memory_desc_wrapper &diff_weights_iter_d); + +void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, + size_t &ws_h_state_offset, size_t &ws_c_state_offset, + size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset, + size_t &ws_cell_comp_offset, size_t &ws_bias_offset, + size_t &scratchpad_size, size_t &workspace_size); + +void get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn, + size_t &scratchpad_size, size_t &workspace_size); +status_t set_expected_desc( + rnn_conf_t &rnn, memory_desc_t &weights_md, bool is_iter); +status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag); + +template +struct ws_gates_aoc { + ws_gates_aoc(const rnn_conf_t &rnn, T *data) + : gates_(data, rnn.gates_nld, rnn.gates_ws_ld), DIC_(rnn.dic) {} + T &operator()(int batch, int gate, int dic) { + return gates_(batch, gate * DIC_ + dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator gates_; + int DIC_; +}; +using ws_gates_aoc_t = ws_gates_aoc; +using ws_gates_aoc_s32_t = ws_gates_aoc; + +struct bias_aoc_t { + bias_aoc_t(const rnn_conf_t &rnn, const float *data) + : bias_(data, rnn.n_bias, rnn.dic) {} + const float &operator()(int bias_n, int dic) { return bias_(bias_n, dic); } + +private: + mkldnn::impl::utils::array_offset_calculator bias_; +}; + +template +struct ws_states_aoc { + ws_states_aoc(const rnn_conf_t &rnn, T *data) + : state_(data, rnn.states_nld, rnn.states_ws_ld) {} + T &operator()(int batch, int dic) { return state_(batch, dic); } + +private: + mkldnn::impl::utils::array_offset_calculator state_; +}; +using ws_states_aoc_t = ws_states_aoc; +using ws_states_aoc_u8_t = ws_states_aoc; + +struct ws_diff_states_aoc_t { + ws_diff_states_aoc_t(const rnn_conf_t &rnn, float *data) + : diff_states_(data, rnn.n_states + 1, rnn.n_iter + 1, rnn.states_nld, + rnn.states_ws_ld) {} + float &operator()(int state_n, int batch, int dic) { + return diff_states_(state_n, 0, batch, dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator diff_states_; +}; + +struct ws_diff_w_iter_aoc_t { + ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data) + : diff_weights_iter_( + data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld) + , DIC_(rnn.dic) {} + float &operator()(int sic, int gate, int dic) { + return diff_weights_iter_(sic, gate * DIC_ + dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator diff_weights_iter_; + int DIC_; +}; +} +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp new file mode 100644 index 000000000000..0420f87aa52f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp @@ -0,0 +1,126 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_thread.hpp" + +#include "simple_concat.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +template +status_t simple_concat_t::execute(const exec_ctx_t &ctx) const { + auto scratchpad = this->scratchpad(ctx); + auto iptrs = scratchpad.template get(key_concat_iptrs); + auto optrs = scratchpad.template get(key_concat_optrs); + auto nelems_to_copy = scratchpad.template get(key_concat_nelems); + auto is = scratchpad.template get(key_concat_istrides); + + const int num_arrs = pd()->n_inputs(); + const int *perm = pd()->perm_, *iperm = pd()->iperm_; + const int concat_dim = pd()->concat_dim(); + auto o_base_ptr = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + for (int a = 0; a < num_arrs; ++a) { + const memory_desc_wrapper i_d(pd()->src_md(a)); + const memory_desc_wrapper o_d(pd()->src_image_md(a)); + + iptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a) + + i_d.blk_off(0); + optrs[a] = o_base_ptr + o_d.blk_off(0); + nelems_to_copy[a] = pd()->nelems_to_concat(i_d); + for (int i = 0; i < MKLDNN_MAX_NDIMS; i++) { + if (i < perm[concat_dim]) + is[a][i] = size_t(i_d.blocking_desc().strides[iperm[i]]); + else + is[a][i] = 0; + } + } + + const memory_desc_wrapper o_d(pd()->src_image_md(0)); + + strides_t os = { 0 }; + for (int i = 0; i < perm[concat_dim]; i++) + os[i] = o_d.blocking_desc().strides[iperm[i]]; + + dims_t phys_dims; + for (size_t i = 0; i < sizeof(phys_dims)/sizeof(phys_dims[0]); i++) + phys_dims[i] = (i < (size_t)perm[concat_dim]) + ? o_d.dims()[iperm[i]] / pd()->blocks_[iperm[i]] : 1; + + if (perm[concat_dim] == 0) { + for (int a = 0; a < num_arrs; ++a) { + const data_t *i = &iptrs[a][0]; + data_t *o = &optrs[a][0]; + parallel_nd((ptrdiff_t)nelems_to_copy[a], + [&](ptrdiff_t e) { o[e] = i[e]; }); + } + } else { + parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3], + phys_dims[4], num_arrs, + [&](dim_t n0, dim_t n1, dim_t n2, dim_t n3, dim_t n4, int a) { + // XXX: this code may access uninitialized values in is[*][0-4] -- + // that's why we have to set them to zero although this is + // probably benign + size_t in_off = is[a][0] * n0 + is[a][1] * n1 + is[a][2] * n2 + + is[a][3] * n3 + is[a][4] * n4; + size_t out_off = os[0] * n0 + os[1] * n1 + os[2] * n2 + + os[3] * n3 + os[4] * n4; + const data_t *i = &iptrs[a][in_off]; + data_t *o = &optrs[a][out_off]; +#if defined(__GNUC__) && !defined(__INTEL_COMPILER) + // The code below performs data copying: o[e] = i[e] + // and uses a workaround to make GNU compilers optimize it + uint8_t *ptro = reinterpret_cast(o); + const uint8_t *ptri = reinterpret_cast(i); + const dim_t main_part = + nelems_to_copy[a] * sizeof(data_t) / sizeof(uint32_t); + const dim_t tail_part = + nelems_to_copy[a] % sizeof(data_t) / sizeof(uint32_t); + + PRAGMA_OMP_SIMD() + for (dim_t e = 0; e < main_part; ++e) { + *(reinterpret_cast(ptro)) + = *(reinterpret_cast(ptri)); + ptro += sizeof(uint32_t); + ptri += sizeof(uint32_t); + } + for (dim_t e = 0; e < tail_part; ++e) { + *ptro = *ptri; + ++ptro; + ++ptri; + } +#else + PRAGMA_OMP_SIMD() + for (dim_t e = 0; e < nelems_to_copy[a]; ++e) o[e] = i[e]; +#endif + }); + } + + return status::success; +} + +template struct simple_concat_t; +template struct simple_concat_t; +template struct simple_concat_t; +template struct simple_concat_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp new file mode 100644 index 000000000000..057cc3c4c7e1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp @@ -0,0 +1,155 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SIMPLE_CONCAT_HPP +#define SIMPLE_CONCAT_HPP + +#include "memory_tracking.hpp" + +#include "cpu_concat_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct simple_concat_t: public cpu_primitive_t { + struct pd_t: public cpu_concat_pd_t { + using cpu_concat_pd_t::cpu_concat_pd_t; + + pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) { + int ndims = rhs.dst_md_.ndims; + utils::array_copy(perm_, rhs.perm_, ndims); + utils::array_copy(iperm_, rhs.iperm_, ndims); + utils::array_copy(blocks_, rhs.blocks_, ndims); + } + + DECLARE_CONCAT_PD_T("simple:any", simple_concat_t); + + status_t init() { + const memory_desc_wrapper dst_d(dst_md()); + bool ok = true + && cpu_concat_pd_t::init() == status::success + && dst_d.ndims() <= 6; + if (!ok) return status::unimplemented; + + for (size_t i = 0; i < src_mds_.size(); ++i) { + const memory_desc_wrapper i_d(&src_mds_[i]); + const memory_desc_wrapper o_d(&src_image_mds_[i]); + + const int ignore_strides = 0; + + ok = ok + && utils::everyone_is(data_type, i_d.data_type(), + o_d.data_type()) + && utils::everyone_is(format_kind::blocked, + i_d.format_kind(), o_d.format_kind()) + && types::blocking_desc_is_equal(i_d.blocking_desc(), + o_d.blocking_desc(), ignore_strides) + && types::blocking_desc_is_equal(i_d.blocking_desc(), + dst_d.blocking_desc(), ignore_strides) + && !i_d.is_additional_buffer(); + if (!ok) return status::unimplemented; + } + + dst_d.compute_blocks(blocks_); + format_perm(); + + // start dim is the first dimension after which the concatenation + // would happen contiguously + const int start_dim = perm_[concat_dim()]; + + // check that contiguous part is indeed contiguous (i.e. dense) + if (nelems_to_concat(dst_d) != + dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()] + * dst_d.blocking_desc().strides[concat_dim()]) + return status::unimplemented; + + // check that all inputs have the same strides for the + // contiguous part [concat_dim .. ndims] for the *major* dims. + // the block part is already checked above + for (size_t i = 0; i < src_mds_.size(); ++i) { + const memory_desc_wrapper i_d(&src_mds_[i]); + for (int d = start_dim; d < dst_d.ndims(); ++d) { + if (dst_d.blocking_desc().strides[iperm_[d]] + != i_d.blocking_desc().strides[iperm_[d]]) + return status::unimplemented; + } + } + + init_scratchpad(); + + return status::success; + } + + int perm_[MKLDNN_MAX_NDIMS] {}; + int iperm_[MKLDNN_MAX_NDIMS] {}; + dims_t blocks_ {}; + + dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const { + const int ndims = data_d.ndims(); + + dim_t nelems = 1; + for (int i = perm_[concat_dim()]; i < ndims; i++) + nelems *= data_d.dims()[iperm_[i]] / blocks_[iperm_[i]]; + for (int i = 0; i < ndims; i++) + nelems *= blocks_[i]; + + return nelems; + } + + private: + void format_perm() { + const memory_desc_wrapper dst_d(dst_md()); + const int ndims = dst_d.ndims(); + + strides_t strides; + utils::array_copy(strides, dst_d.blocking_desc().strides, ndims); + for (int i = 0; i < ndims; i++) iperm_[i] = i; + + utils::simultaneous_sort(strides, iperm_, ndims, + [](stride_t a, stride_t b) { return b - a; }); + + for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i; + } + + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs()); + scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs()); + scratchpad.book(key_concat_nelems, sizeof(dim_t) * n_inputs()); + scratchpad.book(key_concat_istrides, + sizeof(strides_t) * n_inputs()); + } + }; + + simple_concat_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override; + + typedef typename prec_traits::type data_t; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp new file mode 100644 index 000000000000..e6c3b8d7afcd --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp @@ -0,0 +1,98 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SIMPLE_Q10N_HPP +#define CPU_SIMPLE_Q10N_HPP + +#include + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::math; + +template +inline out_t round_and_saturate(float f) +{ return math::saturate(out_round(f)); } + +/* Quantization with alpha == 1 and beta == 0 */ +template +struct qz_a1b0 { + out_t operator()(in_t in) + { return round_and_saturate((float)in); } +}; + +template +struct qz_a1b0::value + && !is_subset::value + >::type> { + out_t operator()(in_t in) { return math::saturate(in); } +}; + +template +struct qz_a1b0::value>::type> { + out_t operator()(in_t in) { return (out_t)in; } +}; + +/* Quantization with alpha == 1 */ +template struct qz_a1 { + out_t operator()(in_t in, out_t out, float beta) + { return round_and_saturate((float)in + beta * out); } +}; + +template struct qz_a1 { + float operator()(in_t in, float out, float beta) + { return (float)in + beta * out; } +}; + +/* Quantization with beta == 0 */ +template struct qz_b0 { + out_t operator()(in_t in, float alpha) + { return round_and_saturate(alpha * in); } +}; + +template struct qz_b0 { + float operator()(in_t in, float alpha) { return alpha * in; } +}; + +/* Quantization */ +template struct qz { + out_t operator()(in_t in, out_t out, float alpha, float beta) { + return round_and_saturate( + alpha * in + (beta ? beta * out : 0)); + } +}; + +template struct qz { + float operator()(in_t in, float out, float alpha, float beta) + { return alpha * in + (beta ? beta * out : 0); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp new file mode 100644 index 000000000000..ff845f5bd36f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp @@ -0,0 +1,1022 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SIMPLE_REORDER_HPP +#define CPU_SIMPLE_REORDER_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "tag_traits.hpp" +#include "cpu_reorder_pd.hpp" +#include "cpu_primitive.hpp" + +#include "simple_q10n.hpp" +#include "cpu_isa_traits.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::data_type; + +using bd = block_dim_t; +using ib = inner_blk_t; + +using namespace mkldnn::impl::utils; +using math::saturate; + +template +using data_t = typename prec_traits::type; + +template +using _qz_a1b0 = qz_a1b0, data_t>; + +template +using _qz = qz, data_t>; + +namespace fmt_order { + const bool keep = true; + const bool reverse = false; + const bool any = keep; +} + +namespace spec { +struct direct_copy {}; +struct direct_copy_except_dim_0 {}; +struct reference {}; +struct conv_s8s8 {}; +} + +#define SIMPLE_REORDER_TEMPL_DECL \ + impl::data_type_t type_i, impl::format_tag_t tag_i, \ + impl::data_type_t type_o, impl::format_tag_t tag_o, bool order_keep +#define SIMPLE_REORDER_TEMPL_CALL \ + type_i, tag_i, type_o, tag_o, order_keep + +#define DECLARE_COMMON_PARAMS() \ + const memory_desc_wrapper &input_d = pd->src_md(); \ + const memory_desc_wrapper &output_d = pd->dst_md(); \ + const float alpha = pd->alpha(); MAYBE_UNUSED(alpha); \ + const float beta = pd->beta(); MAYBE_UNUSED(beta); + +/* specific reorders: common template */ +template +struct simple_reorder_impl {}; + +namespace { +inline bool simple_fmt_check(bool order_keep, impl::format_tag_t tag_i, + impl::format_tag_t tag_o, const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d) { + return input_d.matches_tag(order_keep ? tag_i : tag_o) + && output_d.matches_tag(order_keep ? tag_o : tag_i); +} +inline bool simple_attr_check(const primitive_attr_t *attr, bool many_scales_support) { + if (many_scales_support) + return true; + return IMPLICATION(attr, attr->output_scales_.mask_ == 0); +} +} + +/* specific reorders: implementation */ +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) + { + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(attr->output_scales_.mask_ + 1)); + const int oc = (input_d.dims()[tag_o == hwigo + 0]); + const int g = (tag_o == hwigo) ? (input_d.dims()[0]) : 1; + + return output_d.matches_tag(tag_o) + && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) + && (input_d.data_type() == f32 || input_d.data_type() == s8) + && output_d.data_type() == s8 + && (D_mask == 1 || D_mask == (size_t)g * oc); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + static constexpr bool w_groups = tag_o == hwigo; + + const auto &dims = input_d.dims(); + const auto &pdims = output_d.padded_dims(); + + const int G = w_groups ? dims[0] : 1; + const int OC = dims[w_groups + 0]; + const int IC = dims[w_groups + 1]; + const int H = dims[w_groups + 2]; + const int W = dims[w_groups + 3]; + + const float *scales = pd->attr()->output_scales_.scales_; + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); + + assert(output_d.extra().flags + & memory_extra_flags::compensation_conv_s8s8); + float adj_scale = + (output_d.extra().flags & memory_extra_flags::scale_adjust) + ? output_d.extra().scale_adjust : 1.f; + + size_t offset = G * pdims[w_groups + 0] * pdims[w_groups + 1] * H * W; + int32_t *cp = reinterpret_cast(output + offset); + + parallel_nd(G, OC, [&](int g, int oc) { + cp[g * OC + oc] = 0; + for (int ic = 0; ic < IC; ic++) + for (int h = 0; h < H; h++) + for (int w = 0; w < W; w++) { + auto i = input[input_d.blk_off(g, oc, ic, h, w)]; + auto &o = output[output_d.blk_off(g, oc, ic, h, w)]; + const float s = scales[(D_mask == 1) ? 0 : g * OC + oc]; + + o = qz_b0, data_t>()( + i, s * adj_scale); + cp[g * OC + oc] -= (int32_t)o; + } + cp [g * OC + oc] *= 128; + }); + return success; + } +}; + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) + { + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(attr->output_scales_.mask_ + 1)); + const bool w_groups = !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i); + const int oc = (input_d.dims()[w_groups ? 1 : 0]); + const int g = w_groups ? input_d.dims()[0] : 1; + + return input_d.matches_tag(tag_i) + && output_d.matches_tag(tag_o) + && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) + && (input_d.data_type() == f32 || input_d.data_type() == s8) + && output_d.data_type() == s8 + && (D_mask == 1 || D_mask == (size_t)g * oc); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + static constexpr bool w_groups = + !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i); + constexpr int is_1d = + utils::one_of(tag_o, gOIw4i16o4i, OIw4i16o4i); + constexpr int blksize = tag_traits::inner_blks == ib::_4b4c + ? 4 + : tag_traits::inner_blks == ib::_2c8b4c + ? 8 + : 16; + + const auto &_g_oihw_d = order_keep ? input_d : output_d; + const auto &dims = input_d.dims(); + const auto &pdims = order_keep + ? output_d.padded_dims() + : input_d.padded_dims(); + + const int G = w_groups ? dims[0] : 1; + const int OC = dims[w_groups + 0]; + const int NB_OC = pdims[w_groups + 0] / blksize; + const int IC = dims[w_groups + 1]; + const int NB_IC = pdims[w_groups + 1] / blksize; + const int H = is_1d ? 1 : dims[w_groups + 2]; + const int W = dims[w_groups + 3 - is_1d]; + + const float *scales = pd->attr()->output_scales_.scales_; + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); + + assert(output_d.extra().flags + & memory_extra_flags::compensation_conv_s8s8); + float adj_scale = + (output_d.extra().flags & memory_extra_flags::scale_adjust) + ? output_d.extra().scale_adjust : 1.f; + + auto ker = [&](const data_t *inp, data_t *out, + int32_t *c, const float *s, const int oc_block, const int ic_block) { +# define index AB_or_BC_blk_off::inner_blks> + + for (int ic = 0; ic < ic_block; ++ic) { + for (int oc = 0; oc < oc_block; ++oc) { + const auto _g_oihw_off = + oc * _g_oihw_d.blocking_desc().strides[w_groups + 0] + + ic * _g_oihw_d.blocking_desc().strides[w_groups + 1]; + out[index(oc, ic)] + = qz_b0, data_t>()( + inp[_g_oihw_off], s[oc] * adj_scale); + c[oc] -= (128 * (int32_t)(out[index(oc, ic)])); + } + } +# undef index + }; + + constexpr int i_mult = blksize; + constexpr int o_mult = 1; + + size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W; + int32_t *cp = reinterpret_cast(output + offset); + parallel_nd(G * NB_OC * blksize, [&](int i) { + cp[i] = 0; + }); + +# define wei_blk_off(md, g, o, i, h, w) \ + (is_1d ? (md).blk_off(g, o, i, w) \ + : (md).blk_off(g, o, i, h, w)) + + parallel_nd(G, NB_OC, [&](int g, int O) { + for (int I = 0; I < NB_IC; I++) + for (int h = 0; h < H; h++) + for (int w = 0; w < W; w++) { + auto i = &input[wei_blk_off( + input_d, g, i_mult * O, i_mult * I, h, w)]; + auto o = &output[wei_blk_off( + output_d, g, o_mult * O, o_mult * I, h, w)]; + const int oc_block = nstl::min(blksize, OC - O * blksize); + const int ic_block = nstl::min(blksize, IC - I * blksize); + + int _offset = (g * NB_OC + O) * blksize; + ker(i, o, (order_keep) ? &cp[_offset] : nullptr, + &scales[(D_mask == 1) ? 0 : _offset], + oc_block, ic_block); + } + }); + +# undef wei_blk_off + + return success; + } +}; + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(attr->output_scales_.mask_ + 1)); + const int oc = input_d.dims()[1]; + const int g = input_d.dims()[0]; + + return true + && order_keep + && input_d.matches_tag(tag_i) + && output_d.matches_tag(tag_o) + && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) + && (input_d.data_type() == f32 || input_d.data_type() == s8) + && output_d.data_type() == s8 + && (D_mask == 1 || D_mask == (size_t)g * oc); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + constexpr bool is_1d = tag_i == goiw; + constexpr int blksize = 16; + + const auto &dims = input_d.dims(); + const auto &pdims = output_d.padded_dims(); + const int G = dims[0]; + const int Gp = pdims[0]; + const int OC = dims[1]; + const int IC = dims[2]; + const int H = is_1d ? 1 : dims[3]; + const int W = dims[4 - is_1d]; + + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); + const float *scales = pd->attr()->output_scales_.scales_; + + assert(output_d.extra().flags + & memory_extra_flags::compensation_conv_s8s8); + float adj_scale = + (output_d.extra().flags & memory_extra_flags::scale_adjust) + ? output_d.extra().scale_adjust : 1.f; + + auto ker = [&](const data_t *inp, data_t *out, + int32_t *cp, const float *s, const int g_block) { + PRAGMA_OMP_SIMD() + for (int g = 0; g < g_block; g++) { + const auto i_off = g * input_d.blocking_desc().strides[0]; + out[g] = qz_b0, data_t>()( + inp[i_off], s[g * OC] * adj_scale); + cp[g * OC] -= 128 * (int32_t)(out[g]); + } + }; + + size_t cp_offset = output_d.size() - output_d.additional_buffer_size(); + int32_t *cp = reinterpret_cast(output + cp_offset); + parallel_nd((Gp/blksize) * OC, [&](int ib) { + PRAGMA_OMP_SIMD() + for (int i = 0; i < blksize; i++) + cp[ib * blksize + i] = 0; + }); + +# define wei_blk_off(md, g, o, i, h, w) \ + (is_1d ? (md).blk_off(g, o, i, w) : (md).blk_off(g, o, i, h, w)) + + parallel_nd(Gp/blksize, OC, [&](int gb, int O) { + for (int I = 0; I < IC; I++) { + for (int h = 0; h < H; h++) + for (int w = 0; w < W; w++) + { + const int g_block = nstl::min(G - gb * blksize, blksize); + const auto inp = &input[wei_blk_off( + input_d, gb * blksize, O, I, h, w)]; + const auto out = &output[wei_blk_off( + output_d, gb, O, I, h, w)]; + int offset = gb * blksize + O; + ker(inp, out, &cp[offset], + &scales[(D_mask == 1) ? 0 : offset], g_block); + } + } + }); + +# undef wei_blk_off + + return success; + } +}; + +/* reorders with tail support */ + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) + { + return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d) + && simple_attr_check(attr, false); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + constexpr int is_1d = tag_i == nCw8c; + constexpr int is_3d = tag_i == nCdhw8c; + constexpr int blksize_16 = 16; + constexpr int blksize_8 = 8; + constexpr int ic_mult = order_keep ? 2 : 1; + constexpr int oc_mult = order_keep ? 1 : 2; + + const auto &dims = input_d.dims(); + const auto &pdims = order_keep ? output_d.padded_dims() + : input_d.padded_dims(); + + const int C = dims[1]; + const int D = is_3d ? dims[2] : 1; + const int H = is_1d ? 1 : dims[2 + is_3d]; + const int W = dims[3 + is_3d - is_1d]; + + auto ker = [&](const data_t *i, data_t *o, + const int block_16) { + const int nb = (block_16 - 1) / blksize_8 + 1; + if (alpha == 1.0 && beta == 0.0) { + for (int b = 0; b < nb; ++b) { + const ptrdiff_t i_off = order_keep ? b : b * blksize_8; + const ptrdiff_t o_off = order_keep ? b * blksize_8 : b; + const int block_8 = nstl::min(blksize_8, + block_16 - b * blksize_8); + for (int c = 0; c < block_8; ++c) { + o[o_off + c] = _qz_a1b0()( + i[i_off + c]); + } + } + } else { + for (int b = 0; b < nb; ++b) { + const ptrdiff_t i_off = order_keep ? b : b * blksize_8; + const ptrdiff_t o_off = order_keep ? b * blksize_8 : b; + const int block_8 = nstl::min(blksize_8, + block_16 - b * blksize_8); + for (int c = 0; c < block_8; ++c) { + o[o_off + c] = _qz()(i[i_off + c], + o[o_off + c], alpha, beta); + } + } + } + }; + +# define data_blk_off(md, n, c, d, h, w) \ + ( is_1d ? (md).blk_off(n, c, w) \ + : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w)) + + parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W, + [&](int n, int nb_c, int d, int h, int w) { + auto i = &input[data_blk_off(input_d, n, ic_mult * nb_c, d, h, w)]; + auto o = &output[data_blk_off(output_d, n, oc_mult * nb_c, d, h, w)]; + const int block_16 = nstl::min(blksize_16, C - nb_c * blksize_16); + ker(i, o, block_16); + }); + +# undef data_blk_off + + return success; + } +}; + +#define PLAIN_TO_BLOCKED_IS_APPLICABLE() \ + static bool is_applicable(const memory_desc_wrapper &input_d, \ + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { \ + return simple_attr_check(attr, false) && (order_keep \ + ? output_d.matches_tag(tag_o) && input_d.is_plain() \ + : input_d.matches_tag(tag_o) && output_d.is_plain()); \ + } + +template +struct simple_reorder_impl::block_dims == bd::_A + || tag_traits::block_dims == bd::_B) + && tag_traits::ndims >= 3 + && tag_traits::ndims <= 6 + >::type> +{ + PLAIN_TO_BLOCKED_IS_APPLICABLE(); + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + const auto &flat_d = order_keep ? input_d : output_d; + const auto &block_d = order_keep ? output_d : input_d; + const auto &dims = input_d.dims(); + const auto &pdims = block_d.padded_dims(); + + constexpr int ndims = tag_traits::ndims; + constexpr int blk_idx = tag_traits::block_dims == bd::_A ? 0 : 1; + + const dim_t H0 = dims[0]; + const dim_t H1 = dims[1]; + const dim_t M0 = ndims >= 6 ? dims[ndims - 4] : 1; + const dim_t M1 = ndims >= 5 ? dims[ndims - 3] : 1; + const dim_t M2 = ndims >= 4 ? dims[ndims - 2] : 1; + const dim_t L = dims[ndims - 1]; + const dim_t l_blk_stride = block_d.blocking_desc().strides[ndims - 1]; + + constexpr int blksize = false ? 0 + : utils::one_of(tag_traits::inner_blks, ib::_4a, ib::_4b) ? 4 + : utils::one_of(tag_traits::inner_blks, ib::_8a, ib::_8b) ? 8 + : 16; + + auto ker = [&](const data_t *i, data_t *o, int block) { + if (alpha == 1.0 && beta == 0.0) { + for (int l = 0; l < L; ++l) + for (int blk = 0; blk < block; ++blk) { + const dim_t flat_off = 0 + + blk * flat_d.blocking_desc().strides[blk_idx] + + l * flat_d.blocking_desc().strides[ndims - 1]; + if (order_keep) { + o[l * l_blk_stride + blk] = _qz_a1b0()( + i[flat_off]); + } else { + o[flat_off] = _qz_a1b0()( + i[l * l_blk_stride + blk]); + } + } + } else { + for (int l = 0; l < L; ++l) + for (int blk = 0; blk < block; ++blk) { + const dim_t flat_off = 0 + + blk * flat_d.blocking_desc().strides[blk_idx] + + l * flat_d.blocking_desc().strides[ndims - 1]; + if (order_keep) { + o[l * l_blk_stride + blk] = _qz()( + i[flat_off], o[l * blksize + blk], + alpha, beta); + } else { + o[flat_off] = _qz()( + i[l * l_blk_stride + blk], o[flat_off], + alpha, beta); + } + } + } + }; + +# define off(md, h0, h1, m0, m1, m2) \ + (ndims >= 6 ? (md).blk_off(h0, h1, m0, m1, m2) \ + : ndims >= 5 ? (md).blk_off(h0, h1, m1, m2) \ + : ndims >= 4 ? (md).blk_off(h0, h1, m2) \ + : /* ndims >= 3 ? */ (md).blk_off(h0, h1)) + + constexpr int i_mult = order_keep ? blksize : 1; + constexpr int o_mult = order_keep ? 1 : blksize; + + if (blk_idx == 0) { + const dim_t BH0 = pdims[0] / blksize; + parallel_nd(BH0, H1, M0, M1, M2, + [&](dim_t bh0, dim_t h1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off(input_d, bh0 * i_mult, h1, m0, m1, m2)]; + auto o = &output[off(output_d, bh0 * o_mult, h1, m0, m1, m2)]; + const int block = nstl::min(blksize, H0 - bh0 * blksize); + ker(i, o, block); + }); + } else if (blk_idx == 1) { + const dim_t BH1 = pdims[1] / blksize; + parallel_nd(H0, BH1, M0, M1, M2, + [&](dim_t h0, dim_t bh1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off(input_d, h0, bh1 * i_mult, m0, m1, m2)]; + auto o = &output[off(output_d, h0, bh1 * o_mult, m0, m1, m2)]; + const int block = nstl::min(blksize, H1 - bh1 * blksize); + ker(i, o, block); + }); + } else { + assert(!"unimplemented"); + } + +# undef off + + return success; + } +}; + +template +struct simple_reorder_impl::block_dims == bd::_AB + || tag_traits::block_dims == bd::_BC) + && IMPLICATION(tag_traits::block_dims == bd::_AB, + tag_traits::ndims >= 3 && tag_traits::ndims <= 5) + && IMPLICATION(tag_traits::block_dims == bd::_BC, + tag_traits::ndims >= 4 && tag_traits::ndims <= 6) + >::type> +{ + PLAIN_TO_BLOCKED_IS_APPLICABLE(); + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + const auto &flat_d = order_keep ? input_d : output_d; + const auto &dims = input_d.dims(); + const auto &pdims = order_keep + ? output_d.padded_dims() + : input_d.padded_dims(); + + constexpr int ndims = tag_traits::ndims; + + static constexpr bool with_g = tag_traits::block_dims == bd::_BC; + const dim_t G = with_g ? dims[0] : 1; + + const dim_t H0 = dims[0 + with_g]; + const dim_t H1 = dims[1 + with_g]; + + const dim_t M0 = ndims >= 5 + with_g ? dims[ndims - 3] : 1; + const dim_t M1 = ndims >= 4 + with_g ? dims[ndims - 2] : 1; + const dim_t M2 = ndims >= 3 + with_g ? dims[ndims - 1] : 1; + + constexpr int blksize_0 = false ? 0 + : utils::one_of(tag_traits::inner_blks, + ib::_4b4a, ib::_4b4c, ib::_4c4b) + ? 4 + : utils::one_of(tag_traits::inner_blks, + ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c) + ? 8 + : utils::one_of(tag_traits::inner_blks, + ib::_16a16b, ib::_16a4b, ib::_16b16a, ib::_16b4c, + ib::_16b16c, ib::_16c16b, ib::_8a16b2a, ib::_4b16a4b, + ib::_8b16a2b, ib::_8b16c2b, ib::_4c16b4c, ib::_8c16b2c) + ? 16 : INT_MIN; + + constexpr int blksize_1 = utils::one_of(tag_traits::inner_blks, + ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c) + ? 8 + : utils::one_of(tag_traits::inner_blks, + ib::_16a16b, ib::_16b16a, ib::_16b16c, ib::_16c16b, + ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b, + ib::_4c16b4c, ib::_8c16b2c) + ? 16 + : utils::one_of(tag_traits::inner_blks, + ib::_4b4a, ib::_4b4c, ib::_4c4b, + ib::_16a4b, ib::_16b4c) + ? 4 + : INT_MIN; + + const dim_t NB_H0 = pdims[0 + with_g] / blksize_0; + const dim_t NB_H1 = pdims[1 + with_g] / blksize_1; + + auto ker = [&](const data_t *i, data_t *o, + const int block_h0, const int block_h1) { +# define blk_off AB_or_BC_blk_off::inner_blks> + + if (alpha == 1.0 && beta == 0.0) { + for (int h0 = 0; h0 < block_h0; ++h0) + for (int h1 = 0; h1 < block_h1; ++h1) { + const dim_t flat_off = 0 + + h0 * flat_d.blocking_desc().strides[with_g + 0] + + h1 * flat_d.blocking_desc().strides[with_g + 1]; + if (order_keep) { + o[blk_off(h0, h1)] = _qz_a1b0()( + i[flat_off]); + } else { + o[flat_off] = _qz_a1b0()( + i[blk_off(h0, h1)]); + } + } + } else { + for (int h0 = 0; h0 < block_h0; ++h0) + for (int h1 = 0; h1 < block_h1; ++h1) { + const dim_t flat_off = 0 + + h0 * flat_d.blocking_desc().strides[with_g + 0] + + h1 * flat_d.blocking_desc().strides[with_g + 1]; + if (order_keep) { + o[blk_off(h0, h1)] = _qz()(i[flat_off], + o[blk_off(h0, h1)], alpha, beta); + } else { + o[flat_off] = _qz()(i[blk_off(h0, h1)], + o[flat_off], alpha, beta); + } + } + } + +# undef blk_off + }; + + constexpr int i_mult_0 = order_keep ? blksize_0 : 1; + constexpr int o_mult_0 = order_keep ? 1 : blksize_0; + + constexpr int i_mult_1 = order_keep ? blksize_1 : 1; + constexpr int o_mult_1 = order_keep ? 1 : blksize_1; + +# define off(md, g, h0, h1, m0, m1, m2) \ + (ndims >= 5 + with_g ? (md).blk_off(g, h0, h1, m0, m1, m2) \ + : ndims >= 4 + with_g ? (md).blk_off(g, h0, h1, m1, m2) \ + : /* ndims >= 3 + with_g ? */ (md).blk_off(g, h0, h1, m2)) + + parallel_nd(G, NB_H0, NB_H1, M0, M1, M2, + [&](dim_t g, dim_t nb_h0, dim_t nb_h1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off(input_d, + g, i_mult_0 * nb_h0, i_mult_1 * nb_h1, m0, m1, m2)]; + auto o = &output[off(output_d, + g, o_mult_0 * nb_h0, o_mult_1 * nb_h1, m0, m1, m2)]; + const int block_h0 = nstl::min(blksize_0, H0 - nb_h0 * blksize_0); + const int block_h1 = nstl::min(blksize_1, H1 - nb_h1 * blksize_1); + ker(i, o, block_h0, block_h1); + }); + +# undef off + + return success; + } +}; + +/* generic and direct-copy reorders */ + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + /* FIXME: is the formula correct? */ + return input_d.similar_to(output_d, true, false, 0) + && input_d.is_dense() && output_d.is_dense() + && simple_attr_check(attr, false); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + assert(input_d.is_dense()); + + input += input_d.blk_off(0); + output += output_d.blk_off(0); + + const size_t nelems = input_d.nelems(); + + constexpr int block_size = 16; + const auto num_blocks = nelems / block_size; + const auto rem_elems = nelems % block_size; + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(num_blocks, nthr, ithr, start, end); + start = start * block_size; + end = end * block_size; + + if (alpha == 1.0 && beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz_a1b0, data_t>() + (input[e]); + } + } else if (alpha == 1.0) { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz_a1, data_t>() + (input[e], output[e], beta); + } + } else if (beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz_b0, data_t>() + (input[e], alpha); + } + } else { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz, data_t>() + (input[e], output[e], alpha, beta); + } + } + + if (rem_elems != 0 && ithr == nthr - 1){ + if (alpha == 1.0 && beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz_a1b0, + data_t>()(input[e]); + } + } else if (alpha == 1.0) { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz_a1, + data_t>()(input[e], output[e], beta); + } + } else if (beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz_b0, + data_t>()(input[e], alpha); + } + } else { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz, data_t>() + (input[e], output[e], alpha, beta); + } + } + } + }); + return success; + } +}; + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) { + return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d); + }; + /* FIXME: is the formula correct? */ + return input_d.similar_to(output_d, true, false, 1) + && is_dense_no_0(input_d) && is_dense_no_0(output_d) + && simple_attr_check(attr, false); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + input += input_d.blk_off(0); + output += output_d.blk_off(0); + + const int N = input_d.dims()[0]; + const dim_t is = input_d.blocking_desc().strides[0]; + const dim_t os = output_d.blocking_desc().strides[0]; + const dim_t nelems_no_d0 = nelems_no_dim_0(input_d); + const dim_t work_amount = N * nelems_no_d0; + + if (alpha == 1.0 && beta == 0.0) { + parallel(0, [&](const int ithr, const int nthr) { + dim_t n{0}, dim1_s{0}; + dim_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, N, dim1_s, nelems_no_d0); + while(start < end) { + dim_t work_rem = end - start; + dim_t dim1_e = dim1_s + work_rem > nelems_no_d0 + ? nelems_no_d0 : dim1_s + work_rem; + PRAGMA_OMP_SIMD() + for (dim_t e = dim1_s; e < dim1_e; ++e) { + output[os * n + e] = _qz_a1b0()( + input[is * n + e]); + } + nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0); + } + }); + } else { + parallel(0, [&](const int ithr, const int nthr) { + dim_t n{0}, dim1_s{0}; + dim_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, N, dim1_s, nelems_no_d0); + while(start < end) { + dim_t work_rem = end - start; + dim_t dim1_e = + dim1_s + work_rem > nelems_no_d0 ? nelems_no_d0 + : dim1_s + work_rem; + PRAGMA_OMP_SIMD() + for (dim_t e = dim1_s; e < dim1_e; ++e){ + output[os * n + e] = _qz()( + input[is * n + e], output[os * n + e], alpha, + beta); + } + nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0); + } + }); + } + + return success; + } + +private: + static dim_t nelems_no_dim_0(const memory_desc_wrapper &data_d) { + const int ndims = data_d.ndims(); + if (ndims <= 1) return 1; + return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1); + } + + static dim_t _size_no_dim_0(const memory_desc_wrapper &data_d) { + dims_t blocks; + data_d.compute_blocks(blocks); + + const auto &blk = data_d.blocking_desc(); + + dim_t blk_size = 1; + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) + blk_size *= blk.inner_blks[iblk]; + + dim_t max_size = blk_size; + for (int d = 1; d < data_d.ndims(); ++d) { + max_size = nstl::max(max_size, + data_d.padded_dims()[d] / blocks[d] * blk.strides[d]); + } + + return max_size; + } +}; + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + /* supported smask: 0x0...011..10...0, + * i.e. 1 should be contiguous */ + int smask = attr ? attr->output_scales_.mask_ : 0; + for (; smask > 0 && !(smask & 0x1); smask >>= 1); + for (; smask > 0 && smask & 0x1; smask >>= 1); + return true + && input_d.is_blocking_desc() + && output_d.is_blocking_desc() + && !output_d.is_additional_buffer() + && !input_d.is_additional_buffer() + && smask == 0; + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + const size_t nelems = input_d.nelems(); + + int ndims_start = 0, ndims_mask = 0; + int smask = pd->attr()->output_scales_.mask_; + for (; smask > 0 && !(smask & 0x1); smask >>= 1) ++ndims_start; + for (; smask > 0 && smask & 0x1; smask >>= 1) ++ndims_mask; + assert(smask == 0); + + const ptrdiff_t D_start + = utils::array_product(input_d.dims(), ndims_start); + const ptrdiff_t D_mask + = utils::array_product(input_d.dims() + ndims_start, ndims_mask); + const ptrdiff_t D_rest = nelems / D_start / D_mask; + + const float *scales = pd->attr()->output_scales_.scales_; + + parallel_nd(D_start, D_mask, D_rest, + [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) { + const float scale = scales[dm]; + + const size_t e = (ds * D_mask + dm) * D_rest + dr; + const auto &i = input[input_d.off_l(e)]; + auto &o = output[output_d.off_l(e)]; + + o = _qz()(i, o, scale, beta); + }); + + return success; + } +}; + + +/* high level class declaration */ + +template +struct simple_reorder_t: public cpu_primitive_t { + struct pd_t: public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("simple:any", simple_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + bool args_ok = true + && src_md->data_type == type_i + && dst_md->data_type == type_o + && simple_reorder_impl:: + is_applicable(src_md, dst_md, attr); + if (!args_ok) + return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return status::out_of_memory; + if (_pd->init() != status::success) { + delete _pd; + return status::unimplemented; + } + return safe_ptr_assign(*reorder_pd, _pd); + } + }; + + simple_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(data_t *, MKLDNN_ARG_TO); + simple_reorder_impl::execute( + pd(), input, output); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +#undef SIMPLE_REORDER_TEMPL_DECL +#undef SIMPLE_REORDER_TEMPL_CALL + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp new file mode 100644 index 000000000000..f0947573a932 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_thread.hpp" + +#include "simple_sum.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +status_t simple_sum_t::execute(const exec_ctx_t &ctx) const { + auto output = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper o_d(pd()->dst_md()); + output += o_d.blk_off(0); + + const int num_arrs = pd()->n_inputs(); + const data_t *input_ptrs[max_num_arrs]; + const size_t nelems = o_d.nelems(); + + for (int a = 0; a < num_arrs; ++a) { + const memory_desc_wrapper i_d(pd()->src_md(a)); + input_ptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a) + + i_d.blk_off(0); + } + + const size_t block_size = 16 * 1024 / sizeof(data_type); + const size_t blocks_number = nelems / block_size; + const size_t tail = nelems % block_size; + + const auto scales = pd()->scales(); + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(blocks_number, nthr, ithr, start, end); + + for (size_t nb = start; nb < end; ++nb) { + size_t start_e = nb * block_size; + size_t end_e = start_e + block_size; + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = data_t(scales[0] * input_ptrs[0][e]); + } + for (int a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += data_t(scales[a] * input_ptrs[a][e]); + } + } + } + + if (tail != 0 && ithr == nthr - 1) { + size_t start_e = nelems - tail; + size_t end_e = nelems; + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = data_t(scales[0] * input_ptrs[0][e]); + } + for (int a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += data_t(scales[a] * input_ptrs[a][e]); + } + } + } + }); + + return status::success; +} + +template struct simple_sum_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp new file mode 100644 index 000000000000..2a0187a1845f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp @@ -0,0 +1,74 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SIMPLE_SUM_HPP +#define SIMPLE_SUM_HPP + +#include "cpu_sum_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct simple_sum_t: public cpu_primitive_t { + struct pd_t: public cpu_sum_pd_t { + using cpu_sum_pd_t::cpu_sum_pd_t; + + DECLARE_SUM_PD_T("simple:any", simple_sum_t); + + status_t init() { + const int n = n_inputs(); + + bool ok = true + && cpu_sum_pd_t::init() == status::success + && n <= max_num_arrs; + if (!ok) return status::unimplemented; + + const memory_desc_wrapper o_d(dst_md()); + ok = ok + && o_d.data_type() == data_type + && o_d.is_dense(); + if (!ok) return status::unimplemented; + + for (int i = 0; i < n; ++i) { + const memory_desc_wrapper i_d(src_md(i)); + if (i_d != o_d) return status::unimplemented; + } + + return status::success; + } + }; + + simple_sum_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override; + + enum {max_num_arrs = 16 }; + typedef typename prec_traits::type data_t; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp new file mode 100644 index 000000000000..c2082d7d62f4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp @@ -0,0 +1,376 @@ +/******************************************************************************* + * Copyright 2017-2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#ifndef CPU_WINO_REORDER_HPP +#define CPU_WINO_REORDER_HPP + +#include "mkldnn_thread.hpp" + +#include "simple_q10n.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct wino_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("wino_reorder", wino_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && id.matches_tag(utils::pick(id.ndims() - 4, + format_tag::oihw, format_tag::goihw)) + && od.format_kind() == format_kind::wino + && utils::one_of(od.wino_desc().wino_format, + mkldnn_wino_wei_aaOIoi, mkldnn_wino_wei_aaOio, + mkldnn_wino_wei_aaOBiOo, mkldnn_wino_wei_OBaaIBOIio); + if (!args_ok) return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return status::out_of_memory; + if (_pd->init() != status::success) { + delete _pd; + return status::unimplemented; + } + return safe_ptr_assign(*reorder_pd, _pd); + } + + status_t init() { + status_t status = cpu_reorder_pd_t::init(); + if (status != status::success) return status; + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + auto &o = memory_desc_wrapper(dst_md()).wino_desc(); + size_t transform_space_size = (size_t)o.r * o.alpha * o.oc_block; + size_t plain_size = (size_t)o.alpha * o.alpha * o.oc * o.ic; + + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_reorder_wino_transform_space, + sizeof(in_data_t) * transform_space_size); + scratchpad.book(key_reorder_wino_plain, + sizeof(out_data_t) * plain_size); + } + }; + +private: + typedef typename prec_traits::type in_data_t; + typedef typename prec_traits::type out_data_t; + const int unsign_val_in_wino_domain_ = 5; + + wino_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + + r_ = dst_d.wino_desc().r; + w_alpha_ = dst_d.wino_desc().alpha; + wino_format_ = dst_d.wino_desc().wino_format; + + const auto &in_dims = src_d.dims(); + int groups; + int groups_offset; + if (src_d.ndims() == 5) { + groups = in_dims[0]; + groups_offset = 1; + } else { + groups = 1; + groups_offset = 0; + } + assert(groups == 1); // groups are not supported now + MAYBE_UNUSED(groups); + + or_oc_ = in_dims[0 + groups_offset]; + or_ic_ = in_dims[1 + groups_offset]; + kh_ = in_dims[2 + groups_offset]; + kw_ = in_dims[3 + groups_offset]; + + oc_ = dst_d.wino_desc().oc; + ic_ = dst_d.wino_desc().ic; + oc_block_ = dst_d.wino_desc().oc_block; + ic_block_ = dst_d.wino_desc().ic_block; + assert(oc_ % oc_block_ == 0 && ic_ % ic_block_ == 0); + nb_oc_ = oc_ / oc_block_; + nb_ic_ = ic_ / ic_block_; + ic2_block_ = 1; + if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) + ic2_block_ = dst_d.wino_desc().ic2_block; + oc2_block_ = dst_d.wino_desc().oc2_block; + assert(nb_ic_ % ic2_block_ == 0 && nb_oc_ % oc2_block_ == 0); + + adj_scale_ = dst_d.wino_desc().adj_scale; + + size_wino_wei_ = w_alpha_ * w_alpha_ * oc_ * ic_; + size_wspace_ = r_ * w_alpha_ * oc_block_; + } + + void transform(out_data_t *__restrict tmp_wei, + const in_data_t *__restrict input, + in_data_t *__restrict wspace) const { + const memory_desc_wrapper src_d(pd()->src_md()); + + const int smask = pd()->attr()->output_scales_.mask_; + const int ndims_mask = math::ilog2q(smask + 1); + const size_t D_mask = utils::array_product(src_d.dims(), ndims_mask); + const float *__restrict scales = pd()->attr()->output_scales_.scales_; + assert(D_mask == 1 || D_mask == (size_t)oc_); + + /* transform weights to winograd domain */ + const float G_2x2_3x3[4][3] = { { 1.0, 0.0, 0.0 }, { 0.5, 0.5, 0.5 }, + { 0.5, -0.5, 0.5 }, { 0.0, 0.0, 1.0 } }; + + const float G_4x4_3x3[6][3] = { { 1.13777777777778f, 0.f, 0.f }, + { -0.688403361344538f, -0.430252100840336f, -0.26890756302521f }, + { -0.688403361344538f, 0.430252100840336f, -0.26890756302521f }, + { 0.119514472455649f, 0.179271708683473f, 0.26890756302521f }, + { 0.119514472455649f, -0.179271708683473f, 0.26890756302521f }, + { 0.f, 0.f, 1.f } }; + + float *__restrict g; + if (utils::one_of(wino_format_, mkldnn_wino_wei_aaOIoi, + mkldnn_wino_wei_aaOio, mkldnn_wino_wei_aaOBiOo)) + g = (float *)G_2x2_3x3; + else if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) + g = (float *)G_4x4_3x3; + else { + assert("Unknown winograd weights target layout"); + return; + } + + int Z = oc_ * ic_; + assert(r_ == kh_ && r_ == kw_); + + for (int iic = 0; iic < ic_; iic++) { + for (int ob = 0; ob < nb_oc_; ob++) { + const in_data_t *__restrict _inp + = input + (ob * oc_block_ * or_ic_ + iic) * kh_ * kw_; + out_data_t *__restrict _out + = tmp_wei + (iic * nb_oc_ + ob) * oc_block_; + + for_nd(0, 1, size_wspace_, [&](int i) { wspace[i] = 0.f; }); + + for_nd(0, 1, r_, w_alpha_, oc_block_, + [&](int ih, int j, int ioc) { + for (int iw = 0; iw < r_; ++iw) { + int inp_oc = ob * oc_block_ + ioc; + int inp_ic = iic; + in_data_t inp_v = (inp_ic < or_ic_ && inp_oc < or_oc_) + ? _inp[ioc * or_ic_ * kh_ * kw_ + ih * kw_ + iw] + : 0.f; + wspace[(ih * w_alpha_ + j) * oc_block_ + ioc] + += inp_v * g[j * r_ + iw]; + } + }); + + for_nd(0, 1, w_alpha_, w_alpha_, oc_block_, + [&](int i, int j, int ioc) { + float t = 0; + for (int k = 0; k < r_; ++k) + t += g[i * r_ + k] + * wspace[(k * w_alpha_ + j) * oc_block_ + ioc]; + if (type_o == data_type::s8) { + const float scale = (D_mask == 1) + ? scales[0] + : scales[ob * oc_block_ + ioc]; + _out[(i * w_alpha_ + j) * Z + ioc] + = qz_b0()( + (in_data_t)t, scale * adj_scale_); + } else { + _out[(i * w_alpha_ + j) * Z + ioc] = (out_data_t)t; + } + }); + }} + } + + void reorder_to_aaOIoi(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int32_t *__restrict dst_bias = nullptr; + if (type_o == data_type::s8) { + const auto bias_shift = sizeof(out_data_t) * size_wino_wei_; + const size_t bias_size = w_alpha_ * w_alpha_ * oc_; + + dst_bias = (int32_t *)(output + bias_shift); + utils::array_set((int32_t *)dst_bias, 0, bias_size); + } + int index = 0; + for (int u_h = 0; u_h < w_alpha_; u_h++) { + for (int u_w = 0; u_w < w_alpha_; u_w++) { + for_nd(0, 1, nb_oc_, oc_block_, [&](int ob, int o) { + int u_h_shift = u_h * w_alpha_ * ic_ * oc_; + int u_w_shift = u_w * ic_ * oc_; + int u_h_shift_b = u_h * w_alpha_ * oc_; + int u_w_shift_b = u_w * oc_; + int oc_block_shift = ob * oc_block_ * ic_ + o * ic_block_; + for (int ib = 0; ib < nb_ic_; ib++) { + for (int i = 0; i < ic_block_; i++) { + int _i = ib * ic_block_; + int _o = ob * oc_block_; + int ic_shift = (_i + i) * oc_; + int oc_shift = (_o + o); + int ic_block_shift = ib * oc_block_ * ic_block_ + i; + int src_offset = + u_h_shift + u_w_shift + ic_shift + oc_shift; + int dst_offset = u_h_shift + u_w_shift + oc_block_shift + + ic_block_shift; + + output[dst_offset] = tmp_wei[src_offset]; + if (type_o == data_type::s8) { + int bias_offset = u_h_shift_b + u_w_shift_b + oc_shift; + if (index != unsign_val_in_wino_domain_) + dst_bias[bias_offset] + -= (128 * (int32_t)output[dst_offset]); + else + dst_bias[bias_offset] = 0; + } + }} + }); + index++; + }} + } + + void reorder_to_aaOio(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + for_nd(0, 1, w_alpha_, w_alpha_, nb_oc_, + [&](int u_h, int u_w, int ob) { + for (int ib = 0; ib < nb_ic_; ib++) { + for (int i = 0; i < ic_block_; i++) { + for (int o = 0; o < oc_block_; o++) { + int src_offset = u_h * w_alpha_ * ic_ * oc_ + u_w * ic_ * oc_ + + (ib * ic_block_ + i) * oc_ + (ob * oc_block_ + o); + + int dst_offset + = u_h * w_alpha_ * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ + + u_w * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ + + ob * nb_ic_ * ic_block_ * oc_block_ + + ib * ic_block_ * oc_block_ + i * oc_block_ + o; + output[dst_offset] = tmp_wei[src_offset]; + }}} + }); + } + + void reorder_to_aaOBiOo(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int oc_chunks = nb_oc_ / oc2_block_; + + for_nd(0, 1, w_alpha_, w_alpha_, oc_chunks, + [&](int u_h, int u_w, int occ) { + for (int ib = 0; ib < nb_ic_; ib++) { + out_data_t *__restrict wei_ptr = output + + (((u_h * w_alpha_ + u_w) * oc_chunks + occ) * nb_ic_ + ib) + * oc2_block_ * ic_block_ * oc_block_; + int wei_offset = 0; + for (int i = 0; i < ic_block_; i++) { + for (int ob2 = 0; ob2 < oc2_block_; ob2++) { + for (int o = 0; o < oc_block_; o++) { + int icp = ib * ic_block_ + i; + int ocp = + occ * oc2_block_ * oc_block_ + ob2 * oc_block_ + o; + + int src_offset = u_h * w_alpha_ * ic_ * oc_ + + u_w * ic_ * oc_ + icp * oc_ + ocp; + wei_ptr[wei_offset + o] = tmp_wei[src_offset]; + } + wei_offset += oc_block_; + }} + } + }); + } + + void reorder_to_OBaaIBOIio(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int ic_chunks = nb_ic_ / ic2_block_; + int oc_chunks = nb_oc_ / oc2_block_; + + for_nd(0, 1, oc_chunks, w_alpha_, w_alpha_, + [&](int occ, int u_h, int u_w) { + for (int icc = 0; icc < ic_chunks; icc++) { + for (int ob = 0; ob < oc2_block_; ob++) { + int ocp = (occ * oc2_block_ + ob) * oc_block_; + for (int ib = 0; ib < ic2_block_; ib++) { + for (int i = 0; i < ic_block_; i++) { + int icp = (icc * ic2_block_ + ib) * ic_block_ + i; + + int src_offset = u_h * w_alpha_ * ic_ * oc_ + + u_w * ic_ * oc_ + icp * oc_ + ocp; + int wei_offset + = ((((((occ * w_alpha_ + u_h) * w_alpha_ + u_w) + * ic_chunks + icc) * oc2_block_ + ob) * ic2_block_ + + ib) * ic_block_ + i) * oc_block_; + for (int o = 0; o < oc_block_; o++) + output[wei_offset + o] = tmp_wei[src_offset + o]; + }} + }} + }); + } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); + + auto wspace = (in_data_t *__restrict)scratchpad(ctx).template get( + memory_tracking::names::key_reorder_wino_transform_space); + auto tmp_wei = (out_data_t *__restrict)scratchpad(ctx).template get( + memory_tracking::names::key_reorder_wino_plain); + + transform(tmp_wei, input, wspace); + + /* reorder to winograd domain */ + switch (wino_format_) { + case mkldnn_wino_wei_aaOIoi: + reorder_to_aaOIoi(output, tmp_wei); break; + case mkldnn_wino_wei_aaOio: + reorder_to_aaOio(output, tmp_wei); break; + case mkldnn_wino_wei_aaOBiOo: + reorder_to_aaOBiOo(output, tmp_wei); break; + case mkldnn_wino_wei_OBaaIBOIio: + reorder_to_OBaaIBOIio(output, tmp_wei); break; + default: assert("Unknown wino format"); break; + } + + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + int r_, w_alpha_; + int ic_, oc_, or_ic_, or_oc_, kh_, kw_; + int oc_block_, ic_block_, oc2_block_, ic2_block_; + float adj_scale_; + int nb_oc_, nb_ic_; + mkldnn_wino_memory_format_t wino_format_; + int size_wino_wei_; + int size_wspace_; +}; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT new file mode 100644 index 000000000000..66b6ea55d000 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT @@ -0,0 +1,47 @@ + +Copyright (c) 2007 MITSUNARI Shigeo +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. +Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. +Neither the name of the copyright owner nor the names of its contributors may +be used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +THE POSSIBILITY OF SUCH DAMAGE. +----------------------------------------------------------------------------- +ソースコード形式かバイナリ形式か、変更するかしないかを問わず、以下の条件を満た +す場合に限り、再頒布および使用が許可されます。 + +ソースコードを再頒布する場合、上記の著作権表示、本条件一覧、および下記免責条項 +を含めること。 +バイナリ形式で再頒布する場合、頒布物に付属のドキュメント等の資料に、上記の著作 +権表示、本条件一覧、および下記免責条項を含めること。 +書面による特別の許可なしに、本ソフトウェアから派生した製品の宣伝または販売促進 +に、著作権者の名前またはコントリビューターの名前を使用してはならない。 +本ソフトウェアは、著作権者およびコントリビューターによって「現状のまま」提供さ +れており、明示黙示を問わず、商業的な使用可能性、および特定の目的に対する適合性 +に関する暗黙の保証も含め、またそれに限定されない、いかなる保証もありません。 +著作権者もコントリビューターも、事由のいかんを問わず、 損害発生の原因いかんを +問わず、かつ責任の根拠が契約であるか厳格責任であるか(過失その他の)不法行為で +あるかを問わず、仮にそのような損害が発生する可能性を知らされていたとしても、 +本ソフトウェアの使用によって発生した(代替品または代用サービスの調達、使用の +喪失、データの喪失、利益の喪失、業務の中断も含め、またそれに限定されない)直接 +損害、間接損害、偶発的な損害、特別損害、懲罰的損害、または結果損害について、 +一切責任を負わないものとします。 diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h new file mode 100644 index 000000000000..cf5771332fc4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h @@ -0,0 +1,2658 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +#pragma once +#ifndef XBYAK_XBYAK_H_ +#define XBYAK_XBYAK_H_ +/*! + @file xbyak.h + @brief Xbyak ; JIT assembler for x86(IA32)/x64 by C++ + @author herumi + @url https://github.com/herumi/xbyak + @note modified new BSD license + http://opensource.org/licenses/BSD-3-Clause +*/ +#ifndef XBYAK_NO_OP_NAMES + #if not +0 // trick to detect whether 'not' is operator or not + #error "use -fno-operator-names option if you want to use and(), or(), xor(), not() as function names, Or define XBYAK_NO_OP_NAMES and use and_(), or_(), xor_(), not_()." + #endif +#endif + +#include // for debug print +#include +#include +#include +#include +#ifndef NDEBUG +#include +#endif + +// #define XBYAK_DISABLE_AVX512 + +//#define XBYAK_USE_MMAP_ALLOCATOR +#if !defined(__GNUC__) || defined(__MINGW32__) + #undef XBYAK_USE_MMAP_ALLOCATOR +#endif + +#ifdef __GNUC__ + #define XBYAK_GNUC_PREREQ(major, minor) ((__GNUC__) * 100 + (__GNUC_MINOR__) >= (major) * 100 + (minor)) +#else + #define XBYAK_GNUC_PREREQ(major, minor) 0 +#endif + +// This covers -std=(gnu|c)++(0x|11|1y), -stdlib=libc++, and modern Microsoft. +#if ((defined(_MSC_VER) && (_MSC_VER >= 1600)) || defined(_LIBCPP_VERSION) ||\ + ((__cplusplus >= 201103) || defined(__GXX_EXPERIMENTAL_CXX0X__))) + #include + #define XBYAK_STD_UNORDERED_SET std::unordered_set + #include + #define XBYAK_STD_UNORDERED_MAP std::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::unordered_multimap + +/* + Clang/llvm-gcc and ICC-EDG in 'GCC-mode' always claim to be GCC 4.2, using + libstdcxx 20070719 (from GCC 4.2.1, the last GPL 2 version). +*/ +#elif XBYAK_GNUC_PREREQ(4, 5) || (XBYAK_GNUC_PREREQ(4, 2) && __GLIBCXX__ >= 20070719) || defined(__INTEL_COMPILER) || defined(__llvm__) + #include + #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set + #include + #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap + +#elif defined(_MSC_VER) && (_MSC_VER >= 1500) && (_MSC_VER < 1600) + #include + #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set + #include + #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap + +#else + #include + #define XBYAK_STD_UNORDERED_SET std::set + #include + #define XBYAK_STD_UNORDERED_MAP std::map + #define XBYAK_STD_UNORDERED_MULTIMAP std::multimap +#endif +#ifdef _WIN32 + #include + #include + #include +#elif defined(__GNUC__) + #include + #include + #include +#endif +#if !defined(_MSC_VER) || (_MSC_VER >= 1600) + #include +#endif + +#if defined(_WIN64) || defined(__MINGW64__) || (defined(__CYGWIN__) && defined(__x86_64__)) + #define XBYAK64_WIN +#elif defined(__x86_64__) + #define XBYAK64_GCC +#endif +#if !defined(XBYAK64) && !defined(XBYAK32) + #if defined(XBYAK64_GCC) || defined(XBYAK64_WIN) + #define XBYAK64 + #else + #define XBYAK32 + #endif +#endif + +#if (__cplusplus >= 201103) || (_MSC_VER >= 1800) + #define XBYAK_VARIADIC_TEMPLATE +#endif + +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4514) /* remove inline function */ + #pragma warning(disable : 4786) /* identifier is too long */ + #pragma warning(disable : 4503) /* name is too long */ + #pragma warning(disable : 4127) /* constant expresison */ +#endif + +namespace Xbyak { + +enum { + DEFAULT_MAX_CODE_SIZE = 4096, + VERSION = 0x5760 /* 0xABCD = A.BC(D) */ +}; + +#ifndef MIE_INTEGER_TYPE_DEFINED +#define MIE_INTEGER_TYPE_DEFINED +#ifdef _MSC_VER + typedef unsigned __int64 uint64; + typedef __int64 sint64; +#else + typedef uint64_t uint64; + typedef int64_t sint64; +#endif +typedef unsigned int uint32; +typedef unsigned short uint16; +typedef unsigned char uint8; +#endif + +#ifndef MIE_ALIGN + #ifdef _MSC_VER + #define MIE_ALIGN(x) __declspec(align(x)) + #else + #define MIE_ALIGN(x) __attribute__((aligned(x))) + #endif +#endif +#ifndef MIE_PACK // for shufps + #define MIE_PACK(x, y, z, w) ((x) * 64 + (y) * 16 + (z) * 4 + (w)) +#endif + +enum { + ERR_NONE = 0, + ERR_BAD_ADDRESSING, + ERR_CODE_IS_TOO_BIG, + ERR_BAD_SCALE, + ERR_ESP_CANT_BE_INDEX, + ERR_BAD_COMBINATION, + ERR_BAD_SIZE_OF_REGISTER, + ERR_IMM_IS_TOO_BIG, + ERR_BAD_ALIGN, + ERR_LABEL_IS_REDEFINED, + ERR_LABEL_IS_TOO_FAR, + ERR_LABEL_IS_NOT_FOUND, + ERR_CODE_ISNOT_COPYABLE, + ERR_BAD_PARAMETER, + ERR_CANT_PROTECT, + ERR_CANT_USE_64BIT_DISP, + ERR_OFFSET_IS_TOO_BIG, + ERR_MEM_SIZE_IS_NOT_SPECIFIED, + ERR_BAD_MEM_SIZE, + ERR_BAD_ST_COMBINATION, + ERR_OVER_LOCAL_LABEL, // not used + ERR_UNDER_LOCAL_LABEL, + ERR_CANT_ALLOC, + ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW, + ERR_BAD_PROTECT_MODE, + ERR_BAD_PNUM, + ERR_BAD_TNUM, + ERR_BAD_VSIB_ADDRESSING, + ERR_CANT_CONVERT, + ERR_LABEL_ISNOT_SET_BY_L, + ERR_LABEL_IS_ALREADY_SET_BY_L, + ERR_BAD_LABEL_STR, + ERR_MUNMAP, + ERR_OPMASK_IS_ALREADY_SET, + ERR_ROUNDING_IS_ALREADY_SET, + ERR_K0_IS_INVALID, + ERR_EVEX_IS_INVALID, + ERR_SAE_IS_INVALID, + ERR_ER_IS_INVALID, + ERR_INVALID_BROADCAST, + ERR_INVALID_OPMASK_WITH_MEMORY, + ERR_INVALID_ZERO, + ERR_INVALID_RIP_IN_AUTO_GROW, + ERR_INVALID_MIB_ADDRESS, + ERR_INTERNAL, + ERR_X2APIC_IS_NOT_SUPPORTED +}; + +class Error : public std::exception { + int err_; +public: + explicit Error(int err) : err_(err) + { + if (err_ < 0 || err_ > ERR_INTERNAL) { + fprintf(stderr, "bad err=%d in Xbyak::Error\n", err_); + //exit(1); + } + } + operator int() const { return err_; } + const char *what() const throw() + { + static const char *errTbl[] = { + "none", + "bad addressing", + "code is too big", + "bad scale", + "esp can't be index", + "bad combination", + "bad size of register", + "imm is too big", + "bad align", + "label is redefined", + "label is too far", + "label is not found", + "code is not copyable", + "bad parameter", + "can't protect", + "can't use 64bit disp(use (void*))", + "offset is too big", + "MEM size is not specified", + "bad mem size", + "bad st combination", + "over local label", + "under local label", + "can't alloc", + "T_SHORT is not supported in AutoGrow", + "bad protect mode", + "bad pNum", + "bad tNum", + "bad vsib addressing", + "can't convert", + "label is not set by L()", + "label is already set by L()", + "bad label string", + "err munmap", + "opmask is already set", + "rounding is already set", + "k0 is invalid", + "evex is invalid", + "sae(suppress all exceptions) is invalid", + "er(embedded rounding) is invalid", + "invalid broadcast", + "invalid opmask with memory", + "invalid zero", + "invalid rip in AutoGrow", + "invalid mib address", + "internal error", + "x2APIC is not supported" + }; + assert((size_t)err_ < sizeof(errTbl) / sizeof(*errTbl)); + return errTbl[err_]; + } +}; + +inline const char *ConvertErrorToString(const Error& err) +{ + return err.what(); +} + +inline void *AlignedMalloc(size_t size, size_t alignment) +{ +#ifdef __MINGW32__ + return __mingw_aligned_malloc(size, alignment); +#elif defined(_WIN32) + return _aligned_malloc(size, alignment); +#else + void *p; + int ret = posix_memalign(&p, alignment, size); + return (ret == 0) ? p : 0; +#endif +} + +inline void AlignedFree(void *p) +{ +#ifdef __MINGW32__ + __mingw_aligned_free(p); +#elif defined(_MSC_VER) + _aligned_free(p); +#else + free(p); +#endif +} + +template +inline const To CastTo(From p) throw() +{ + return (const To)(size_t)(p); +} +namespace inner { + +static const size_t ALIGN_PAGE_SIZE = 4096; + +inline bool IsInDisp8(uint32 x) { return 0xFFFFFF80 <= x || x <= 0x7F; } +inline bool IsInInt32(uint64 x) { return ~uint64(0x7fffffffu) <= x || x <= 0x7FFFFFFFU; } + +inline uint32 VerifyInInt32(uint64 x) +{ +#ifdef XBYAK64 + if (!IsInInt32(x)) throw Error(ERR_OFFSET_IS_TOO_BIG); +#endif + return static_cast(x); +} + +enum LabelMode { + LasIs, // as is + Labs, // absolute + LaddTop // (addr + top) for mov(reg, label) with AutoGrow +}; + +} // inner + +/* + custom allocator +*/ +struct Allocator { + virtual uint8 *alloc(size_t size) { return reinterpret_cast(AlignedMalloc(size, inner::ALIGN_PAGE_SIZE)); } + virtual void free(uint8 *p) { AlignedFree(p); } + virtual ~Allocator() {} + /* override to return false if you call protect() manually */ + virtual bool useProtect() const { return true; } +}; + +#ifdef XBYAK_USE_MMAP_ALLOCATOR +class MmapAllocator : Allocator { + typedef XBYAK_STD_UNORDERED_MAP SizeList; + SizeList sizeList_; +public: + uint8 *alloc(size_t size) + { + const size_t alignedSizeM1 = inner::ALIGN_PAGE_SIZE - 1; + size = (size + alignedSizeM1) & ~alignedSizeM1; +#ifdef MAP_ANONYMOUS + const int mode = MAP_PRIVATE | MAP_ANONYMOUS; +#elif defined(MAP_ANON) + const int mode = MAP_PRIVATE | MAP_ANON; +#else + #error "not supported" +#endif + void *p = mmap(NULL, size, PROT_READ | PROT_WRITE, mode, -1, 0); + if (p == MAP_FAILED) throw Error(ERR_CANT_ALLOC); + assert(p); + sizeList_[(uintptr_t)p] = size; + return (uint8*)p; + } + void free(uint8 *p) + { + if (p == 0) return; + SizeList::iterator i = sizeList_.find((uintptr_t)p); + if (i == sizeList_.end()) throw Error(ERR_BAD_PARAMETER); + if (munmap((void*)i->first, i->second) < 0) throw Error(ERR_MUNMAP); + sizeList_.erase(i); + } +}; +#endif + +class Address; +class Reg; + +class Operand { + static const uint8 EXT8BIT = 0x20; + unsigned int idx_:6; // 0..31 + EXT8BIT = 1 if spl/bpl/sil/dil + unsigned int kind_:9; + unsigned int bit_:10; +protected: + unsigned int zero_:1; + unsigned int mask_:3; + unsigned int rounding_:3; + void setIdx(int idx) { idx_ = idx; } +public: + enum Kind { + NONE = 0, + MEM = 1 << 0, + REG = 1 << 1, + MMX = 1 << 2, + FPU = 1 << 3, + XMM = 1 << 4, + YMM = 1 << 5, + ZMM = 1 << 6, + OPMASK = 1 << 7, + BNDREG = 1 << 8 + }; + enum Code { +#ifdef XBYAK64 + RAX = 0, RCX, RDX, RBX, RSP, RBP, RSI, RDI, R8, R9, R10, R11, R12, R13, R14, R15, + R8D = 8, R9D, R10D, R11D, R12D, R13D, R14D, R15D, + R8W = 8, R9W, R10W, R11W, R12W, R13W, R14W, R15W, + R8B = 8, R9B, R10B, R11B, R12B, R13B, R14B, R15B, + SPL = 4, BPL, SIL, DIL, +#endif + EAX = 0, ECX, EDX, EBX, ESP, EBP, ESI, EDI, + AX = 0, CX, DX, BX, SP, BP, SI, DI, + AL = 0, CL, DL, BL, AH, CH, DH, BH + }; + Operand() : idx_(0), kind_(0), bit_(0), zero_(0), mask_(0), rounding_(0) { } + Operand(int idx, Kind kind, int bit, bool ext8bit = 0) + : idx_(static_cast(idx | (ext8bit ? EXT8BIT : 0))) + , kind_(kind) + , bit_(bit) + , zero_(0), mask_(0), rounding_(0) + { + assert((bit_ & (bit_ - 1)) == 0); // bit must be power of two + } + Kind getKind() const { return static_cast(kind_); } + int getIdx() const { return idx_ & (EXT8BIT - 1); } + bool isNone() const { return kind_ == 0; } + bool isMMX() const { return is(MMX); } + bool isXMM() const { return is(XMM); } + bool isYMM() const { return is(YMM); } + bool isZMM() const { return is(ZMM); } + bool isXMEM() const { return is(XMM | MEM); } + bool isYMEM() const { return is(YMM | MEM); } + bool isZMEM() const { return is(ZMM | MEM); } + bool isOPMASK() const { return is(OPMASK); } + bool isBNDREG() const { return is(BNDREG); } + bool isREG(int bit = 0) const { return is(REG, bit); } + bool isMEM(int bit = 0) const { return is(MEM, bit); } + bool isFPU() const { return is(FPU); } + bool isExt8bit() const { return (idx_ & EXT8BIT) != 0; } + bool isExtIdx() const { return (getIdx() & 8) != 0; } + bool isExtIdx2() const { return (getIdx() & 16) != 0; } + bool hasEvex() const { return isZMM() || isExtIdx2() || getOpmaskIdx() || getRounding(); } + bool hasRex() const { return isExt8bit() || isREG(64) || isExtIdx(); } + bool hasZero() const { return zero_; } + int getOpmaskIdx() const { return mask_; } + int getRounding() const { return rounding_; } + void setKind(Kind kind) + { + if ((kind & (XMM|YMM|ZMM)) == 0) return; + kind_ = kind; + bit_ = kind == XMM ? 128 : kind == YMM ? 256 : 512; + } + void setBit(int bit) { bit_ = bit; } + void setOpmaskIdx(int idx, bool ignore_idx0 = false) + { + if (!ignore_idx0 && idx == 0) throw Error(ERR_K0_IS_INVALID); + if (mask_) throw Error(ERR_OPMASK_IS_ALREADY_SET); + mask_ = idx; + } + void setRounding(int idx) + { + if (rounding_) throw Error(ERR_ROUNDING_IS_ALREADY_SET); + rounding_ = idx; + } + void setZero() { zero_ = true; } + // ah, ch, dh, bh? + bool isHigh8bit() const + { + if (!isBit(8)) return false; + if (isExt8bit()) return false; + const int idx = getIdx(); + return AH <= idx && idx <= BH; + } + // any bit is accetable if bit == 0 + bool is(int kind, uint32 bit = 0) const + { + return (kind == 0 || (kind_ & kind)) && (bit == 0 || (bit_ & bit)); // cf. you can set (8|16) + } + bool isBit(uint32 bit) const { return (bit_ & bit) != 0; } + uint32 getBit() const { return bit_; } + const char *toString() const + { + const int idx = getIdx(); + if (kind_ == REG) { + if (isExt8bit()) { + static const char *tbl[4] = { "spl", "bpl", "sil", "dil" }; + return tbl[idx - 4]; + } + static const char *tbl[4][16] = { + { "al", "cl", "dl", "bl", "ah", "ch", "dh", "bh", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b" }, + { "ax", "cx", "dx", "bx", "sp", "bp", "si", "di", "r8w", "r9w", "r10w", "r11w", "r12w", "r13w", "r14w", "r15w" }, + { "eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi", "r8d", "r9d", "r10d", "r11d", "r12d", "r13d", "r14d", "r15d" }, + { "rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15" }, + }; + return tbl[bit_ == 8 ? 0 : bit_ == 16 ? 1 : bit_ == 32 ? 2 : 3][idx]; + } else if (isOPMASK()) { + static const char *tbl[8] = { "k0", "k1", "k2", "k3", "k4", "k5", "k6", "k7" }; + return tbl[idx]; + } else if (isZMM()) { + static const char *tbl[32] = { + "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31" + }; + return tbl[idx]; + } else if (isYMM()) { + static const char *tbl[32] = { + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", + "ymm16", "ymm17", "ymm18", "ymm19", "ymm20", "ymm21", "ymm22", "ymm23", "ymm24", "ymm25", "ymm26", "ymm27", "ymm28", "ymm29", "ymm30", "ymm31" + }; + return tbl[idx]; + } else if (isXMM()) { + static const char *tbl[32] = { + "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", + "xmm16", "xmm17", "xmm18", "xmm19", "xmm20", "xmm21", "xmm22", "xmm23", "xmm24", "xmm25", "xmm26", "xmm27", "xmm28", "xmm29", "xmm30", "xmm31" + }; + return tbl[idx]; + } else if (isMMX()) { + static const char *tbl[8] = { "mm0", "mm1", "mm2", "mm3", "mm4", "mm5", "mm6", "mm7" }; + return tbl[idx]; + } else if (isFPU()) { + static const char *tbl[8] = { "st0", "st1", "st2", "st3", "st4", "st5", "st6", "st7" }; + return tbl[idx]; + } else if (isBNDREG()) { + static const char *tbl[4] = { "bnd0", "bnd1", "bnd2", "bnd3" }; + return tbl[idx]; + } + throw Error(ERR_INTERNAL); + } + bool isEqualIfNotInherited(const Operand& rhs) const { return idx_ == rhs.idx_ && kind_ == rhs.kind_ && bit_ == rhs.bit_ && zero_ == rhs.zero_ && mask_ == rhs.mask_ && rounding_ == rhs.rounding_; } + bool operator==(const Operand& rhs) const; + bool operator!=(const Operand& rhs) const { return !operator==(rhs); } + const Address& getAddress() const; + const Reg& getReg() const; +}; + +class Label; + +struct Reg8; +struct Reg16; +struct Reg32; +#ifdef XBYAK64 +struct Reg64; +#endif +class Reg : public Operand { +public: + Reg() { } + Reg(int idx, Kind kind, int bit = 0, bool ext8bit = false) : Operand(idx, kind, bit, ext8bit) { } + Reg changeBit(int bit) const { return Reg(getIdx(), getKind(), bit, isExt8bit()); } + uint8 getRexW() const { return isREG(64) ? 8 : 0; } + uint8 getRexR() const { return isExtIdx() ? 4 : 0; } + uint8 getRexX() const { return isExtIdx() ? 2 : 0; } + uint8 getRexB() const { return isExtIdx() ? 1 : 0; } + uint8 getRex(const Reg& base = Reg()) const + { + uint8 rex = getRexW() | getRexR() | base.getRexW() | base.getRexB(); + if (rex || isExt8bit() || base.isExt8bit()) rex |= 0x40; + return rex; + } + Reg8 cvt8() const; + Reg16 cvt16() const; + Reg32 cvt32() const; +#ifdef XBYAK64 + Reg64 cvt64() const; +#endif +}; + +inline const Reg& Operand::getReg() const +{ + assert(!isMEM()); + return static_cast(*this); +} + +struct Reg8 : public Reg { + explicit Reg8(int idx = 0, bool ext8bit = false) : Reg(idx, Operand::REG, 8, ext8bit) { } +}; + +struct Reg16 : public Reg { + explicit Reg16(int idx = 0) : Reg(idx, Operand::REG, 16) { } +}; + +struct Mmx : public Reg { + explicit Mmx(int idx = 0, Kind kind = Operand::MMX, int bit = 64) : Reg(idx, kind, bit) { } +}; + +struct EvexModifierRounding { + enum { + T_RN_SAE = 1, + T_RD_SAE = 2, + T_RU_SAE = 3, + T_RZ_SAE = 4, + T_SAE = 5 + }; + explicit EvexModifierRounding(int rounding) : rounding(rounding) {} + int rounding; +}; +struct EvexModifierZero{EvexModifierZero() {}}; + +struct Xmm : public Mmx { + explicit Xmm(int idx = 0, Kind kind = Operand::XMM, int bit = 128) : Mmx(idx, kind, bit) { } + Xmm(Kind kind, int idx) : Mmx(idx, kind, kind == XMM ? 128 : kind == YMM ? 256 : 512) { } + Xmm operator|(const EvexModifierRounding& emr) const { Xmm r(*this); r.setRounding(emr.rounding); return r; } + Xmm copyAndSetIdx(int idx) const { Xmm ret(*this); ret.setIdx(idx); return ret; } + Xmm copyAndSetKind(Operand::Kind kind) const { Xmm ret(*this); ret.setKind(kind); return ret; } +}; + +struct Ymm : public Xmm { + explicit Ymm(int idx = 0, Kind kind = Operand::YMM, int bit = 256) : Xmm(idx, kind, bit) { } + Ymm operator|(const EvexModifierRounding& emr) const { Ymm r(*this); r.setRounding(emr.rounding); return r; } +}; + +struct Zmm : public Ymm { + explicit Zmm(int idx = 0) : Ymm(idx, Operand::ZMM, 512) { } + Zmm operator|(const EvexModifierRounding& emr) const { Zmm r(*this); r.setRounding(emr.rounding); return r; } +}; + +struct Opmask : public Reg { + explicit Opmask(int idx = 0) : Reg(idx, Operand::OPMASK, 64) {} +}; + +struct BoundsReg : public Reg { + explicit BoundsReg(int idx = 0) : Reg(idx, Operand::BNDREG, 128) {} +}; + +templateT operator|(const T& x, const Opmask& k) { T r(x); r.setOpmaskIdx(k.getIdx()); return r; } +templateT operator|(const T& x, const EvexModifierZero&) { T r(x); r.setZero(); return r; } +templateT operator|(const T& x, const EvexModifierRounding& emr) { T r(x); r.setRounding(emr.rounding); return r; } + +struct Fpu : public Reg { + explicit Fpu(int idx = 0) : Reg(idx, Operand::FPU, 32) { } +}; + +struct Reg32e : public Reg { + explicit Reg32e(int idx, int bit) : Reg(idx, Operand::REG, bit) {} +}; +struct Reg32 : public Reg32e { + explicit Reg32(int idx = 0) : Reg32e(idx, 32) {} +}; +#ifdef XBYAK64 +struct Reg64 : public Reg32e { + explicit Reg64(int idx = 0) : Reg32e(idx, 64) {} +}; +struct RegRip { + sint64 disp_; + const Label* label_; + bool isAddr_; + explicit RegRip(sint64 disp = 0, const Label* label = 0, bool isAddr = false) : disp_(disp), label_(label), isAddr_(isAddr) {} + friend const RegRip operator+(const RegRip& r, int disp) { + return RegRip(r.disp_ + disp, r.label_, r.isAddr_); + } + friend const RegRip operator-(const RegRip& r, int disp) { + return RegRip(r.disp_ - disp, r.label_, r.isAddr_); + } + friend const RegRip operator+(const RegRip& r, sint64 disp) { + return RegRip(r.disp_ + disp, r.label_, r.isAddr_); + } + friend const RegRip operator-(const RegRip& r, sint64 disp) { + return RegRip(r.disp_ - disp, r.label_, r.isAddr_); + } + friend const RegRip operator+(const RegRip& r, const Label& label) { + if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING); + return RegRip(r.disp_, &label); + } + friend const RegRip operator+(const RegRip& r, const void *addr) { + if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING); + return RegRip(r.disp_ + (sint64)addr, 0, true); + } +}; +#endif + +inline Reg8 Reg::cvt8() const +{ + const int idx = getIdx(); + if (isBit(8)) return Reg8(idx, isExt8bit()); +#ifdef XBYAK32 + if (idx >= 4) throw Error(ERR_CANT_CONVERT); +#endif + return Reg8(idx, 4 <= idx && idx < 8); +} + +inline Reg16 Reg::cvt16() const +{ + const int idx = getIdx(); + if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); + return Reg16(idx); +} + +inline Reg32 Reg::cvt32() const +{ + const int idx = getIdx(); + if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); + return Reg32(idx); +} + +#ifdef XBYAK64 +inline Reg64 Reg::cvt64() const +{ + const int idx = getIdx(); + if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); + return Reg64(idx); +} +#endif + +#ifndef XBYAK_DISABLE_SEGMENT +// not derived from Reg +class Segment { + int idx_; +public: + enum { + es, cs, ss, ds, fs, gs + }; + explicit Segment(int idx) : idx_(idx) { assert(0 <= idx_ && idx_ < 6); } + int getIdx() const { return idx_; } + const char *toString() const + { + static const char tbl[][3] = { + "es", "cs", "ss", "ds", "fs", "gs" + }; + return tbl[idx_]; + } +}; +#endif + +class RegExp { +public: +#ifdef XBYAK64 + enum { i32e = 32 | 64 }; +#else + enum { i32e = 32 }; +#endif + RegExp(size_t disp = 0) : scale_(0), disp_(disp) { } + RegExp(const Reg& r, int scale = 1) + : scale_(scale) + , disp_(0) + { + if (!r.isREG(i32e) && !r.is(Reg::XMM|Reg::YMM|Reg::ZMM)) throw Error(ERR_BAD_SIZE_OF_REGISTER); + if (scale == 0) return; + if (scale != 1 && scale != 2 && scale != 4 && scale != 8) throw Error(ERR_BAD_SCALE); + if (r.getBit() >= 128 || scale != 1) { // xmm/ymm is always index + index_ = r; + } else { + base_ = r; + } + } + bool isVsib(int bit = 128 | 256 | 512) const { return index_.isBit(bit); } + RegExp optimize() const + { + RegExp exp = *this; + // [reg * 2] => [reg + reg] + if (index_.isBit(i32e) && !base_.getBit() && scale_ == 2) { + exp.base_ = index_; + exp.scale_ = 1; + } + return exp; + } + bool operator==(const RegExp& rhs) const + { + return base_ == rhs.base_ && index_ == rhs.index_ && disp_ == rhs.disp_ && scale_ == rhs.scale_; + } + const Reg& getBase() const { return base_; } + const Reg& getIndex() const { return index_; } + int getScale() const { return scale_; } + size_t getDisp() const { return disp_; } + void verify() const + { + if (base_.getBit() >= 128) throw Error(ERR_BAD_SIZE_OF_REGISTER); + if (index_.getBit() && index_.getBit() <= 64) { + if (index_.getIdx() == Operand::ESP) throw Error(ERR_ESP_CANT_BE_INDEX); + if (base_.getBit() && base_.getBit() != index_.getBit()) throw Error(ERR_BAD_SIZE_OF_REGISTER); + } + } + friend RegExp operator+(const RegExp& a, const RegExp& b); + friend RegExp operator-(const RegExp& e, size_t disp); + uint8 getRex() const + { + uint8 rex = index_.getRexX() | base_.getRexB(); + return rex ? uint8(rex | 0x40) : 0; + } +private: + /* + [base_ + index_ * scale_ + disp_] + base : Reg32e, index : Reg32e(w/o esp), Xmm, Ymm + */ + Reg base_; + Reg index_; + int scale_; + size_t disp_; +}; + +inline RegExp operator+(const RegExp& a, const RegExp& b) +{ + if (a.index_.getBit() && b.index_.getBit()) throw Error(ERR_BAD_ADDRESSING); + RegExp ret = a; + if (!ret.index_.getBit()) { ret.index_ = b.index_; ret.scale_ = b.scale_; } + if (b.base_.getBit()) { + if (ret.base_.getBit()) { + if (ret.index_.getBit()) throw Error(ERR_BAD_ADDRESSING); + // base + base => base + index * 1 + ret.index_ = b.base_; + // [reg + esp] => [esp + reg] + if (ret.index_.getIdx() == Operand::ESP) std::swap(ret.base_, ret.index_); + ret.scale_ = 1; + } else { + ret.base_ = b.base_; + } + } + ret.disp_ += b.disp_; + return ret; +} +inline RegExp operator*(const Reg& r, int scale) +{ + return RegExp(r, scale); +} +inline RegExp operator-(const RegExp& e, size_t disp) +{ + RegExp ret = e; + ret.disp_ -= disp; + return ret; +} + +// 2nd parameter for constructor of CodeArray(maxSize, userPtr, alloc) +void *const AutoGrow = (void*)1; //-V566 +void *const DontSetProtectRWE = (void*)2; //-V566 + +class CodeArray { + enum Type { + USER_BUF = 1, // use userPtr(non alignment, non protect) + ALLOC_BUF, // use new(alignment, protect) + AUTO_GROW // automatically move and grow memory if necessary + }; + CodeArray(const CodeArray& rhs); + void operator=(const CodeArray&); + bool isAllocType() const { return type_ == ALLOC_BUF || type_ == AUTO_GROW; } + struct AddrInfo { + size_t codeOffset; // position to write + size_t jmpAddr; // value to write + int jmpSize; // size of jmpAddr + inner::LabelMode mode; + AddrInfo(size_t _codeOffset, size_t _jmpAddr, int _jmpSize, inner::LabelMode _mode) + : codeOffset(_codeOffset), jmpAddr(_jmpAddr), jmpSize(_jmpSize), mode(_mode) {} + uint64 getVal(const uint8 *top) const + { + uint64 disp = (mode == inner::LaddTop) ? jmpAddr + size_t(top) : (mode == inner::LasIs) ? jmpAddr : jmpAddr - size_t(top); + if (jmpSize == 4) disp = inner::VerifyInInt32(disp); + return disp; + } + }; + typedef std::list AddrInfoList; + AddrInfoList addrInfoList_; + const Type type_; +#ifdef XBYAK_USE_MMAP_ALLOCATOR + MmapAllocator defaultAllocator_; +#else + Allocator defaultAllocator_; +#endif + Allocator *alloc_; +protected: + size_t maxSize_; + uint8 *top_; + size_t size_; + bool isCalledCalcJmpAddress_; + + bool useProtect() const { return alloc_->useProtect(); } + /* + allocate new memory and copy old data to the new area + */ + void growMemory() + { + const size_t newSize = (std::max)(DEFAULT_MAX_CODE_SIZE, maxSize_ * 2); + uint8 *newTop = alloc_->alloc(newSize); + if (newTop == 0) throw Error(ERR_CANT_ALLOC); + for (size_t i = 0; i < size_; i++) newTop[i] = top_[i]; + alloc_->free(top_); + top_ = newTop; + maxSize_ = newSize; + } + /* + calc jmp address for AutoGrow mode + */ + void calcJmpAddress() + { + if (isCalledCalcJmpAddress_) return; + for (AddrInfoList::const_iterator i = addrInfoList_.begin(), ie = addrInfoList_.end(); i != ie; ++i) { + uint64 disp = i->getVal(top_); + rewrite(i->codeOffset, disp, i->jmpSize); + } + isCalledCalcJmpAddress_ = true; + } +public: + enum ProtectMode { + PROTECT_RW = 0, // read/write + PROTECT_RWE = 1, // read/write/exec + PROTECT_RE = 2 // read/exec + }; + explicit CodeArray(size_t maxSize, void *userPtr = 0, Allocator *allocator = 0) + : type_(userPtr == AutoGrow ? AUTO_GROW : (userPtr == 0 || userPtr == DontSetProtectRWE) ? ALLOC_BUF : USER_BUF) + , alloc_(allocator ? allocator : (Allocator*)&defaultAllocator_) + , maxSize_(maxSize) + , top_(type_ == USER_BUF ? reinterpret_cast(userPtr) : alloc_->alloc((std::max)(maxSize, 1))) + , size_(0) + , isCalledCalcJmpAddress_(false) + { + if (maxSize_ > 0 && top_ == 0) throw Error(ERR_CANT_ALLOC); + if ((type_ == ALLOC_BUF && userPtr != DontSetProtectRWE && useProtect()) && !setProtectMode(PROTECT_RWE, false)) { + alloc_->free(top_); + throw Error(ERR_CANT_PROTECT); + } + } + virtual ~CodeArray() + { + if (isAllocType()) { + if (useProtect()) setProtectModeRW(false); + alloc_->free(top_); + } + } + bool setProtectMode(ProtectMode mode, bool throwException = true) + { + bool isOK = protect(top_, maxSize_, mode); + if (isOK) return true; + if (throwException) throw Error(ERR_CANT_PROTECT); + return false; + } + bool setProtectModeRE(bool throwException = true) { return setProtectMode(PROTECT_RE, throwException); } + bool setProtectModeRW(bool throwException = true) { return setProtectMode(PROTECT_RW, throwException); } + void resetSize() + { + size_ = 0; + addrInfoList_.clear(); + isCalledCalcJmpAddress_ = false; + } + void db(int code) + { + if (size_ >= maxSize_) { + if (type_ == AUTO_GROW) { + growMemory(); + } else { + throw Error(ERR_CODE_IS_TOO_BIG); + } + } + top_[size_++] = static_cast(code); + } + void db(const uint8 *code, size_t codeSize) + { + for (size_t i = 0; i < codeSize; i++) db(code[i]); + } + void db(uint64 code, size_t codeSize) + { + if (codeSize > 8) throw Error(ERR_BAD_PARAMETER); + for (size_t i = 0; i < codeSize; i++) db(static_cast(code >> (i * 8))); + } + void dw(uint32 code) { db(code, 2); } + void dd(uint32 code) { db(code, 4); } + void dq(uint64 code) { db(code, 8); } + const uint8 *getCode() const { return top_; } + template + const F getCode() const { return reinterpret_cast(top_); } + const uint8 *getCurr() const { return &top_[size_]; } + template + const F getCurr() const { return reinterpret_cast(&top_[size_]); } + size_t getSize() const { return size_; } + void setSize(size_t size) + { + if (size > maxSize_) throw Error(ERR_OFFSET_IS_TOO_BIG); + size_ = size; + } + void dump() const + { + const uint8 *p = getCode(); + size_t bufSize = getSize(); + size_t remain = bufSize; + for (int i = 0; i < 4; i++) { + size_t disp = 16; + if (remain < 16) { + disp = remain; + } + for (size_t j = 0; j < 16; j++) { + if (j < disp) { + printf("%02X", p[i * 16 + j]); + } + } + putchar('\n'); + remain -= disp; + if (remain == 0) { + break; + } + } + } + /* + @param offset [in] offset from top + @param disp [in] offset from the next of jmp + @param size [in] write size(1, 2, 4, 8) + */ + void rewrite(size_t offset, uint64 disp, size_t size) + { + assert(offset < maxSize_); + if (size != 1 && size != 2 && size != 4 && size != 8) throw Error(ERR_BAD_PARAMETER); + uint8 *const data = top_ + offset; + for (size_t i = 0; i < size; i++) { + data[i] = static_cast(disp >> (i * 8)); + } + } + void save(size_t offset, size_t val, int size, inner::LabelMode mode) + { + addrInfoList_.push_back(AddrInfo(offset, val, size, mode)); + } + bool isAutoGrow() const { return type_ == AUTO_GROW; } + bool isCalledCalcJmpAddress() const { return isCalledCalcJmpAddress_; } + /** + change exec permission of memory + @param addr [in] buffer address + @param size [in] buffer size + @param protectMode [in] mode(RW/RWE/RE) + @return true(success), false(failure) + */ + static inline bool protect(const void *addr, size_t size, int protectMode) + { +#if defined(_WIN32) + const DWORD c_rw = PAGE_READWRITE; + const DWORD c_rwe = PAGE_EXECUTE_READWRITE; + const DWORD c_re = PAGE_EXECUTE_READ; + DWORD mode; +#else + const int c_rw = PROT_READ | PROT_WRITE; + const int c_rwe = PROT_READ | PROT_WRITE | PROT_EXEC; + const int c_re = PROT_READ | PROT_EXEC; + int mode; +#endif + switch (protectMode) { + case PROTECT_RW: mode = c_rw; break; + case PROTECT_RWE: mode = c_rwe; break; + case PROTECT_RE: mode = c_re; break; + default: + return false; + } +#if defined(_WIN32) + DWORD oldProtect; + return VirtualProtect(const_cast(addr), size, mode, &oldProtect) != 0; +#elif defined(__GNUC__) + size_t pageSize = sysconf(_SC_PAGESIZE); + size_t iaddr = reinterpret_cast(addr); + size_t roundAddr = iaddr & ~(pageSize - static_cast(1)); +#ifndef NDEBUG + if (pageSize != 4096) fprintf(stderr, "large page(%zd) is used. not tested enough.\n", pageSize); +#endif + return mprotect(reinterpret_cast(roundAddr), size + (iaddr - roundAddr), mode) == 0; +#else + return true; +#endif + } + /** + get aligned memory pointer + @param addr [in] address + @param alignedSize [in] power of two + @return aligned addr by alingedSize + */ + static inline uint8 *getAlignedAddress(uint8 *addr, size_t alignedSize = 16) + { + return reinterpret_cast((reinterpret_cast(addr) + alignedSize - 1) & ~(alignedSize - static_cast(1))); + } +}; + +class Address : public Operand { +public: + enum Mode { + M_ModRM, + M_64bitDisp, + M_rip, + M_ripAddr + }; + Address(uint32 sizeBit, bool broadcast, const RegExp& e) + : Operand(0, MEM, sizeBit), e_(e), label_(0), mode_(M_ModRM), broadcast_(broadcast) + { + e_.verify(); + } +#ifdef XBYAK64 + explicit Address(size_t disp) + : Operand(0, MEM, 64), e_(disp), label_(0), mode_(M_64bitDisp), broadcast_(false){ } + Address(uint32 sizeBit, bool broadcast, const RegRip& addr) + : Operand(0, MEM, sizeBit), e_(addr.disp_), label_(addr.label_), mode_(addr.isAddr_ ? M_ripAddr : M_rip), broadcast_(broadcast) { } +#endif + RegExp getRegExp(bool optimize = true) const + { + return optimize ? e_.optimize() : e_; + } + Mode getMode() const { return mode_; } + bool is32bit() const { return e_.getBase().getBit() == 32 || e_.getIndex().getBit() == 32; } + bool isOnlyDisp() const { return !e_.getBase().getBit() && !e_.getIndex().getBit(); } // for mov eax + size_t getDisp() const { return e_.getDisp(); } + uint8 getRex() const + { + if (mode_ != M_ModRM) return 0; + return getRegExp().getRex(); + } + bool is64bitDisp() const { return mode_ == M_64bitDisp; } // for moffset + bool isBroadcast() const { return broadcast_; } + const Label* getLabel() const { return label_; } + bool operator==(const Address& rhs) const + { + return getBit() == rhs.getBit() && e_ == rhs.e_ && label_ == rhs.label_ && mode_ == rhs.mode_ && broadcast_ == rhs.broadcast_; + } + bool operator!=(const Address& rhs) const { return !operator==(rhs); } + bool isVsib() const { return e_.isVsib(); } +private: + RegExp e_; + const Label* label_; + Mode mode_; + bool broadcast_; +}; + +inline const Address& Operand::getAddress() const +{ + assert(isMEM()); + return static_cast(*this); +} + +inline bool Operand::operator==(const Operand& rhs) const +{ + if (isMEM() && rhs.isMEM()) return this->getAddress() == rhs.getAddress(); + return isEqualIfNotInherited(rhs); +} + +class AddressFrame { + void operator=(const AddressFrame&); + AddressFrame(const AddressFrame&); +public: + const uint32 bit_; + const bool broadcast_; + explicit AddressFrame(uint32 bit, bool broadcast = false) : bit_(bit), broadcast_(broadcast) { } + Address operator[](const RegExp& e) const + { + return Address(bit_, broadcast_, e); + } + Address operator[](const void *disp) const + { + return Address(bit_, broadcast_, RegExp(reinterpret_cast(disp))); + } +#ifdef XBYAK64 + Address operator[](uint64 disp) const { return Address(disp); } + Address operator[](const RegRip& addr) const { return Address(bit_, broadcast_, addr); } +#endif +}; + +struct JmpLabel { + size_t endOfJmp; /* offset from top to the end address of jmp */ + int jmpSize; + inner::LabelMode mode; + size_t disp; // disp for [rip + disp] + explicit JmpLabel(size_t endOfJmp = 0, int jmpSize = 0, inner::LabelMode mode = inner::LasIs, size_t disp = 0) + : endOfJmp(endOfJmp), jmpSize(jmpSize), mode(mode), disp(disp) + { + } +}; + +class LabelManager; + +class Label { + mutable LabelManager *mgr; + mutable int id; + friend class LabelManager; +public: + Label() : mgr(0), id(0) {} + Label(const Label& rhs); + Label& operator=(const Label& rhs); + ~Label(); + void clear() { mgr = 0; id = 0; } + int getId() const { return id; } + const uint8 *getAddress() const; + + // backward compatibility + static inline std::string toStr(int num) + { + char buf[16]; +#if defined(_MSC_VER) && (_MSC_VER < 1900) + _snprintf_s +#else + snprintf +#endif + (buf, sizeof(buf), ".%08x", num); + return buf; + } +}; + +class LabelManager { + // for string label + struct SlabelVal { + size_t offset; + SlabelVal(size_t offset) : offset(offset) {} + }; + typedef XBYAK_STD_UNORDERED_MAP SlabelDefList; + typedef XBYAK_STD_UNORDERED_MULTIMAP SlabelUndefList; + struct SlabelState { + SlabelDefList defList; + SlabelUndefList undefList; + }; + typedef std::list StateList; + // for Label class + struct ClabelVal { + ClabelVal(size_t offset = 0) : offset(offset), refCount(1) {} + size_t offset; + int refCount; + }; + typedef XBYAK_STD_UNORDERED_MAP ClabelDefList; + typedef XBYAK_STD_UNORDERED_MULTIMAP ClabelUndefList; + typedef XBYAK_STD_UNORDERED_SET LabelPtrList; + + CodeArray *base_; + // global : stateList_.front(), local : stateList_.back() + StateList stateList_; + mutable int labelId_; + ClabelDefList clabelDefList_; + ClabelUndefList clabelUndefList_; + LabelPtrList labelPtrList_; + + int getId(const Label& label) const + { + if (label.id == 0) label.id = labelId_++; + return label.id; + } + template + void define_inner(DefList& defList, UndefList& undefList, const T& labelId, size_t addrOffset) + { + // add label + typename DefList::value_type item(labelId, addrOffset); + std::pair ret = defList.insert(item); + if (!ret.second) throw Error(ERR_LABEL_IS_REDEFINED); + // search undefined label + for (;;) { + typename UndefList::iterator itr = undefList.find(labelId); + if (itr == undefList.end()) break; + const JmpLabel *jmp = &itr->second; + const size_t offset = jmp->endOfJmp - jmp->jmpSize; + size_t disp; + if (jmp->mode == inner::LaddTop) { + disp = addrOffset; + } else if (jmp->mode == inner::Labs) { + disp = size_t(base_->getCurr()); + } else { + disp = addrOffset - jmp->endOfJmp + jmp->disp; +#ifdef XBYAK64 + if (jmp->jmpSize <= 4 && !inner::IsInInt32(disp)) throw Error(ERR_OFFSET_IS_TOO_BIG); +#endif + if (jmp->jmpSize == 1 && !inner::IsInDisp8((uint32)disp)) throw Error(ERR_LABEL_IS_TOO_FAR); + } + if (base_->isAutoGrow()) { + base_->save(offset, disp, jmp->jmpSize, jmp->mode); + } else { + base_->rewrite(offset, disp, jmp->jmpSize); + } + undefList.erase(itr); + } + } + template + bool getOffset_inner(const DefList& defList, size_t *offset, const T& label) const + { + typename DefList::const_iterator i = defList.find(label); + if (i == defList.end()) return false; + *offset = i->second.offset; + return true; + } + friend class Label; + void incRefCount(int id, Label *label) + { + clabelDefList_[id].refCount++; + labelPtrList_.insert(label); + } + void decRefCount(int id, Label *label) + { + labelPtrList_.erase(label); + ClabelDefList::iterator i = clabelDefList_.find(id); + if (i == clabelDefList_.end()) return; + if (i->second.refCount == 1) { + clabelDefList_.erase(id); + } else { + --i->second.refCount; + } + } + template + bool hasUndefinedLabel_inner(const T& list) const + { +#ifndef NDEBUG + for (typename T::const_iterator i = list.begin(); i != list.end(); ++i) { + std::cerr << "undefined label:" << i->first << std::endl; + } +#endif + return !list.empty(); + } + // detach all labels linked to LabelManager + void resetLabelPtrList() + { + for (LabelPtrList::iterator i = labelPtrList_.begin(), ie = labelPtrList_.end(); i != ie; ++i) { + (*i)->clear(); + } + labelPtrList_.clear(); + } +public: + LabelManager() + { + reset(); + } + ~LabelManager() + { + resetLabelPtrList(); + } + void reset() + { + base_ = 0; + labelId_ = 1; + stateList_.clear(); + stateList_.push_back(SlabelState()); + stateList_.push_back(SlabelState()); + clabelDefList_.clear(); + clabelUndefList_.clear(); + resetLabelPtrList(); + } + void enterLocal() + { + stateList_.push_back(SlabelState()); + } + void leaveLocal() + { + if (stateList_.size() <= 2) throw Error(ERR_UNDER_LOCAL_LABEL); + if (hasUndefinedLabel_inner(stateList_.back().undefList)) throw Error(ERR_LABEL_IS_NOT_FOUND); + stateList_.pop_back(); + } + void set(CodeArray *base) { base_ = base; } + void defineSlabel(std::string label) + { + if (label == "@b" || label == "@f") throw Error(ERR_BAD_LABEL_STR); + if (label == "@@") { + SlabelDefList& defList = stateList_.front().defList; + SlabelDefList::iterator i = defList.find("@f"); + if (i != defList.end()) { + defList.erase(i); + label = "@b"; + } else { + i = defList.find("@b"); + if (i != defList.end()) { + defList.erase(i); + } + label = "@f"; + } + } + SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + define_inner(st.defList, st.undefList, label, base_->getSize()); + } + void defineClabel(Label& label) + { + define_inner(clabelDefList_, clabelUndefList_, getId(label), base_->getSize()); + label.mgr = this; + labelPtrList_.insert(&label); + } + void assign(Label& dst, const Label& src) + { + ClabelDefList::const_iterator i = clabelDefList_.find(src.id); + if (i == clabelDefList_.end()) throw Error(ERR_LABEL_ISNOT_SET_BY_L); + define_inner(clabelDefList_, clabelUndefList_, dst.id, i->second.offset); + dst.mgr = this; + labelPtrList_.insert(&dst); + } + bool getOffset(size_t *offset, std::string& label) const + { + const SlabelDefList& defList = stateList_.front().defList; + if (label == "@b") { + if (defList.find("@f") != defList.end()) { + label = "@f"; + } else if (defList.find("@b") == defList.end()) { + throw Error(ERR_LABEL_IS_NOT_FOUND); + } + } else if (label == "@f") { + if (defList.find("@f") != defList.end()) { + label = "@b"; + } + } + const SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + return getOffset_inner(st.defList, offset, label); + } + bool getOffset(size_t *offset, const Label& label) const + { + return getOffset_inner(clabelDefList_, offset, getId(label)); + } + void addUndefinedLabel(const std::string& label, const JmpLabel& jmp) + { + SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + st.undefList.insert(SlabelUndefList::value_type(label, jmp)); + } + void addUndefinedLabel(const Label& label, const JmpLabel& jmp) + { + clabelUndefList_.insert(ClabelUndefList::value_type(label.id, jmp)); + } + bool hasUndefSlabel() const + { + for (StateList::const_iterator i = stateList_.begin(), ie = stateList_.end(); i != ie; ++i) { + if (hasUndefinedLabel_inner(i->undefList)) return true; + } + return false; + } + bool hasUndefClabel() const { return hasUndefinedLabel_inner(clabelUndefList_); } + const uint8 *getCode() const { return base_->getCode(); } + bool isReady() const { return !base_->isAutoGrow() || base_->isCalledCalcJmpAddress(); } +}; + +inline Label::Label(const Label& rhs) +{ + id = rhs.id; + mgr = rhs.mgr; + if (mgr) mgr->incRefCount(id, this); +} +inline Label& Label::operator=(const Label& rhs) +{ + if (id) throw Error(ERR_LABEL_IS_ALREADY_SET_BY_L); + id = rhs.id; + mgr = rhs.mgr; + if (mgr) mgr->incRefCount(id, this); + return *this; +} +inline Label::~Label() +{ + if (id && mgr) mgr->decRefCount(id, this); +} +inline const uint8* Label::getAddress() const +{ + if (mgr == 0 || !mgr->isReady()) return 0; + size_t offset; + if (!mgr->getOffset(&offset, *this)) return 0; + return mgr->getCode() + offset; +} + +class CodeGenerator : public CodeArray { +public: + enum LabelType { + T_SHORT, + T_NEAR, + T_AUTO // T_SHORT if possible + }; +private: + CodeGenerator operator=(const CodeGenerator&); // don't call +#ifdef XBYAK64 + enum { i32e = 32 | 64, BIT = 64 }; + static const size_t dummyAddr = (size_t(0x11223344) << 32) | 55667788; + typedef Reg64 NativeReg; +#else + enum { i32e = 32, BIT = 32 }; + static const size_t dummyAddr = 0x12345678; + typedef Reg32 NativeReg; +#endif + // (XMM, XMM|MEM) + static inline bool isXMM_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isXMM() || op2.isMEM()); + } + // (MMX, MMX|MEM) or (XMM, XMM|MEM) + static inline bool isXMMorMMX_MEM(const Operand& op1, const Operand& op2) + { + return (op1.isMMX() && (op2.isMMX() || op2.isMEM())) || isXMM_XMMorMEM(op1, op2); + } + // (XMM, MMX|MEM) + static inline bool isXMM_MMXorMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isMMX() || op2.isMEM()); + } + // (MMX, XMM|MEM) + static inline bool isMMX_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isMMX() && (op2.isXMM() || op2.isMEM()); + } + // (XMM, REG32|MEM) + static inline bool isXMM_REG32orMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isREG(i32e) || op2.isMEM()); + } + // (REG32, XMM|MEM) + static inline bool isREG32_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isREG(i32e) && (op2.isXMM() || op2.isMEM()); + } + // (REG32, REG32|MEM) + static inline bool isREG32_REG32orMEM(const Operand& op1, const Operand& op2) + { + return op1.isREG(i32e) && ((op2.isREG(i32e) && op1.getBit() == op2.getBit()) || op2.isMEM()); + } + void rex(const Operand& op1, const Operand& op2 = Operand()) + { + uint8 rex = 0; + const Operand *p1 = &op1, *p2 = &op2; + if (p1->isMEM()) std::swap(p1, p2); + if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION); + if (p2->isMEM()) { + const Address& addr = p2->getAddress(); + if (BIT == 64 && addr.is32bit()) db(0x67); + rex = addr.getRex() | p1->getReg().getRex(); + } else { + // ModRM(reg, base); + rex = op2.getReg().getRex(op1.getReg()); + } + // except movsx(16bit, 32/64bit) + if ((op1.isBit(16) && !op2.isBit(i32e)) || (op2.isBit(16) && !op1.isBit(i32e))) db(0x66); + if (rex) db(rex); + } + enum AVXtype { + // low 3 bit + T_N1 = 1, + T_N2 = 2, + T_N4 = 3, + T_N8 = 4, + T_N16 = 5, + T_N32 = 6, + T_NX_MASK = 7, + // + T_N_VL = 1 << 3, // N * (1, 2, 4) for VL + T_DUP = 1 << 4, // N = (8, 32, 64) + T_66 = 1 << 5, + T_F3 = 1 << 6, + T_F2 = 1 << 7, + T_0F = 1 << 8, + T_0F38 = 1 << 9, + T_0F3A = 1 << 10, + T_L0 = 1 << 11, + T_L1 = 1 << 12, + T_W0 = 1 << 13, + T_W1 = 1 << 14, + T_EW0 = 1 << 15, + T_EW1 = 1 << 16, + T_YMM = 1 << 17, // support YMM, ZMM + T_EVEX = 1 << 18, + T_ER_X = 1 << 19, // xmm{er} + T_ER_Y = 1 << 20, // ymm{er} + T_ER_Z = 1 << 21, // zmm{er} + T_SAE_X = 1 << 22, // xmm{sae} + T_SAE_Y = 1 << 23, // ymm{sae} + T_SAE_Z = 1 << 24, // zmm{sae} + T_MUST_EVEX = 1 << 25, // contains T_EVEX + T_B32 = 1 << 26, // m32bcst + T_B64 = 1 << 27, // m64bcst + T_M_K = 1 << 28, // mem{k} + T_VSIB = 1 << 29, + T_MEM_EVEX = 1 << 30, // use evex if mem + T_XXX + }; + void vex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false) + { + int w = (type & T_W1) ? 1 : 0; + bool is256 = (type & T_L1) ? true : (type & T_L0) ? false : reg.isYMM(); + bool r = reg.isExtIdx(); + bool b = base.isExtIdx(); + int idx = v ? v->getIdx() : 0; + if ((idx | reg.getIdx() | base.getIdx()) >= 16) throw Error(ERR_BAD_COMBINATION); + uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0; + uint32 vvvv = (((~idx) & 15) << 3) | (is256 ? 4 : 0) | pp; + if (!b && !x && !w && (type & T_0F)) { + db(0xC5); db((r ? 0 : 0x80) | vvvv); + } else { + uint32 mmmm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0; + db(0xC4); db((r ? 0 : 0x80) | (x ? 0 : 0x40) | (b ? 0 : 0x20) | mmmm); db((w << 7) | vvvv); + } + db(code); + } + void verifySAE(const Reg& r, int type) const + { + if (((type & T_SAE_X) && r.isXMM()) || ((type & T_SAE_Y) && r.isYMM()) || ((type & T_SAE_Z) && r.isZMM())) return; + throw Error(ERR_SAE_IS_INVALID); + } + void verifyER(const Reg& r, int type) const + { + if (((type & T_ER_X) && r.isXMM()) || ((type & T_ER_Y) && r.isYMM()) || ((type & T_ER_Z) && r.isZMM())) return; + throw Error(ERR_ER_IS_INVALID); + } + // (a, b, c) contains non zero two or three values then err + int verifyDuplicate(int a, int b, int c, int err) + { + int v = a | b | c; + if ((a > 0 && a != v) + (b > 0 && b != v) + (c > 0 && c != v) > 0) return Error(err); + return v; + } + int evex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false, bool b = false, int aaa = 0, uint32 VL = 0, bool Hi16Vidx = false) + { + if (!(type & (T_EVEX | T_MUST_EVEX))) throw Error(ERR_EVEX_IS_INVALID); + int w = (type & T_EW1) ? 1 : 0; + uint32 mm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0; + uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0; + + int idx = v ? v->getIdx() : 0; + uint32 vvvv = ~idx; + + bool R = !reg.isExtIdx(); + bool X = x ? false : !base.isExtIdx2(); + bool B = !base.isExtIdx(); + bool Rp = !reg.isExtIdx2(); + int LL; + int rounding = verifyDuplicate(reg.getRounding(), base.getRounding(), v ? v->getRounding() : 0, ERR_ROUNDING_IS_ALREADY_SET); + int disp8N = 1; + if (rounding) { + if (rounding == EvexModifierRounding::T_SAE) { + verifySAE(base, type); LL = 0; + } else { + verifyER(base, type); LL = rounding - 1; + } + b = true; + } else { + if (v) VL = (std::max)(VL, v->getBit()); + VL = (std::max)((std::max)(reg.getBit(), base.getBit()), VL); + LL = (VL == 512) ? 2 : (VL == 256) ? 1 : 0; + if (b) { + disp8N = (type & T_B32) ? 4 : 8; + } else if (type & T_DUP) { + disp8N = VL == 128 ? 8 : VL == 256 ? 32 : 64; + } else { + if ((type & (T_NX_MASK | T_N_VL)) == 0) { + type |= T_N16 | T_N_VL; // default + } + int low = type & T_NX_MASK; + if (low > 0) { + disp8N = 1 << (low - 1); + if (type & T_N_VL) disp8N *= (VL == 512 ? 4 : VL == 256 ? 2 : 1); + } + } + } + bool Vp = !((v ? v->isExtIdx2() : 0) | Hi16Vidx); + bool z = reg.hasZero() || base.hasZero() || (v ? v->hasZero() : false); + if (aaa == 0) aaa = verifyDuplicate(base.getOpmaskIdx(), reg.getOpmaskIdx(), (v ? v->getOpmaskIdx() : 0), ERR_OPMASK_IS_ALREADY_SET); + db(0x62); + db((R ? 0x80 : 0) | (X ? 0x40 : 0) | (B ? 0x20 : 0) | (Rp ? 0x10 : 0) | (mm & 3)); + db((w == 1 ? 0x80 : 0) | ((vvvv & 15) << 3) | 4 | (pp & 3)); + db((z ? 0x80 : 0) | ((LL & 3) << 5) | (b ? 0x10 : 0) | (Vp ? 8 : 0) | (aaa & 7)); + db(code); + return disp8N; + } + void setModRM(int mod, int r1, int r2) + { + db(static_cast((mod << 6) | ((r1 & 7) << 3) | (r2 & 7))); + } + void setSIB(const RegExp& e, int reg, int disp8N = 0) + { + size_t disp64 = e.getDisp(); +#ifdef XBYAK64 + size_t high = disp64 >> 32; + if (high != 0 && high != 0xFFFFFFFF) throw Error(ERR_OFFSET_IS_TOO_BIG); +#endif + uint32 disp = static_cast(disp64); + const Reg& base = e.getBase(); + const Reg& index = e.getIndex(); + const int baseIdx = base.getIdx(); + const int baseBit = base.getBit(); + const int indexBit = index.getBit(); + enum { + mod00 = 0, mod01 = 1, mod10 = 2 + }; + int mod = mod10; // disp32 + if (!baseBit || ((baseIdx & 7) != Operand::EBP && disp == 0)) { + mod = mod00; + } else { + if (disp8N == 0) { + if (inner::IsInDisp8(disp)) { + mod = mod01; + } + } else { + // disp must be casted to signed + uint32 t = static_cast(static_cast(disp) / disp8N); + if ((disp % disp8N) == 0 && inner::IsInDisp8(t)) { + disp = t; + mod = mod01; + } + } + } + const int newBaseIdx = baseBit ? (baseIdx & 7) : Operand::EBP; + /* ModR/M = [2:3:3] = [Mod:reg/code:R/M] */ + bool hasSIB = indexBit || (baseIdx & 7) == Operand::ESP; +#ifdef XBYAK64 + if (!baseBit && !indexBit) hasSIB = true; +#endif + if (hasSIB) { + setModRM(mod, reg, Operand::ESP); + /* SIB = [2:3:3] = [SS:index:base(=rm)] */ + const int idx = indexBit ? (index.getIdx() & 7) : Operand::ESP; + const int scale = e.getScale(); + const int SS = (scale == 8) ? 3 : (scale == 4) ? 2 : (scale == 2) ? 1 : 0; + setModRM(SS, idx, newBaseIdx); + } else { + setModRM(mod, reg, newBaseIdx); + } + if (mod == mod01) { + db(disp); + } else if (mod == mod10 || (mod == mod00 && !baseBit)) { + dd(disp); + } + } + LabelManager labelMgr_; + bool isInDisp16(uint32 x) const { return 0xFFFF8000 <= x || x <= 0x7FFF; } + void opModR(const Reg& reg1, const Reg& reg2, int code0, int code1 = NONE, int code2 = NONE) + { + rex(reg2, reg1); + db(code0 | (reg1.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2); + setModRM(3, reg1.getIdx(), reg2.getIdx()); + } + void opModM(const Address& addr, const Reg& reg, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0) + { + if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); + rex(addr, reg); + db(code0 | (reg.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2); + opAddr(addr, reg.getIdx(), immSize); + } + void opMIB(const Address& addr, const Reg& reg, int code0, int code1) + { + if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); + if (addr.getMode() != Address::M_ModRM) throw Error(ERR_INVALID_MIB_ADDRESS); + if (BIT == 64 && addr.is32bit()) db(0x67); + const RegExp& regExp = addr.getRegExp(false); + uint8 rex = regExp.getRex(); + if (rex) db(rex); + db(code0); db(code1); + setSIB(regExp, reg.getIdx()); + } + void makeJmp(uint32 disp, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref) + { + const int shortJmpSize = 2; + const int longHeaderSize = longPref ? 2 : 1; + const int longJmpSize = longHeaderSize + 4; + if (type != T_NEAR && inner::IsInDisp8(disp - shortJmpSize)) { + db(shortCode); db(disp - shortJmpSize); + } else { + if (type == T_SHORT) throw Error(ERR_LABEL_IS_TOO_FAR); + if (longPref) db(longPref); + db(longCode); dd(disp - longJmpSize); + } + } + template + void opJmp(T& label, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref) + { + if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); /* avoid splitting code of jmp */ + size_t offset = 0; + if (labelMgr_.getOffset(&offset, label)) { /* label exists */ + makeJmp(inner::VerifyInInt32(offset - size_), type, shortCode, longCode, longPref); + } else { + int jmpSize = 0; + if (type == T_NEAR) { + jmpSize = 4; + if (longPref) db(longPref); + db(longCode); dd(0); + } else { + jmpSize = 1; + db(shortCode); db(0); + } + JmpLabel jmp(size_, jmpSize, inner::LasIs); + labelMgr_.addUndefinedLabel(label, jmp); + } + } + void opJmpAbs(const void *addr, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref = 0) + { + if (isAutoGrow()) { + if (type != T_NEAR) throw Error(ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW); + if (size_ + 16 >= maxSize_) growMemory(); + if (longPref) db(longPref); + db(longCode); + dd(0); + save(size_ - 4, size_t(addr) - size_, 4, inner::Labs); + } else { + makeJmp(inner::VerifyInInt32(reinterpret_cast(addr) - getCurr()), type, shortCode, longCode, longPref); + } + + } + // reg is reg field of ModRM + // immSize is the size for immediate value + // disp8N = 0(normal), disp8N = 1(force disp32), disp8N = {2, 4, 8} ; compressed displacement + void opAddr(const Address &addr, int reg, int immSize = 0, int disp8N = 0, bool permitVisb = false) + { + if (!permitVisb && addr.isVsib()) throw Error(ERR_BAD_VSIB_ADDRESSING); + if (addr.getMode() == Address::M_ModRM) { + setSIB(addr.getRegExp(), reg, disp8N); + } else if (addr.getMode() == Address::M_rip || addr.getMode() == Address::M_ripAddr) { + setModRM(0, reg, 5); + if (addr.getLabel()) { // [rip + Label] + putL_inner(*addr.getLabel(), true, addr.getDisp() - immSize); + } else { + size_t disp = addr.getDisp(); + if (addr.getMode() == Address::M_ripAddr) { + if (isAutoGrow()) throw Error(ERR_INVALID_RIP_IN_AUTO_GROW); + disp -= (size_t)getCurr() + 4 + immSize; + } + dd(inner::VerifyInInt32(disp)); + } + } + } + /* preCode is for SSSE3/SSE4 */ + void opGen(const Operand& reg, const Operand& op, int code, int pref, bool isValid(const Operand&, const Operand&), int imm8 = NONE, int preCode = NONE) + { + if (isValid && !isValid(reg, op)) throw Error(ERR_BAD_COMBINATION); + if (pref != NONE) db(pref); + if (op.isMEM()) { + opModM(op.getAddress(), reg.getReg(), 0x0F, preCode, code, (imm8 != NONE) ? 1 : 0); + } else { + opModR(reg.getReg(), op.getReg(), 0x0F, preCode, code); + } + if (imm8 != NONE) db(imm8); + } + void opMMX_IMM(const Mmx& mmx, int imm8, int code, int ext) + { + if (mmx.isXMM()) db(0x66); + opModR(Reg32(ext), mmx, 0x0F, code); + db(imm8); + } + void opMMX(const Mmx& mmx, const Operand& op, int code, int pref = 0x66, int imm8 = NONE, int preCode = NONE) + { + opGen(mmx, op, code, mmx.isXMM() ? pref : NONE, isXMMorMMX_MEM, imm8, preCode); + } + void opMovXMM(const Operand& op1, const Operand& op2, int code, int pref) + { + if (pref != NONE) db(pref); + if (op1.isXMM() && op2.isMEM()) { + opModM(op2.getAddress(), op1.getReg(), 0x0F, code); + } else if (op1.isMEM() && op2.isXMM()) { + opModM(op1.getAddress(), op2.getReg(), 0x0F, code | 1); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void opExt(const Operand& op, const Mmx& mmx, int code, int imm, bool hasMMX2 = false) + { + if (hasMMX2 && op.isREG(i32e)) { /* pextrw is special */ + if (mmx.isXMM()) db(0x66); + opModR(op.getReg(), mmx, 0x0F, 0xC5); db(imm); + } else { + opGen(mmx, op, code, 0x66, isXMM_REG32orMEM, imm, 0x3A); + } + } + void opR_ModM(const Operand& op, int bit, int ext, int code0, int code1 = NONE, int code2 = NONE, bool disableRex = false, int immSize = 0) + { + int opBit = op.getBit(); + if (disableRex && opBit == 64) opBit = 32; + if (op.isREG(bit)) { + opModR(Reg(ext, Operand::REG, opBit), op.getReg().changeBit(opBit), code0, code1, code2); + } else if (op.isMEM()) { + opModM(op.getAddress(), Reg(ext, Operand::REG, opBit), code0, code1, code2, immSize); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void opShift(const Operand& op, int imm, int ext) + { + verifyMemHasSize(op); + opR_ModM(op, 0, ext, (0xC0 | ((imm == 1 ? 1 : 0) << 4)), NONE, NONE, false, (imm != 1) ? 1 : 0); + if (imm != 1) db(imm); + } + void opShift(const Operand& op, const Reg8& _cl, int ext) + { + if (_cl.getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION); + opR_ModM(op, 0, ext, 0xD2); + } + void opModRM(const Operand& op1, const Operand& op2, bool condR, bool condM, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0) + { + if (condR) { + opModR(op1.getReg(), op2.getReg(), code0, code1, code2); + } else if (condM) { + opModM(op2.getAddress(), op1.getReg(), code0, code1, code2, immSize); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void opShxd(const Operand& op, const Reg& reg, uint8 imm, int code, const Reg8 *_cl = 0) + { + if (_cl && _cl->getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION); + opModRM(reg, op, (op.isREG(16 | i32e) && op.getBit() == reg.getBit()), op.isMEM() && (reg.isREG(16 | i32e)), 0x0F, code | (_cl ? 1 : 0), NONE, _cl ? 0 : 1); + if (!_cl) db(imm); + } + // (REG, REG|MEM), (MEM, REG) + void opRM_RM(const Operand& op1, const Operand& op2, int code) + { + if (op1.isREG() && op2.isMEM()) { + opModM(op2.getAddress(), op1.getReg(), code | 2); + } else { + opModRM(op2, op1, op1.isREG() && op1.getKind() == op2.getKind(), op1.isMEM() && op2.isREG(), code); + } + } + // (REG|MEM, IMM) + void opRM_I(const Operand& op, uint32 imm, int code, int ext) + { + verifyMemHasSize(op); + uint32 immBit = inner::IsInDisp8(imm) ? 8 : isInDisp16(imm) ? 16 : 32; + if (op.isBit(8)) immBit = 8; + if (op.getBit() < immBit) throw Error(ERR_IMM_IS_TOO_BIG); + if (op.isBit(32|64) && immBit == 16) immBit = 32; /* don't use MEM16 if 32/64bit mode */ + if (op.isREG() && op.getIdx() == 0 && (op.getBit() == immBit || (op.isBit(64) && immBit == 32))) { // rax, eax, ax, al + rex(op); + db(code | 4 | (immBit == 8 ? 0 : 1)); + } else { + int tmp = immBit < (std::min)(op.getBit(), 32U) ? 2 : 0; + opR_ModM(op, 0, ext, 0x80 | tmp, NONE, NONE, false, immBit / 8); + } + db(imm, immBit / 8); + } + void opIncDec(const Operand& op, int code, int ext) + { + verifyMemHasSize(op); +#ifndef XBYAK64 + if (op.isREG() && !op.isBit(8)) { + rex(op); db(code | op.getIdx()); + return; + } +#endif + code = 0xFE; + if (op.isREG()) { + opModR(Reg(ext, Operand::REG, op.getBit()), op.getReg(), code); + } else { + opModM(op.getAddress(), Reg(ext, Operand::REG, op.getBit()), code); + } + } + void opPushPop(const Operand& op, int code, int ext, int alt) + { + int bit = op.getBit(); + if (bit == 16 || bit == BIT) { + if (bit == 16) db(0x66); + if (op.isREG()) { + if (op.getReg().getIdx() >= 8) db(0x41); + db(alt | (op.getIdx() & 7)); + return; + } + if (op.isMEM()) { + opModM(op.getAddress(), Reg(ext, Operand::REG, 32), code); + return; + } + } + throw Error(ERR_BAD_COMBINATION); + } + void verifyMemHasSize(const Operand& op) const + { + if (op.isMEM() && op.getBit() == 0) throw Error(ERR_MEM_SIZE_IS_NOT_SPECIFIED); + } + /* + mov(r, imm) = db(imm, mov_imm(r, imm)) + */ + int mov_imm(const Reg& reg, size_t imm) + { + int bit = reg.getBit(); + const int idx = reg.getIdx(); + int code = 0xB0 | ((bit == 8 ? 0 : 1) << 3); + if (bit == 64 && (imm & ~size_t(0xffffffffu)) == 0) { + rex(Reg32(idx)); + bit = 32; + } else { + rex(reg); + if (bit == 64 && inner::IsInInt32(imm)) { + db(0xC7); + code = 0xC0; + bit = 32; + } + } + db(code | (idx & 7)); + return bit / 8; + } + template + void putL_inner(T& label, bool relative = false, size_t disp = 0) + { + const int jmpSize = relative ? 4 : (int)sizeof(size_t); + if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); + size_t offset = 0; + if (labelMgr_.getOffset(&offset, label)) { + if (relative) { + db(inner::VerifyInInt32(offset + disp - size_ - jmpSize), jmpSize); + } else if (isAutoGrow()) { + db(uint64(0), jmpSize); + save(size_ - jmpSize, offset, jmpSize, inner::LaddTop); + } else { + db(size_t(top_) + offset, jmpSize); + } + return; + } + db(uint64(0), jmpSize); + JmpLabel jmp(size_, jmpSize, (relative ? inner::LasIs : isAutoGrow() ? inner::LaddTop : inner::Labs), disp); + labelMgr_.addUndefinedLabel(label, jmp); + } + void opMovxx(const Reg& reg, const Operand& op, uint8 code) + { + if (op.isBit(32)) throw Error(ERR_BAD_COMBINATION); + int w = op.isBit(16); +#ifdef XBYAK64 + if (op.isHigh8bit()) throw Error(ERR_BAD_COMBINATION); +#endif + bool cond = reg.isREG() && (reg.getBit() > op.getBit()); + opModRM(reg, op, cond && op.isREG(), cond && op.isMEM(), 0x0F, code | w); + } + void opFpuMem(const Address& addr, uint8 m16, uint8 m32, uint8 m64, uint8 ext, uint8 m64ext) + { + if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); + uint8 code = addr.isBit(16) ? m16 : addr.isBit(32) ? m32 : addr.isBit(64) ? m64 : 0; + if (!code) throw Error(ERR_BAD_MEM_SIZE); + if (m64ext && addr.isBit(64)) ext = m64ext; + + rex(addr, st0); + db(code); + opAddr(addr, ext); + } + // use code1 if reg1 == st0 + // use code2 if reg1 != st0 && reg2 == st0 + void opFpuFpu(const Fpu& reg1, const Fpu& reg2, uint32 code1, uint32 code2) + { + uint32 code = reg1.getIdx() == 0 ? code1 : reg2.getIdx() == 0 ? code2 : 0; + if (!code) throw Error(ERR_BAD_ST_COMBINATION); + db(uint8(code >> 8)); + db(uint8(code | (reg1.getIdx() | reg2.getIdx()))); + } + void opFpu(const Fpu& reg, uint8 code1, uint8 code2) + { + db(code1); db(code2 | reg.getIdx()); + } + void opVex(const Reg& r, const Operand *p1, const Operand& op2, int type, int code, int imm8 = NONE) + { + if (op2.isMEM()) { + const Address& addr = op2.getAddress(); + const RegExp& regExp = addr.getRegExp(); + const Reg& base = regExp.getBase(); + const Reg& index = regExp.getIndex(); + if (BIT == 64 && addr.is32bit()) db(0x67); + int disp8N = 0; + bool x = index.isExtIdx(); + if ((type & (T_MUST_EVEX|T_MEM_EVEX)) || r.hasEvex() || (p1 && p1->hasEvex()) || addr.isBroadcast() || addr.getOpmaskIdx()) { + int aaa = addr.getOpmaskIdx(); + if (aaa && !(type & T_M_K)) throw Error(ERR_INVALID_OPMASK_WITH_MEMORY); + bool b = false; + if (addr.isBroadcast()) { + if (!(type & (T_B32 | T_B64))) throw Error(ERR_INVALID_BROADCAST); + b = true; + } + int VL = regExp.isVsib() ? index.getBit() : 0; + disp8N = evex(r, base, p1, type, code, x, b, aaa, VL, index.isExtIdx2()); + } else { + vex(r, base, p1, type, code, x); + } + opAddr(addr, r.getIdx(), (imm8 != NONE) ? 1 : 0, disp8N, (type & T_VSIB) != 0); + } else { + const Reg& base = op2.getReg(); + if ((type & T_MUST_EVEX) || r.hasEvex() || (p1 && p1->hasEvex()) || base.hasEvex()) { + evex(r, base, p1, type, code); + } else { + vex(r, base, p1, type, code); + } + setModRM(3, r.getIdx(), base.getIdx()); + } + if (imm8 != NONE) db(imm8); + } + // (r, r, r/m) if isR_R_RM + // (r, r/m, r) + void opGpr(const Reg32e& r, const Operand& op1, const Operand& op2, int type, uint8 code, bool isR_R_RM, int imm8 = NONE) + { + const Operand *p1 = &op1; + const Operand *p2 = &op2; + if (!isR_R_RM) std::swap(p1, p2); + const unsigned int bit = r.getBit(); + if (p1->getBit() != bit || (p2->isREG() && p2->getBit() != bit)) throw Error(ERR_BAD_COMBINATION); + type |= (bit == 64) ? T_W1 : T_W0; + opVex(r, p1, *p2, type, code, imm8); + } + void opAVX_X_X_XM(const Xmm& x1, const Operand& op1, const Operand& op2, int type, int code0, int imm8 = NONE) + { + const Xmm *x2 = static_cast(&op1); + const Operand *op = &op2; + if (op2.isNone()) { // (x1, op1) -> (x1, x1, op1) + x2 = &x1; + op = &op1; + } + // (x1, x2, op) + if (!((x1.isXMM() && x2->isXMM()) || ((type & T_YMM) && ((x1.isYMM() && x2->isYMM()) || (x1.isZMM() && x2->isZMM()))))) throw Error(ERR_BAD_COMBINATION); + opVex(x1, x2, *op, type, code0, imm8); + } + void opAVX_K_X_XM(const Opmask& k, const Xmm& x2, const Operand& op3, int type, int code0, int imm8 = NONE) + { + if (!op3.isMEM() && (x2.getKind() != op3.getKind())) throw Error(ERR_BAD_COMBINATION); + opVex(k, &x2, op3, type, code0, imm8); + } + // (x, x/m), (y, x/m256), (z, y/m) + void checkCvt1(const Operand& x, const Operand& op) const + { + if (!op.isMEM() && !(x.is(Operand::XMM | Operand::YMM) && op.isXMM()) && !(x.isZMM() && op.isYMM())) throw Error(ERR_BAD_COMBINATION); + } + // (x, x/m), (x, y/m256), (y, z/m) + void checkCvt2(const Xmm& x, const Operand& op) const + { + if (!(x.isXMM() && op.is(Operand::XMM | Operand::YMM | Operand::MEM)) && !(x.isYMM() && op.is(Operand::ZMM | Operand::MEM))) throw Error(ERR_BAD_COMBINATION); + } + void opCvt2(const Xmm& x, const Operand& op, int type, int code) + { + checkCvt2(x, op); + Operand::Kind kind = x.isXMM() ? (op.isBit(256) ? Operand::YMM : Operand::XMM) : Operand::ZMM; + opVex(x.copyAndSetKind(kind), &xm0, op, type, code); + } + void opCvt3(const Xmm& x1, const Xmm& x2, const Operand& op, int type, int type64, int type32, uint8 code) + { + if (!(x1.isXMM() && x2.isXMM() && (op.isREG(i32e) || op.isMEM()))) throw Error(ERR_BAD_SIZE_OF_REGISTER); + Xmm x(op.getIdx()); + const Operand *p = op.isREG() ? &x : &op; + opVex(x1, &x2, *p, type | (op.isBit(64) ? type64 : type32), code); + } + const Xmm& cvtIdx0(const Operand& x) const + { + return x.isZMM() ? zm0 : x.isYMM() ? ym0 : xm0; + } + // support (x, x/m, imm), (y, y/m, imm) + void opAVX_X_XM_IMM(const Xmm& x, const Operand& op, int type, int code, int imm8 = NONE) + { + opAVX_X_X_XM(x, cvtIdx0(x), op, type, code, imm8); + } + // QQQ:need to refactor + void opSp1(const Reg& reg, const Operand& op, uint8 pref, uint8 code0, uint8 code1) + { + if (reg.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); + bool is16bit = reg.isREG(16) && (op.isREG(16) || op.isMEM()); + if (!is16bit && !(reg.isREG(i32e) && (op.isREG(reg.getBit()) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); + if (is16bit) db(0x66); + db(pref); opModRM(reg.changeBit(i32e == 32 ? 32 : reg.getBit()), op, op.isREG(), true, code0, code1); + } + void opGather(const Xmm& x1, const Address& addr, const Xmm& x2, int type, uint8 code, int mode) + { + const RegExp& regExp = addr.getRegExp(); + if (!regExp.isVsib(128 | 256)) throw Error(ERR_BAD_VSIB_ADDRESSING); + const int y_vx_y = 0; + const int y_vy_y = 1; +// const int x_vy_x = 2; + const bool isAddrYMM = regExp.getIndex().getBit() == 256; + if (!x1.isXMM() || isAddrYMM || !x2.isXMM()) { + bool isOK = false; + if (mode == y_vx_y) { + isOK = x1.isYMM() && !isAddrYMM && x2.isYMM(); + } else if (mode == y_vy_y) { + isOK = x1.isYMM() && isAddrYMM && x2.isYMM(); + } else { // x_vy_x + isOK = !x1.isYMM() && isAddrYMM && !x2.isYMM(); + } + if (!isOK) throw Error(ERR_BAD_VSIB_ADDRESSING); + } + opAVX_X_X_XM(isAddrYMM ? Ymm(x1.getIdx()) : x1, isAddrYMM ? Ymm(x2.getIdx()) : x2, addr, type, code); + } + enum { + xx_yy_zz = 0, + xx_yx_zy = 1, + xx_xy_yz = 2 + }; + void checkGather2(const Xmm& x1, const Reg& x2, int mode) const + { + if (x1.isXMM() && x2.isXMM()) return; + switch (mode) { + case xx_yy_zz: if ((x1.isYMM() && x2.isYMM()) || (x1.isZMM() && x2.isZMM())) return; + break; + case xx_yx_zy: if ((x1.isYMM() && x2.isXMM()) || (x1.isZMM() && x2.isYMM())) return; + break; + case xx_xy_yz: if ((x1.isXMM() && x2.isYMM()) || (x1.isYMM() && x2.isZMM())) return; + break; + } + throw Error(ERR_BAD_VSIB_ADDRESSING); + } + void opGather2(const Xmm& x, const Address& addr, int type, uint8 code, int mode) + { + if (x.hasZero()) throw Error(ERR_INVALID_ZERO); + checkGather2(x, addr.getRegExp().getIndex(), mode); + opVex(x, 0, addr, type, code); + } + /* + xx_xy_yz ; mode = true + xx_xy_xz ; mode = false + */ + void opVmov(const Operand& op, const Xmm& x, int type, uint8 code, bool mode) + { + if (mode) { + if (!op.isMEM() && !((op.isXMM() && x.isXMM()) || (op.isXMM() && x.isYMM()) || (op.isYMM() && x.isZMM()))) throw Error(ERR_BAD_COMBINATION); + } else { + if (!op.isMEM() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); + } + opVex(x, 0, op, type, code); + } + void opGatherFetch(const Address& addr, const Xmm& x, int type, uint8 code, Operand::Kind kind) + { + if (addr.hasZero()) throw Error(ERR_INVALID_ZERO); + if (addr.getRegExp().getIndex().getKind() != kind) throw Error(ERR_BAD_VSIB_ADDRESSING); + opVex(x, 0, addr, type, code); + } +public: + unsigned int getVersion() const { return VERSION; } + using CodeArray::db; + const Mmx mm0, mm1, mm2, mm3, mm4, mm5, mm6, mm7; + const Xmm xmm0, xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7; + const Ymm ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7; + const Zmm zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + const Xmm &xm0, &xm1, &xm2, &xm3, &xm4, &xm5, &xm6, &xm7; + const Ymm &ym0, &ym1, &ym2, &ym3, &ym4, &ym5, &ym6, &ym7; + const Ymm &zm0, &zm1, &zm2, &zm3, &zm4, &zm5, &zm6, &zm7; + const Reg32 eax, ecx, edx, ebx, esp, ebp, esi, edi; + const Reg16 ax, cx, dx, bx, sp, bp, si, di; + const Reg8 al, cl, dl, bl, ah, ch, dh, bh; + const AddressFrame ptr, byte, word, dword, qword, xword, yword, zword; // xword is same as oword of NASM + const AddressFrame ptr_b, xword_b, yword_b, zword_b; // broadcast such as {1to2}, {1to4}, {1to8}, {1to16}, {b} + const Fpu st0, st1, st2, st3, st4, st5, st6, st7; + const Opmask k0, k1, k2, k3, k4, k5, k6, k7; + const BoundsReg bnd0, bnd1, bnd2, bnd3; + const EvexModifierRounding T_sae, T_rn_sae, T_rd_sae, T_ru_sae, T_rz_sae; // {sae}, {rn-sae}, {rd-sae}, {ru-sae}, {rz-sae} + const EvexModifierZero T_z; // {z} +#ifdef XBYAK64 + const Reg64 rax, rcx, rdx, rbx, rsp, rbp, rsi, rdi, r8, r9, r10, r11, r12, r13, r14, r15; + const Reg32 r8d, r9d, r10d, r11d, r12d, r13d, r14d, r15d; + const Reg16 r8w, r9w, r10w, r11w, r12w, r13w, r14w, r15w; + const Reg8 r8b, r9b, r10b, r11b, r12b, r13b, r14b, r15b; + const Reg8 spl, bpl, sil, dil; + const Xmm xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15; + const Xmm xmm16, xmm17, xmm18, xmm19, xmm20, xmm21, xmm22, xmm23; + const Xmm xmm24, xmm25, xmm26, xmm27, xmm28, xmm29, xmm30, xmm31; + const Ymm ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15; + const Ymm ymm16, ymm17, ymm18, ymm19, ymm20, ymm21, ymm22, ymm23; + const Ymm ymm24, ymm25, ymm26, ymm27, ymm28, ymm29, ymm30, ymm31; + const Zmm zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15; + const Zmm zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23; + const Zmm zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; + const Xmm &xm8, &xm9, &xm10, &xm11, &xm12, &xm13, &xm14, &xm15; // for my convenience + const Xmm &xm16, &xm17, &xm18, &xm19, &xm20, &xm21, &xm22, &xm23; + const Xmm &xm24, &xm25, &xm26, &xm27, &xm28, &xm29, &xm30, &xm31; + const Ymm &ym8, &ym9, &ym10, &ym11, &ym12, &ym13, &ym14, &ym15; + const Ymm &ym16, &ym17, &ym18, &ym19, &ym20, &ym21, &ym22, &ym23; + const Ymm &ym24, &ym25, &ym26, &ym27, &ym28, &ym29, &ym30, &ym31; + const Zmm &zm8, &zm9, &zm10, &zm11, &zm12, &zm13, &zm14, &zm15; + const Zmm &zm16, &zm17, &zm18, &zm19, &zm20, &zm21, &zm22, &zm23; + const Zmm &zm24, &zm25, &zm26, &zm27, &zm28, &zm29, &zm30, &zm31; + const RegRip rip; +#endif +#ifndef XBYAK_DISABLE_SEGMENT + const Segment es, cs, ss, ds, fs, gs; +#endif + void L(const std::string& label) { labelMgr_.defineSlabel(label); } + void L(Label& label) { labelMgr_.defineClabel(label); } + Label L() { Label label; L(label); return label; } + void inLocalLabel() { labelMgr_.enterLocal(); } + void outLocalLabel() { labelMgr_.leaveLocal(); } + /* + assign src to dst + require + dst : does not used by L() + src : used by L() + */ + void assignL(Label& dst, const Label& src) { labelMgr_.assign(dst, src); } + /* + put address of label to buffer + @note the put size is 4(32-bit), 8(64-bit) + */ + void putL(std::string label) { putL_inner(label); } + void putL(const Label& label) { putL_inner(label); } + + void jmp(const Operand& op) { opR_ModM(op, BIT, 4, 0xFF, NONE, NONE, true); } + void jmp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); } + void jmp(const char *label, LabelType type = T_AUTO) { jmp(std::string(label), type); } + void jmp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); } + void jmp(const void *addr, LabelType type = T_AUTO) { opJmpAbs(addr, type, 0xEB, 0xE9); } + + void call(const Operand& op) { opR_ModM(op, 16 | i32e, 2, 0xFF, NONE, NONE, true); } + // call(string label), not const std::string& + void call(std::string label) { opJmp(label, T_NEAR, 0, 0xE8, 0); } + void call(const char *label) { call(std::string(label)); } + void call(const Label& label) { opJmp(label, T_NEAR, 0, 0xE8, 0); } + // call(function pointer) +#ifdef XBYAK_VARIADIC_TEMPLATE + template + void call(Ret(*func)(Params...)) { call(reinterpret_cast(func)); } +#endif + void call(const void *addr) { opJmpAbs(addr, T_NEAR, 0, 0xE8); } + + void test(const Operand& op, const Reg& reg) + { + opModRM(reg, op, op.isREG() && (op.getKind() == reg.getKind()), op.isMEM(), 0x84); + } + void test(const Operand& op, uint32 imm) + { + verifyMemHasSize(op); + int immSize = (std::min)(op.getBit() / 8, 4U); + if (op.isREG() && op.getIdx() == 0) { // al, ax, eax + rex(op); + db(0xA8 | (op.isBit(8) ? 0 : 1)); + } else { + opR_ModM(op, 0, 0, 0xF6, NONE, NONE, false, immSize); + } + db(imm, immSize); + } + void imul(const Reg& reg, const Operand& op) + { + opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x0F, 0xAF); + } + void imul(const Reg& reg, const Operand& op, int imm) + { + int s = inner::IsInDisp8(imm) ? 1 : 0; + int immSize = s ? 1 : reg.isREG(16) ? 2 : 4; + opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x69 | (s << 1), NONE, NONE, immSize); + db(imm, immSize); + } + void push(const Operand& op) { opPushPop(op, 0xFF, 6, 0x50); } + void pop(const Operand& op) { opPushPop(op, 0x8F, 0, 0x58); } + void push(const AddressFrame& af, uint32 imm) + { + if (af.bit_ == 8 && inner::IsInDisp8(imm)) { + db(0x6A); db(imm); + } else if (af.bit_ == 16 && isInDisp16(imm)) { + db(0x66); db(0x68); dw(imm); + } else { + db(0x68); dd(imm); + } + } + /* use "push(word, 4)" if you want "push word 4" */ + void push(uint32 imm) + { + if (inner::IsInDisp8(imm)) { + push(byte, imm); + } else { + push(dword, imm); + } + } + void mov(const Operand& reg1, const Operand& reg2) + { + const Reg *reg = 0; + const Address *addr = 0; + uint8 code = 0; + if (reg1.isREG() && reg1.getIdx() == 0 && reg2.isMEM()) { // mov eax|ax|al, [disp] + reg = ®1.getReg(); + addr= ®2.getAddress(); + code = 0xA0; + } else + if (reg1.isMEM() && reg2.isREG() && reg2.getIdx() == 0) { // mov [disp], eax|ax|al + reg = ®2.getReg(); + addr= ®1.getAddress(); + code = 0xA2; + } +#ifdef XBYAK64 + if (addr && addr->is64bitDisp()) { + if (code) { + rex(*reg); + db(reg1.isREG(8) ? 0xA0 : reg1.isREG() ? 0xA1 : reg2.isREG(8) ? 0xA2 : 0xA3); + db(addr->getDisp(), 8); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } else +#else + if (code && addr->isOnlyDisp()) { + rex(*reg, *addr); + db(code | (reg->isBit(8) ? 0 : 1)); + dd(static_cast(addr->getDisp())); + } else +#endif + { + opRM_RM(reg1, reg2, 0x88); + } + } + void mov(const Operand& op, size_t imm) + { + if (op.isREG()) { + const int size = mov_imm(op.getReg(), imm); + db(imm, size); + } else if (op.isMEM()) { + verifyMemHasSize(op); + int immSize = op.getBit() / 8; + if (immSize <= 4) { + sint64 s = sint64(imm) >> (immSize * 8); + if (s != 0 && s != -1) throw Error(ERR_IMM_IS_TOO_BIG); + } else { + if (!inner::IsInInt32(imm)) throw Error(ERR_IMM_IS_TOO_BIG); + immSize = 4; + } + opModM(op.getAddress(), Reg(0, Operand::REG, op.getBit()), 0xC6, NONE, NONE, immSize); + db(static_cast(imm), immSize); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void mov(const NativeReg& reg, const char *label) // can't use std::string + { + if (label == 0) { + mov(static_cast(reg), 0); // call imm + return; + } + mov_imm(reg, dummyAddr); + putL(label); + } + void mov(const NativeReg& reg, const Label& label) + { + mov_imm(reg, dummyAddr); + putL(label); + } + void xchg(const Operand& op1, const Operand& op2) + { + const Operand *p1 = &op1, *p2 = &op2; + if (p1->isMEM() || (p2->isREG(16 | i32e) && p2->getIdx() == 0)) { + p1 = &op2; p2 = &op1; + } + if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION); + if (p2->isREG() && (p1->isREG(16 | i32e) && p1->getIdx() == 0) +#ifdef XBYAK64 + && (p2->getIdx() != 0 || !p1->isREG(32)) +#endif + ) { + rex(*p2, *p1); db(0x90 | (p2->getIdx() & 7)); + return; + } + opModRM(*p1, *p2, (p1->isREG() && p2->isREG() && (p1->getBit() == p2->getBit())), p2->isMEM(), 0x86 | (p1->isBit(8) ? 0 : 1)); + } + +#ifndef XBYAK_DISABLE_SEGMENT + void push(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x06); break; + case Segment::cs: db(0x0E); break; + case Segment::ss: db(0x16); break; + case Segment::ds: db(0x1E); break; + case Segment::fs: db(0x0F); db(0xA0); break; + case Segment::gs: db(0x0F); db(0xA8); break; + default: + assert(0); + } + } + void pop(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x07); break; + case Segment::cs: throw Error(ERR_BAD_COMBINATION); + case Segment::ss: db(0x17); break; + case Segment::ds: db(0x1F); break; + case Segment::fs: db(0x0F); db(0xA1); break; + case Segment::gs: db(0x0F); db(0xA9); break; + default: + assert(0); + } + } + void putSeg(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x2E); break; + case Segment::cs: db(0x36); break; + case Segment::ss: db(0x3E); break; + case Segment::ds: db(0x26); break; + case Segment::fs: db(0x64); break; + case Segment::gs: db(0x65); break; + default: + assert(0); + } + } + void mov(const Operand& op, const Segment& seg) + { + opModRM(Reg8(seg.getIdx()), op, op.isREG(16|i32e), op.isMEM(), 0x8C); + } + void mov(const Segment& seg, const Operand& op) + { + opModRM(Reg8(seg.getIdx()), op.isREG(16|i32e) ? static_cast(op.getReg().cvt32()) : op, op.isREG(16|i32e), op.isMEM(), 0x8E); + } +#endif + + enum { NONE = 256 }; + // constructor + CodeGenerator(size_t maxSize = DEFAULT_MAX_CODE_SIZE, void *userPtr = 0, Allocator *allocator = 0) + : CodeArray(maxSize, userPtr, allocator) + , mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7) + , xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7) + , ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7) + , zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7) + // for my convenience + , xm0(xmm0), xm1(xmm1), xm2(xmm2), xm3(xmm3), xm4(xmm4), xm5(xmm5), xm6(xmm6), xm7(xmm7) + , ym0(ymm0), ym1(ymm1), ym2(ymm2), ym3(ymm3), ym4(ymm4), ym5(ymm5), ym6(ymm6), ym7(ymm7) + , zm0(zmm0), zm1(zmm1), zm2(zmm2), zm3(zmm3), zm4(zmm4), zm5(zmm5), zm6(zmm6), zm7(zmm7) + + , eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI) + , ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI) + , al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH) + , ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512) + , ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true) + , st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7) + , k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7) + , bnd0(0), bnd1(1), bnd2(2), bnd3(3) + , T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE) + , T_z() +#ifdef XBYAK64 + , rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15) + , r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15) + , r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15) + , r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15) + , spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true) + , xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15) + , xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23) + , xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31) + , ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15) + , ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23) + , ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31) + , zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15) + , zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23) + , zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31) + // for my convenience + , xm8(xmm8), xm9(xmm9), xm10(xmm10), xm11(xmm11), xm12(xmm12), xm13(xmm13), xm14(xmm14), xm15(xmm15) + , xm16(xmm16), xm17(xmm17), xm18(xmm18), xm19(xmm19), xm20(xmm20), xm21(xmm21), xm22(xmm22), xm23(xmm23) + , xm24(xmm24), xm25(xmm25), xm26(xmm26), xm27(xmm27), xm28(xmm28), xm29(xmm29), xm30(xmm30), xm31(xmm31) + , ym8(ymm8), ym9(ymm9), ym10(ymm10), ym11(ymm11), ym12(ymm12), ym13(ymm13), ym14(ymm14), ym15(ymm15) + , ym16(ymm16), ym17(ymm17), ym18(ymm18), ym19(ymm19), ym20(ymm20), ym21(ymm21), ym22(ymm22), ym23(ymm23) + , ym24(ymm24), ym25(ymm25), ym26(ymm26), ym27(ymm27), ym28(ymm28), ym29(ymm29), ym30(ymm30), ym31(ymm31) + , zm8(zmm8), zm9(zmm9), zm10(zmm10), zm11(zmm11), zm12(zmm12), zm13(zmm13), zm14(zmm14), zm15(zmm15) + , zm16(zmm16), zm17(zmm17), zm18(zmm18), zm19(zmm19), zm20(zmm20), zm21(zmm21), zm22(zmm22), zm23(zmm23) + , zm24(zmm24), zm25(zmm25), zm26(zmm26), zm27(zmm27), zm28(zmm28), zm29(zmm29), zm30(zmm30), zm31(zmm31) + , rip() +#endif +#ifndef XBYAK_DISABLE_SEGMENT + , es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs) +#endif + { + labelMgr_.set(this); + } + void reset() + { + resetSize(); + labelMgr_.reset(); + labelMgr_.set(this); + } + bool hasUndefinedLabel() const { return labelMgr_.hasUndefSlabel() || labelMgr_.hasUndefClabel(); } + /* + MUST call ready() to complete generating code if you use AutoGrow mode. + It is not necessary for the other mode if hasUndefinedLabel() is true. + */ + void ready(ProtectMode mode = PROTECT_RWE) + { + if (hasUndefinedLabel()) throw Error(ERR_LABEL_IS_NOT_FOUND); + if (isAutoGrow()) { + calcJmpAddress(); + if (useProtect()) setProtectMode(mode); + } + } + // set read/exec + void readyRE() { return ready(PROTECT_RE); } +#ifdef XBYAK_TEST + void dump(bool doClear = true) + { + CodeArray::dump(); + if (doClear) size_ = 0; + } +#endif + +#ifdef XBYAK_UNDEF_JNL + #undef jnl +#endif + + /* + use single byte nop if useMultiByteNop = false + */ + void nop(size_t size = 1, bool useMultiByteNop = true) + { + if (!useMultiByteNop) { + for (size_t i = 0; i < size; i++) { + db(0x90); + } + return; + } + /* + Intel Architectures Software Developer's Manual Volume 2 + recommended multi-byte sequence of NOP instruction + AMD and Intel seem to agree on the same sequences for up to 9 bytes: + https://support.amd.com/TechDocs/55723_SOG_Fam_17h_Processors_3.00.pdf + */ + static const uint8 nopTbl[9][9] = { + {0x90}, + {0x66, 0x90}, + {0x0F, 0x1F, 0x00}, + {0x0F, 0x1F, 0x40, 0x00}, + {0x0F, 0x1F, 0x44, 0x00, 0x00}, + {0x66, 0x0F, 0x1F, 0x44, 0x00, 0x00}, + {0x0F, 0x1F, 0x80, 0x00, 0x00, 0x00, 0x00}, + {0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x66, 0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00}, + }; + const size_t n = sizeof(nopTbl) / sizeof(nopTbl[0]); + while (size > 0) { + size_t len = (std::min)(n, size); + const uint8 *seq = nopTbl[len - 1]; + db(seq, len); + size -= len; + } + } + +#ifndef XBYAK_DONT_READ_LIST +#include "xbyak_mnemonic.h" + /* + use single byte nop if useMultiByteNop = false + */ + void align(size_t x = 16, bool useMultiByteNop = true) + { + if (x == 1) return; + if (x < 1 || (x & (x - 1))) throw Error(ERR_BAD_ALIGN); + if (isAutoGrow() && x > inner::ALIGN_PAGE_SIZE) fprintf(stderr, "warning:autoGrow mode does not support %d align\n", (int)x); + size_t remain = size_t(getCurr()) % x; + if (remain) { + nop(x - remain, useMultiByteNop); + } + } +#endif +}; + +namespace util { +static const Mmx mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7); +static const Xmm xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7); +static const Ymm ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7); +static const Zmm zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7); +static const Reg32 eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI); +static const Reg16 ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI); +static const Reg8 al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH); +static const AddressFrame ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512); +static const AddressFrame ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true); +static const Fpu st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7); +static const Opmask k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7); +static const BoundsReg bnd0(0), bnd1(1), bnd2(2), bnd3(3); +static const EvexModifierRounding T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE); +static const EvexModifierZero T_z; +#ifdef XBYAK64 +static const Reg64 rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15); +static const Reg32 r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15); +static const Reg16 r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15); +static const Reg8 r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15), spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true); +static const Xmm xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15); +static const Xmm xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23); +static const Xmm xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31); +static const Ymm ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15); +static const Ymm ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23); +static const Ymm ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31); +static const Zmm zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15); +static const Zmm zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23); +static const Zmm zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31); +static const RegRip rip; +#endif +#ifndef XBYAK_DISABLE_SEGMENT +static const Segment es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs); +#endif +} // util + +#ifdef _MSC_VER + #pragma warning(pop) +#endif + +} // end of namespace + +#endif // XBYAK_XBYAK_H_ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h new file mode 100644 index 000000000000..a22e5224c3cb --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h @@ -0,0 +1,303 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +enum { + B00000000= 0, + B00000001= 1, + B00000010= 2, + B00000011= 3, + B00000100= 4, + B00000101= 5, + B00000110= 6, + B00000111= 7, + B00001000= 8, + B00001001= 9, + B00001010= 10, + B00001011= 11, + B00001100= 12, + B00001101= 13, + B00001110= 14, + B00001111= 15, + B00010000= 16, + B00010001= 17, + B00010010= 18, + B00010011= 19, + B00010100= 20, + B00010101= 21, + B00010110= 22, + B00010111= 23, + B00011000= 24, + B00011001= 25, + B00011010= 26, + B00011011= 27, + B00011100= 28, + B00011101= 29, + B00011110= 30, + B00011111= 31, + B00100000= 32, + B00100001= 33, + B00100010= 34, + B00100011= 35, + B00100100= 36, + B00100101= 37, + B00100110= 38, + B00100111= 39, + B00101000= 40, + B00101001= 41, + B00101010= 42, + B00101011= 43, + B00101100= 44, + B00101101= 45, + B00101110= 46, + B00101111= 47, + B00110000= 48, + B00110001= 49, + B00110010= 50, + B00110011= 51, + B00110100= 52, + B00110101= 53, + B00110110= 54, + B00110111= 55, + B00111000= 56, + B00111001= 57, + B00111010= 58, + B00111011= 59, + B00111100= 60, + B00111101= 61, + B00111110= 62, + B00111111= 63, + B01000000= 64, + B01000001= 65, + B01000010= 66, + B01000011= 67, + B01000100= 68, + B01000101= 69, + B01000110= 70, + B01000111= 71, + B01001000= 72, + B01001001= 73, + B01001010= 74, + B01001011= 75, + B01001100= 76, + B01001101= 77, + B01001110= 78, + B01001111= 79, + B01010000= 80, + B01010001= 81, + B01010010= 82, + B01010011= 83, + B01010100= 84, + B01010101= 85, + B01010110= 86, + B01010111= 87, + B01011000= 88, + B01011001= 89, + B01011010= 90, + B01011011= 91, + B01011100= 92, + B01011101= 93, + B01011110= 94, + B01011111= 95, + B01100000= 96, + B01100001= 97, + B01100010= 98, + B01100011= 99, + B01100100= 100, + B01100101= 101, + B01100110= 102, + B01100111= 103, + B01101000= 104, + B01101001= 105, + B01101010= 106, + B01101011= 107, + B01101100= 108, + B01101101= 109, + B01101110= 110, + B01101111= 111, + B01110000= 112, + B01110001= 113, + B01110010= 114, + B01110011= 115, + B01110100= 116, + B01110101= 117, + B01110110= 118, + B01110111= 119, + B01111000= 120, + B01111001= 121, + B01111010= 122, + B01111011= 123, + B01111100= 124, + B01111101= 125, + B01111110= 126, + B01111111= 127, + B10000000= 128, + B10000001= 129, + B10000010= 130, + B10000011= 131, + B10000100= 132, + B10000101= 133, + B10000110= 134, + B10000111= 135, + B10001000= 136, + B10001001= 137, + B10001010= 138, + B10001011= 139, + B10001100= 140, + B10001101= 141, + B10001110= 142, + B10001111= 143, + B10010000= 144, + B10010001= 145, + B10010010= 146, + B10010011= 147, + B10010100= 148, + B10010101= 149, + B10010110= 150, + B10010111= 151, + B10011000= 152, + B10011001= 153, + B10011010= 154, + B10011011= 155, + B10011100= 156, + B10011101= 157, + B10011110= 158, + B10011111= 159, + B10100000= 160, + B10100001= 161, + B10100010= 162, + B10100011= 163, + B10100100= 164, + B10100101= 165, + B10100110= 166, + B10100111= 167, + B10101000= 168, + B10101001= 169, + B10101010= 170, + B10101011= 171, + B10101100= 172, + B10101101= 173, + B10101110= 174, + B10101111= 175, + B10110000= 176, + B10110001= 177, + B10110010= 178, + B10110011= 179, + B10110100= 180, + B10110101= 181, + B10110110= 182, + B10110111= 183, + B10111000= 184, + B10111001= 185, + B10111010= 186, + B10111011= 187, + B10111100= 188, + B10111101= 189, + B10111110= 190, + B10111111= 191, + B11000000= 192, + B11000001= 193, + B11000010= 194, + B11000011= 195, + B11000100= 196, + B11000101= 197, + B11000110= 198, + B11000111= 199, + B11001000= 200, + B11001001= 201, + B11001010= 202, + B11001011= 203, + B11001100= 204, + B11001101= 205, + B11001110= 206, + B11001111= 207, + B11010000= 208, + B11010001= 209, + B11010010= 210, + B11010011= 211, + B11010100= 212, + B11010101= 213, + B11010110= 214, + B11010111= 215, + B11011000= 216, + B11011001= 217, + B11011010= 218, + B11011011= 219, + B11011100= 220, + B11011101= 221, + B11011110= 222, + B11011111= 223, + B11100000= 224, + B11100001= 225, + B11100010= 226, + B11100011= 227, + B11100100= 228, + B11100101= 229, + B11100110= 230, + B11100111= 231, + B11101000= 232, + B11101001= 233, + B11101010= 234, + B11101011= 235, + B11101100= 236, + B11101101= 237, + B11101110= 238, + B11101111= 239, + B11110000= 240, + B11110001= 241, + B11110010= 242, + B11110011= 243, + B11110100= 244, + B11110101= 245, + B11110110= 246, + B11110111= 247, + B11111000= 248, + B11111001= 249, + B11111010= 250, + B11111011= 251, + B11111100= 252, + B11111101= 253, + B11111110= 254, + B11111111= 255 +}; diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h new file mode 100644 index 000000000000..28d2d222f9f9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h @@ -0,0 +1,2017 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +const char *getVersionString() const { return "5.76"; } +void adc(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x10, 2); } +void adc(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x10); } +void adcx(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0x66, isREG32_REG32orMEM, NONE, 0x38); } +void add(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x00, 0); } +void add(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x00); } +void addpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x66, isXMM_XMMorMEM); } +void addps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x100, isXMM_XMMorMEM); } +void addsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF2, isXMM_XMMorMEM); } +void addss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF3, isXMM_XMMorMEM); } +void addsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0x66, isXMM_XMMorMEM); } +void addsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0xF2, isXMM_XMMorMEM); } +void adox(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0xF3, isREG32_REG32orMEM, NONE, 0x38); } +void aesdec(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDE, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesdeclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesenc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDC, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesenclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDD, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesimc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDB, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aeskeygenassist(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void and_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x20, 4); } +void and_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x20); } +void andn(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_0F38, 0xf2, true); } +void andnpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x66, isXMM_XMMorMEM); } +void andnps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x100, isXMM_XMMorMEM); } +void andpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x66, isXMM_XMMorMEM); } +void andps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x100, isXMM_XMMorMEM); } +void bextr(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf7, false); } +void blendpd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0D, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void blendps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0C, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void blendvpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void blendvps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void blsi(const Reg32e& r, const Operand& op) { opGpr(Reg32e(3, r.getBit()), op, r, T_0F38, 0xf3, false); } +void blsmsk(const Reg32e& r, const Operand& op) { opGpr(Reg32e(2, r.getBit()), op, r, T_0F38, 0xf3, false); } +void blsr(const Reg32e& r, const Operand& op) { opGpr(Reg32e(1, r.getBit()), op, r, T_0F38, 0xf3, false); } +void bnd() { db(0xF2); } +void bndcl(const BoundsReg& bnd, const Operand& op) { db(0xF3); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); } +void bndcn(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1B, NONE, !op.isMEM()); } +void bndcu(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); } +void bndldx(const BoundsReg& bnd, const Address& addr) { opMIB(addr, bnd, 0x0F, 0x1A); } +void bndmk(const BoundsReg& bnd, const Address& addr) { db(0xF3); opModM(addr, bnd, 0x0F, 0x1B); } +void bndmov(const Address& addr, const BoundsReg& bnd) { db(0x66); opModM(addr, bnd, 0x0F, 0x1B); } +void bndmov(const BoundsReg& bnd, const Operand& op) { db(0x66); opModRM(bnd, op, op.isBNDREG(), op.isMEM(), 0x0F, 0x1A); } +void bndstx(const Address& addr, const BoundsReg& bnd) { opMIB(addr, bnd, 0x0F, 0x1B); } +void bsf(const Reg®, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBC); } +void bsr(const Reg®, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBD); } +void bswap(const Reg32e& reg) { opModR(Reg32(1), reg, 0x0F); } +void bt(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xA3); } +void bt(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 4, 0x0f, 0xba, NONE, false, 1); db(imm); } +void btc(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xBB); } +void btc(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 7, 0x0f, 0xba, NONE, false, 1); db(imm); } +void btr(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xB3); } +void btr(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 6, 0x0f, 0xba, NONE, false, 1); db(imm); } +void bts(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xAB); } +void bts(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 5, 0x0f, 0xba, NONE, false, 1); db(imm); } +void bzhi(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf5, false); } +void cbw() { db(0x66); db(0x98); } +void cdq() { db(0x99); } +void clc() { db(0xF8); } +void cld() { db(0xFC); } +void clflush(const Address& addr) { opModM(addr, Reg32(7), 0x0F, 0xAE); } +void cli() { db(0xFA); } +void cmc() { db(0xF5); } +void cmova(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524 +void cmovae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 +void cmovb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 +void cmovbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524 +void cmovc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 +void cmove(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524 +void cmovg(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524 +void cmovge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524 +void cmovl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524 +void cmovle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524 +void cmovna(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524 +void cmovnae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 +void cmovnb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 +void cmovnbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524 +void cmovnc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 +void cmovne(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524 +void cmovng(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524 +void cmovnge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524 +void cmovnl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524 +void cmovnle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524 +void cmovno(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 1); }//-V524 +void cmovnp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524 +void cmovns(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 9); }//-V524 +void cmovnz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524 +void cmovo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 0); }//-V524 +void cmovp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524 +void cmovpe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524 +void cmovpo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524 +void cmovs(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 8); }//-V524 +void cmovz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524 +void cmp(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x38, 7); } +void cmp(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x38); } +void cmpeqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 0); } +void cmpeqps(const Xmm& x, const Operand& op) { cmpps(x, op, 0); } +void cmpeqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 0); } +void cmpeqss(const Xmm& x, const Operand& op) { cmpss(x, op, 0); } +void cmplepd(const Xmm& x, const Operand& op) { cmppd(x, op, 2); } +void cmpleps(const Xmm& x, const Operand& op) { cmpps(x, op, 2); } +void cmplesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 2); } +void cmpless(const Xmm& x, const Operand& op) { cmpss(x, op, 2); } +void cmpltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 1); } +void cmpltps(const Xmm& x, const Operand& op) { cmpps(x, op, 1); } +void cmpltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 1); } +void cmpltss(const Xmm& x, const Operand& op) { cmpss(x, op, 1); } +void cmpneqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 4); } +void cmpneqps(const Xmm& x, const Operand& op) { cmpps(x, op, 4); } +void cmpneqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 4); } +void cmpneqss(const Xmm& x, const Operand& op) { cmpss(x, op, 4); } +void cmpnlepd(const Xmm& x, const Operand& op) { cmppd(x, op, 6); } +void cmpnleps(const Xmm& x, const Operand& op) { cmpps(x, op, 6); } +void cmpnlesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 6); } +void cmpnless(const Xmm& x, const Operand& op) { cmpss(x, op, 6); } +void cmpnltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 5); } +void cmpnltps(const Xmm& x, const Operand& op) { cmpps(x, op, 5); } +void cmpnltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 5); } +void cmpnltss(const Xmm& x, const Operand& op) { cmpss(x, op, 5); } +void cmpordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 7); } +void cmpordps(const Xmm& x, const Operand& op) { cmpps(x, op, 7); } +void cmpordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 7); } +void cmpordss(const Xmm& x, const Operand& op) { cmpss(x, op, 7); } +void cmppd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x66, isXMM_XMMorMEM, imm8); } +void cmpps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x100, isXMM_XMMorMEM, imm8); } +void cmpsb() { db(0xA6); } +void cmpsd() { db(0xA7); } +void cmpsd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF2, isXMM_XMMorMEM, imm8); } +void cmpss(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF3, isXMM_XMMorMEM, imm8); } +void cmpsw() { db(0x66); db(0xA7); } +void cmpunordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 3); } +void cmpunordps(const Xmm& x, const Operand& op) { cmpps(x, op, 3); } +void cmpunordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 3); } +void cmpunordss(const Xmm& x, const Operand& op) { cmpss(x, op, 3); } +void cmpxchg(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xB0 | (reg.isBit(8) ? 0 : 1)); } +void cmpxchg8b(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0xC7); } +void comisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x66, isXMM_XMMorMEM); } +void comiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x100, isXMM_XMMorMEM); } +void cpuid() { db(0x0F); db(0xA2); } +void crc32(const Reg32e& reg, const Operand& op) { if (reg.isBit(32) && op.isBit(16)) db(0x66); db(0xF2); opModRM(reg, op, op.isREG(), op.isMEM(), 0x0F, 0x38, 0xF0 | (op.isBit(8) ? 0 : 1)); } +void cvtdq2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF3, isXMM_XMMorMEM); } +void cvtdq2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x100, isXMM_XMMorMEM); } +void cvtpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF2, isXMM_XMMorMEM); } +void cvtpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x66, isMMX_XMMorMEM); } +void cvtpd2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x66, isXMM_XMMorMEM); } +void cvtpi2pd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x66, isXMM_MMXorMEM); } +void cvtpi2ps(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x100, isXMM_MMXorMEM); } +void cvtps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x66, isXMM_XMMorMEM); } +void cvtps2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x100, isXMM_XMMorMEM); } +void cvtps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x100, isMMX_XMMorMEM); } +void cvtsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF2, isREG32_XMMorMEM); } +void cvtsd2ss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF2, isXMM_XMMorMEM); } +void cvtsi2sd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF2, isXMM_REG32orMEM); } +void cvtsi2ss(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF3, isXMM_REG32orMEM); } +void cvtss2sd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF3, isXMM_XMMorMEM); } +void cvtss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF3, isREG32_XMMorMEM); } +void cvttpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0x66, isXMM_XMMorMEM); } +void cvttpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x66, isMMX_XMMorMEM); } +void cvttps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0xF3, isXMM_XMMorMEM); } +void cvttps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x100, isMMX_XMMorMEM); } +void cvttsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF2, isREG32_XMMorMEM); } +void cvttss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF3, isREG32_XMMorMEM); } +void cwd() { db(0x66); db(0x99); } +void cwde() { db(0x98); } +void dec(const Operand& op) { opIncDec(op, 0x48, 1); } +void div(const Operand& op) { opR_ModM(op, 0, 6, 0xF6); } +void divpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x66, isXMM_XMMorMEM); } +void divps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x100, isXMM_XMMorMEM); } +void divsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF2, isXMM_XMMorMEM); } +void divss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF3, isXMM_XMMorMEM); } +void dppd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void dpps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void emms() { db(0x0F); db(0x77); } +void extractps(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x17, imm); } +void f2xm1() { db(0xD9); db(0xF0); } +void fabs() { db(0xD9); db(0xE1); } +void fadd(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 0, 0); } +void fadd(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C0, 0xDCC0); } +void fadd(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C0, 0xDCC0); } +void faddp() { db(0xDE); db(0xC1); } +void faddp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC0); } +void faddp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC0); } +void fchs() { db(0xD9); db(0xE0); } +void fcmovb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC0, 0x00C0); } +void fcmovb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC0, 0x00C0); } +void fcmovbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD0, 0x00D0); } +void fcmovbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD0, 0x00D0); } +void fcmove(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC8, 0x00C8); } +void fcmove(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC8, 0x00C8); } +void fcmovnb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC0, 0x00C0); } +void fcmovnb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC0, 0x00C0); } +void fcmovnbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD0, 0x00D0); } +void fcmovnbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD0, 0x00D0); } +void fcmovne(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC8, 0x00C8); } +void fcmovne(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC8, 0x00C8); } +void fcmovnu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD8, 0x00D8); } +void fcmovnu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD8, 0x00D8); } +void fcmovu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD8, 0x00D8); } +void fcmovu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD8, 0x00D8); } +void fcom() { db(0xD8); db(0xD1); } +void fcom(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 2, 0); } +void fcom(const Fpu& reg) { opFpu(reg, 0xD8, 0xD0); } +void fcomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBF0, 0x00F0); } +void fcomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBF0, 0x00F0); } +void fcomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFF0, 0x00F0); } +void fcomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFF0, 0x00F0); } +void fcomp() { db(0xD8); db(0xD9); } +void fcomp(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 3, 0); } +void fcomp(const Fpu& reg) { opFpu(reg, 0xD8, 0xD8); } +void fcompp() { db(0xDE); db(0xD9); } +void fcos() { db(0xD9); db(0xFF); } +void fdecstp() { db(0xD9); db(0xF6); } +void fdiv(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 6, 0); } +void fdiv(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F0, 0xDCF8); } +void fdiv(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F0, 0xDCF8); } +void fdivp() { db(0xDE); db(0xF9); } +void fdivp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF8); } +void fdivp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF8); } +void fdivr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 7, 0); } +void fdivr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F8, 0xDCF0); } +void fdivr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F8, 0xDCF0); } +void fdivrp() { db(0xDE); db(0xF1); } +void fdivrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF0); } +void fdivrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF0); } +void ffree(const Fpu& reg) { opFpu(reg, 0xDD, 0xC0); } +void fiadd(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 0, 0); } +void ficom(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 2, 0); } +void ficomp(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 3, 0); } +void fidiv(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 6, 0); } +void fidivr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 7, 0); } +void fild(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 0, 5); } +void fimul(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 1, 0); } +void fincstp() { db(0xD9); db(0xF7); } +void finit() { db(0x9B); db(0xDB); db(0xE3); } +void fist(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0x00, 2, 0); } +void fistp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 3, 7); } +void fisttp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDD, 1, 0); } +void fisub(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 4, 0); } +void fisubr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 5, 0); } +void fld(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 0, 0); } +void fld(const Fpu& reg) { opFpu(reg, 0xD9, 0xC0); } +void fld1() { db(0xD9); db(0xE8); } +void fldcw(const Address& addr) { opModM(addr, Reg32(5), 0xD9, 0x100); } +void fldl2e() { db(0xD9); db(0xEA); } +void fldl2t() { db(0xD9); db(0xE9); } +void fldlg2() { db(0xD9); db(0xEC); } +void fldln2() { db(0xD9); db(0xED); } +void fldpi() { db(0xD9); db(0xEB); } +void fldz() { db(0xD9); db(0xEE); } +void fmul(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 1, 0); } +void fmul(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C8, 0xDCC8); } +void fmul(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C8, 0xDCC8); } +void fmulp() { db(0xDE); db(0xC9); } +void fmulp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC8); } +void fmulp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC8); } +void fninit() { db(0xDB); db(0xE3); } +void fnop() { db(0xD9); db(0xD0); } +void fpatan() { db(0xD9); db(0xF3); } +void fprem() { db(0xD9); db(0xF8); } +void fprem1() { db(0xD9); db(0xF5); } +void fptan() { db(0xD9); db(0xF2); } +void frndint() { db(0xD9); db(0xFC); } +void fscale() { db(0xD9); db(0xFD); } +void fsin() { db(0xD9); db(0xFE); } +void fsincos() { db(0xD9); db(0xFB); } +void fsqrt() { db(0xD9); db(0xFA); } +void fst(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 2, 0); } +void fst(const Fpu& reg) { opFpu(reg, 0xDD, 0xD0); } +void fstcw(const Address& addr) { db(0x9B); opModM(addr, Reg32(7), 0xD9, NONE); } +void fstp(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 3, 0); } +void fstp(const Fpu& reg) { opFpu(reg, 0xDD, 0xD8); } +void fsub(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 4, 0); } +void fsub(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E0, 0xDCE8); } +void fsub(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E0, 0xDCE8); } +void fsubp() { db(0xDE); db(0xE9); } +void fsubp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE8); } +void fsubp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE8); } +void fsubr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 5, 0); } +void fsubr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E8, 0xDCE0); } +void fsubr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E8, 0xDCE0); } +void fsubrp() { db(0xDE); db(0xE1); } +void fsubrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE0); } +void fsubrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE0); } +void ftst() { db(0xD9); db(0xE4); } +void fucom() { db(0xDD); db(0xE1); } +void fucom(const Fpu& reg) { opFpu(reg, 0xDD, 0xE0); } +void fucomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBE8, 0x00E8); } +void fucomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBE8, 0x00E8); } +void fucomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFE8, 0x00E8); } +void fucomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFE8, 0x00E8); } +void fucomp() { db(0xDD); db(0xE9); } +void fucomp(const Fpu& reg) { opFpu(reg, 0xDD, 0xE8); } +void fucompp() { db(0xDA); db(0xE9); } +void fwait() { db(0x9B); } +void fxam() { db(0xD9); db(0xE5); } +void fxch() { db(0xD9); db(0xC9); } +void fxch(const Fpu& reg) { opFpu(reg, 0xD9, 0xC8); } +void fxtract() { db(0xD9); db(0xF4); } +void fyl2x() { db(0xD9); db(0xF1); } +void fyl2xp1() { db(0xD9); db(0xF9); } +void gf2p8affineinvqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void gf2p8affineqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCE, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void gf2p8mulb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void haddpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0x66, isXMM_XMMorMEM); } +void haddps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0xF2, isXMM_XMMorMEM); } +void hsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0x66, isXMM_XMMorMEM); } +void hsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0xF2, isXMM_XMMorMEM); } +void idiv(const Operand& op) { opR_ModM(op, 0, 7, 0xF6); } +void imul(const Operand& op) { opR_ModM(op, 0, 5, 0xF6); } +void inc(const Operand& op) { opIncDec(op, 0x40, 0); } +void insertps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void ja(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void ja(const char *label, LabelType type = T_AUTO) { ja(std::string(label), type); }//-V524 +void ja(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524 +void ja(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jae(const char *label, LabelType type = T_AUTO) { jae(std::string(label), type); }//-V524 +void jae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jb(const char *label, LabelType type = T_AUTO) { jb(std::string(label), type); }//-V524 +void jb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jbe(const char *label, LabelType type = T_AUTO) { jbe(std::string(label), type); }//-V524 +void jbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524 +void jbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jc(const char *label, LabelType type = T_AUTO) { jc(std::string(label), type); }//-V524 +void jc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void je(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void je(const char *label, LabelType type = T_AUTO) { je(std::string(label), type); }//-V524 +void je(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524 +void je(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void jg(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jg(const char *label, LabelType type = T_AUTO) { jg(std::string(label), type); }//-V524 +void jg(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524 +void jg(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jge(const char *label, LabelType type = T_AUTO) { jge(std::string(label), type); }//-V524 +void jge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524 +void jge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jl(const char *label, LabelType type = T_AUTO) { jl(std::string(label), type); }//-V524 +void jl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524 +void jl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jle(const char *label, LabelType type = T_AUTO) { jle(std::string(label), type); }//-V524 +void jle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524 +void jle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jna(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jna(const char *label, LabelType type = T_AUTO) { jna(std::string(label), type); }//-V524 +void jna(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524 +void jna(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jnae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jnae(const char *label, LabelType type = T_AUTO) { jnae(std::string(label), type); }//-V524 +void jnae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jnae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jnb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnb(const char *label, LabelType type = T_AUTO) { jnb(std::string(label), type); }//-V524 +void jnb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jnb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jnbe(const char *label, LabelType type = T_AUTO) { jnbe(std::string(label), type); }//-V524 +void jnbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524 +void jnbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jnc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnc(const char *label, LabelType type = T_AUTO) { jnc(std::string(label), type); }//-V524 +void jnc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jnc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jne(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jne(const char *label, LabelType type = T_AUTO) { jne(std::string(label), type); }//-V524 +void jne(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524 +void jne(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jng(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jng(const char *label, LabelType type = T_AUTO) { jng(std::string(label), type); }//-V524 +void jng(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524 +void jng(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jnge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jnge(const char *label, LabelType type = T_AUTO) { jnge(std::string(label), type); }//-V524 +void jnge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524 +void jnge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jnl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jnl(const char *label, LabelType type = T_AUTO) { jnl(std::string(label), type); }//-V524 +void jnl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524 +void jnl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jnle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jnle(const char *label, LabelType type = T_AUTO) { jnle(std::string(label), type); }//-V524 +void jnle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524 +void jnle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jno(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524 +void jno(const char *label, LabelType type = T_AUTO) { jno(std::string(label), type); }//-V524 +void jno(const void *addr) { opJmpAbs(addr, T_NEAR, 0x71, 0x81, 0x0F); }//-V524 +void jno(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524 +void jnp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jnp(const char *label, LabelType type = T_AUTO) { jnp(std::string(label), type); }//-V524 +void jnp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524 +void jnp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jns(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524 +void jns(const char *label, LabelType type = T_AUTO) { jns(std::string(label), type); }//-V524 +void jns(const void *addr) { opJmpAbs(addr, T_NEAR, 0x79, 0x89, 0x0F); }//-V524 +void jns(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524 +void jnz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jnz(const char *label, LabelType type = T_AUTO) { jnz(std::string(label), type); }//-V524 +void jnz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524 +void jnz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524 +void jo(const char *label, LabelType type = T_AUTO) { jo(std::string(label), type); }//-V524 +void jo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x70, 0x80, 0x0F); }//-V524 +void jo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524 +void jp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jp(const char *label, LabelType type = T_AUTO) { jp(std::string(label), type); }//-V524 +void jp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524 +void jp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(const char *label, LabelType type = T_AUTO) { jpe(std::string(label), type); }//-V524 +void jpe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jpo(const char *label, LabelType type = T_AUTO) { jpo(std::string(label), type); }//-V524 +void jpo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524 +void jpo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void js(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524 +void js(const char *label, LabelType type = T_AUTO) { js(std::string(label), type); }//-V524 +void js(const void *addr) { opJmpAbs(addr, T_NEAR, 0x78, 0x88, 0x0F); }//-V524 +void js(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524 +void jz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void jz(const char *label, LabelType type = T_AUTO) { jz(std::string(label), type); }//-V524 +void jz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524 +void jz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void lahf() { db(0x9F); } +void lddqu(const Xmm& xmm, const Address& addr) { db(0xF2); opModM(addr, xmm, 0x0F, 0xF0); } +void ldmxcsr(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0xAE); } +void lea(const Reg& reg, const Address& addr) { if (!reg.isBit(16 | i32e)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModM(addr, reg, 0x8D); } +void lfence() { db(0x0F); db(0xAE); db(0xE8); } +void lock() { db(0xF0); } +void lzcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBD); } +void maskmovdqu(const Xmm& reg1, const Xmm& reg2) { db(0x66); opModR(reg1, reg2, 0x0F, 0xF7); } +void maskmovq(const Mmx& reg1, const Mmx& reg2) { if (!reg1.isMMX() || !reg2.isMMX()) throw Error(ERR_BAD_COMBINATION); opModR(reg1, reg2, 0x0F, 0xF7); } +void maxpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x66, isXMM_XMMorMEM); } +void maxps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x100, isXMM_XMMorMEM); } +void maxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF2, isXMM_XMMorMEM); } +void maxss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF3, isXMM_XMMorMEM); } +void mfence() { db(0x0F); db(0xAE); db(0xF0); } +void minpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x66, isXMM_XMMorMEM); } +void minps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x100, isXMM_XMMorMEM); } +void minsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF2, isXMM_XMMorMEM); } +void minss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF3, isXMM_XMMorMEM); } +void monitor() { db(0x0F); db(0x01); db(0xC8); } +void movapd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x29); } +void movapd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x66); } +void movaps(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x29); } +void movaps(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x100); } +void movbe(const Address& addr, const Reg& reg) { opModM(addr, reg, 0x0F, 0x38, 0xF1); } +void movbe(const Reg& reg, const Address& addr) { opModM(addr, reg, 0x0F, 0x38, 0xF0); } +void movd(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x7E); } +void movd(const Mmx& mmx, const Address& addr) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x6E); } +void movd(const Mmx& mmx, const Reg32& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); } +void movd(const Reg32& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); } +void movddup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF2, isXMM_XMMorMEM, NONE, NONE); } +void movdq2q(const Mmx& mmx, const Xmm& xmm) { db(0xF2); opModR(mmx, xmm, 0x0F, 0xD6); } +void movdqa(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x7F); } +void movdqa(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0x66); } +void movdqu(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x7F); } +void movdqu(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0xF3); } +void movhlps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x12); } +void movhpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x66); } +void movhps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x100); } +void movlhps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x16); } +void movlpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x66); } +void movlps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x100); } +void movmskpd(const Reg32e& reg, const Xmm& xmm) { db(0x66); movmskps(reg, xmm); } +void movmskps(const Reg32e& reg, const Xmm& xmm) { opModR(reg, xmm, 0x0F, 0x50); } +void movntdq(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0xE7); } +void movntdqa(const Xmm& xmm, const Address& addr) { db(0x66); opModM(addr, xmm, 0x0F, 0x38, 0x2A); } +void movnti(const Address& addr, const Reg32e& reg) { opModM(addr, reg, 0x0F, 0xC3); } +void movntpd(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0x2B); } +void movntps(const Address& addr, const Xmm& xmm) { opModM(addr, Mmx(xmm.getIdx()), 0x0F, 0x2B); } +void movntq(const Address& addr, const Mmx& mmx) { if (!mmx.isMMX()) throw Error(ERR_BAD_COMBINATION); opModM(addr, mmx, 0x0F, 0xE7); } +void movq(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, mmx.isXMM() ? 0xD6 : 0x7F); } +void movq(const Mmx& mmx, const Operand& op) { if (mmx.isXMM()) db(0xF3); opModRM(mmx, op, (mmx.getKind() == op.getKind()), op.isMEM(), 0x0F, mmx.isXMM() ? 0x7E : 0x6F); } +void movq2dq(const Xmm& xmm, const Mmx& mmx) { db(0xF3); opModR(xmm, mmx, 0x0F, 0xD6); } +void movsb() { db(0xA4); } +void movsd() { db(0xA5); } +void movsd(const Address& addr, const Xmm& xmm) { db(0xF2); opModM(addr, xmm, 0x0F, 0x11); } +void movsd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF2); } +void movshdup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x16, 0xF3, isXMM_XMMorMEM, NONE, NONE); } +void movsldup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF3, isXMM_XMMorMEM, NONE, NONE); } +void movss(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x11); } +void movss(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF3); } +void movsw() { db(0x66); db(0xA5); } +void movsx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xBE); } +void movupd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x11); } +void movupd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x66); } +void movups(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x11); } +void movups(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x100); } +void movzx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xB6); } +void mpsadbw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x42, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void mul(const Operand& op) { opR_ModM(op, 0, 4, 0xF6); } +void mulpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x66, isXMM_XMMorMEM); } +void mulps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x100, isXMM_XMMorMEM); } +void mulsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF2, isXMM_XMMorMEM); } +void mulss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF3, isXMM_XMMorMEM); } +void mulx(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf6, true); } +void mwait() { db(0x0F); db(0x01); db(0xC9); } +void neg(const Operand& op) { opR_ModM(op, 0, 3, 0xF6); } +void not_(const Operand& op) { opR_ModM(op, 0, 2, 0xF6); } +void or_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x08, 1); } +void or_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x08); } +void orpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x66, isXMM_XMMorMEM); } +void orps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x100, isXMM_XMMorMEM); } +void pabsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1C, 0x66, NONE, 0x38); } +void pabsd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1E, 0x66, NONE, 0x38); } +void pabsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1D, 0x66, NONE, 0x38); } +void packssdw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6B); } +void packsswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x63); } +void packusdw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2B, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void packuswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x67); } +void paddb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFC); } +void paddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFE); } +void paddq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD4); } +void paddsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEC); } +void paddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xED); } +void paddusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDC); } +void paddusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDD); } +void paddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFD); } +void palignr(const Mmx& mmx, const Operand& op, int imm) { opMMX(mmx, op, 0x0f, 0x66, static_cast(imm), 0x3a); } +void pand(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDB); } +void pandn(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDF); } +void pause() { db(0xF3); db(0x90); } +void pavgb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE0); } +void pavgw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE3); } +void pblendvb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x10, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pblendw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0E, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void pclmulhqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x11); } +void pclmulhqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x01); } +void pclmullqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x10); } +void pclmullqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x00); } +void pclmulqdq(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x44, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void pcmpeqb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x74); } +void pcmpeqd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x76); } +void pcmpeqq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x29, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pcmpeqw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x75); } +void pcmpestri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x61, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pcmpestrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x60, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pcmpgtb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x64); } +void pcmpgtd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x66); } +void pcmpgtq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x37, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pcmpgtw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x65); } +void pcmpistri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x63, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pcmpistrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x62, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pdep(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf5, true); } +void pext(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F3 | T_0F38, 0xf5, true); } +void pextrb(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x14, imm); } +void pextrd(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x16, imm); } +void pextrw(const Operand& op, const Mmx& xmm, uint8 imm) { opExt(op, xmm, 0x15, imm, true); } +void phaddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x02, 0x66, NONE, 0x38); } +void phaddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x03, 0x66, NONE, 0x38); } +void phaddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x01, 0x66, NONE, 0x38); } +void phminposuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void phsubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x06, 0x66, NONE, 0x38); } +void phsubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x07, 0x66, NONE, 0x38); } +void phsubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x05, 0x66, NONE, 0x38); } +void pinsrb(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x20, 0x66, isXMM_REG32orMEM, imm, 0x3A); } +void pinsrd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x22, 0x66, isXMM_REG32orMEM, imm, 0x3A); } +void pinsrw(const Mmx& mmx, const Operand& op, int imm) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(mmx, op, 0xC4, mmx.isXMM() ? 0x66 : NONE, 0, imm); } +void pmaddubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x04, 0x66, NONE, 0x38); } +void pmaddwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF5); } +void pmaxsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3C, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmaxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3D, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmaxsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEE); } +void pmaxub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDE); } +void pmaxud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3F, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmaxuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3E, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x38, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x39, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEA); } +void pminub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDA); } +void pminud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3B, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3A, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovmskb(const Reg32e& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(reg, mmx, 0x0F, 0xD7); } +void pmovsxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x22, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x20, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x25, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x23, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x24, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x31, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x32, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x30, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x35, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x33, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x34, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmuldq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x28, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmulhrsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0B, 0x66, NONE, 0x38); } +void pmulhuw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE4); } +void pmulhw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE5); } +void pmulld(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmullw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD5); } +void pmuludq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF4); } +void popcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xB8); } +void popf() { db(0x9D); } +void por(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEB); } +void prefetchnta(const Address& addr) { opModM(addr, Reg32(0), 0x0F, 0x18); } +void prefetcht0(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x18); } +void prefetcht1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x18); } +void prefetcht2(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0x18); } +void prefetchw(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x0D); } +void prefetchwt1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x0D); } +void psadbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF6); } +void pshufb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x00, 0x66, NONE, 0x38); } +void pshufd(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x66, imm8); } +void pshufhw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF3, imm8); } +void pshuflw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF2, imm8); } +void pshufw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x00, imm8); } +void psignb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x08, 0x66, NONE, 0x38); } +void psignd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0A, 0x66, NONE, 0x38); } +void psignw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x09, 0x66, NONE, 0x38); } +void pslld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF2); } +void pslld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 6); } +void pslldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 7); } +void psllq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF3); } +void psllq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 6); } +void psllw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF1); } +void psllw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 6); } +void psrad(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE2); } +void psrad(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 4); } +void psraw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE1); } +void psraw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 4); } +void psrld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD2); } +void psrld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 2); } +void psrldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 3); } +void psrlq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD3); } +void psrlq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 2); } +void psrlw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD1); } +void psrlw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 2); } +void psubb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF8); } +void psubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFA); } +void psubq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFB); } +void psubsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE8); } +void psubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE9); } +void psubusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD8); } +void psubusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD9); } +void psubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF9); } +void ptest(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x17, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void punpckhbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x68); } +void punpckhdq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6A); } +void punpckhqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6D, 0x66, isXMM_XMMorMEM); } +void punpckhwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x69); } +void punpcklbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x60); } +void punpckldq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x62); } +void punpcklqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6C, 0x66, isXMM_XMMorMEM); } +void punpcklwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x61); } +void pushf() { db(0x9C); } +void pxor(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEF); } +void rcl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 2); } +void rcl(const Operand& op, int imm) { opShift(op, imm, 2); } +void rcpps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0x100, isXMM_XMMorMEM); } +void rcpss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0xF3, isXMM_XMMorMEM); } +void rcr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 3); } +void rcr(const Operand& op, int imm) { opShift(op, imm, 3); } +void rdmsr() { db(0x0F); db(0x32); } +void rdpmc() { db(0x0F); db(0x33); } +void rdrand(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(6, Operand::REG, r.getBit()), r, 0x0F, 0xC7); } +void rdseed(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(7, Operand::REG, r.getBit()), r, 0x0F, 0xC7); } +void rdtsc() { db(0x0F); db(0x31); } +void rdtscp() { db(0x0F); db(0x01); db(0xF9); } +void rep() { db(0xF3); } +void ret(int imm = 0) { if (imm) { db(0xC2); dw(imm); } else { db(0xC3); } } +void rol(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 0); } +void rol(const Operand& op, int imm) { opShift(op, imm, 0); } +void ror(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 1); } +void ror(const Operand& op, int imm) { opShift(op, imm, 1); } +void rorx(const Reg32e& r, const Operand& op, uint8 imm) { opGpr(r, op, Reg32e(0, r.getBit()), T_0F3A | T_F2, 0xF0, false, imm); } +void roundpd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x09, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void roundps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x08, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void roundsd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0B, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void roundss(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0A, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void rsqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0x100, isXMM_XMMorMEM); } +void rsqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0xF3, isXMM_XMMorMEM); } +void sahf() { db(0x9E); } +void sal(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); } +void sal(const Operand& op, int imm) { opShift(op, imm, 4); } +void sar(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 7); } +void sar(const Operand& op, int imm) { opShift(op, imm, 7); } +void sarx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F3 | T_0F38, 0xf7, false); } +void sbb(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x18, 3); } +void sbb(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x18); } +void scasb() { db(0xAE); } +void scasd() { db(0xAF); } +void scasw() { db(0x66); db(0xAF); } +void seta(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524 +void setae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 +void setb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 +void setbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524 +void setc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 +void sete(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524 +void setg(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524 +void setge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524 +void setl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524 +void setle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524 +void setna(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524 +void setnae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 +void setnb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 +void setnbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524 +void setnc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 +void setne(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524 +void setng(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524 +void setnge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524 +void setnl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524 +void setnle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524 +void setno(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 1); }//-V524 +void setnp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524 +void setns(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 9); }//-V524 +void setnz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524 +void seto(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 0); }//-V524 +void setp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524 +void setpe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524 +void setpo(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524 +void sets(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 8); }//-V524 +void setz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524 +void sfence() { db(0x0F); db(0xAE); db(0xF8); } +void sha1msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC9, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha1msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCA, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha1nexte(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC8, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha1rnds4(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, imm, 0x3A); } +void sha256msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha256msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCD, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha256rnds2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCB, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void shl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); } +void shl(const Operand& op, int imm) { opShift(op, imm, 4); } +void shld(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xA4, &_cl); } +void shld(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xA4); } +void shlx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_66 | T_0F38, 0xf7, false); } +void shr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 5); } +void shr(const Operand& op, int imm) { opShift(op, imm, 5); } +void shrd(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xAC, &_cl); } +void shrd(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xAC); } +void shrx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F2 | T_0F38, 0xf7, false); } +void shufpd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x66, isXMM_XMMorMEM, imm8); } +void shufps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x100, isXMM_XMMorMEM, imm8); } +void sqrtpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x66, isXMM_XMMorMEM); } +void sqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x100, isXMM_XMMorMEM); } +void sqrtsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF2, isXMM_XMMorMEM); } +void sqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF3, isXMM_XMMorMEM); } +void stac() { db(0x0F); db(0x01); db(0xCB); } +void stc() { db(0xF9); } +void std() { db(0xFD); } +void sti() { db(0xFB); } +void stmxcsr(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0xAE); } +void stosb() { db(0xAA); } +void stosd() { db(0xAB); } +void stosw() { db(0x66); db(0xAB); } +void sub(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x28, 5); } +void sub(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x28); } +void subpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x66, isXMM_XMMorMEM); } +void subps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x100, isXMM_XMMorMEM); } +void subsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF2, isXMM_XMMorMEM); } +void subss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF3, isXMM_XMMorMEM); } +void tzcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBC); } +void ucomisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x66, isXMM_XMMorMEM); } +void ucomiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x100, isXMM_XMMorMEM); } +void ud2() { db(0x0F); db(0x0B); } +void unpckhpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM); } +void unpckhps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x100, isXMM_XMMorMEM); } +void unpcklpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM); } +void unpcklps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x100, isXMM_XMMorMEM); } +void vaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x58); } +void vaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x58); } +void vaddsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x58); } +void vaddss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x58); } +void vaddsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0xD0); } +void vaddsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0xD0); } +void vaesdec(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDE); } +void vaesdeclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDF); } +void vaesenc(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDC); } +void vaesenclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDD); } +void vaesimc(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_W0, 0xDB); } +void vaeskeygenassist(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0xDF, imm); } +void vandnpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x55); } +void vandnps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x55); } +void vandpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x54); } +void vandps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x54); } +void vblendpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0D, imm); } +void vblendps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0C, imm); } +void vblendvpd(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4B, x4.getIdx() << 4); } +void vblendvps(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4A, x4.getIdx() << 4); } +void vbroadcastf128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x1A); } +void vbroadcasti128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x5A); } +void vbroadcastsd(const Ymm& y, const Operand& op) { if (!op.isMEM() && !(y.isYMM() && op.isXMM()) && !(y.isZMM() && op.isXMM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(y, op, T_0F38 | T_66 | T_W0 | T_YMM | T_EVEX | T_EW1 | T_N8, 0x19); } +void vbroadcastss(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x18); } +void vcmpeq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 16); } +void vcmpeq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 16); } +void vcmpeq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 16); } +void vcmpeq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 16); } +void vcmpeq_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 8); } +void vcmpeq_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 8); } +void vcmpeq_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 8); } +void vcmpeq_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 8); } +void vcmpeq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 24); } +void vcmpeq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 24); } +void vcmpeq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 24); } +void vcmpeq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 24); } +void vcmpeqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 0); } +void vcmpeqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 0); } +void vcmpeqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 0); } +void vcmpeqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 0); } +void vcmpfalse_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 27); } +void vcmpfalse_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 27); } +void vcmpfalse_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 27); } +void vcmpfalse_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 27); } +void vcmpfalsepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 11); } +void vcmpfalseps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 11); } +void vcmpfalsesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 11); } +void vcmpfalsess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 11); } +void vcmpge_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 29); } +void vcmpge_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 29); } +void vcmpge_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 29); } +void vcmpge_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 29); } +void vcmpgepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 13); } +void vcmpgeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 13); } +void vcmpgesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 13); } +void vcmpgess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 13); } +void vcmpgt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 30); } +void vcmpgt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 30); } +void vcmpgt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 30); } +void vcmpgt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 30); } +void vcmpgtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 14); } +void vcmpgtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 14); } +void vcmpgtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 14); } +void vcmpgtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 14); } +void vcmple_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 18); } +void vcmple_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 18); } +void vcmple_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 18); } +void vcmple_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 18); } +void vcmplepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 2); } +void vcmpleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 2); } +void vcmplesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 2); } +void vcmpless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 2); } +void vcmplt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 17); } +void vcmplt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 17); } +void vcmplt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 17); } +void vcmplt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 17); } +void vcmpltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 1); } +void vcmpltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 1); } +void vcmpltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 1); } +void vcmpltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 1); } +void vcmpneq_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 12); } +void vcmpneq_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 12); } +void vcmpneq_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 12); } +void vcmpneq_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 12); } +void vcmpneq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 28); } +void vcmpneq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 28); } +void vcmpneq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 28); } +void vcmpneq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 28); } +void vcmpneq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 20); } +void vcmpneq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 20); } +void vcmpneq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 20); } +void vcmpneq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 20); } +void vcmpneqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 4); } +void vcmpneqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 4); } +void vcmpneqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 4); } +void vcmpneqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 4); } +void vcmpnge_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 25); } +void vcmpnge_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 25); } +void vcmpnge_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 25); } +void vcmpnge_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 25); } +void vcmpngepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 9); } +void vcmpngeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 9); } +void vcmpngesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 9); } +void vcmpngess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 9); } +void vcmpngt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 26); } +void vcmpngt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 26); } +void vcmpngt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 26); } +void vcmpngt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 26); } +void vcmpngtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 10); } +void vcmpngtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 10); } +void vcmpngtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 10); } +void vcmpngtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 10); } +void vcmpnle_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 22); } +void vcmpnle_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 22); } +void vcmpnle_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 22); } +void vcmpnle_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 22); } +void vcmpnlepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 6); } +void vcmpnleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 6); } +void vcmpnlesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 6); } +void vcmpnless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 6); } +void vcmpnlt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 21); } +void vcmpnlt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 21); } +void vcmpnlt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 21); } +void vcmpnlt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 21); } +void vcmpnltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 5); } +void vcmpnltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 5); } +void vcmpnltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 5); } +void vcmpnltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 5); } +void vcmpord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 23); } +void vcmpord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 23); } +void vcmpord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 23); } +void vcmpord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 23); } +void vcmpordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 7); } +void vcmpordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 7); } +void vcmpordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 7); } +void vcmpordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 7); } +void vcmppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xC2, imm); } +void vcmpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_YMM, 0xC2, imm); } +void vcmpsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F2 | T_0F, 0xC2, imm); } +void vcmpss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0xC2, imm); } +void vcmptrue_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 31); } +void vcmptrue_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 31); } +void vcmptrue_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 31); } +void vcmptrue_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 31); } +void vcmptruepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 15); } +void vcmptrueps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 15); } +void vcmptruesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 15); } +void vcmptruess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 15); } +void vcmpunord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 19); } +void vcmpunord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 19); } +void vcmpunord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 19); } +void vcmpunord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 19); } +void vcmpunordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 3); } +void vcmpunordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 3); } +void vcmpunordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 3); } +void vcmpunordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 3); } +void vcomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2F); } +void vcomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2F); } +void vcvtdq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_F3 | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0xE6); } +void vcvtdq2ps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); } +void vcvtpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_F2 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0xE6); } +void vcvtpd2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_66 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5A); } +void vcvtph2ps(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F38 | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x13); } +void vcvtps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); } +void vcvtps2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x5A); } +void vcvtps2ph(const Operand& op, const Xmm& x, uint8 imm) { checkCvt1(x, op); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x1D, imm); } +void vcvtsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_ER_X, 0x2D); } +void vcvtsd2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x5A); } +void vcvtsi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F2 | T_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); } +void vcvtsi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F3 | T_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); } +void vcvtss2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x5A); } +void vcvtss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_ER_X | T_N8, 0x2D); } +void vcvttpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_66 | T_0F | T_YMM | T_EVEX |T_EW1 | T_B64 | T_ER_Z, 0xE6); } +void vcvttps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX | T_SAE_Z | T_B32, 0x5B); } +void vcvttsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_SAE_X, 0x2C); } +void vcvttss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_SAE_X | T_N8, 0x2C); } +void vdivpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5E); } +void vdivps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5E); } +void vdivsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5E); } +void vdivss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5E); } +void vdppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x41, imm); } +void vdpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x40, imm); } +void vextractf128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x19, imm); } +void vextracti128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x39, imm); } +void vextractps(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_N4, 0x17, imm); } +void vfmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x98); } +void vfmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x98); } +void vfmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x99); } +void vfmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x99); } +void vfmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA8); } +void vfmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA8); } +void vfmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xA9); } +void vfmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xA9); } +void vfmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB8); } +void vfmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB8); } +void vfmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xB9); } +void vfmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xB9); } +void vfmaddsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x96); } +void vfmaddsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x96); } +void vfmaddsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA6); } +void vfmaddsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA6); } +void vfmaddsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB6); } +void vfmaddsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB6); } +void vfmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9A); } +void vfmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9A); } +void vfmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9B); } +void vfmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9B); } +void vfmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAA); } +void vfmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAA); } +void vfmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAB); } +void vfmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAB); } +void vfmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBA); } +void vfmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBA); } +void vfmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBB); } +void vfmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBB); } +void vfmsubadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x97); } +void vfmsubadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x97); } +void vfmsubadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA7); } +void vfmsubadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA7); } +void vfmsubadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB7); } +void vfmsubadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB7); } +void vfnmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9C); } +void vfnmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9C); } +void vfnmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9D); } +void vfnmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9D); } +void vfnmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAC); } +void vfnmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAC); } +void vfnmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAD); } +void vfnmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAD); } +void vfnmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBC); } +void vfnmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBC); } +void vfnmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBD); } +void vfnmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBD); } +void vfnmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9E); } +void vfnmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9E); } +void vfnmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9F); } +void vfnmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9F); } +void vfnmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAE); } +void vfnmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAE); } +void vfnmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAF); } +void vfnmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAF); } +void vfnmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBE); } +void vfnmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBE); } +void vfnmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBF); } +void vfnmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBF); } +void vgatherdpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x92, 0); } +void vgatherdps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x92, 1); } +void vgatherqpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x93, 1); } +void vgatherqps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x93, 2); } +void vgf2p8affineinvqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCF, imm); } +void vgf2p8affineqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCE, imm); } +void vgf2p8mulb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_SAE_Z, 0xCF); } +void vhaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7C); } +void vhaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7C); } +void vhsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7D); } +void vhsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7D); } +void vinsertf128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x18, imm); } +void vinserti128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x38, imm); } +void vinsertps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_W0 | T_EW0 | T_EVEX, 0x21, imm); } +void vlddqu(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, cvtIdx0(x), addr, T_0F | T_F2 | T_W0 | T_YMM, 0xF0); } +void vldmxcsr(const Address& addr) { opAVX_X_X_XM(xm2, xm0, addr, T_0F, 0xAE); } +void vmaskmovdqu(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_66, 0xF7); } +void vmaskmovpd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2F); } +void vmaskmovpd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2D); } +void vmaskmovps(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2E); } +void vmaskmovps(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2C); } +void vmaxpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5F); } +void vmaxps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5F); } +void vmaxsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5F); } +void vmaxss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5F); } +void vminpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5D); } +void vminps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5D); } +void vminsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5D); } +void vminss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5D); } +void vmovapd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x29); } +void vmovapd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x28); } +void vmovaps(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x29); } +void vmovaps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x28); } +void vmovd(const Operand& op, const Xmm& x) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x7E); } +void vmovd(const Xmm& x, const Operand& op) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x6E); } +void vmovddup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_DUP | T_F2 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_X | T_ER_Y | T_ER_Z, 0x12); } +void vmovdqa(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_YMM, 0x7F); } +void vmovdqa(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_YMM, 0x6F); } +void vmovdqu(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_F3 | T_0F | T_YMM, 0x7F); } +void vmovdqu(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM, 0x6F); } +void vmovhlps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x12); } +void vmovhpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x17); } +void vmovhpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x16); } +void vmovhps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x17); } +void vmovhps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x16); } +void vmovlhps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x16); } +void vmovlpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x13); } +void vmovlpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x12); } +void vmovlps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x13); } +void vmovlps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x12); } +void vmovmskpd(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_66 | T_W0 | T_YMM, 0x50); } +void vmovmskps(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_W0 | T_YMM, 0x50); } +void vmovntdq(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW0, 0xE7); } +void vmovntdqa(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_0F38 | T_66 | T_YMM | T_EVEX | T_EW0, 0x2A); } +void vmovntpd(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW1, 0x2B); } +void vmovntps(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_YMM | T_EVEX | T_EW0, 0x2B); } +void vmovq(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, x.getIdx() < 16 ? 0xD6 : 0x7E); } +void vmovq(const Xmm& x, const Address& addr) { int type, code; if (x.getIdx() < 16) { type = T_0F | T_F3; code = 0x7E; } else { type = T_0F | T_66 | T_EVEX | T_EW1 | T_N8; code = 0x6E; } opAVX_X_X_XM(x, xm0, addr, type, code); } +void vmovq(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_F3 | T_EVEX | T_EW1 | T_N8, 0x7E); } +void vmovsd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_M_K, 0x11); } +void vmovsd(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); } +void vmovsd(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); } +void vmovshdup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x16); } +void vmovsldup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x12); } +void vmovss(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_M_K, 0x11); } +void vmovss(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); } +void vmovss(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); } +void vmovupd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x11); } +void vmovupd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x10); } +void vmovups(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x11); } +void vmovups(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x10); } +void vmpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x42, imm); } +void vmulpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x59); } +void vmulps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x59); } +void vmulsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x59); } +void vmulss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x59); } +void vorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x56); } +void vorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x56); } +void vpabsb(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1C); } +void vpabsd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x1E); } +void vpabsw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1D); } +void vpackssdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6B); } +void vpacksswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x63); } +void vpackusdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x2B); } +void vpackuswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x67); } +void vpaddb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFC); } +void vpaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFE); } +void vpaddq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xD4); } +void vpaddsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEC); } +void vpaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xED); } +void vpaddusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDC); } +void vpaddusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDD); } +void vpaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFD); } +void vpalignr(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_YMM | T_EVEX, 0x0F, imm); } +void vpand(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDB); } +void vpandn(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDF); } +void vpavgb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE0); } +void vpavgw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE3); } +void vpblendd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x02, imm); } +void vpblendvb(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4C, x4.getIdx() << 4); } +void vpblendw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0E, imm); } +void vpbroadcastb(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x78); } +void vpbroadcastd(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x58); } +void vpbroadcastq(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX, 0x59); } +void vpbroadcastw(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x79); } +void vpclmulqdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM | T_EVEX, 0x44, imm); } +void vpcmpeqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x74); } +void vpcmpeqd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x76); } +void vpcmpeqq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x29); } +void vpcmpeqw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x75); } +void vpcmpestri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x61, imm); } +void vpcmpestrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x60, imm); } +void vpcmpgtb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x64); } +void vpcmpgtd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x66); } +void vpcmpgtq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x37); } +void vpcmpgtw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x65); } +void vpcmpistri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x63, imm); } +void vpcmpistrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x62, imm); } +void vperm2f128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x06, imm); } +void vperm2i128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x46, imm); } +void vpermd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x36); } +void vpermilpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x0D); } +void vpermilpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_EVEX | T_B64, 0x05, imm); } +void vpermilps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x0C); } +void vpermilps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_EVEX | T_B32, 0x04, imm); } +void vpermpd(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x01, imm); } +void vpermpd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x16); } +void vpermps(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x16); } +void vpermq(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x00, imm); } +void vpermq(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x36); } +void vpextrb(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(8|16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x14, imm); } +void vpextrd(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x16, imm); } +void vpextrq(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(64) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x16, imm); } +void vpextrw(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); if (op.isREG() && x.getIdx() < 16) { opAVX_X_X_XM(Xmm(op.getIdx()), xm0, x, T_0F | T_66, 0xC5, imm); } else { opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N2, 0x15, imm); } } +void vpgatherdd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x90, 1); } +void vpgatherdq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x90, 0); } +void vpgatherqd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x91, 2); } +void vpgatherqq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x91, 1); } +void vphaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x02); } +void vphaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x03); } +void vphaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x01); } +void vphminposuw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38, 0x41); } +void vphsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x06); } +void vphsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x07); } +void vphsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x05); } +void vpinsrb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x20, imm); } +void vpinsrd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x22, imm); } +void vpinsrq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(64) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x22, imm); } +void vpinsrw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F | T_66 | T_EVEX | T_N2, 0xC4, imm); } +void vpmaddubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x04); } +void vpmaddwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF5); } +void vpmaskmovd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8E); } +void vpmaskmovd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8C); } +void vpmaskmovq(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8E); } +void vpmaskmovq(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8C); } +void vpmaxsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3C); } +void vpmaxsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3D); } +void vpmaxsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEE); } +void vpmaxub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDE); } +void vpmaxud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3F); } +void vpmaxuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3E); } +void vpminsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x38); } +void vpminsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x39); } +void vpminsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEA); } +void vpminub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDA); } +void vpminud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3B); } +void vpminuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3A); } +void vpmovmskb(const Reg32e& r, const Xmm& x) { if (!x.is(Operand::XMM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(x.isYMM() ? Ymm(r.getIdx()) : Xmm(r.getIdx()), 0, x, T_0F | T_66 | T_YMM, 0xD7); } +void vpmovsxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x21); } +void vpmovsxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x22); } +void vpmovsxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x20); } +void vpmovsxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x25); } +void vpmovsxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x23); } +void vpmovsxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x24); } +void vpmovzxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x31); } +void vpmovzxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x32); } +void vpmovzxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x30); } +void vpmovzxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x35); } +void vpmovzxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x33); } +void vpmovzxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x34); } +void vpmuldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x28); } +void vpmulhrsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x0B); } +void vpmulhuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE4); } +void vpmulhw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE5); } +void vpmulld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x40); } +void vpmullw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD5); } +void vpmuludq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xF4); } +void vpor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEB); } +void vpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF6); } +void vpshufb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x00); } +void vpshufd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x70, imm); } +void vpshufhw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM | T_EVEX, 0x70, imm); } +void vpshuflw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F2 | T_0F | T_YMM | T_EVEX, 0x70, imm); } +void vpsignb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x08); } +void vpsignd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x0A); } +void vpsignw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x09); } +void vpslld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } +void vpslld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xF2); } +void vpslldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 7), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); } +void vpsllq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); } +void vpsllq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xF3); } +void vpsllvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x47); } +void vpsllvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x47); } +void vpsllw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } +void vpsllw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xF1); } +void vpsrad(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } +void vpsrad(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xE2); } +void vpsravd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x46); } +void vpsraw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } +void vpsraw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xE1); } +void vpsrld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } +void vpsrld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xD2); } +void vpsrldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 3), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); } +void vpsrlq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); } +void vpsrlq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xD3); } +void vpsrlvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x45); } +void vpsrlvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x45); } +void vpsrlw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } +void vpsrlw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xD1); } +void vpsubb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF8); } +void vpsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFA); } +void vpsubq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xFB); } +void vpsubsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE8); } +void vpsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE9); } +void vpsubusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD8); } +void vpsubusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD9); } +void vpsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF9); } +void vptest(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x17); } +void vpunpckhbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x68); } +void vpunpckhdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6A); } +void vpunpckhqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6D); } +void vpunpckhwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x69); } +void vpunpcklbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x60); } +void vpunpckldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x62); } +void vpunpcklqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6C); } +void vpunpcklwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x61); } +void vpxor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEF); } +void vrcpps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x53); } +void vrcpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x53); } +void vroundpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x09, imm); } +void vroundps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x08, imm); } +void vroundsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0B, imm); } +void vroundss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0A, imm); } +void vrsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x52); } +void vrsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x52); } +void vshufpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xC6, imm); } +void vshufps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xC6, imm); } +void vsqrtpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x51); } +void vsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x51); } +void vsqrtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x51); } +void vsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_ER_X, 0x51); } +void vstmxcsr(const Address& addr) { opAVX_X_X_XM(xm3, xm0, addr, T_0F, 0xAE); } +void vsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5C); } +void vsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5C); } +void vsubsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5C); } +void vsubss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5C); } +void vtestpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0F); } +void vtestps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0E); } +void vucomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2E); } +void vucomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2E); } +void vunpckhpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x15); } +void vunpckhps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x15); } +void vunpcklpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x14); } +void vunpcklps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x14); } +void vxorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x57); } +void vxorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x57); } +void vzeroall() { db(0xC5); db(0xFC); db(0x77); } +void vzeroupper() { db(0xC5); db(0xF8); db(0x77); } +void wait() { db(0x9B); } +void wbinvd() { db(0x0F); db(0x09); } +void wrmsr() { db(0x0F); db(0x30); } +void xadd(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xC0 | (reg.isBit(8) ? 0 : 1)); } +void xgetbv() { db(0x0F); db(0x01); db(0xD0); } +void xlatb() { db(0xD7); } +void xor_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x30, 6); } +void xor_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x30); } +void xorpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x66, isXMM_XMMorMEM); } +void xorps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x100, isXMM_XMMorMEM); } +#ifdef XBYAK_ENABLE_OMITTED_OPERAND +void vblendpd(const Xmm& x, const Operand& op, uint8 imm) { vblendpd(x, x, op, imm); } +void vblendps(const Xmm& x, const Operand& op, uint8 imm) { vblendps(x, x, op, imm); } +void vblendvpd(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvpd(x1, x1, op, x4); } +void vblendvps(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvps(x1, x1, op, x4); } +void vcmpeq_ospd(const Xmm& x, const Operand& op) { vcmpeq_ospd(x, x, op); } +void vcmpeq_osps(const Xmm& x, const Operand& op) { vcmpeq_osps(x, x, op); } +void vcmpeq_ossd(const Xmm& x, const Operand& op) { vcmpeq_ossd(x, x, op); } +void vcmpeq_osss(const Xmm& x, const Operand& op) { vcmpeq_osss(x, x, op); } +void vcmpeq_uqpd(const Xmm& x, const Operand& op) { vcmpeq_uqpd(x, x, op); } +void vcmpeq_uqps(const Xmm& x, const Operand& op) { vcmpeq_uqps(x, x, op); } +void vcmpeq_uqsd(const Xmm& x, const Operand& op) { vcmpeq_uqsd(x, x, op); } +void vcmpeq_uqss(const Xmm& x, const Operand& op) { vcmpeq_uqss(x, x, op); } +void vcmpeq_uspd(const Xmm& x, const Operand& op) { vcmpeq_uspd(x, x, op); } +void vcmpeq_usps(const Xmm& x, const Operand& op) { vcmpeq_usps(x, x, op); } +void vcmpeq_ussd(const Xmm& x, const Operand& op) { vcmpeq_ussd(x, x, op); } +void vcmpeq_usss(const Xmm& x, const Operand& op) { vcmpeq_usss(x, x, op); } +void vcmpeqpd(const Xmm& x, const Operand& op) { vcmpeqpd(x, x, op); } +void vcmpeqps(const Xmm& x, const Operand& op) { vcmpeqps(x, x, op); } +void vcmpeqsd(const Xmm& x, const Operand& op) { vcmpeqsd(x, x, op); } +void vcmpeqss(const Xmm& x, const Operand& op) { vcmpeqss(x, x, op); } +void vcmpfalse_ospd(const Xmm& x, const Operand& op) { vcmpfalse_ospd(x, x, op); } +void vcmpfalse_osps(const Xmm& x, const Operand& op) { vcmpfalse_osps(x, x, op); } +void vcmpfalse_ossd(const Xmm& x, const Operand& op) { vcmpfalse_ossd(x, x, op); } +void vcmpfalse_osss(const Xmm& x, const Operand& op) { vcmpfalse_osss(x, x, op); } +void vcmpfalsepd(const Xmm& x, const Operand& op) { vcmpfalsepd(x, x, op); } +void vcmpfalseps(const Xmm& x, const Operand& op) { vcmpfalseps(x, x, op); } +void vcmpfalsesd(const Xmm& x, const Operand& op) { vcmpfalsesd(x, x, op); } +void vcmpfalsess(const Xmm& x, const Operand& op) { vcmpfalsess(x, x, op); } +void vcmpge_oqpd(const Xmm& x, const Operand& op) { vcmpge_oqpd(x, x, op); } +void vcmpge_oqps(const Xmm& x, const Operand& op) { vcmpge_oqps(x, x, op); } +void vcmpge_oqsd(const Xmm& x, const Operand& op) { vcmpge_oqsd(x, x, op); } +void vcmpge_oqss(const Xmm& x, const Operand& op) { vcmpge_oqss(x, x, op); } +void vcmpgepd(const Xmm& x, const Operand& op) { vcmpgepd(x, x, op); } +void vcmpgeps(const Xmm& x, const Operand& op) { vcmpgeps(x, x, op); } +void vcmpgesd(const Xmm& x, const Operand& op) { vcmpgesd(x, x, op); } +void vcmpgess(const Xmm& x, const Operand& op) { vcmpgess(x, x, op); } +void vcmpgt_oqpd(const Xmm& x, const Operand& op) { vcmpgt_oqpd(x, x, op); } +void vcmpgt_oqps(const Xmm& x, const Operand& op) { vcmpgt_oqps(x, x, op); } +void vcmpgt_oqsd(const Xmm& x, const Operand& op) { vcmpgt_oqsd(x, x, op); } +void vcmpgt_oqss(const Xmm& x, const Operand& op) { vcmpgt_oqss(x, x, op); } +void vcmpgtpd(const Xmm& x, const Operand& op) { vcmpgtpd(x, x, op); } +void vcmpgtps(const Xmm& x, const Operand& op) { vcmpgtps(x, x, op); } +void vcmpgtsd(const Xmm& x, const Operand& op) { vcmpgtsd(x, x, op); } +void vcmpgtss(const Xmm& x, const Operand& op) { vcmpgtss(x, x, op); } +void vcmple_oqpd(const Xmm& x, const Operand& op) { vcmple_oqpd(x, x, op); } +void vcmple_oqps(const Xmm& x, const Operand& op) { vcmple_oqps(x, x, op); } +void vcmple_oqsd(const Xmm& x, const Operand& op) { vcmple_oqsd(x, x, op); } +void vcmple_oqss(const Xmm& x, const Operand& op) { vcmple_oqss(x, x, op); } +void vcmplepd(const Xmm& x, const Operand& op) { vcmplepd(x, x, op); } +void vcmpleps(const Xmm& x, const Operand& op) { vcmpleps(x, x, op); } +void vcmplesd(const Xmm& x, const Operand& op) { vcmplesd(x, x, op); } +void vcmpless(const Xmm& x, const Operand& op) { vcmpless(x, x, op); } +void vcmplt_oqpd(const Xmm& x, const Operand& op) { vcmplt_oqpd(x, x, op); } +void vcmplt_oqps(const Xmm& x, const Operand& op) { vcmplt_oqps(x, x, op); } +void vcmplt_oqsd(const Xmm& x, const Operand& op) { vcmplt_oqsd(x, x, op); } +void vcmplt_oqss(const Xmm& x, const Operand& op) { vcmplt_oqss(x, x, op); } +void vcmpltpd(const Xmm& x, const Operand& op) { vcmpltpd(x, x, op); } +void vcmpltps(const Xmm& x, const Operand& op) { vcmpltps(x, x, op); } +void vcmpltsd(const Xmm& x, const Operand& op) { vcmpltsd(x, x, op); } +void vcmpltss(const Xmm& x, const Operand& op) { vcmpltss(x, x, op); } +void vcmpneq_oqpd(const Xmm& x, const Operand& op) { vcmpneq_oqpd(x, x, op); } +void vcmpneq_oqps(const Xmm& x, const Operand& op) { vcmpneq_oqps(x, x, op); } +void vcmpneq_oqsd(const Xmm& x, const Operand& op) { vcmpneq_oqsd(x, x, op); } +void vcmpneq_oqss(const Xmm& x, const Operand& op) { vcmpneq_oqss(x, x, op); } +void vcmpneq_ospd(const Xmm& x, const Operand& op) { vcmpneq_ospd(x, x, op); } +void vcmpneq_osps(const Xmm& x, const Operand& op) { vcmpneq_osps(x, x, op); } +void vcmpneq_ossd(const Xmm& x, const Operand& op) { vcmpneq_ossd(x, x, op); } +void vcmpneq_osss(const Xmm& x, const Operand& op) { vcmpneq_osss(x, x, op); } +void vcmpneq_uspd(const Xmm& x, const Operand& op) { vcmpneq_uspd(x, x, op); } +void vcmpneq_usps(const Xmm& x, const Operand& op) { vcmpneq_usps(x, x, op); } +void vcmpneq_ussd(const Xmm& x, const Operand& op) { vcmpneq_ussd(x, x, op); } +void vcmpneq_usss(const Xmm& x, const Operand& op) { vcmpneq_usss(x, x, op); } +void vcmpneqpd(const Xmm& x, const Operand& op) { vcmpneqpd(x, x, op); } +void vcmpneqps(const Xmm& x, const Operand& op) { vcmpneqps(x, x, op); } +void vcmpneqsd(const Xmm& x, const Operand& op) { vcmpneqsd(x, x, op); } +void vcmpneqss(const Xmm& x, const Operand& op) { vcmpneqss(x, x, op); } +void vcmpnge_uqpd(const Xmm& x, const Operand& op) { vcmpnge_uqpd(x, x, op); } +void vcmpnge_uqps(const Xmm& x, const Operand& op) { vcmpnge_uqps(x, x, op); } +void vcmpnge_uqsd(const Xmm& x, const Operand& op) { vcmpnge_uqsd(x, x, op); } +void vcmpnge_uqss(const Xmm& x, const Operand& op) { vcmpnge_uqss(x, x, op); } +void vcmpngepd(const Xmm& x, const Operand& op) { vcmpngepd(x, x, op); } +void vcmpngeps(const Xmm& x, const Operand& op) { vcmpngeps(x, x, op); } +void vcmpngesd(const Xmm& x, const Operand& op) { vcmpngesd(x, x, op); } +void vcmpngess(const Xmm& x, const Operand& op) { vcmpngess(x, x, op); } +void vcmpngt_uqpd(const Xmm& x, const Operand& op) { vcmpngt_uqpd(x, x, op); } +void vcmpngt_uqps(const Xmm& x, const Operand& op) { vcmpngt_uqps(x, x, op); } +void vcmpngt_uqsd(const Xmm& x, const Operand& op) { vcmpngt_uqsd(x, x, op); } +void vcmpngt_uqss(const Xmm& x, const Operand& op) { vcmpngt_uqss(x, x, op); } +void vcmpngtpd(const Xmm& x, const Operand& op) { vcmpngtpd(x, x, op); } +void vcmpngtps(const Xmm& x, const Operand& op) { vcmpngtps(x, x, op); } +void vcmpngtsd(const Xmm& x, const Operand& op) { vcmpngtsd(x, x, op); } +void vcmpngtss(const Xmm& x, const Operand& op) { vcmpngtss(x, x, op); } +void vcmpnle_uqpd(const Xmm& x, const Operand& op) { vcmpnle_uqpd(x, x, op); } +void vcmpnle_uqps(const Xmm& x, const Operand& op) { vcmpnle_uqps(x, x, op); } +void vcmpnle_uqsd(const Xmm& x, const Operand& op) { vcmpnle_uqsd(x, x, op); } +void vcmpnle_uqss(const Xmm& x, const Operand& op) { vcmpnle_uqss(x, x, op); } +void vcmpnlepd(const Xmm& x, const Operand& op) { vcmpnlepd(x, x, op); } +void vcmpnleps(const Xmm& x, const Operand& op) { vcmpnleps(x, x, op); } +void vcmpnlesd(const Xmm& x, const Operand& op) { vcmpnlesd(x, x, op); } +void vcmpnless(const Xmm& x, const Operand& op) { vcmpnless(x, x, op); } +void vcmpnlt_uqpd(const Xmm& x, const Operand& op) { vcmpnlt_uqpd(x, x, op); } +void vcmpnlt_uqps(const Xmm& x, const Operand& op) { vcmpnlt_uqps(x, x, op); } +void vcmpnlt_uqsd(const Xmm& x, const Operand& op) { vcmpnlt_uqsd(x, x, op); } +void vcmpnlt_uqss(const Xmm& x, const Operand& op) { vcmpnlt_uqss(x, x, op); } +void vcmpnltpd(const Xmm& x, const Operand& op) { vcmpnltpd(x, x, op); } +void vcmpnltps(const Xmm& x, const Operand& op) { vcmpnltps(x, x, op); } +void vcmpnltsd(const Xmm& x, const Operand& op) { vcmpnltsd(x, x, op); } +void vcmpnltss(const Xmm& x, const Operand& op) { vcmpnltss(x, x, op); } +void vcmpord_spd(const Xmm& x, const Operand& op) { vcmpord_spd(x, x, op); } +void vcmpord_sps(const Xmm& x, const Operand& op) { vcmpord_sps(x, x, op); } +void vcmpord_ssd(const Xmm& x, const Operand& op) { vcmpord_ssd(x, x, op); } +void vcmpord_sss(const Xmm& x, const Operand& op) { vcmpord_sss(x, x, op); } +void vcmpordpd(const Xmm& x, const Operand& op) { vcmpordpd(x, x, op); } +void vcmpordps(const Xmm& x, const Operand& op) { vcmpordps(x, x, op); } +void vcmpordsd(const Xmm& x, const Operand& op) { vcmpordsd(x, x, op); } +void vcmpordss(const Xmm& x, const Operand& op) { vcmpordss(x, x, op); } +void vcmppd(const Xmm& x, const Operand& op, uint8 imm) { vcmppd(x, x, op, imm); } +void vcmpps(const Xmm& x, const Operand& op, uint8 imm) { vcmpps(x, x, op, imm); } +void vcmpsd(const Xmm& x, const Operand& op, uint8 imm) { vcmpsd(x, x, op, imm); } +void vcmpss(const Xmm& x, const Operand& op, uint8 imm) { vcmpss(x, x, op, imm); } +void vcmptrue_uspd(const Xmm& x, const Operand& op) { vcmptrue_uspd(x, x, op); } +void vcmptrue_usps(const Xmm& x, const Operand& op) { vcmptrue_usps(x, x, op); } +void vcmptrue_ussd(const Xmm& x, const Operand& op) { vcmptrue_ussd(x, x, op); } +void vcmptrue_usss(const Xmm& x, const Operand& op) { vcmptrue_usss(x, x, op); } +void vcmptruepd(const Xmm& x, const Operand& op) { vcmptruepd(x, x, op); } +void vcmptrueps(const Xmm& x, const Operand& op) { vcmptrueps(x, x, op); } +void vcmptruesd(const Xmm& x, const Operand& op) { vcmptruesd(x, x, op); } +void vcmptruess(const Xmm& x, const Operand& op) { vcmptruess(x, x, op); } +void vcmpunord_spd(const Xmm& x, const Operand& op) { vcmpunord_spd(x, x, op); } +void vcmpunord_sps(const Xmm& x, const Operand& op) { vcmpunord_sps(x, x, op); } +void vcmpunord_ssd(const Xmm& x, const Operand& op) { vcmpunord_ssd(x, x, op); } +void vcmpunord_sss(const Xmm& x, const Operand& op) { vcmpunord_sss(x, x, op); } +void vcmpunordpd(const Xmm& x, const Operand& op) { vcmpunordpd(x, x, op); } +void vcmpunordps(const Xmm& x, const Operand& op) { vcmpunordps(x, x, op); } +void vcmpunordsd(const Xmm& x, const Operand& op) { vcmpunordsd(x, x, op); } +void vcmpunordss(const Xmm& x, const Operand& op) { vcmpunordss(x, x, op); } +void vcvtsd2ss(const Xmm& x, const Operand& op) { vcvtsd2ss(x, x, op); } +void vcvtsi2sd(const Xmm& x, const Operand& op) { vcvtsi2sd(x, x, op); } +void vcvtsi2ss(const Xmm& x, const Operand& op) { vcvtsi2ss(x, x, op); } +void vcvtss2sd(const Xmm& x, const Operand& op) { vcvtss2sd(x, x, op); } +void vdppd(const Xmm& x, const Operand& op, uint8 imm) { vdppd(x, x, op, imm); } +void vdpps(const Xmm& x, const Operand& op, uint8 imm) { vdpps(x, x, op, imm); } +void vinsertps(const Xmm& x, const Operand& op, uint8 imm) { vinsertps(x, x, op, imm); } +void vmpsadbw(const Xmm& x, const Operand& op, uint8 imm) { vmpsadbw(x, x, op, imm); } +void vpackssdw(const Xmm& x, const Operand& op) { vpackssdw(x, x, op); } +void vpacksswb(const Xmm& x, const Operand& op) { vpacksswb(x, x, op); } +void vpackusdw(const Xmm& x, const Operand& op) { vpackusdw(x, x, op); } +void vpackuswb(const Xmm& x, const Operand& op) { vpackuswb(x, x, op); } +void vpaddb(const Xmm& x, const Operand& op) { vpaddb(x, x, op); } +void vpaddd(const Xmm& x, const Operand& op) { vpaddd(x, x, op); } +void vpaddq(const Xmm& x, const Operand& op) { vpaddq(x, x, op); } +void vpaddsb(const Xmm& x, const Operand& op) { vpaddsb(x, x, op); } +void vpaddsw(const Xmm& x, const Operand& op) { vpaddsw(x, x, op); } +void vpaddusb(const Xmm& x, const Operand& op) { vpaddusb(x, x, op); } +void vpaddusw(const Xmm& x, const Operand& op) { vpaddusw(x, x, op); } +void vpaddw(const Xmm& x, const Operand& op) { vpaddw(x, x, op); } +void vpalignr(const Xmm& x, const Operand& op, uint8 imm) { vpalignr(x, x, op, imm); } +void vpand(const Xmm& x, const Operand& op) { vpand(x, x, op); } +void vpandn(const Xmm& x, const Operand& op) { vpandn(x, x, op); } +void vpavgb(const Xmm& x, const Operand& op) { vpavgb(x, x, op); } +void vpavgw(const Xmm& x, const Operand& op) { vpavgw(x, x, op); } +void vpblendd(const Xmm& x, const Operand& op, uint8 imm) { vpblendd(x, x, op, imm); } +void vpblendvb(const Xmm& x1, const Operand& op, const Xmm& x4) { vpblendvb(x1, x1, op, x4); } +void vpblendw(const Xmm& x, const Operand& op, uint8 imm) { vpblendw(x, x, op, imm); } +void vpclmulqdq(const Xmm& x, const Operand& op, uint8 imm) { vpclmulqdq(x, x, op, imm); } +void vpcmpeqb(const Xmm& x, const Operand& op) { vpcmpeqb(x, x, op); } +void vpcmpeqd(const Xmm& x, const Operand& op) { vpcmpeqd(x, x, op); } +void vpcmpeqq(const Xmm& x, const Operand& op) { vpcmpeqq(x, x, op); } +void vpcmpeqw(const Xmm& x, const Operand& op) { vpcmpeqw(x, x, op); } +void vpcmpgtb(const Xmm& x, const Operand& op) { vpcmpgtb(x, x, op); } +void vpcmpgtd(const Xmm& x, const Operand& op) { vpcmpgtd(x, x, op); } +void vpcmpgtq(const Xmm& x, const Operand& op) { vpcmpgtq(x, x, op); } +void vpcmpgtw(const Xmm& x, const Operand& op) { vpcmpgtw(x, x, op); } +void vphaddd(const Xmm& x, const Operand& op) { vphaddd(x, x, op); } +void vphaddsw(const Xmm& x, const Operand& op) { vphaddsw(x, x, op); } +void vphaddw(const Xmm& x, const Operand& op) { vphaddw(x, x, op); } +void vphsubd(const Xmm& x, const Operand& op) { vphsubd(x, x, op); } +void vphsubsw(const Xmm& x, const Operand& op) { vphsubsw(x, x, op); } +void vphsubw(const Xmm& x, const Operand& op) { vphsubw(x, x, op); } +void vpinsrb(const Xmm& x, const Operand& op, uint8 imm) { vpinsrb(x, x, op, imm); } +void vpinsrd(const Xmm& x, const Operand& op, uint8 imm) { vpinsrd(x, x, op, imm); } +void vpinsrq(const Xmm& x, const Operand& op, uint8 imm) { vpinsrq(x, x, op, imm); } +void vpinsrw(const Xmm& x, const Operand& op, uint8 imm) { vpinsrw(x, x, op, imm); } +void vpmaddubsw(const Xmm& x, const Operand& op) { vpmaddubsw(x, x, op); } +void vpmaddwd(const Xmm& x, const Operand& op) { vpmaddwd(x, x, op); } +void vpmaxsb(const Xmm& x, const Operand& op) { vpmaxsb(x, x, op); } +void vpmaxsd(const Xmm& x, const Operand& op) { vpmaxsd(x, x, op); } +void vpmaxsw(const Xmm& x, const Operand& op) { vpmaxsw(x, x, op); } +void vpmaxub(const Xmm& x, const Operand& op) { vpmaxub(x, x, op); } +void vpmaxud(const Xmm& x, const Operand& op) { vpmaxud(x, x, op); } +void vpmaxuw(const Xmm& x, const Operand& op) { vpmaxuw(x, x, op); } +void vpminsb(const Xmm& x, const Operand& op) { vpminsb(x, x, op); } +void vpminsd(const Xmm& x, const Operand& op) { vpminsd(x, x, op); } +void vpminsw(const Xmm& x, const Operand& op) { vpminsw(x, x, op); } +void vpminub(const Xmm& x, const Operand& op) { vpminub(x, x, op); } +void vpminud(const Xmm& x, const Operand& op) { vpminud(x, x, op); } +void vpminuw(const Xmm& x, const Operand& op) { vpminuw(x, x, op); } +void vpmuldq(const Xmm& x, const Operand& op) { vpmuldq(x, x, op); } +void vpmulhrsw(const Xmm& x, const Operand& op) { vpmulhrsw(x, x, op); } +void vpmulhuw(const Xmm& x, const Operand& op) { vpmulhuw(x, x, op); } +void vpmulhw(const Xmm& x, const Operand& op) { vpmulhw(x, x, op); } +void vpmulld(const Xmm& x, const Operand& op) { vpmulld(x, x, op); } +void vpmullw(const Xmm& x, const Operand& op) { vpmullw(x, x, op); } +void vpmuludq(const Xmm& x, const Operand& op) { vpmuludq(x, x, op); } +void vpor(const Xmm& x, const Operand& op) { vpor(x, x, op); } +void vpsadbw(const Xmm& x, const Operand& op) { vpsadbw(x, x, op); } +void vpsignb(const Xmm& x, const Operand& op) { vpsignb(x, x, op); } +void vpsignd(const Xmm& x, const Operand& op) { vpsignd(x, x, op); } +void vpsignw(const Xmm& x, const Operand& op) { vpsignw(x, x, op); } +void vpslld(const Xmm& x, const Operand& op) { vpslld(x, x, op); } +void vpslld(const Xmm& x, uint8 imm) { vpslld(x, x, imm); } +void vpslldq(const Xmm& x, uint8 imm) { vpslldq(x, x, imm); } +void vpsllq(const Xmm& x, const Operand& op) { vpsllq(x, x, op); } +void vpsllq(const Xmm& x, uint8 imm) { vpsllq(x, x, imm); } +void vpsllw(const Xmm& x, const Operand& op) { vpsllw(x, x, op); } +void vpsllw(const Xmm& x, uint8 imm) { vpsllw(x, x, imm); } +void vpsrad(const Xmm& x, const Operand& op) { vpsrad(x, x, op); } +void vpsrad(const Xmm& x, uint8 imm) { vpsrad(x, x, imm); } +void vpsraw(const Xmm& x, const Operand& op) { vpsraw(x, x, op); } +void vpsraw(const Xmm& x, uint8 imm) { vpsraw(x, x, imm); } +void vpsrld(const Xmm& x, const Operand& op) { vpsrld(x, x, op); } +void vpsrld(const Xmm& x, uint8 imm) { vpsrld(x, x, imm); } +void vpsrldq(const Xmm& x, uint8 imm) { vpsrldq(x, x, imm); } +void vpsrlq(const Xmm& x, const Operand& op) { vpsrlq(x, x, op); } +void vpsrlq(const Xmm& x, uint8 imm) { vpsrlq(x, x, imm); } +void vpsrlw(const Xmm& x, const Operand& op) { vpsrlw(x, x, op); } +void vpsrlw(const Xmm& x, uint8 imm) { vpsrlw(x, x, imm); } +void vpsubb(const Xmm& x, const Operand& op) { vpsubb(x, x, op); } +void vpsubd(const Xmm& x, const Operand& op) { vpsubd(x, x, op); } +void vpsubq(const Xmm& x, const Operand& op) { vpsubq(x, x, op); } +void vpsubsb(const Xmm& x, const Operand& op) { vpsubsb(x, x, op); } +void vpsubsw(const Xmm& x, const Operand& op) { vpsubsw(x, x, op); } +void vpsubusb(const Xmm& x, const Operand& op) { vpsubusb(x, x, op); } +void vpsubusw(const Xmm& x, const Operand& op) { vpsubusw(x, x, op); } +void vpsubw(const Xmm& x, const Operand& op) { vpsubw(x, x, op); } +void vpunpckhbw(const Xmm& x, const Operand& op) { vpunpckhbw(x, x, op); } +void vpunpckhdq(const Xmm& x, const Operand& op) { vpunpckhdq(x, x, op); } +void vpunpckhqdq(const Xmm& x, const Operand& op) { vpunpckhqdq(x, x, op); } +void vpunpckhwd(const Xmm& x, const Operand& op) { vpunpckhwd(x, x, op); } +void vpunpcklbw(const Xmm& x, const Operand& op) { vpunpcklbw(x, x, op); } +void vpunpckldq(const Xmm& x, const Operand& op) { vpunpckldq(x, x, op); } +void vpunpcklqdq(const Xmm& x, const Operand& op) { vpunpcklqdq(x, x, op); } +void vpunpcklwd(const Xmm& x, const Operand& op) { vpunpcklwd(x, x, op); } +void vpxor(const Xmm& x, const Operand& op) { vpxor(x, x, op); } +void vrcpss(const Xmm& x, const Operand& op) { vrcpss(x, x, op); } +void vroundsd(const Xmm& x, const Operand& op, uint8 imm) { vroundsd(x, x, op, imm); } +void vroundss(const Xmm& x, const Operand& op, uint8 imm) { vroundss(x, x, op, imm); } +void vrsqrtss(const Xmm& x, const Operand& op) { vrsqrtss(x, x, op); } +void vshufpd(const Xmm& x, const Operand& op, uint8 imm) { vshufpd(x, x, op, imm); } +void vshufps(const Xmm& x, const Operand& op, uint8 imm) { vshufps(x, x, op, imm); } +void vsqrtsd(const Xmm& x, const Operand& op) { vsqrtsd(x, x, op); } +void vsqrtss(const Xmm& x, const Operand& op) { vsqrtss(x, x, op); } +void vunpckhpd(const Xmm& x, const Operand& op) { vunpckhpd(x, x, op); } +void vunpckhps(const Xmm& x, const Operand& op) { vunpckhps(x, x, op); } +void vunpcklpd(const Xmm& x, const Operand& op) { vunpcklpd(x, x, op); } +void vunpcklps(const Xmm& x, const Operand& op) { vunpcklps(x, x, op); } +#endif +#ifdef XBYAK64 +void jecxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jrcxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jrcxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void cdqe() { db(0x48); db(0x98); } +void cqo() { db(0x48); db(0x99); } +void cmpsq() { db(0x48); db(0xA7); } +void movsq() { db(0x48); db(0xA5); } +void scasq() { db(0x48); db(0xAF); } +void stosq() { db(0x48); db(0xAB); } +void cmpxchg16b(const Address& addr) { opModM(addr, Reg64(1), 0x0F, 0xC7); } +void movq(const Reg64& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); } +void movq(const Mmx& mmx, const Reg64& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); } +void movsxd(const Reg64& reg, const Operand& op) { if (!op.isBit(32)) throw Error(ERR_BAD_COMBINATION); opModRM(reg, op, op.isREG(), op.isMEM(), 0x63); } +void pextrq(const Operand& op, const Xmm& xmm, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x16, 0x66, 0, imm, 0x3A); } +void pinsrq(const Xmm& xmm, const Operand& op, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x22, 0x66, 0, imm, 0x3A); } +void vcvtss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_ER_X | T_N8, 0x2D); } +void vcvttss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_SAE_X | T_N8, 0x2C); } +void vcvtsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_ER_X, 0x2D); } +void vcvttsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_SAE_X, 0x2C); } +void vmovq(const Xmm& x, const Reg64& r) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x6E); } +void vmovq(const Reg64& r, const Xmm& x) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x7E); } +#else +void jcxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jcxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void aaa() { db(0x37); } +void aad() { db(0xD5); db(0x0A); } +void aam() { db(0xD4); db(0x0A); } +void aas() { db(0x3F); } +void daa() { db(0x27); } +void das() { db(0x2F); } +void popad() { db(0x61); } +void popfd() { db(0x9D); } +void pusha() { db(0x60); } +void pushad() { db(0x60); } +void pushfd() { db(0x9C); } +void popa() { db(0x61); } +#endif +#ifndef XBYAK_NO_OP_NAMES +void and(const Operand& op1, const Operand& op2) { and_(op1, op2); } +void and(const Operand& op, uint32 imm) { and_(op, imm); } +void or(const Operand& op1, const Operand& op2) { or_(op1, op2); } +void or(const Operand& op, uint32 imm) { or_(op, imm); } +void xor(const Operand& op1, const Operand& op2) { xor_(op1, op2); } +void xor(const Operand& op, uint32 imm) { xor_(op, imm); } +void not(const Operand& op) { not_(op); } +#endif +#ifndef XBYAK_DISABLE_AVX512 +void kaddb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4A); } +void kaddd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x4A); } +void kaddq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4A); } +void kaddw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4A); } +void kandb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x41); } +void kandd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x41); } +void kandnb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x42); } +void kandnd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x42); } +void kandnq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x42); } +void kandnw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x42); } +void kandq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x41); } +void kandw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x41); } +void kmovb(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W0, 0x91); } +void kmovb(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W0, 0x90); } +void kmovb(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_66 | T_W0, 0x92); } +void kmovb(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_66 | T_W0, 0x93); } +void kmovd(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W1, 0x91); } +void kmovd(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W1, 0x90); } +void kmovd(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W0, 0x92); } +void kmovd(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W0, 0x93); } +void kmovq(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W1, 0x91); } +void kmovq(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W1, 0x90); } +void kmovw(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W0, 0x91); } +void kmovw(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W0, 0x90); } +void kmovw(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_W0, 0x92); } +void kmovw(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_W0, 0x93); } +void knotb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x44); } +void knotd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x44); } +void knotq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x44); } +void knotw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x44); } +void korb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x45); } +void kord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x45); } +void korq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x45); } +void kortestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x98); } +void kortestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x98); } +void kortestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x98); } +void kortestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x98); } +void korw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x45); } +void kshiftlb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x32, imm); } +void kshiftld(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x33, imm); } +void kshiftlq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x33, imm); } +void kshiftlw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x32, imm); } +void kshiftrb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x30, imm); } +void kshiftrd(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x31, imm); } +void kshiftrq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x31, imm); } +void kshiftrw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x30, imm); } +void ktestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x99); } +void ktestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x99); } +void ktestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x99); } +void ktestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x99); } +void kunpckbw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4B); } +void kunpckdq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4B); } +void kunpckwd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4B); } +void kxnorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x46); } +void kxnord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x46); } +void kxnorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x46); } +void kxnorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x46); } +void kxorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x47); } +void kxord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x47); } +void kxorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x47); } +void kxorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x47); } +void v4fmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x9A); } +void v4fmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0x9B); } +void v4fnmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0xAA); } +void v4fnmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0xAB); } +void valignd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x03, imm); } +void valignq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x03, imm); } +void vblendmpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x65); } +void vblendmps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x65); } +void vbroadcastf32x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x19); } +void vbroadcastf32x4(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x1A); } +void vbroadcastf32x8(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x1B); } +void vbroadcastf64x2(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x1A); } +void vbroadcastf64x4(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x1B); } +void vbroadcasti32x2(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x59); } +void vbroadcasti32x4(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x5A); } +void vbroadcasti32x8(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x5B); } +void vbroadcasti64x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x5A); } +void vbroadcasti64x4(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x5B); } +void vcmppd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcmpps(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcmpsd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N8 | T_F2 | T_0F | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcmpss(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N4 | T_F3 | T_0F | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcompressb(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x63); } +void vcompresspd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8A); } +void vcompressps(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8A); } +void vcompressw(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x63); } +void vcvtpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7B); } +void vcvtpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x79); } +void vcvtpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x79); } +void vcvtps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x7B); } +void vcvtps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x79); } +void vcvtps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x79); } +void vcvtqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0xE6); } +void vcvtqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5B); } +void vcvtsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); } +void vcvtss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); } +void vcvttpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x7A); } +void vcvttpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_SAE_Z, 0x78); } +void vcvttpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x78); } +void vcvttps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x7A); } +void vcvttps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x78); } +void vcvttps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x78); } +void vcvttsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); } +void vcvttss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); } +void vcvtudq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_F3 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0x7A); } +void vcvtudq2ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x7A); } +void vcvtuqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7A); } +void vcvtuqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_F2 | T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x7A); } +void vcvtusi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F2 | T_0F | T_MUST_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); } +void vcvtusi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F3 | T_0F | T_MUST_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); } +void vdbpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x42, imm); } +void vexp2pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xC8); } +void vexp2ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xC8); } +void vexpandpd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x88); } +void vexpandps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x88); } +void vextractf32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x19, imm); } +void vextractf32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1B, imm); } +void vextractf64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x19, imm); } +void vextractf64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1B, imm); } +void vextracti32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x39, imm); } +void vextracti32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3B, imm); } +void vextracti64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x39, imm); } +void vextracti64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3B, imm); } +void vfixupimmpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x54, imm); } +void vfixupimmps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x54, imm); } +void vfixupimmsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); } +void vfixupimmss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); } +void vfpclasspd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW1 | T_B64, 0x66, imm); } +void vfpclassps(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW0 | T_B32, 0x66, imm); } +void vfpclasssd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW1 | T_N8, 0x67, imm); } +void vfpclassss(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW0 | T_N4, 0x67, imm); } +void vgatherdpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 1); } +void vgatherdps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 0); } +void vgatherpf0dpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vgatherpf0dps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vgatherpf0qpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf0qps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf1dpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vgatherpf1dps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vgatherpf1qpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf1qps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherqpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 0); } +void vgatherqps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 2); } +void vgetexppd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x42); } +void vgetexpps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x42); } +void vgetexpsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x43); } +void vgetexpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x43); } +void vgetmantpd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x26, imm); } +void vgetmantps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x26, imm); } +void vgetmantsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x27, imm); } +void vgetmantss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x27, imm); } +void vinsertf32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x18, imm); } +void vinsertf32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1A, imm); } +void vinsertf64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x18, imm); } +void vinsertf64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1A, imm); } +void vinserti32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x38, imm); } +void vinserti32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3A, imm); } +void vinserti64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x38, imm); } +void vinserti64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3A, imm); } +void vmovdqa32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqa32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqa64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqa64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu16(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu16(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu8(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu8(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vp4dpwssd(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x52); } +void vp4dpwssds(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x53); } +void vpabsq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_MUST_EVEX | T_EW1 | T_B64 | T_YMM, 0x1F); } +void vpandd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDB); } +void vpandnd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDF); } +void vpandnq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDF); } +void vpandq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDB); } +void vpblendmb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x66); } +void vpblendmd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x64); } +void vpblendmq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x64); } +void vpblendmw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x66); } +void vpbroadcastb(const Xmm& x, const Reg8& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7A); } +void vpbroadcastd(const Xmm& x, const Reg32& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7C); } +void vpbroadcastmb2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1, 0x2A); } +void vpbroadcastmw2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0, 0x3A); } +void vpbroadcastw(const Xmm& x, const Reg16& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7B); } +void vpcmpb(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3F, imm); } +void vpcmpd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1F, imm); } +void vpcmpeqb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x74); } +void vpcmpeqd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_B32, 0x76); } +void vpcmpeqq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x29); } +void vpcmpeqw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x75); } +void vpcmpgtb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x64); } +void vpcmpgtd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x66); } +void vpcmpgtq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x37); } +void vpcmpgtw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x65); } +void vpcmpq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1F, imm); } +void vpcmpub(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3E, imm); } +void vpcmpud(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1E, imm); } +void vpcmpuq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1E, imm); } +void vpcmpuw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3E, imm); } +void vpcmpw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3F, imm); } +void vpcompressd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8B); } +void vpcompressq(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8B); } +void vpconflictd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xC4); } +void vpconflictq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xC4); } +void vpdpbusd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50); } +void vpdpbusds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x51); } +void vpdpwssd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x52); } +void vpdpwssds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x53); } +void vpermb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8D); } +void vpermi2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x75); } +void vpermi2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x76); } +void vpermi2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x77); } +void vpermi2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x77); } +void vpermi2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x76); } +void vpermi2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x75); } +void vpermt2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7D); } +void vpermt2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7E); } +void vpermt2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7F); } +void vpermt2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7F); } +void vpermt2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7E); } +void vpermt2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7D); } +void vpermw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8D); } +void vpexpandb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); } +void vpexpandd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x89); } +void vpexpandq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x89); } +void vpexpandw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); } +void vpgatherdd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 0); } +void vpgatherdq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 1); } +void vpgatherqd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 2); } +void vpgatherqq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 0); } +void vplzcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x44); } +void vplzcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x44); } +void vpmadd52huq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB5); } +void vpmadd52luq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB4); } +void vpmaxsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3D); } +void vpmaxuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3F); } +void vpminsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x39); } +void vpminuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3B); } +void vpmovb2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x29); } +void vpmovd2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x39); } +void vpmovdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x31, false); } +void vpmovdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x33, true); } +void vpmovm2b(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x28); } +void vpmovm2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x38); } +void vpmovm2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x38); } +void vpmovm2w(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x28); } +void vpmovq2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x39); } +void vpmovqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x32, false); } +void vpmovqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x35, true); } +void vpmovqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x34, false); } +void vpmovsdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x21, false); } +void vpmovsdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x23, true); } +void vpmovsqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x22, false); } +void vpmovsqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x25, true); } +void vpmovsqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x24, false); } +void vpmovswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x20, true); } +void vpmovusdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x11, false); } +void vpmovusdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x13, true); } +void vpmovusqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x12, false); } +void vpmovusqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x15, true); } +void vpmovusqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x14, false); } +void vpmovuswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x10, true); } +void vpmovw2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x29); } +void vpmovwb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x30, true); } +void vpmullq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x40); } +void vpmultishiftqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x83); } +void vpopcntb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); } +void vpopcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x55); } +void vpopcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x55); } +void vpopcntw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); } +void vpord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEB); } +void vporq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEB); } +void vprold(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); } +void vprolq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } +void vprolvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x15); } +void vprolvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x15); } +void vprord(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); } +void vprorq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } +void vprorvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x14); } +void vprorvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x14); } +void vpscatterdd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 0); } +void vpscatterdq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 1); } +void vpscatterqd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 2); } +void vpscatterqq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 0); } +void vpshldd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71, imm); } +void vpshldq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71, imm); } +void vpshldvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71); } +void vpshldvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71); } +void vpshldvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70); } +void vpshldw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70, imm); } +void vpshrdd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73, imm); } +void vpshrdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73, imm); } +void vpshrdvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73); } +void vpshrdvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73); } +void vpshrdvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72); } +void vpshrdw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72, imm); } +void vpshufbitqmb(const Opmask& k, const Xmm& x, const Operand& op) { opVex(k, &x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8F); } +void vpsllvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x12); } +void vpsraq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } +void vpsraq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX, 0xE2); } +void vpsravq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x46); } +void vpsravw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x11); } +void vpsrlvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x10); } +void vpternlogd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x25, imm); } +void vpternlogq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x25, imm); } +void vptestmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); } +void vptestmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); } +void vptestmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); } +void vptestmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); } +void vptestnmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); } +void vptestnmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); } +void vptestnmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); } +void vptestnmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); } +void vpxord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEF); } +void vpxorq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEF); } +void vrangepd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x50, imm); } +void vrangeps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50, imm); } +void vrangesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x51, imm); } +void vrangess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x51, imm); } +void vrcp14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4C); } +void vrcp14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4C); } +void vrcp14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX, 0x4D); } +void vrcp14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX, 0x4D); } +void vrcp28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCA); } +void vrcp28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCA); } +void vrcp28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCB); } +void vrcp28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCB); } +void vreducepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x56, imm); } +void vreduceps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x56, imm); } +void vreducesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x57, imm); } +void vreducess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x57, imm); } +void vrndscalepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x09, imm); } +void vrndscaleps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x08, imm); } +void vrndscalesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_MUST_EVEX, 0x0B, imm); } +void vrndscaless(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_MUST_EVEX, 0x0A, imm); } +void vrsqrt14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4E); } +void vrsqrt14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4E); } +void vrsqrt14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x4F); } +void vrsqrt14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x4F); } +void vrsqrt28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCC); } +void vrsqrt28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCC); } +void vrsqrt28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCD); } +void vrsqrt28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCD); } +void vscalefpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x2C); } +void vscalefps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x2C); } +void vscalefsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_ER_X | T_MUST_EVEX, 0x2D); } +void vscalefss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_ER_X | T_MUST_EVEX, 0x2D); } +void vscatterdpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 1); } +void vscatterdps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 0); } +void vscatterpf0dpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vscatterpf0dps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vscatterpf0qpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf0qps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf1dpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vscatterpf1dps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vscatterpf1qpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf1qps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterqpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 0); } +void vscatterqps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 2); } +void vshuff32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x23, imm); } +void vshuff64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x23, imm); } +void vshufi32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x43, imm); } +void vshufi64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x43, imm); } +#ifdef XBYAK64 +void kmovq(const Opmask& k, const Reg64& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W1, 0x92); } +void kmovq(const Reg64& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W1, 0x93); } +void vpbroadcastq(const Xmm& x, const Reg64& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7C); } +#endif +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h new file mode 100644 index 000000000000..8ef076e680ca --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h @@ -0,0 +1,772 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +#ifndef XBYAK_XBYAK_UTIL_H_ +#define XBYAK_XBYAK_UTIL_H_ + +/** + utility class and functions for Xbyak + Xbyak::util::Clock ; rdtsc timer + Xbyak::util::Cpu ; detect CPU + @note this header is UNDER CONSTRUCTION! +*/ +#include "xbyak.h" + +#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64) + #define XBYAK_INTEL_CPU_SPECIFIC +#endif + +#ifdef XBYAK_INTEL_CPU_SPECIFIC +#ifdef _MSC_VER + #if (_MSC_VER < 1400) && defined(XBYAK32) + static inline __declspec(naked) void __cpuid(int[4], int) + { + __asm { + push ebx + push esi + mov eax, dword ptr [esp + 4 * 2 + 8] // eaxIn + cpuid + mov esi, dword ptr [esp + 4 * 2 + 4] // data + mov dword ptr [esi], eax + mov dword ptr [esi + 4], ebx + mov dword ptr [esi + 8], ecx + mov dword ptr [esi + 12], edx + pop esi + pop ebx + ret + } + } + #else + #include // for __cpuid + #endif +#else + #ifndef __GNUC_PREREQ + #define __GNUC_PREREQ(major, minor) ((((__GNUC__) << 16) + (__GNUC_MINOR__)) >= (((major) << 16) + (minor))) + #endif + #if __GNUC_PREREQ(4, 3) && !defined(__APPLE__) + #include + #else + #if defined(__APPLE__) && defined(XBYAK32) // avoid err : can't find a register in class `BREG' while reloading `asm' + #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn)) + #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn)) + #else + #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn)) + #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn)) + #endif + #endif +#endif +#endif + +namespace Xbyak { namespace util { + +typedef enum { + SmtLevel = 1, + CoreLevel = 2 +} IntelCpuTopologyLevel; + +/** + CPU detection class +*/ +class Cpu { + uint64 type_; + //system topology + bool x2APIC_supported_; + static const size_t maxTopologyLevels = 2; + unsigned int numCores_[maxTopologyLevels]; + + static const unsigned int maxNumberCacheLevels = 10; + unsigned int dataCacheSize_[maxNumberCacheLevels]; + unsigned int coresSharignDataCache_[maxNumberCacheLevels]; + unsigned int dataCacheLevels_; + + unsigned int get32bitAsBE(const char *x) const + { + return x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24); + } + unsigned int mask(int n) const + { + return (1U << n) - 1; + } + void setFamily() + { + unsigned int data[4] = {}; + getCpuid(1, data); + stepping = data[0] & mask(4); + model = (data[0] >> 4) & mask(4); + family = (data[0] >> 8) & mask(4); + // type = (data[0] >> 12) & mask(2); + extModel = (data[0] >> 16) & mask(4); + extFamily = (data[0] >> 20) & mask(8); + if (family == 0x0f) { + displayFamily = family + extFamily; + } else { + displayFamily = family; + } + if (family == 6 || family == 0x0f) { + displayModel = (extModel << 4) + model; + } else { + displayModel = model; + } + } + unsigned int extractBit(unsigned int val, unsigned int base, unsigned int end) + { + return (val >> base) & ((1u << (end - base)) - 1); + } + void setNumCores() + { + if ((type_ & tINTEL) == 0) return; + + unsigned int data[4] = {}; + + /* CAUTION: These numbers are configuration as shipped by Intel. */ + getCpuidEx(0x0, 0, data); + if (data[0] >= 0xB) { + /* + if leaf 11 exists(x2APIC is supported), + we use it to get the number of smt cores and cores on socket + + leaf 0xB can be zeroed-out by a hypervisor + */ + x2APIC_supported_ = true; + for (unsigned int i = 0; i < maxTopologyLevels; i++) { + getCpuidEx(0xB, i, data); + IntelCpuTopologyLevel level = (IntelCpuTopologyLevel)extractBit(data[2], 8, 15); + if (level == SmtLevel || level == CoreLevel) { + numCores_[level - 1] = extractBit(data[1], 0, 15); + } + } + } else { + /* + Failed to deremine num of cores without x2APIC support. + TODO: USE initial APIC ID to determine ncores. + */ + numCores_[SmtLevel - 1] = 0; + numCores_[CoreLevel - 1] = 0; + } + + } + void setCacheHierarchy() + { + if ((type_ & tINTEL) == 0) return; + const unsigned int NO_CACHE = 0; + const unsigned int DATA_CACHE = 1; +// const unsigned int INSTRUCTION_CACHE = 2; + const unsigned int UNIFIED_CACHE = 3; + unsigned int smt_width = 0; + unsigned int logical_cores = 0; + unsigned int data[4] = {}; + + if (x2APIC_supported_) { + smt_width = numCores_[0]; + logical_cores = numCores_[1]; + } + + /* + Assumptions: + the first level of data cache is not shared (which is the + case for every existing architecture) and use this to + determine the SMT width for arch not supporting leaf 11. + when leaf 4 reports a number of core less than numCores_ + on socket reported by leaf 11, then it is a correct number + of cores not an upperbound. + */ + for (int i = 0; dataCacheLevels_ < maxNumberCacheLevels; i++) { + getCpuidEx(0x4, i, data); + unsigned int cacheType = extractBit(data[0], 0, 4); + if (cacheType == NO_CACHE) break; + if (cacheType == DATA_CACHE || cacheType == UNIFIED_CACHE) { + unsigned int actual_logical_cores = extractBit(data[0], 14, 25) + 1; + if (logical_cores != 0) { // true only if leaf 0xB is supported and valid + actual_logical_cores = (std::min)(actual_logical_cores, logical_cores); + } + assert(actual_logical_cores != 0); + dataCacheSize_[dataCacheLevels_] = + (extractBit(data[1], 22, 31) + 1) + * (extractBit(data[1], 12, 21) + 1) + * (extractBit(data[1], 0, 11) + 1) + * (data[2] + 1); + if (cacheType == DATA_CACHE && smt_width == 0) smt_width = actual_logical_cores; + assert(smt_width != 0); + // FIXME: check and fix number of cores sharing L3 cache for different configurations + // (HT-, 2 sockets), (HT-, 1 socket), (HT+, 2 sockets), (HT+, 1 socket) + coresSharignDataCache_[dataCacheLevels_] = (std::max)(actual_logical_cores / smt_width, 1u); + dataCacheLevels_++; + } + } + } + +public: + int model; + int family; + int stepping; + int extModel; + int extFamily; + int displayFamily; // family + extFamily + int displayModel; // model + extModel + + unsigned int getNumCores(IntelCpuTopologyLevel level) { + if (level != SmtLevel && level != CoreLevel) throw Error(ERR_BAD_PARAMETER); + if (!x2APIC_supported_) throw Error(ERR_X2APIC_IS_NOT_SUPPORTED); + return (level == CoreLevel) + ? numCores_[level - 1] / numCores_[SmtLevel - 1] + : numCores_[level - 1]; + } + + unsigned int getDataCacheLevels() const { return dataCacheLevels_; } + unsigned int getCoresSharingDataCache(unsigned int i) const + { + if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER); + return coresSharignDataCache_[i]; + } + unsigned int getDataCacheSize(unsigned int i) const + { + if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER); + return dataCacheSize_[i]; + } + + /* + data[] = { eax, ebx, ecx, edx } + */ + static inline void getCpuid(unsigned int eaxIn, unsigned int data[4]) + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + __cpuid(reinterpret_cast(data), eaxIn); + #else + __cpuid(eaxIn, data[0], data[1], data[2], data[3]); + #endif +#else + (void)eaxIn; + (void)data; +#endif + } + static inline void getCpuidEx(unsigned int eaxIn, unsigned int ecxIn, unsigned int data[4]) + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + __cpuidex(reinterpret_cast(data), eaxIn, ecxIn); + #else + __cpuid_count(eaxIn, ecxIn, data[0], data[1], data[2], data[3]); + #endif +#else + (void)eaxIn; + (void)ecxIn; + (void)data; +#endif + } + static inline uint64 getXfeature() + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + return _xgetbv(0); + #else + unsigned int eax, edx; + // xgetvb is not support on gcc 4.2 +// __asm__ volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0)); + __asm__ volatile(".byte 0x0f, 0x01, 0xd0" : "=a"(eax), "=d"(edx) : "c"(0)); + return ((uint64)edx << 32) | eax; + #endif +#else + return 0; +#endif + } + typedef uint64 Type; + + static const Type NONE = 0; + static const Type tMMX = 1 << 0; + static const Type tMMX2 = 1 << 1; + static const Type tCMOV = 1 << 2; + static const Type tSSE = 1 << 3; + static const Type tSSE2 = 1 << 4; + static const Type tSSE3 = 1 << 5; + static const Type tSSSE3 = 1 << 6; + static const Type tSSE41 = 1 << 7; + static const Type tSSE42 = 1 << 8; + static const Type tPOPCNT = 1 << 9; + static const Type tAESNI = 1 << 10; + static const Type tSSE5 = 1 << 11; + static const Type tOSXSAVE = 1 << 12; + static const Type tPCLMULQDQ = 1 << 13; + static const Type tAVX = 1 << 14; + static const Type tFMA = 1 << 15; + + static const Type t3DN = 1 << 16; + static const Type tE3DN = 1 << 17; + static const Type tSSE4a = 1 << 18; + static const Type tRDTSCP = 1 << 19; + static const Type tAVX2 = 1 << 20; + static const Type tBMI1 = 1 << 21; // andn, bextr, blsi, blsmsk, blsr, tzcnt + static const Type tBMI2 = 1 << 22; // bzhi, mulx, pdep, pext, rorx, sarx, shlx, shrx + static const Type tLZCNT = 1 << 23; + + static const Type tINTEL = 1 << 24; + static const Type tAMD = 1 << 25; + + static const Type tENHANCED_REP = 1 << 26; // enhanced rep movsb/stosb + static const Type tRDRAND = 1 << 27; + static const Type tADX = 1 << 28; // adcx, adox + static const Type tRDSEED = 1 << 29; // rdseed + static const Type tSMAP = 1 << 30; // stac + static const Type tHLE = uint64(1) << 31; // xacquire, xrelease, xtest + static const Type tRTM = uint64(1) << 32; // xbegin, xend, xabort + static const Type tF16C = uint64(1) << 33; // vcvtph2ps, vcvtps2ph + static const Type tMOVBE = uint64(1) << 34; // mobve + static const Type tAVX512F = uint64(1) << 35; + static const Type tAVX512DQ = uint64(1) << 36; + static const Type tAVX512_IFMA = uint64(1) << 37; + static const Type tAVX512IFMA = tAVX512_IFMA; + static const Type tAVX512PF = uint64(1) << 38; + static const Type tAVX512ER = uint64(1) << 39; + static const Type tAVX512CD = uint64(1) << 40; + static const Type tAVX512BW = uint64(1) << 41; + static const Type tAVX512VL = uint64(1) << 42; + static const Type tAVX512_VBMI = uint64(1) << 43; + static const Type tAVX512VBMI = tAVX512_VBMI; // changed by Intel's manual + static const Type tAVX512_4VNNIW = uint64(1) << 44; + static const Type tAVX512_4FMAPS = uint64(1) << 45; + static const Type tPREFETCHWT1 = uint64(1) << 46; + static const Type tPREFETCHW = uint64(1) << 47; + static const Type tSHA = uint64(1) << 48; + static const Type tMPX = uint64(1) << 49; + static const Type tAVX512_VBMI2 = uint64(1) << 50; + static const Type tGFNI = uint64(1) << 51; + static const Type tVAES = uint64(1) << 52; + static const Type tVPCLMULQDQ = uint64(1) << 53; + static const Type tAVX512_VNNI = uint64(1) << 54; + static const Type tAVX512_BITALG = uint64(1) << 55; + static const Type tAVX512_VPOPCNTDQ = uint64(1) << 56; + + Cpu() + : type_(NONE) + , x2APIC_supported_(false) + , numCores_() + , dataCacheSize_() + , coresSharignDataCache_() + , dataCacheLevels_(0) + { + unsigned int data[4] = {}; + const unsigned int& EAX = data[0]; + const unsigned int& EBX = data[1]; + const unsigned int& ECX = data[2]; + const unsigned int& EDX = data[3]; + getCpuid(0, data); + const unsigned int maxNum = EAX; + static const char intel[] = "ntel"; + static const char amd[] = "cAMD"; + if (ECX == get32bitAsBE(amd)) { + type_ |= tAMD; + getCpuid(0x80000001, data); + if (EDX & (1U << 31)) type_ |= t3DN; + if (EDX & (1U << 15)) type_ |= tCMOV; + if (EDX & (1U << 30)) type_ |= tE3DN; + if (EDX & (1U << 22)) type_ |= tMMX2; + if (EDX & (1U << 27)) type_ |= tRDTSCP; + } + if (ECX == get32bitAsBE(intel)) { + type_ |= tINTEL; + getCpuid(0x80000001, data); + if (EDX & (1U << 27)) type_ |= tRDTSCP; + if (ECX & (1U << 5)) type_ |= tLZCNT; + if (ECX & (1U << 8)) type_ |= tPREFETCHW; + } + getCpuid(1, data); + if (ECX & (1U << 0)) type_ |= tSSE3; + if (ECX & (1U << 9)) type_ |= tSSSE3; + if (ECX & (1U << 19)) type_ |= tSSE41; + if (ECX & (1U << 20)) type_ |= tSSE42; + if (ECX & (1U << 22)) type_ |= tMOVBE; + if (ECX & (1U << 23)) type_ |= tPOPCNT; + if (ECX & (1U << 25)) type_ |= tAESNI; + if (ECX & (1U << 1)) type_ |= tPCLMULQDQ; + if (ECX & (1U << 27)) type_ |= tOSXSAVE; + if (ECX & (1U << 30)) type_ |= tRDRAND; + if (ECX & (1U << 29)) type_ |= tF16C; + + if (EDX & (1U << 15)) type_ |= tCMOV; + if (EDX & (1U << 23)) type_ |= tMMX; + if (EDX & (1U << 25)) type_ |= tMMX2 | tSSE; + if (EDX & (1U << 26)) type_ |= tSSE2; + + if (type_ & tOSXSAVE) { + // check XFEATURE_ENABLED_MASK[2:1] = '11b' + uint64 bv = getXfeature(); + if ((bv & 6) == 6) { + if (ECX & (1U << 28)) type_ |= tAVX; + if (ECX & (1U << 12)) type_ |= tFMA; + if (((bv >> 5) & 7) == 7) { + getCpuidEx(7, 0, data); + if (EBX & (1U << 16)) type_ |= tAVX512F; + if (type_ & tAVX512F) { + if (EBX & (1U << 17)) type_ |= tAVX512DQ; + if (EBX & (1U << 21)) type_ |= tAVX512_IFMA; + if (EBX & (1U << 26)) type_ |= tAVX512PF; + if (EBX & (1U << 27)) type_ |= tAVX512ER; + if (EBX & (1U << 28)) type_ |= tAVX512CD; + if (EBX & (1U << 30)) type_ |= tAVX512BW; + if (EBX & (1U << 31)) type_ |= tAVX512VL; + if (ECX & (1U << 1)) type_ |= tAVX512_VBMI; + if (ECX & (1U << 6)) type_ |= tAVX512_VBMI2; + if (ECX & (1U << 8)) type_ |= tGFNI; + if (ECX & (1U << 9)) type_ |= tVAES; + if (ECX & (1U << 10)) type_ |= tVPCLMULQDQ; + if (ECX & (1U << 11)) type_ |= tAVX512_VNNI; + if (ECX & (1U << 12)) type_ |= tAVX512_BITALG; + if (ECX & (1U << 14)) type_ |= tAVX512_VPOPCNTDQ; + if (EDX & (1U << 2)) type_ |= tAVX512_4VNNIW; + if (EDX & (1U << 3)) type_ |= tAVX512_4FMAPS; + } + } + } + } + if (maxNum >= 7) { + getCpuidEx(7, 0, data); + if (type_ & tAVX && (EBX & (1U << 5))) type_ |= tAVX2; + if (EBX & (1U << 3)) type_ |= tBMI1; + if (EBX & (1U << 8)) type_ |= tBMI2; + if (EBX & (1U << 9)) type_ |= tENHANCED_REP; + if (EBX & (1U << 18)) type_ |= tRDSEED; + if (EBX & (1U << 19)) type_ |= tADX; + if (EBX & (1U << 20)) type_ |= tSMAP; + if (EBX & (1U << 4)) type_ |= tHLE; + if (EBX & (1U << 11)) type_ |= tRTM; + if (EBX & (1U << 14)) type_ |= tMPX; + if (EBX & (1U << 29)) type_ |= tSHA; + if (ECX & (1U << 0)) type_ |= tPREFETCHWT1; + } + setFamily(); + setNumCores(); + setCacheHierarchy(); + } + void putFamily() const + { + printf("family=%d, model=%X, stepping=%d, extFamily=%d, extModel=%X\n", + family, model, stepping, extFamily, extModel); + printf("display:family=%X, model=%X\n", displayFamily, displayModel); + } + bool has(Type type) const + { + return (type & type_) != 0; + } +}; + +class Clock { +public: + static inline uint64 getRdtsc() + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + return __rdtsc(); + #else + unsigned int eax, edx; + __asm__ volatile("rdtsc" : "=a"(eax), "=d"(edx)); + return ((uint64)edx << 32) | eax; + #endif +#else + // TODO: Need another impl of Clock or rdtsc-equivalent for non-x86 cpu + return 0; +#endif + } + Clock() + : clock_(0) + , count_(0) + { + } + void begin() + { + clock_ -= getRdtsc(); + } + void end() + { + clock_ += getRdtsc(); + count_++; + } + int getCount() const { return count_; } + uint64 getClock() const { return clock_; } + void clear() { count_ = 0; clock_ = 0; } +private: + uint64 clock_; + int count_; +}; + +#ifdef XBYAK64 +const int UseRCX = 1 << 6; +const int UseRDX = 1 << 7; + +class Pack { + static const size_t maxTblNum = 15; + const Xbyak::Reg64 *tbl_[maxTblNum]; + size_t n_; +public: + Pack() : tbl_(), n_(0) {} + Pack(const Xbyak::Reg64 *tbl, size_t n) { init(tbl, n); } + Pack(const Pack& rhs) + : n_(rhs.n_) + { + for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i]; + } + Pack& operator=(const Pack& rhs) + { + n_ = rhs.n_; + for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i]; + return *this; + } + Pack(const Xbyak::Reg64& t0) + { n_ = 1; tbl_[0] = &t0; } + Pack(const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 2; tbl_[0] = &t0; tbl_[1] = &t1; } + Pack(const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 3; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; } + Pack(const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 4; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; } + Pack(const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 5; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; } + Pack(const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 6; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; } + Pack(const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 7; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; } + Pack(const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 8; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; } + Pack(const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 9; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; } + Pack(const Xbyak::Reg64& t9, const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 10; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; tbl_[9] = &t9; } + Pack& append(const Xbyak::Reg64& t) + { + if (n_ == maxTblNum) { + fprintf(stderr, "ERR Pack::can't append\n"); + throw Error(ERR_BAD_PARAMETER); + } + tbl_[n_++] = &t; + return *this; + } + void init(const Xbyak::Reg64 *tbl, size_t n) + { + if (n > maxTblNum) { + fprintf(stderr, "ERR Pack::init bad n=%d\n", (int)n); + throw Error(ERR_BAD_PARAMETER); + } + n_ = n; + for (size_t i = 0; i < n; i++) { + tbl_[i] = &tbl[i]; + } + } + const Xbyak::Reg64& operator[](size_t n) const + { + if (n >= n_) { + fprintf(stderr, "ERR Pack bad n=%d(%d)\n", (int)n, (int)n_); + throw Error(ERR_BAD_PARAMETER); + } + return *tbl_[n]; + } + size_t size() const { return n_; } + /* + get tbl[pos, pos + num) + */ + Pack sub(size_t pos, size_t num = size_t(-1)) const + { + if (num == size_t(-1)) num = n_ - pos; + if (pos + num > n_) { + fprintf(stderr, "ERR Pack::sub bad pos=%d, num=%d\n", (int)pos, (int)num); + throw Error(ERR_BAD_PARAMETER); + } + Pack pack; + pack.n_ = num; + for (size_t i = 0; i < num; i++) { + pack.tbl_[i] = tbl_[pos + i]; + } + return pack; + } + void put() const + { + for (size_t i = 0; i < n_; i++) { + printf("%s ", tbl_[i]->toString()); + } + printf("\n"); + } +}; + +class StackFrame { +#ifdef XBYAK64_WIN + static const int noSaveNum = 6; + static const int rcxPos = 0; + static const int rdxPos = 1; +#else + static const int noSaveNum = 8; + static const int rcxPos = 3; + static const int rdxPos = 2; +#endif + static const int maxRegNum = 14; // maxRegNum = 16 - rsp - rax + Xbyak::CodeGenerator *code_; + int pNum_; + int tNum_; + bool useRcx_; + bool useRdx_; + int saveNum_; + int P_; + bool makeEpilog_; + Xbyak::Reg64 pTbl_[4]; + Xbyak::Reg64 tTbl_[maxRegNum]; + Pack p_; + Pack t_; + StackFrame(const StackFrame&); + void operator=(const StackFrame&); +public: + const Pack& p; + const Pack& t; + /* + make stack frame + @param sf [in] this + @param pNum [in] num of function parameter(0 <= pNum <= 4) + @param tNum [in] num of temporary register(0 <= tNum, with UseRCX, UseRDX) #{pNum + tNum [+rcx] + [rdx]} <= 14 + @param stackSizeByte [in] local stack size + @param makeEpilog [in] automatically call close() if true + + you can use + rax + gp0, ..., gp(pNum - 1) + gt0, ..., gt(tNum-1) + rcx if tNum & UseRCX + rdx if tNum & UseRDX + rsp[0..stackSizeByte - 1] + */ + StackFrame(Xbyak::CodeGenerator *code, int pNum, int tNum = 0, int stackSizeByte = 0, bool makeEpilog = true) + : code_(code) + , pNum_(pNum) + , tNum_(tNum & ~(UseRCX | UseRDX)) + , useRcx_((tNum & UseRCX) != 0) + , useRdx_((tNum & UseRDX) != 0) + , saveNum_(0) + , P_(0) + , makeEpilog_(makeEpilog) + , p(p_) + , t(t_) + { + using namespace Xbyak; + if (pNum < 0 || pNum > 4) throw Error(ERR_BAD_PNUM); + const int allRegNum = pNum + tNum_ + (useRcx_ ? 1 : 0) + (useRdx_ ? 1 : 0); + if (tNum_ < 0 || allRegNum > maxRegNum) throw Error(ERR_BAD_TNUM); + const Reg64& _rsp = code->rsp; + saveNum_ = (std::max)(0, allRegNum - noSaveNum); + const int *tbl = getOrderTbl() + noSaveNum; + for (int i = 0; i < saveNum_; i++) { + code->push(Reg64(tbl[i])); + } + P_ = (stackSizeByte + 7) / 8; + if (P_ > 0 && (P_ & 1) == (saveNum_ & 1)) P_++; // (rsp % 16) == 8, then increment P_ for 16 byte alignment + P_ *= 8; + if (P_ > 0) code->sub(_rsp, P_); + int pos = 0; + for (int i = 0; i < pNum; i++) { + pTbl_[i] = Xbyak::Reg64(getRegIdx(pos)); + } + for (int i = 0; i < tNum_; i++) { + tTbl_[i] = Xbyak::Reg64(getRegIdx(pos)); + } + if (useRcx_ && rcxPos < pNum) code_->mov(code_->r10, code_->rcx); + if (useRdx_ && rdxPos < pNum) code_->mov(code_->r11, code_->rdx); + p_.init(pTbl_, pNum); + t_.init(tTbl_, tNum_); + } + /* + make epilog manually + @param callRet [in] call ret() if true + */ + void close(bool callRet = true) + { + using namespace Xbyak; + const Reg64& _rsp = code_->rsp; + const int *tbl = getOrderTbl() + noSaveNum; + if (P_ > 0) code_->add(_rsp, P_); + for (int i = 0; i < saveNum_; i++) { + code_->pop(Reg64(tbl[saveNum_ - 1 - i])); + } + + if (callRet) code_->ret(); + } + ~StackFrame() + { + if (!makeEpilog_) return; + try { + close(); + } catch (std::exception& e) { + printf("ERR:StackFrame %s\n", e.what()); + //exit(1); + } + } +private: + const int *getOrderTbl() const + { + using namespace Xbyak; + static const int tbl[] = { +#ifdef XBYAK64_WIN + Operand::RCX, Operand::RDX, Operand::R8, Operand::R9, Operand::R10, Operand::R11, Operand::RDI, Operand::RSI, +#else + Operand::RDI, Operand::RSI, Operand::RDX, Operand::RCX, Operand::R8, Operand::R9, Operand::R10, Operand::R11, +#endif + Operand::RBX, Operand::RBP, Operand::R12, Operand::R13, Operand::R14, Operand::R15 + }; + return &tbl[0]; + } + int getRegIdx(int& pos) const + { + assert(pos < maxRegNum); + using namespace Xbyak; + const int *tbl = getOrderTbl(); + int r = tbl[pos++]; + if (useRcx_) { + if (r == Operand::RCX) { return Operand::R10; } + if (r == Operand::R10) { r = tbl[pos++]; } + } + if (useRdx_) { + if (r == Operand::RDX) { return Operand::R11; } + if (r == Operand::R11) { return tbl[pos++]; } + } + return r; + } +}; +#endif + +} } // end of util +#endif diff --git a/thirdparty/oidn/patches/godot-changes-c58c5216.patch b/thirdparty/oidn/patches/godot-changes-c58c5216.patch new file mode 100644 index 000000000000..c01f00187b99 --- /dev/null +++ b/thirdparty/oidn/patches/godot-changes-c58c5216.patch @@ -0,0 +1,337 @@ +diff --git a/common/platform.h b/common/platform.h +index be14bc7..9373b61 100644 +--- a/common/platform.h ++++ b/common/platform.h +@@ -19,7 +19,7 @@ + #if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #define NOMINMAX +- #include ++ #include + #elif defined(__APPLE__) + #include + #endif +@@ -129,4 +129,3 @@ namespace oidn { + std::string getBuildName(); + + } // namespace oidn +- +diff --git a/core/autoencoder.cpp b/core/autoencoder.cpp +index d6915e6..d8da684 100644 +--- a/core/autoencoder.cpp ++++ b/core/autoencoder.cpp +@@ -90,13 +90,19 @@ namespace oidn { + if (!dirty) + return; + +- device->executeTask([&]() +- { ++ // -- GODOT start -- ++ //device->executeTask([&]() ++ //{ ++ // GODOT end -- ++ + if (mayiuse(avx512_common)) + net = buildNet<16>(); + else + net = buildNet<8>(); +- }); ++ ++ // GODOT start -- ++ //}); ++ // GODOT end -- + + dirty = false; + } +@@ -108,9 +114,10 @@ namespace oidn { + + if (!net) + return; +- +- device->executeTask([&]() +- { ++ // -- GODOT start -- ++ //device->executeTask([&]() ++ //{ ++ // -- GODOT end -- + Progress progress; + progress.func = progressFunc; + progress.userPtr = progressUserPtr; +@@ -156,7 +163,9 @@ namespace oidn { + tileIndex++; + } + } +- }); ++ // -- GODOT start -- ++ //}); ++ // -- GODOT end -- + } + + void AutoencoderFilter::computeTileSize() +@@ -464,6 +473,11 @@ namespace oidn { + return std::make_shared(); + } + ++// -- GODOT start -- ++// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. ++#if 0 ++// -- GODOT end -- ++ + // -------------------------------------------------------------------------- + // RTFilter + // -------------------------------------------------------------------------- +@@ -491,6 +505,9 @@ namespace oidn { + weightData.hdr_alb = weights::rt_hdr_alb; + weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm; + } ++// -- GODOT start -- ++#endif ++// -- GODOT end -- + + // -------------------------------------------------------------------------- + // RTLightmapFilter +diff --git a/core/autoencoder.h b/core/autoencoder.h +index c199052..98b6108 100644 +--- a/core/autoencoder.h ++++ b/core/autoencoder.h +@@ -93,11 +93,18 @@ namespace oidn { + // RTFilter - Generic ray tracing denoiser + // -------------------------------------------------------------------------- + ++// -- GODOT start -- ++// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. ++#if 0 ++// -- GODOT end -- + class RTFilter : public AutoencoderFilter + { + public: + explicit RTFilter(const Ref& device); + }; ++// -- GODOT start -- ++#endif ++// -- GODOT end -- + + // -------------------------------------------------------------------------- + // RTLightmapFilter - Ray traced lightmap denoiser +diff --git a/core/common.h b/core/common.h +index a3a7e8a..a35dd90 100644 +--- a/core/common.h ++++ b/core/common.h +@@ -27,7 +27,9 @@ + #include "common/ref.h" + #include "common/exception.h" + #include "common/thread.h" +-#include "common/tasking.h" ++// -- GODOT start -- ++//#include "common/tasking.h" ++// -- GODOT end -- + #include "math.h" + + namespace oidn { +diff --git a/core/device.cpp b/core/device.cpp +index c455695..3cd658b 100644 +--- a/core/device.cpp ++++ b/core/device.cpp +@@ -29,7 +29,9 @@ namespace oidn { + + Device::~Device() + { +- observer.reset(); ++ // -- GODOT start -- ++ //observer.reset(); ++ // -- GODOT end -- + } + + void Device::setError(Device* device, Error code, const std::string& message) +@@ -141,6 +143,9 @@ namespace oidn { + if (isCommitted()) + throw Exception(Error::InvalidOperation, "device can be committed only once"); + ++ // -- GODOT start -- ++ #if 0 ++ // -- GODOT end -- + // Get the optimal thread affinities + if (setAffinity) + { +@@ -157,7 +162,10 @@ namespace oidn { + // Automatically set the thread affinities + if (affinity) + observer = std::make_shared(affinity, *arena); +- ++ // -- GODOT start -- ++ #endif ++ numThreads = 1; ++ // -- GODOT end -- + dirty = false; + + if (isVerbose()) +@@ -191,9 +199,17 @@ namespace oidn { + + Ref filter; + ++// -- GODOT start -- ++// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. ++#if 0 ++// -- GODOT end -- + if (type == "RT") + filter = makeRef(Ref(this)); +- else if (type == "RTLightmap") ++// -- GODOT start -- ++// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. ++#endif ++ if (type == "RTLightmap") ++// -- GODOT end -- + filter = makeRef(Ref(this)); + else + throw Exception(Error::InvalidArgument, "unknown filter type"); +@@ -210,11 +226,12 @@ namespace oidn { + std::cout << " Build : " << getBuildName() << std::endl; + std::cout << " Platform: " << getPlatformName() << std::endl; + +- std::cout << " Tasking :"; +- std::cout << " TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR; +- std::cout << " TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version(); +- std::cout << std::endl; +- ++// -- GODOT start -- ++// std::cout << " Tasking :"; ++// std::cout << " TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR; ++// std::cout << " TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version(); ++// std::cout << std::endl; ++// -- GODOT end -- + std::cout << std::endl; + } + +diff --git a/core/device.h b/core/device.h +index c2df714..d9cfd85 100644 +--- a/core/device.h ++++ b/core/device.h +@@ -41,10 +41,12 @@ namespace oidn { + ErrorFunction errorFunc = nullptr; + void* errorUserPtr = nullptr; + +- // Tasking +- std::shared_ptr arena; +- std::shared_ptr observer; +- std::shared_ptr affinity; ++// -- GODOT start -- ++// // Tasking ++// std::shared_ptr arena; ++// std::shared_ptr observer; ++// std::shared_ptr affinity; ++// -- GODOT end -- + + // Parameters + int numThreads = 0; // autodetect by default +@@ -66,17 +68,19 @@ namespace oidn { + + void commit(); + +- template +- void executeTask(F& f) +- { +- arena->execute(f); +- } ++// -- GODOT start -- ++// template ++// void executeTask(F& f) ++// { ++// arena->execute(f); ++// } + +- template +- void executeTask(const F& f) +- { +- arena->execute(f); +- } ++// template ++// void executeTask(const F& f) ++// { ++// arena->execute(f); ++// } ++// -- GODOT end -- + + Ref newBuffer(size_t byteSize); + Ref newBuffer(void* ptr, size_t byteSize); +@@ -86,7 +90,10 @@ namespace oidn { + __forceinline std::mutex& getMutex() { return mutex; } + + private: +- bool isCommitted() const { return bool(arena); } ++// -- GODOT start -- ++ //bool isCommitted() const { return bool(arena); } ++ bool isCommitted() const { return false; } ++// -- GODOT end -- + void checkCommitted(); + + void print(); +diff --git a/core/network.cpp b/core/network.cpp +index 8c2de09..ed8328c 100644 +--- a/core/network.cpp ++++ b/core/network.cpp +@@ -17,6 +17,9 @@ + #include "upsample.h" + #include "weights_reorder.h" + #include "network.h" ++// -- GODOT start -- ++#include ++// -- GODOT end -- + + namespace oidn { + +diff --git a/core/transfer_function.cpp b/core/transfer_function.cpp +index 601f814..ce5deca 100644 +--- a/core/transfer_function.cpp ++++ b/core/transfer_function.cpp +@@ -38,16 +38,24 @@ namespace oidn { + // Compute the average log luminance of the downsampled image + using Sum = std::pair; + +- Sum sum = +- tbb::parallel_reduce( +- tbb::blocked_range2d(0, HK, 0, WK), +- Sum(0.f, 0), +- [&](const tbb::blocked_range2d& r, Sum sum) -> Sum ++ // -- GODOT start -- ++ // Sum sum = ++ // tbb::parallel_reduce( ++ // tbb::blocked_range2d(0, HK, 0, WK), ++ // Sum(0.f, 0), ++ // [&](const tbb::blocked_range2d& r, Sum sum) -> Sum ++ // { ++ // // Iterate over blocks ++ // for (int i = r.rows().begin(); i != r.rows().end(); ++i) ++ // { ++ // for (int j = r.cols().begin(); j != r.cols().end(); ++j) ++ // { ++ ++ Sum sum = Sum(0.0f, 0); ++ ++ for (int i = 0; i != HK; ++i) + { +- // Iterate over blocks +- for (int i = r.rows().begin(); i != r.rows().end(); ++i) +- { +- for (int j = r.cols().begin(); j != r.cols().end(); ++j) ++ for (int j = 0; j != WK; ++j) + { + // Compute the average luminance in the current block + const int beginH = int(ptrdiff_t(i) * H / HK); +@@ -82,11 +90,12 @@ namespace oidn { + } + } + +- return sum; +- }, +- [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); }, +- tbb::static_partitioner() +- ); ++ // return sum; ++ // }, ++ // [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); }, ++ // tbb::static_partitioner() ++ // ); ++ // -- GODOT end -- + + return (sum.second > 0) ? (key / exp2(sum.first / float(sum.second))) : 1.f; + } diff --git a/thirdparty/oidn/patches/mkl-dnn-fix-vs2017-build.patch b/thirdparty/oidn/patches/mkl-dnn-fix-vs2017-build.patch new file mode 100644 index 000000000000..50d94ebffa85 --- /dev/null +++ b/thirdparty/oidn/patches/mkl-dnn-fix-vs2017-build.patch @@ -0,0 +1,45 @@ +Rediffed by @akien-mga to match oidn 1.1.0 source. + +From 1e42e6db81e1a5270ecc0191c5385ce7e7d978e9 Mon Sep 17 00:00:00 2001 +From: Jeremy Wong +Date: Wed, 11 Sep 2019 04:46:53 +0800 +Subject: [PATCH] src: initialize members in some structures to prevent compile + errors with VS2017 + +addresses "error C3615: constexpr function '...' cannot result in a constant expression" with VS2017 +--- + src/cpu/rnn/rnn_reorders.hpp | 2 +- + src/cpu/simple_concat.hpp | 6 +++--- + src/cpu/simple_sum.hpp | 2 +- + 3 files changed, 5 insertions(+), 5 deletions(-) + +diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp +index 597c63e3f8..ae1551390a 100644 +--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp ++++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp +@@ -131,7 +131,7 @@ struct rnn_weights_reorder_t : public cpu_primitive_t { + return status::success; + } + +- format_tag_t itag_; ++ format_tag_t itag_ = mkldnn_format_tag_undef; + + private: + void init_scratchpad() { +diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp +index 5177275452..057cc3c4c7 100644 +--- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp ++++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp +@@ -96,9 +96,9 @@ struct simple_concat_t: public cpu_primitive_t { + return status::success; + } + +- int perm_[MKLDNN_MAX_NDIMS]; +- int iperm_[MKLDNN_MAX_NDIMS]; +- dims_t blocks_; ++ int perm_[MKLDNN_MAX_NDIMS] {}; ++ int iperm_[MKLDNN_MAX_NDIMS] {}; ++ dims_t blocks_ {}; + + dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const { + const int ndims = data_d.ndims(); diff --git a/thirdparty/oidn/weights/LICENSE.txt b/thirdparty/oidn/weights/LICENSE.txt new file mode 100644 index 000000000000..d64569567334 --- /dev/null +++ b/thirdparty/oidn/weights/LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/thirdparty/oidn/weights/rtlightmap_hdr.tza b/thirdparty/oidn/weights/rtlightmap_hdr.tza new file mode 100644 index 000000000000..12459a33bc88 Binary files /dev/null and b/thirdparty/oidn/weights/rtlightmap_hdr.tza differ diff --git a/thirdparty/stb_rect_pack/stb_rect_pack.h b/thirdparty/stb_rect_pack/stb_rect_pack.h new file mode 100644 index 000000000000..3336fe7395bb --- /dev/null +++ b/thirdparty/stb_rect_pack/stb_rect_pack.h @@ -0,0 +1,629 @@ +// stb_rect_pack.h - v1.00 - public domain - rectangle packing +// Sean Barrett 2014 +// +// Useful for e.g. packing rectangular textures into an atlas. +// Does not do rotation. +// +// Not necessarily the awesomest packing method, but better than +// the totally naive one in stb_truetype (which is primarily what +// this is meant to replace). +// +// Has only had a few tests run, may have issues. +// +// More docs to come. +// +// No memory allocations; uses qsort() and assert() from stdlib. +// Can override those by defining STBRP_SORT and STBRP_ASSERT. +// +// This library currently uses the Skyline Bottom-Left algorithm. +// +// Please note: better rectangle packers are welcome! Please +// implement them to the same API, but with a different init +// function. +// +// Credits +// +// Library +// Sean Barrett +// Minor features +// Martins Mozeiko +// github:IntellectualKitty +// +// Bugfixes / warning fixes +// Jeremy Jaussaud +// Fabian Giesen +// +// Version history: +// +// 1.00 (2019-02-25) avoid small space waste; gracefully fail too-wide rectangles +// 0.99 (2019-02-07) warning fixes +// 0.11 (2017-03-03) return packing success/fail result +// 0.10 (2016-10-25) remove cast-away-const to avoid warnings +// 0.09 (2016-08-27) fix compiler warnings +// 0.08 (2015-09-13) really fix bug with empty rects (w=0 or h=0) +// 0.07 (2015-09-13) fix bug with empty rects (w=0 or h=0) +// 0.06 (2015-04-15) added STBRP_SORT to allow replacing qsort +// 0.05: added STBRP_ASSERT to allow replacing assert +// 0.04: fixed minor bug in STBRP_LARGE_RECTS support +// 0.01: initial release +// +// LICENSE +// +// See end of file for license information. + +////////////////////////////////////////////////////////////////////////////// +// +// INCLUDE SECTION +// + +#ifndef STB_INCLUDE_STB_RECT_PACK_H +#define STB_INCLUDE_STB_RECT_PACK_H + +#define STB_RECT_PACK_VERSION 1 + +#ifdef STBRP_STATIC +#define STBRP_DEF static +#else +#define STBRP_DEF extern +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct stbrp_context stbrp_context; +typedef struct stbrp_node stbrp_node; +typedef struct stbrp_rect stbrp_rect; + +#ifdef STBRP_LARGE_RECTS +typedef int stbrp_coord; +#else +typedef unsigned short stbrp_coord; +#endif + +STBRP_DEF int stbrp_pack_rects (stbrp_context *context, stbrp_rect *rects, int num_rects); +// Assign packed locations to rectangles. The rectangles are of type +// 'stbrp_rect' defined below, stored in the array 'rects', and there +// are 'num_rects' many of them. +// +// Rectangles which are successfully packed have the 'was_packed' flag +// set to a non-zero value and 'x' and 'y' store the minimum location +// on each axis (i.e. bottom-left in cartesian coordinates, top-left +// if you imagine y increasing downwards). Rectangles which do not fit +// have the 'was_packed' flag set to 0. +// +// You should not try to access the 'rects' array from another thread +// while this function is running, as the function temporarily reorders +// the array while it executes. +// +// To pack into another rectangle, you need to call stbrp_init_target +// again. To continue packing into the same rectangle, you can call +// this function again. Calling this multiple times with multiple rect +// arrays will probably produce worse packing results than calling it +// a single time with the full rectangle array, but the option is +// available. +// +// The function returns 1 if all of the rectangles were successfully +// packed and 0 otherwise. + +struct stbrp_rect +{ + // reserved for your use: + int id; + + // input: + stbrp_coord w, h; + + // output: + stbrp_coord x, y; + int was_packed; // non-zero if valid packing + +}; // 16 bytes, nominally + + +STBRP_DEF void stbrp_init_target (stbrp_context *context, int width, int height, stbrp_node *nodes, int num_nodes); +// Initialize a rectangle packer to: +// pack a rectangle that is 'width' by 'height' in dimensions +// using temporary storage provided by the array 'nodes', which is 'num_nodes' long +// +// You must call this function every time you start packing into a new target. +// +// There is no "shutdown" function. The 'nodes' memory must stay valid for +// the following stbrp_pack_rects() call (or calls), but can be freed after +// the call (or calls) finish. +// +// Note: to guarantee best results, either: +// 1. make sure 'num_nodes' >= 'width' +// or 2. call stbrp_allow_out_of_mem() defined below with 'allow_out_of_mem = 1' +// +// If you don't do either of the above things, widths will be quantized to multiples +// of small integers to guarantee the algorithm doesn't run out of temporary storage. +// +// If you do #2, then the non-quantized algorithm will be used, but the algorithm +// may run out of temporary storage and be unable to pack some rectangles. + +STBRP_DEF void stbrp_setup_allow_out_of_mem (stbrp_context *context, int allow_out_of_mem); +// Optionally call this function after init but before doing any packing to +// change the handling of the out-of-temp-memory scenario, described above. +// If you call init again, this will be reset to the default (false). + + +STBRP_DEF void stbrp_setup_heuristic (stbrp_context *context, int heuristic); +// Optionally select which packing heuristic the library should use. Different +// heuristics will produce better/worse results for different data sets. +// If you call init again, this will be reset to the default. + +enum +{ + STBRP_HEURISTIC_Skyline_default=0, + STBRP_HEURISTIC_Skyline_BL_sortHeight = STBRP_HEURISTIC_Skyline_default, + STBRP_HEURISTIC_Skyline_BF_sortHeight +}; + + +////////////////////////////////////////////////////////////////////////////// +// +// the details of the following structures don't matter to you, but they must +// be visible so you can handle the memory allocations for them + +struct stbrp_node +{ + stbrp_coord x,y; + stbrp_node *next; +}; + +struct stbrp_context +{ + int width; + int height; + int align; + int init_mode; + int heuristic; + int num_nodes; + stbrp_node *active_head; + stbrp_node *free_head; + stbrp_node extra[2]; // we allocate two extra nodes so optimal user-node-count is 'width' not 'width+2' +}; + +#ifdef __cplusplus +} +#endif + +#endif + +////////////////////////////////////////////////////////////////////////////// +// +// IMPLEMENTATION SECTION +// + +#ifdef STB_RECT_PACK_IMPLEMENTATION +#ifndef STBRP_SORT +#include +#define STBRP_SORT qsort +#endif + +#ifndef STBRP_ASSERT +#include +#define STBRP_ASSERT assert +#endif + +#ifdef _MSC_VER +#define STBRP__NOTUSED(v) (void)(v) +#else +#define STBRP__NOTUSED(v) (void)sizeof(v) +#endif + +enum +{ + STBRP__INIT_skyline = 1 +}; + +STBRP_DEF void stbrp_setup_heuristic(stbrp_context *context, int heuristic) +{ + switch (context->init_mode) { + case STBRP__INIT_skyline: + STBRP_ASSERT(heuristic == STBRP_HEURISTIC_Skyline_BL_sortHeight || heuristic == STBRP_HEURISTIC_Skyline_BF_sortHeight); + context->heuristic = heuristic; + break; + default: + STBRP_ASSERT(0); + } +} + +STBRP_DEF void stbrp_setup_allow_out_of_mem(stbrp_context *context, int allow_out_of_mem) +{ + if (allow_out_of_mem) + // if it's ok to run out of memory, then don't bother aligning them; + // this gives better packing, but may fail due to OOM (even though + // the rectangles easily fit). @TODO a smarter approach would be to only + // quantize once we've hit OOM, then we could get rid of this parameter. + context->align = 1; + else { + // if it's not ok to run out of memory, then quantize the widths + // so that num_nodes is always enough nodes. + // + // I.e. num_nodes * align >= width + // align >= width / num_nodes + // align = ceil(width/num_nodes) + + context->align = (context->width + context->num_nodes-1) / context->num_nodes; + } +} + +STBRP_DEF void stbrp_init_target(stbrp_context *context, int width, int height, stbrp_node *nodes, int num_nodes) +{ + int i; +#ifndef STBRP_LARGE_RECTS + STBRP_ASSERT(width <= 0xffff && height <= 0xffff); +#endif + + for (i=0; i < num_nodes-1; ++i) + nodes[i].next = &nodes[i+1]; + nodes[i].next = NULL; + context->init_mode = STBRP__INIT_skyline; + context->heuristic = STBRP_HEURISTIC_Skyline_default; + context->free_head = &nodes[0]; + context->active_head = &context->extra[0]; + context->width = width; + context->height = height; + context->num_nodes = num_nodes; + stbrp_setup_allow_out_of_mem(context, 0); + + // node 0 is the full width, node 1 is the sentinel (lets us not store width explicitly) + context->extra[0].x = 0; + context->extra[0].y = 0; + context->extra[0].next = &context->extra[1]; + context->extra[1].x = (stbrp_coord) width; +#ifdef STBRP_LARGE_RECTS + context->extra[1].y = (1<<30); +#else + context->extra[1].y = 65535; +#endif + context->extra[1].next = NULL; +} + +// find minimum y position if it starts at x1 +static int stbrp__skyline_find_min_y(stbrp_context *c, stbrp_node *first, int x0, int width, int *pwaste) +{ + stbrp_node *node = first; + int x1 = x0 + width; + int min_y, visited_width, waste_area; + + STBRP__NOTUSED(c); + + STBRP_ASSERT(first->x <= x0); + + #if 0 + // skip in case we're past the node + while (node->next->x <= x0) + ++node; + #else + STBRP_ASSERT(node->next->x > x0); // we ended up handling this in the caller for efficiency + #endif + + STBRP_ASSERT(node->x <= x0); + + min_y = 0; + waste_area = 0; + visited_width = 0; + while (node->x < x1) { + if (node->y > min_y) { + // raise min_y higher. + // we've accounted for all waste up to min_y, + // but we'll now add more waste for everything we've visted + waste_area += visited_width * (node->y - min_y); + min_y = node->y; + // the first time through, visited_width might be reduced + if (node->x < x0) + visited_width += node->next->x - x0; + else + visited_width += node->next->x - node->x; + } else { + // add waste area + int under_width = node->next->x - node->x; + if (under_width + visited_width > width) + under_width = width - visited_width; + waste_area += under_width * (min_y - node->y); + visited_width += under_width; + } + node = node->next; + } + + *pwaste = waste_area; + return min_y; +} + +typedef struct +{ + int x,y; + stbrp_node **prev_link; +} stbrp__findresult; + +static stbrp__findresult stbrp__skyline_find_best_pos(stbrp_context *c, int width, int height) +{ + int best_waste = (1<<30), best_x, best_y = (1 << 30); + stbrp__findresult fr; + stbrp_node **prev, *node, *tail, **best = NULL; + + // align to multiple of c->align + width = (width + c->align - 1); + width -= width % c->align; + STBRP_ASSERT(width % c->align == 0); + + // if it can't possibly fit, bail immediately + if (width > c->width || height > c->height) { + fr.prev_link = NULL; + fr.x = fr.y = 0; + return fr; + } + + node = c->active_head; + prev = &c->active_head; + while (node->x + width <= c->width) { + int y,waste; + y = stbrp__skyline_find_min_y(c, node, node->x, width, &waste); + if (c->heuristic == STBRP_HEURISTIC_Skyline_BL_sortHeight) { // actually just want to test BL + // bottom left + if (y < best_y) { + best_y = y; + best = prev; + } + } else { + // best-fit + if (y + height <= c->height) { + // can only use it if it first vertically + if (y < best_y || (y == best_y && waste < best_waste)) { + best_y = y; + best_waste = waste; + best = prev; + } + } + } + prev = &node->next; + node = node->next; + } + + best_x = (best == NULL) ? 0 : (*best)->x; + + // if doing best-fit (BF), we also have to try aligning right edge to each node position + // + // e.g, if fitting + // + // ____________________ + // |____________________| + // + // into + // + // | | + // | ____________| + // |____________| + // + // then right-aligned reduces waste, but bottom-left BL is always chooses left-aligned + // + // This makes BF take about 2x the time + + if (c->heuristic == STBRP_HEURISTIC_Skyline_BF_sortHeight) { + tail = c->active_head; + node = c->active_head; + prev = &c->active_head; + // find first node that's admissible + while (tail->x < width) + tail = tail->next; + while (tail) { + int xpos = tail->x - width; + int y,waste; + STBRP_ASSERT(xpos >= 0); + // find the left position that matches this + while (node->next->x <= xpos) { + prev = &node->next; + node = node->next; + } + STBRP_ASSERT(node->next->x > xpos && node->x <= xpos); + y = stbrp__skyline_find_min_y(c, node, xpos, width, &waste); + if (y + height <= c->height) { + if (y <= best_y) { + if (y < best_y || waste < best_waste || (waste==best_waste && xpos < best_x)) { + best_x = xpos; + STBRP_ASSERT(y <= best_y); + best_y = y; + best_waste = waste; + best = prev; + } + } + } + tail = tail->next; + } + } + + fr.prev_link = best; + fr.x = best_x; + fr.y = best_y; + return fr; +} + +static stbrp__findresult stbrp__skyline_pack_rectangle(stbrp_context *context, int width, int height) +{ + // find best position according to heuristic + stbrp__findresult res = stbrp__skyline_find_best_pos(context, width, height); + stbrp_node *node, *cur; + + // bail if: + // 1. it failed + // 2. the best node doesn't fit (we don't always check this) + // 3. we're out of memory + if (res.prev_link == NULL || res.y + height > context->height || context->free_head == NULL) { + res.prev_link = NULL; + return res; + } + + // on success, create new node + node = context->free_head; + node->x = (stbrp_coord) res.x; + node->y = (stbrp_coord) (res.y + height); + + context->free_head = node->next; + + // insert the new node into the right starting point, and + // let 'cur' point to the remaining nodes needing to be + // stiched back in + + cur = *res.prev_link; + if (cur->x < res.x) { + // preserve the existing one, so start testing with the next one + stbrp_node *next = cur->next; + cur->next = node; + cur = next; + } else { + *res.prev_link = node; + } + + // from here, traverse cur and free the nodes, until we get to one + // that shouldn't be freed + while (cur->next && cur->next->x <= res.x + width) { + stbrp_node *next = cur->next; + // move the current node to the free list + cur->next = context->free_head; + context->free_head = cur; + cur = next; + } + + // stitch the list back in + node->next = cur; + + if (cur->x < res.x + width) + cur->x = (stbrp_coord) (res.x + width); + +#ifdef _DEBUG + cur = context->active_head; + while (cur->x < context->width) { + STBRP_ASSERT(cur->x < cur->next->x); + cur = cur->next; + } + STBRP_ASSERT(cur->next == NULL); + + { + int count=0; + cur = context->active_head; + while (cur) { + cur = cur->next; + ++count; + } + cur = context->free_head; + while (cur) { + cur = cur->next; + ++count; + } + STBRP_ASSERT(count == context->num_nodes+2); + } +#endif + + return res; +} + +static int rect_height_compare(const void *a, const void *b) +{ + const stbrp_rect *p = (const stbrp_rect *) a; + const stbrp_rect *q = (const stbrp_rect *) b; + if (p->h > q->h) + return -1; + if (p->h < q->h) + return 1; + return (p->w > q->w) ? -1 : (p->w < q->w); +} + +static int rect_original_order(const void *a, const void *b) +{ + const stbrp_rect *p = (const stbrp_rect *) a; + const stbrp_rect *q = (const stbrp_rect *) b; + return (p->was_packed < q->was_packed) ? -1 : (p->was_packed > q->was_packed); +} + +#ifdef STBRP_LARGE_RECTS +#define STBRP__MAXVAL 0xffffffff +#else +#define STBRP__MAXVAL 0xffff +#endif + +STBRP_DEF int stbrp_pack_rects(stbrp_context *context, stbrp_rect *rects, int num_rects) +{ + int i, all_rects_packed = 1; + + // we use the 'was_packed' field internally to allow sorting/unsorting + for (i=0; i < num_rects; ++i) { + rects[i].was_packed = i; + } + + // sort according to heuristic + STBRP_SORT(rects, num_rects, sizeof(rects[0]), rect_height_compare); + + for (i=0; i < num_rects; ++i) { + if (rects[i].w == 0 || rects[i].h == 0) { + rects[i].x = rects[i].y = 0; // empty rect needs no space + } else { + stbrp__findresult fr = stbrp__skyline_pack_rectangle(context, rects[i].w, rects[i].h); + if (fr.prev_link) { + rects[i].x = (stbrp_coord) fr.x; + rects[i].y = (stbrp_coord) fr.y; + } else { + rects[i].x = rects[i].y = STBRP__MAXVAL; + } + } + } + + // unsort + STBRP_SORT(rects, num_rects, sizeof(rects[0]), rect_original_order); + + // set was_packed flags and all_rects_packed status + for (i=0; i < num_rects; ++i) { + rects[i].was_packed = !(rects[i].x == STBRP__MAXVAL && rects[i].y == STBRP__MAXVAL); + if (!rects[i].was_packed) + all_rects_packed = 0; + } + + // return the all_rects_packed status + return all_rects_packed; +} +#endif + +/* +------------------------------------------------------------------------------ +This software is available under 2 licenses -- choose whichever you prefer. +------------------------------------------------------------------------------ +ALTERNATIVE A - MIT License +Copyright (c) 2017 Sean Barrett +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +------------------------------------------------------------------------------ +ALTERNATIVE B - Public Domain (www.unlicense.org) +This is free and unencumbered software released into the public domain. +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +------------------------------------------------------------------------------ +*/ +